Sp2503 commited on
Commit
3d71fc5
Β·
verified Β·
1 Parent(s): 5b88ecf

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -30
main.py CHANGED
@@ -6,28 +6,29 @@ from pydantic import BaseModel
6
  from sentence_transformers import SentenceTransformer, util
7
  from langdetect import detect
8
  from huggingface_hub import hf_hub_download
9
- import threading
10
 
11
- # --- Hugging Face cache settings ---
12
- os.environ["HF_HOME"] = "/tmp/hf_cache"
13
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
 
14
 
15
- # --- Configuration ---
16
  MODEL_PATH = './muril_combined_multilingual_model'
17
  CSV_PATH = './muril_multilingual_dataset.csv'
18
  HF_REPO = "Sp2503/muril-dataset"
19
  HF_FILE = "answer_embeddings.pt"
20
 
21
- # --- FastAPI app setup ---
22
  app = FastAPI(title="MuRIL Multilingual QA API")
23
 
 
24
  model = None
25
  df = None
26
  answer_embeddings = None
27
- load_status = {"ready": False, "error": None}
28
 
29
 
30
- # --- Load embeddings from Hugging Face ---
31
  def load_embeddings():
32
  print("πŸ“₯ Downloading embeddings from Hugging Face...")
33
  hf_path = hf_hub_download(
@@ -40,29 +41,28 @@ def load_embeddings():
40
  return torch.load(hf_path, map_location="cpu")
41
 
42
 
43
- # --- Background resource loading ---
44
  def load_resources():
45
- global model, df, answer_embeddings, load_status
46
  try:
47
  print("βš™οΈ Loading model and dataset in background...")
48
  model = SentenceTransformer(MODEL_PATH)
49
  df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
50
  answer_embeddings = load_embeddings()
51
- load_status["ready"] = True
52
  print("βœ… Model and embeddings ready.")
53
  except Exception as e:
54
- load_status["error"] = str(e)
55
  print(f"❌ Error loading resources: {e}")
56
 
57
 
 
58
  @app.on_event("startup")
59
- def schedule_background_load():
60
- """Run model load in a background thread to prevent startup timeout"""
61
- thread = threading.Thread(target=load_resources, daemon=True)
62
- thread.start()
63
 
64
 
65
- # --- API Models ---
66
  class QueryRequest(BaseModel):
67
  question: str
68
  lang: str = None
@@ -75,22 +75,21 @@ class QAResponse(BaseModel):
75
  # --- Root Endpoint ---
76
  @app.get("/")
77
  def root():
78
- if load_status["error"]:
79
- return {"status": "❌ Error", "details": load_status["error"]}
80
- return {"status": "βœ… Running", "model_ready": load_status["ready"]}
 
 
81
 
82
 
83
- # --- QA Endpoint ---
84
  @app.post("/get-answer", response_model=QAResponse)
85
  def get_answer_endpoint(request: QueryRequest):
86
- if not load_status["ready"]:
87
- return {"answer": "⏳ Model still loading, please try again in a few seconds."}
88
 
89
  question_text = request.question.strip()
90
- try:
91
- lang_filter = request.lang or detect(question_text)
92
- except Exception:
93
- lang_filter = None
94
 
95
  filtered_df = df
96
  filtered_embeddings = answer_embeddings
@@ -98,12 +97,18 @@ def get_answer_endpoint(request: QueryRequest):
98
  mask = df['lang'] == lang_filter
99
  filtered_df = df[mask].reset_index(drop=True)
100
  filtered_embeddings = answer_embeddings[mask.values]
101
- if filtered_df.empty:
102
- return {"answer": f"⚠️ No answers available for language '{lang_filter}'."}
103
 
104
- # Semantic similarity search
 
 
105
  question_emb = model.encode(question_text, convert_to_tensor=True)
106
  cosine_scores = util.pytorch_cos_sim(question_emb, filtered_embeddings)
107
  best_idx = torch.argmax(cosine_scores).item()
108
  answer = filtered_df.iloc[best_idx]['answer']
109
  return {"answer": answer}
 
 
 
 
 
 
 
6
  from sentence_transformers import SentenceTransformer, util
7
  from langdetect import detect
8
  from huggingface_hub import hf_hub_download
9
+ import asyncio
10
 
11
+ # --- Cache Configuration ---
12
+ os.environ["HF_HOME"] = "/app/hf_cache"
13
+ os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
14
+ os.environ["TORCH_DISABLE_CUDA"] = "1" # disable GPU
15
 
16
+ # --- Paths ---
17
  MODEL_PATH = './muril_combined_multilingual_model'
18
  CSV_PATH = './muril_multilingual_dataset.csv'
19
  HF_REPO = "Sp2503/muril-dataset"
20
  HF_FILE = "answer_embeddings.pt"
21
 
22
+ # --- FastAPI Setup ---
23
  app = FastAPI(title="MuRIL Multilingual QA API")
24
 
25
+ # Global variables (loaded at startup)
26
  model = None
27
  df = None
28
  answer_embeddings = None
 
29
 
30
 
31
+ # --- Helper: Load embeddings from Hugging Face ---
32
  def load_embeddings():
33
  print("πŸ“₯ Downloading embeddings from Hugging Face...")
34
  hf_path = hf_hub_download(
 
41
  return torch.load(hf_path, map_location="cpu")
42
 
43
 
44
+ # --- Resource Loader ---
45
  def load_resources():
46
+ global model, df, answer_embeddings
47
  try:
48
  print("βš™οΈ Loading model and dataset in background...")
49
  model = SentenceTransformer(MODEL_PATH)
50
  df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
51
  answer_embeddings = load_embeddings()
 
52
  print("βœ… Model and embeddings ready.")
53
  except Exception as e:
 
54
  print(f"❌ Error loading resources: {e}")
55
 
56
 
57
+ # --- Async Background Loading ---
58
  @app.on_event("startup")
59
+ async def startup_event():
60
+ loop = asyncio.get_event_loop()
61
+ loop.run_in_executor(None, load_resources)
62
+ print("πŸš€ Background model loading started.")
63
 
64
 
65
+ # --- Request Models ---
66
  class QueryRequest(BaseModel):
67
  question: str
68
  lang: str = None
 
75
  # --- Root Endpoint ---
76
  @app.get("/")
77
  def root():
78
+ ready = model is not None and df is not None and answer_embeddings is not None
79
+ return {
80
+ "status": "βœ… Running MuRIL QA API",
81
+ "model_loaded": ready
82
+ }
83
 
84
 
85
+ # --- Question Answer Endpoint ---
86
  @app.post("/get-answer", response_model=QAResponse)
87
  def get_answer_endpoint(request: QueryRequest):
88
+ if model is None or df is None or answer_embeddings is None:
89
+ return {"answer": "⏳ Model still loading, please try again shortly."}
90
 
91
  question_text = request.question.strip()
92
+ lang_filter = request.lang or detect(question_text)
 
 
 
93
 
94
  filtered_df = df
95
  filtered_embeddings = answer_embeddings
 
97
  mask = df['lang'] == lang_filter
98
  filtered_df = df[mask].reset_index(drop=True)
99
  filtered_embeddings = answer_embeddings[mask.values]
 
 
100
 
101
+ if len(filtered_df) == 0:
102
+ return {"answer": f"⚠️ No data found for language '{lang_filter}'."}
103
+
104
  question_emb = model.encode(question_text, convert_to_tensor=True)
105
  cosine_scores = util.pytorch_cos_sim(question_emb, filtered_embeddings)
106
  best_idx = torch.argmax(cosine_scores).item()
107
  answer = filtered_df.iloc[best_idx]['answer']
108
  return {"answer": answer}
109
+
110
+
111
+ # --- Keep app alive when run directly ---
112
+ if __name__ == "__main__":
113
+ import uvicorn
114
+ uvicorn.run("main:app", host="0.0.0.0", port=8080)