chalana2001 commited on
Commit
e61bf1e
Β·
verified Β·
1 Parent(s): 29fe708

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -7
app.py CHANGED
@@ -1,25 +1,94 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
 
 
4
  app = FastAPI(title="Quiz Guru Chatbot", version="1.0.0")
5
 
 
 
 
 
 
 
6
  class PromptRequest(BaseModel):
7
  prompt: str
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  @app.get("/")
10
  def read_root():
11
- return {"message": "Quiz Guru Chatbot API", "status": "running"}
 
 
 
 
12
 
13
- @app.get("/test")
14
- def test():
15
- return {"message": "Test endpoint working!"}
 
 
 
 
16
 
17
  @app.post("/predict")
18
  def predict(request: PromptRequest):
19
- # For now, just echo back - we'll add the model next
20
- return {"result": f"Echo: {request.prompt}", "status": "working"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # This will run when you use: python app.py
23
  if __name__ == "__main__":
24
  import uvicorn
25
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ import os
4
 
5
+ # Only import transformers when we need it
6
  app = FastAPI(title="Quiz Guru Chatbot", version="1.0.0")
7
 
8
+ # Global variables
9
+ model = None
10
+ tokenizer = None
11
+ device = None
12
+ model_loaded = False
13
+
14
  class PromptRequest(BaseModel):
15
  prompt: str
16
 
17
+ def load_model():
18
+ global model, tokenizer, device, model_loaded
19
+ try:
20
+ print("πŸ”„ Starting model loading...")
21
+
22
+ # Set cache directory
23
+ os.environ["HF_HOME"] = "/tmp"
24
+
25
+ # Import here to avoid startup issues
26
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
27
+ import torch
28
+
29
+ print("πŸ“¦ Loading tokenizer...")
30
+ tokenizer = T5Tokenizer.from_pretrained("chalana2001/quiz_guru_chatbot")
31
+
32
+ print("πŸ€– Loading model...")
33
+ model = T5ForConditionalGeneration.from_pretrained(
34
+ "chalana2001/quiz_guru_chatbot",
35
+ trust_remote_code=True
36
+ )
37
+
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ model.to(device)
40
+ model_loaded = True
41
+
42
+ print(f"βœ… Model loaded successfully on {device}")
43
+ return True
44
+
45
+ except Exception as e:
46
+ print(f"❌ Error loading model: {e}")
47
+ return False
48
+
49
+ @app.on_event("startup")
50
+ async def startup_event():
51
+ print("πŸš€ Starting up...")
52
+ # Don't block startup if model fails to load
53
+ load_model()
54
+
55
  @app.get("/")
56
  def read_root():
57
+ return {
58
+ "message": "Quiz Guru Chatbot API",
59
+ "status": "running",
60
+ "model_loaded": model_loaded
61
+ }
62
 
63
+ @app.get("/health")
64
+ def health():
65
+ return {
66
+ "status": "healthy",
67
+ "model_loaded": model_loaded,
68
+ "device": str(device) if device else "unknown"
69
+ }
70
 
71
  @app.post("/predict")
72
  def predict(request: PromptRequest):
73
+ if not model_loaded:
74
+ return {"error": "Model not loaded. Please check /health endpoint."}
75
+
76
+ try:
77
+ # Import torch here
78
+ import torch
79
+
80
+ inputs = tokenizer(request.prompt, return_tensors="pt", padding=True).to(device)
81
+
82
+ with torch.no_grad():
83
+ output = model.generate(**inputs, max_length=256, num_beams=4, early_stopping=True)
84
+
85
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
86
+
87
+ return {"result": decoded, "status": "success"}
88
+
89
+ except Exception as e:
90
+ return {"error": str(e), "status": "error"}
91
 
 
92
  if __name__ == "__main__":
93
  import uvicorn
94
  uvicorn.run(app, host="0.0.0.0", port=7860)