from fastapi import FastAPI from pydantic import BaseModel from contextlib import asynccontextmanager import os # Global variables model = None tokenizer = None device = None model_loaded = False def load_model(): global model, tokenizer, device, model_loaded try: print("🔄 Starting model loading...") # Set cache directory os.environ["HF_HOME"] = "/tmp" # Import here to avoid startup issues from transformers import T5ForConditionalGeneration, T5Tokenizer import torch print("📦 Loading tokenizer...") tokenizer = T5Tokenizer.from_pretrained("chalana2001/quiz_guru_chatbot") print("🤖 Loading model...") model = T5ForConditionalGeneration.from_pretrained( "chalana2001/quiz_guru_chatbot", trust_remote_code=True ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model_loaded = True print(f"✅ Model loaded successfully on {device}") return True except Exception as e: print(f"❌ Error loading model: {e}") return False @asynccontextmanager async def lifespan(app: FastAPI): # Startup print("🚀 Starting up...") load_model() yield # Shutdown (if needed) print("🛑 Shutting down...") app = FastAPI(title="Quiz Guru Chatbot", version="1.0.0", lifespan=lifespan) class PromptRequest(BaseModel): prompt: str @app.get("/") def read_root(): return { "message": "Quiz Guru Chatbot API", "status": "running", "model_loaded": model_loaded } @app.get("/health") def health(): return { "status": "healthy", "model_loaded": model_loaded, "device": str(device) if device else "unknown" } @app.post("/predict") def predict(request: PromptRequest): if not model_loaded: return {"error": "Model not loaded. Please check /health endpoint."} try: # Import torch here import torch inputs = tokenizer(request.prompt, return_tensors="pt", padding=True).to(device) with torch.no_grad(): output = model.generate(**inputs, max_length=256, num_beams=4, early_stopping=True) decoded = tokenizer.decode(output[0], skip_special_tokens=True) return {"result": decoded, "status": "success"} except Exception as e: return {"error": str(e), "status": "error"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)