Sp2503 commited on
Commit
725631e
Β·
verified Β·
1 Parent(s): 0cf3edb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +34 -93
main.py CHANGED
@@ -1,6 +1,4 @@
1
  import os
2
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" # cache before importing model
3
-
4
  import torch
5
  import pandas as pd
6
  from fastapi import FastAPI
@@ -8,136 +6,79 @@ from pydantic import BaseModel
8
  from sentence_transformers import SentenceTransformer, util
9
  from langdetect import detect
10
  from huggingface_hub import hf_hub_download
 
 
 
11
 
12
- # --- Configuration ---
13
  MODEL_PATH = './muril_combined_multilingual_model'
14
  CSV_PATH = './muril_multilingual_dataset.csv'
15
- EMBEDDINGS_PATH = './answer_embeddings.pt'
16
- HF_DATASET_REPO = "Sp2503/muril-dataset" # your HF dataset repo
17
- HF_FILE_NAME = "answer_embeddings.pt"
18
-
19
-
20
 
 
21
 
 
 
 
22
 
23
- def load_or_download_embeddings():
24
- CACHE_DIR = "/tmp"
25
- EMBEDDING_FILENAME = "answer_embeddings.pt"
26
- LOCAL_PATH = os.path.join(CACHE_DIR, EMBEDDING_FILENAME)
27
 
 
28
  print("πŸ“₯ Downloading embeddings from Hugging Face...")
 
 
 
29
 
30
- try:
31
- # Download (stays in cache_dir)
32
- hf_path = hf_hub_download(
33
- repo_id="Sp2503/muril-dataset",
34
- filename=EMBEDDING_FILENAME,
35
- repo_type="dataset",
36
- token=os.getenv("HF_TOKEN"),
37
- cache_dir=CACHE_DIR
38
- )
39
-
40
- print(f"βœ… Embeddings available at {hf_path}")
41
-
42
- # Load directly from hf_path β€” no rename, no copy
43
- if not os.path.exists(hf_path):
44
- raise FileNotFoundError(f"{hf_path} not found after download!")
45
-
46
- embeddings = torch.load(hf_path, map_location="cpu")
47
- print("βœ… Embeddings loaded successfully.")
48
- return embeddings
49
-
50
- except Exception as e:
51
- print(f"❌ Failed to load embeddings: {e}")
52
- print("βš™οΈ Computing new embeddings from scratch...")
53
-
54
- # === Compute your embeddings here ===
55
- # Example:
56
- # from sentence_transformers import SentenceTransformer
57
- # model = SentenceTransformer("muril_combined_multilingual_model")
58
- # embeddings = model.encode(sentences)
59
- #
60
- # torch.save(embeddings, LOCAL_PATH)
61
- # =====================================
62
-
63
- raise RuntimeError("Embeddings not available and could not be regenerated.") from e
64
-
65
-
66
- # === Call this during app startup ===
67
- answer_embeddings = load_or_download_embeddings()
68
 
69
- # --- Load Model + Data ---
70
  def load_resources():
 
71
  try:
72
- # Load model
73
  model = SentenceTransformer(MODEL_PATH)
74
- # Load dataset
75
  df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
76
-
77
- # Use already loaded embeddings from HF
78
- if answer_embeddings is None:
79
- print("βš™οΈ Computing new embeddings from scratch...")
80
- answers = df['answer'].tolist()
81
- embeddings = model.encode(answers, convert_to_tensor=True)
82
- torch.save(embeddings, EMBEDDINGS_PATH)
83
- print("βœ… Computed and saved embeddings")
84
- else:
85
- embeddings = answer_embeddings
86
- print("βœ… Using embeddings loaded from Hugging Face")
87
-
88
- return model, df, embeddings
89
-
90
  except Exception as e:
91
  print(f"❌ Error loading resources: {e}")
92
- return None, None, None
93
 
94
 
95
- model, df, answer_embeddings = load_resources()
 
 
 
 
96
 
97
- # --- FastAPI Setup ---
98
- app = FastAPI(title="MuRIL Multilingual QA API")
99
 
100
  class QueryRequest(BaseModel):
101
  question: str
102
- lang: str = None # optional: en, hi, mr, etc.
103
 
104
  class QAResponse(BaseModel):
105
  answer: str
106
 
107
- # --- API Endpoint ---
 
 
 
 
 
 
108
  @app.post("/get-answer", response_model=QAResponse)
109
  def get_answer_endpoint(request: QueryRequest):
110
- if not model:
111
- return {"answer": "❌ Model not loaded properly."}
112
 
113
  question_text = request.question.strip()
114
- lang_filter = request.lang
115
 
116
- # Detect language if not given
117
- if not lang_filter:
118
- try:
119
- lang_filter = detect(question_text)
120
- except Exception:
121
- lang_filter = None
122
-
123
- # Filter dataframe by language if column exists
124
  filtered_df = df
125
  filtered_embeddings = answer_embeddings
126
- if lang_filter and 'lang' in df.columns:
127
  mask = df['lang'] == lang_filter
128
  filtered_df = df[mask].reset_index(drop=True)
129
- if len(filtered_df) == 0:
130
- return {"answer": f"⚠️ No data found for language '{lang_filter}'."}
131
  filtered_embeddings = answer_embeddings[mask.values]
132
 
133
- # Encode question + find best match
134
  question_emb = model.encode(question_text, convert_to_tensor=True)
135
  cosine_scores = util.pytorch_cos_sim(question_emb, filtered_embeddings)
136
  best_idx = torch.argmax(cosine_scores).item()
137
  answer = filtered_df.iloc[best_idx]['answer']
138
-
139
  return {"answer": answer}
140
-
141
- @app.get("/")
142
- def root():
143
- return {"status": "βœ… MuRIL Multilingual QA API running successfully!"}
 
1
  import os
 
 
2
  import torch
3
  import pandas as pd
4
  from fastapi import FastAPI
 
6
  from sentence_transformers import SentenceTransformer, util
7
  from langdetect import detect
8
  from huggingface_hub import hf_hub_download
9
+ import threading
10
+
11
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
12
 
 
13
  MODEL_PATH = './muril_combined_multilingual_model'
14
  CSV_PATH = './muril_multilingual_dataset.csv'
15
+ HF_REPO = "Sp2503/muril-dataset"
16
+ HF_FILE = "answer_embeddings.pt"
 
 
 
17
 
18
+ app = FastAPI(title="MuRIL Multilingual QA API")
19
 
20
+ model = None
21
+ df = None
22
+ answer_embeddings = None
23
 
 
 
 
 
24
 
25
+ def load_embeddings():
26
  print("πŸ“₯ Downloading embeddings from Hugging Face...")
27
+ hf_path = hf_hub_download(repo_id=HF_REPO, filename=HF_FILE, repo_type="dataset", cache_dir="/tmp")
28
+ print(f"βœ… Embeddings available at {hf_path}")
29
+ return torch.load(hf_path, map_location="cpu")
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
32
  def load_resources():
33
+ global model, df, answer_embeddings
34
  try:
35
+ print("βš™οΈ Loading model and dataset in background...")
36
  model = SentenceTransformer(MODEL_PATH)
 
37
  df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
38
+ answer_embeddings = load_embeddings()
39
+ print("βœ… Model and embeddings ready.")
 
 
 
 
 
 
 
 
 
 
 
 
40
  except Exception as e:
41
  print(f"❌ Error loading resources: {e}")
 
42
 
43
 
44
+ # --- Fast startup ---
45
+ @app.on_event("startup")
46
+ def schedule_background_load():
47
+ thread = threading.Thread(target=load_resources)
48
+ thread.start()
49
 
 
 
50
 
51
  class QueryRequest(BaseModel):
52
  question: str
53
+ lang: str = None
54
 
55
  class QAResponse(BaseModel):
56
  answer: str
57
 
58
+
59
+ @app.get("/")
60
+ def root():
61
+ ready = model is not None and df is not None and answer_embeddings is not None
62
+ return {"status": "βœ… Running", "model_loaded": ready}
63
+
64
+
65
  @app.post("/get-answer", response_model=QAResponse)
66
  def get_answer_endpoint(request: QueryRequest):
67
+ if model is None or df is None or answer_embeddings is None:
68
+ return {"answer": "⏳ Model still loading, please try again shortly."}
69
 
70
  question_text = request.question.strip()
71
+ lang_filter = request.lang or detect(question_text)
72
 
 
 
 
 
 
 
 
 
73
  filtered_df = df
74
  filtered_embeddings = answer_embeddings
75
+ if 'lang' in df.columns and lang_filter:
76
  mask = df['lang'] == lang_filter
77
  filtered_df = df[mask].reset_index(drop=True)
 
 
78
  filtered_embeddings = answer_embeddings[mask.values]
79
 
 
80
  question_emb = model.encode(question_text, convert_to_tensor=True)
81
  cosine_scores = util.pytorch_cos_sim(question_emb, filtered_embeddings)
82
  best_idx = torch.argmax(cosine_scores).item()
83
  answer = filtered_df.iloc[best_idx]['answer']
 
84
  return {"answer": answer}