gakrbot / main.py
extraplus's picture
Update main.py
093ce27 verified
import torch
import time
import logging
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
# -----------------------
# Basic Logging Setup
# -----------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# -----------------------
# CORS (allow Netlify)
# -----------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Replace with your Netlify domain later
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_ID = "AshokGakr/model-tiny"
logger.info("Loading model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
).to(device)
model.eval()
logger.info(f"Model loaded on {device}")
# -----------------------
# Root Health Check
# -----------------------
@app.get("/")
def root():
return {"status": "API is running"}
# -----------------------
# Streaming Generator
# -----------------------
def generate_stream(prompt: str):
logger.info("Starting generation...")
start_time = time.time()
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
**inputs,
max_new_tokens=120,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
streamer=streamer
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
yield new_text
duration = round(time.time() - start_time, 2)
logger.info(f"Generation finished in {duration} seconds.")
# -----------------------
# Chat Endpoint
# -----------------------
@app.post("/chat")
async def chat(data: dict):
system_prompt = data.get("system", "You are a helpful AI assistant.")
history = data.get("history", "")
message = data.get("message", "")
# Trim history if too large (prevents memory overflow)
max_history_chars = 2000
if len(history) > max_history_chars:
history = history[-max_history_chars:]
logger.info("----- NEW REQUEST -----")
logger.info(f"User message: {message}")
logger.info(f"History length: {len(history)}")
full_prompt = f"{system_prompt}\n{history}\nUser: {message}\nAssistant:"
return StreamingResponse(
generate_stream(full_prompt),
media_type="text/plain"
)