import torch import time import logging from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from threading import Thread # ----------------------- # Basic Logging Setup # ----------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # ----------------------- # CORS (allow Netlify) # ----------------------- app.add_middleware( CORSMiddleware, allow_origins=["*"], # Replace with your Netlify domain later allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MODEL_ID = "AshokGakr/model-tiny" logger.info("Loading model...") device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float32, low_cpu_mem_usage=True ).to(device) model.eval() logger.info(f"Model loaded on {device}") # ----------------------- # Root Health Check # ----------------------- @app.get("/") def root(): return {"status": "API is running"} # ----------------------- # Streaming Generator # ----------------------- def generate_stream(prompt: str): logger.info("Starting generation...") start_time = time.time() inputs = tokenizer(prompt, return_tensors="pt").to(device) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) generation_kwargs = dict( **inputs, max_new_tokens=120, temperature=0.7, top_p=0.9, repetition_penalty=1.1, do_sample=True, streamer=streamer ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() for new_text in streamer: yield new_text duration = round(time.time() - start_time, 2) logger.info(f"Generation finished in {duration} seconds.") # ----------------------- # Chat Endpoint # ----------------------- @app.post("/chat") async def chat(data: dict): system_prompt = data.get("system", "You are a helpful AI assistant.") history = data.get("history", "") message = data.get("message", "") # Trim history if too large (prevents memory overflow) max_history_chars = 2000 if len(history) > max_history_chars: history = history[-max_history_chars:] logger.info("----- NEW REQUEST -----") logger.info(f"User message: {message}") logger.info(f"History length: {len(history)}") full_prompt = f"{system_prompt}\n{history}\nUser: {message}\nAssistant:" return StreamingResponse( generate_stream(full_prompt), media_type="text/plain" )