Spaces:
Sleeping
Sleeping
File size: 2,658 Bytes
f6f3d33 bb39159 28df82e e61bf1e f6f3d33 e61bf1e 28df82e e61bf1e 28df82e e61bf1e f6f3d33 e61bf1e f6f3d33 e61bf1e f6f3d33 bb39159 e61bf1e bb39159 f6f3d33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from fastapi import FastAPI
from pydantic import BaseModel
from contextlib import asynccontextmanager
import os
# Global variables
model = None
tokenizer = None
device = None
model_loaded = False
def load_model():
global model, tokenizer, device, model_loaded
try:
print("π Starting model loading...")
# Set cache directory
os.environ["HF_HOME"] = "/tmp"
# Import here to avoid startup issues
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
print("π¦ Loading tokenizer...")
tokenizer = T5Tokenizer.from_pretrained("chalana2001/quiz_guru_chatbot")
print("π€ Loading model...")
model = T5ForConditionalGeneration.from_pretrained(
"chalana2001/quiz_guru_chatbot",
trust_remote_code=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model_loaded = True
print(f"β
Model loaded successfully on {device}")
return True
except Exception as e:
print(f"β Error loading model: {e}")
return False
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
print("π Starting up...")
load_model()
yield
# Shutdown (if needed)
print("π Shutting down...")
app = FastAPI(title="Quiz Guru Chatbot", version="1.0.0", lifespan=lifespan)
class PromptRequest(BaseModel):
prompt: str
@app.get("/")
def read_root():
return {
"message": "Quiz Guru Chatbot API",
"status": "running",
"model_loaded": model_loaded
}
@app.get("/health")
def health():
return {
"status": "healthy",
"model_loaded": model_loaded,
"device": str(device) if device else "unknown"
}
@app.post("/predict")
def predict(request: PromptRequest):
if not model_loaded:
return {"error": "Model not loaded. Please check /health endpoint."}
try:
# Import torch here
import torch
inputs = tokenizer(request.prompt, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
output = model.generate(**inputs, max_length=256, num_beams=4, early_stopping=True)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
return {"result": decoded, "status": "success"}
except Exception as e:
return {"error": str(e), "status": "error"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |