Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| import os | |
| # Global variables | |
| model = None | |
| tokenizer = None | |
| device = None | |
| model_loaded = False | |
| def load_model(): | |
| global model, tokenizer, device, model_loaded | |
| try: | |
| print("π Starting model loading...") | |
| # Set cache directory | |
| os.environ["HF_HOME"] = "/tmp" | |
| # Import here to avoid startup issues | |
| from transformers import T5ForConditionalGeneration, T5Tokenizer | |
| import torch | |
| print("π¦ Loading tokenizer...") | |
| tokenizer = T5Tokenizer.from_pretrained("chalana2001/quiz_guru_chatbot") | |
| print("π€ Loading model...") | |
| model = T5ForConditionalGeneration.from_pretrained( | |
| "chalana2001/quiz_guru_chatbot", | |
| trust_remote_code=True | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model_loaded = True | |
| print(f"β Model loaded successfully on {device}") | |
| return True | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| return False | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| print("π Starting up...") | |
| load_model() | |
| yield | |
| # Shutdown (if needed) | |
| print("π Shutting down...") | |
| app = FastAPI(title="Quiz Guru Chatbot", version="1.0.0", lifespan=lifespan) | |
| class PromptRequest(BaseModel): | |
| prompt: str | |
| def read_root(): | |
| return { | |
| "message": "Quiz Guru Chatbot API", | |
| "status": "running", | |
| "model_loaded": model_loaded | |
| } | |
| def health(): | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model_loaded, | |
| "device": str(device) if device else "unknown" | |
| } | |
| def predict(request: PromptRequest): | |
| if not model_loaded: | |
| return {"error": "Model not loaded. Please check /health endpoint."} | |
| try: | |
| # Import torch here | |
| import torch | |
| inputs = tokenizer(request.prompt, return_tensors="pt", padding=True).to(device) | |
| with torch.no_grad(): | |
| output = model.generate(**inputs, max_length=256, num_beams=4, early_stopping=True) | |
| decoded = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return {"result": decoded, "status": "success"} | |
| except Exception as e: | |
| return {"error": str(e), "status": "error"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |