test / app.py
chalana2001's picture
Update app.py
28df82e verified
from fastapi import FastAPI
from pydantic import BaseModel
from contextlib import asynccontextmanager
import os
# Global variables
model = None
tokenizer = None
device = None
model_loaded = False
def load_model():
global model, tokenizer, device, model_loaded
try:
print("πŸ”„ Starting model loading...")
# Set cache directory
os.environ["HF_HOME"] = "/tmp"
# Import here to avoid startup issues
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
print("πŸ“¦ Loading tokenizer...")
tokenizer = T5Tokenizer.from_pretrained("chalana2001/quiz_guru_chatbot")
print("πŸ€– Loading model...")
model = T5ForConditionalGeneration.from_pretrained(
"chalana2001/quiz_guru_chatbot",
trust_remote_code=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model_loaded = True
print(f"βœ… Model loaded successfully on {device}")
return True
except Exception as e:
print(f"❌ Error loading model: {e}")
return False
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
print("πŸš€ Starting up...")
load_model()
yield
# Shutdown (if needed)
print("πŸ›‘ Shutting down...")
app = FastAPI(title="Quiz Guru Chatbot", version="1.0.0", lifespan=lifespan)
class PromptRequest(BaseModel):
prompt: str
@app.get("/")
def read_root():
return {
"message": "Quiz Guru Chatbot API",
"status": "running",
"model_loaded": model_loaded
}
@app.get("/health")
def health():
return {
"status": "healthy",
"model_loaded": model_loaded,
"device": str(device) if device else "unknown"
}
@app.post("/predict")
def predict(request: PromptRequest):
if not model_loaded:
return {"error": "Model not loaded. Please check /health endpoint."}
try:
# Import torch here
import torch
inputs = tokenizer(request.prompt, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
output = model.generate(**inputs, max_length=256, num_beams=4, early_stopping=True)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
return {"result": decoded, "status": "success"}
except Exception as e:
return {"error": str(e), "status": "error"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)