Sp2503 commited on
Commit
d38f9c4
·
verified ·
1 Parent(s): 0a41dbe

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +105 -47
main.py CHANGED
@@ -1,66 +1,124 @@
1
- from fastapi import FastAPI
2
- from fastapi.middleware.cors import CORSMiddleware
3
- import torch
4
- from transformers import AutoTokenizer, AutoModel
5
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- app = FastAPI(title="MuRIL QA Demo")
8
-
9
- # Allow cross-origin requests
10
- app.add_middleware(
11
- CORSMiddleware,
12
- allow_origins=["*"],
13
- allow_credentials=True,
14
- allow_methods=["*"],
15
- allow_headers=["*"],
16
- )
17
 
18
- MODEL_NAME = "google/muril-base-cased"
19
- EMBED_PATH = "/tmp/datasets--Sp2503--muril-dataset/snapshots/b768e5a3a401589f25b723c20f9674e88717db1b/answer_embeddings.pt"
 
 
 
 
20
 
21
- model = None
22
- tokenizer = None
23
- answer_embeddings = None
24
 
25
- def load_model():
26
- global model, tokenizer, answer_embeddings
 
27
 
28
- print("⚙️ Loading model and dataset...")
29
 
30
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
31
- model = AutoModel.from_pretrained(MODEL_NAME)
32
 
33
- if os.path.exists(EMBED_PATH):
34
- answer_embeddings = torch.load(EMBED_PATH, map_location="cpu")
35
- print(f"✅ Embeddings loaded from {EMBED_PATH}")
36
- else:
37
- print("⚠️ Embeddings not found! Please check dataset path.")
38
 
39
- print(" Model and embeddings ready.")
 
 
 
40
 
41
- # 🚀 Load everything before starting FastAPI
42
- print("🚀 Starting app...")
43
- load_model()
44
 
45
- @app.get("/")
46
- def health_check():
47
- return {"status": "ok"}
 
 
 
48
 
49
- @app.get("/ask")
50
- def ask(question: str):
51
- if model is None or tokenizer is None or answer_embeddings is None:
52
- return {"error": "Model not loaded yet"}
53
 
54
- inputs = tokenizer(question, return_tensors="pt")
55
- with torch.no_grad():
56
- q_emb = model(**inputs).last_hidden_state.mean(dim=1)
 
 
57
 
58
- similarities = torch.nn.functional.cosine_similarity(q_emb, answer_embeddings)
59
- top_idx = torch.argmax(similarities).item()
60
 
61
- return {"question": question, "answer_id": top_idx, "score": similarities[top_idx].item()}
 
 
 
 
 
 
62
 
63
 
64
  if __name__ == "__main__":
65
  import uvicorn
66
- uvicorn.run("main:app", host="0.0.0.0", port=8080)
 
 
 
 
 
 
 
1
  import os
2
+ import torch
3
+ import pandas as pd
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ from sentence_transformers import SentenceTransformer, util
7
+ from langdetect import detect
8
+ from huggingface_hub import hf_hub_download
9
+ import threading
10
+ import time
11
+
12
+ # --- Cache Configuration ---
13
+ os.environ["HF_HOME"] = "/app/hf_cache"
14
+ os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
15
+ os.environ["TORCH_DISABLE_CUDA"] = "1"
16
+
17
+ # --- Paths ---
18
+ MODEL_PATH = './muril_combined_multilingual_model'
19
+ CSV_PATH = './muril_multilingual_dataset.csv'
20
+ HF_REPO = "Sp2503/muril-dataset"
21
+ HF_FILE = "answer_embeddings.pt"
22
+
23
+ # --- FastAPI Setup ---
24
+ app = FastAPI(title="MuRIL Multilingual QA API")
25
+
26
+ # Global variables
27
+ model = None
28
+ df = None
29
+ answer_embeddings = None
30
+ is_model_ready = False
31
+
32
+
33
+ # --- Helper: Load embeddings from Hugging Face ---
34
+ def load_embeddings():
35
+ print("📥 Downloading embeddings from Hugging Face...")
36
+ hf_path = hf_hub_download(
37
+ repo_id=HF_REPO,
38
+ filename=HF_FILE,
39
+ repo_type="dataset",
40
+ cache_dir="/tmp"
41
+ )
42
+ print(f"✅ Embeddings available at {hf_path}")
43
+ return torch.load(hf_path, map_location="cpu")
44
+
45
+
46
+ # --- Resource Loader ---
47
+ def load_resources():
48
+ global model, df, answer_embeddings, is_model_ready
49
+ try:
50
+ print("⚙️ Loading model and dataset...")
51
+ model = SentenceTransformer(MODEL_PATH)
52
+ df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
53
+ answer_embeddings = load_embeddings()
54
+ is_model_ready = True
55
+ print("✅ Model and embeddings ready.")
56
+ except Exception as e:
57
+ print(f"❌ Error loading resources: {e}")
58
+ is_model_ready = False
59
+
60
+
61
+ # --- Background Loader Thread ---
62
+ @app.on_event("startup")
63
+ def startup_event():
64
+ print("🚀 Starting background model loader thread...")
65
+ thread = threading.Thread(target=load_resources, daemon=True)
66
+ thread.start()
67
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ @app.get("/")
70
+ def root():
71
+ return {
72
+ "status": "✅ Running MuRIL QA API",
73
+ "model_loaded": is_model_ready
74
+ }
75
 
 
 
 
76
 
77
+ class QueryRequest(BaseModel):
78
+ question: str
79
+ lang: str = None
80
 
 
81
 
82
+ class QAResponse(BaseModel):
83
+ answer: str
84
 
 
 
 
 
 
85
 
86
+ @app.post("/get-answer", response_model=QAResponse)
87
+ def get_answer_endpoint(request: QueryRequest):
88
+ if not is_model_ready:
89
+ return {"answer": "⏳ Model still loading, please try again shortly."}
90
 
91
+ question_text = request.question.strip()
92
+ lang_filter = request.lang or detect(question_text)
 
93
 
94
+ filtered_df = df
95
+ filtered_embeddings = answer_embeddings
96
+ if 'lang' in df.columns and lang_filter:
97
+ mask = df['lang'] == lang_filter
98
+ filtered_df = df[mask].reset_index(drop=True)
99
+ filtered_embeddings = answer_embeddings[mask.values]
100
 
101
+ if len(filtered_df) == 0:
102
+ return {"answer": f"⚠️ No data found for language '{lang_filter}'."}
 
 
103
 
104
+ question_emb = model.encode(question_text, convert_to_tensor=True)
105
+ cosine_scores = util.pytorch_cos_sim(question_emb, filtered_embeddings)
106
+ best_idx = torch.argmax(cosine_scores).item()
107
+ answer = filtered_df.iloc[best_idx]['answer']
108
+ return {"answer": answer}
109
 
 
 
110
 
111
+ # --- Keep-alive thread for Spaces ---
112
+ def keep_alive():
113
+ while True:
114
+ # This ensures the app doesn’t shut down for inactivity
115
+ time.sleep(60)
116
+ if not is_model_ready:
117
+ print("🕒 Model still loading...")
118
 
119
 
120
  if __name__ == "__main__":
121
  import uvicorn
122
+ threading.Thread(target=keep_alive, daemon=True).start()
123
+ # Run with fewer workers for Spaces (prevents timeout)
124
+ uvicorn.run("main:app", host="0.0.0.0", port=8080, workers=1)