Sp2503 commited on
Commit
69e5845
Β·
verified Β·
1 Parent(s): 725631e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +34 -9
main.py CHANGED
@@ -8,67 +8,89 @@ from langdetect import detect
8
  from huggingface_hub import hf_hub_download
9
  import threading
10
 
 
 
11
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
12
 
 
13
  MODEL_PATH = './muril_combined_multilingual_model'
14
  CSV_PATH = './muril_multilingual_dataset.csv'
15
  HF_REPO = "Sp2503/muril-dataset"
16
  HF_FILE = "answer_embeddings.pt"
17
 
 
18
  app = FastAPI(title="MuRIL Multilingual QA API")
19
 
20
  model = None
21
  df = None
22
  answer_embeddings = None
 
23
 
24
 
 
25
  def load_embeddings():
26
  print("πŸ“₯ Downloading embeddings from Hugging Face...")
27
- hf_path = hf_hub_download(repo_id=HF_REPO, filename=HF_FILE, repo_type="dataset", cache_dir="/tmp")
 
 
 
 
 
28
  print(f"βœ… Embeddings available at {hf_path}")
29
  return torch.load(hf_path, map_location="cpu")
30
 
31
 
 
32
  def load_resources():
33
- global model, df, answer_embeddings
34
  try:
35
  print("βš™οΈ Loading model and dataset in background...")
36
  model = SentenceTransformer(MODEL_PATH)
37
  df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
38
  answer_embeddings = load_embeddings()
 
39
  print("βœ… Model and embeddings ready.")
40
  except Exception as e:
 
41
  print(f"❌ Error loading resources: {e}")
42
 
43
 
44
- # --- Fast startup ---
45
  @app.on_event("startup")
46
  def schedule_background_load():
47
- thread = threading.Thread(target=load_resources)
 
48
  thread.start()
49
 
50
 
 
51
  class QueryRequest(BaseModel):
52
  question: str
53
  lang: str = None
54
 
 
55
  class QAResponse(BaseModel):
56
  answer: str
57
 
58
 
 
59
  @app.get("/")
60
  def root():
61
- ready = model is not None and df is not None and answer_embeddings is not None
62
- return {"status": "βœ… Running", "model_loaded": ready}
 
63
 
64
 
 
65
  @app.post("/get-answer", response_model=QAResponse)
66
  def get_answer_endpoint(request: QueryRequest):
67
- if model is None or df is None or answer_embeddings is None:
68
- return {"answer": "⏳ Model still loading, please try again shortly."}
69
 
70
  question_text = request.question.strip()
71
- lang_filter = request.lang or detect(question_text)
 
 
 
72
 
73
  filtered_df = df
74
  filtered_embeddings = answer_embeddings
@@ -76,7 +98,10 @@ def get_answer_endpoint(request: QueryRequest):
76
  mask = df['lang'] == lang_filter
77
  filtered_df = df[mask].reset_index(drop=True)
78
  filtered_embeddings = answer_embeddings[mask.values]
 
 
79
 
 
80
  question_emb = model.encode(question_text, convert_to_tensor=True)
81
  cosine_scores = util.pytorch_cos_sim(question_emb, filtered_embeddings)
82
  best_idx = torch.argmax(cosine_scores).item()
 
8
  from huggingface_hub import hf_hub_download
9
  import threading
10
 
11
+ # --- Hugging Face cache settings ---
12
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
13
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
14
 
15
+ # --- Configuration ---
16
  MODEL_PATH = './muril_combined_multilingual_model'
17
  CSV_PATH = './muril_multilingual_dataset.csv'
18
  HF_REPO = "Sp2503/muril-dataset"
19
  HF_FILE = "answer_embeddings.pt"
20
 
21
+ # --- FastAPI app setup ---
22
  app = FastAPI(title="MuRIL Multilingual QA API")
23
 
24
  model = None
25
  df = None
26
  answer_embeddings = None
27
+ load_status = {"ready": False, "error": None}
28
 
29
 
30
+ # --- Load embeddings from Hugging Face ---
31
  def load_embeddings():
32
  print("πŸ“₯ Downloading embeddings from Hugging Face...")
33
+ hf_path = hf_hub_download(
34
+ repo_id=HF_REPO,
35
+ filename=HF_FILE,
36
+ repo_type="dataset",
37
+ cache_dir="/tmp"
38
+ )
39
  print(f"βœ… Embeddings available at {hf_path}")
40
  return torch.load(hf_path, map_location="cpu")
41
 
42
 
43
+ # --- Background resource loading ---
44
  def load_resources():
45
+ global model, df, answer_embeddings, load_status
46
  try:
47
  print("βš™οΈ Loading model and dataset in background...")
48
  model = SentenceTransformer(MODEL_PATH)
49
  df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
50
  answer_embeddings = load_embeddings()
51
+ load_status["ready"] = True
52
  print("βœ… Model and embeddings ready.")
53
  except Exception as e:
54
+ load_status["error"] = str(e)
55
  print(f"❌ Error loading resources: {e}")
56
 
57
 
 
58
  @app.on_event("startup")
59
  def schedule_background_load():
60
+ """Run model load in a background thread to prevent startup timeout"""
61
+ thread = threading.Thread(target=load_resources, daemon=True)
62
  thread.start()
63
 
64
 
65
+ # --- API Models ---
66
  class QueryRequest(BaseModel):
67
  question: str
68
  lang: str = None
69
 
70
+
71
  class QAResponse(BaseModel):
72
  answer: str
73
 
74
 
75
+ # --- Root Endpoint ---
76
  @app.get("/")
77
  def root():
78
+ if load_status["error"]:
79
+ return {"status": "❌ Error", "details": load_status["error"]}
80
+ return {"status": "βœ… Running", "model_ready": load_status["ready"]}
81
 
82
 
83
+ # --- QA Endpoint ---
84
  @app.post("/get-answer", response_model=QAResponse)
85
  def get_answer_endpoint(request: QueryRequest):
86
+ if not load_status["ready"]:
87
+ return {"answer": "⏳ Model still loading, please try again in a few seconds."}
88
 
89
  question_text = request.question.strip()
90
+ try:
91
+ lang_filter = request.lang or detect(question_text)
92
+ except Exception:
93
+ lang_filter = None
94
 
95
  filtered_df = df
96
  filtered_embeddings = answer_embeddings
 
98
  mask = df['lang'] == lang_filter
99
  filtered_df = df[mask].reset_index(drop=True)
100
  filtered_embeddings = answer_embeddings[mask.values]
101
+ if filtered_df.empty:
102
+ return {"answer": f"⚠️ No answers available for language '{lang_filter}'."}
103
 
104
+ # Semantic similarity search
105
  question_emb = model.encode(question_text, convert_to_tensor=True)
106
  cosine_scores = util.pytorch_cos_sim(question_emb, filtered_embeddings)
107
  best_idx = torch.argmax(cosine_scores).item()