Sp2503 commited on
Commit
c0e90e0
·
verified ·
1 Parent(s): c0b6243

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +20 -31
main.py CHANGED
@@ -4,43 +4,35 @@ import pandas as pd
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
  from sentence_transformers import SentenceTransformer, util
7
- from langdetect import detect
8
- from huggingface_hub import hf_hub_download
9
 
10
- # --- Cache configuration ---
11
  os.environ["HF_HOME"] = "/app/hf_cache"
12
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
13
- os.environ["TORCH_DISABLE_CUDA"] = "1" # CPU only
14
 
15
- # --- Paths ---
16
- MODEL_PATH = './muril_combined_multilingual_model'
17
- CSV_PATH = './muril_multilingual_dataset.csv'
18
- HF_REPO = "Sp2503/muril-dataset"
19
- HF_FILE = "answer_embeddings.pt"
20
 
21
- print("⚙️ Loading model and embeddings...")
 
 
 
22
 
23
- # Load MuRIL model
24
- model = SentenceTransformer(MODEL_PATH)
 
25
 
26
- # Load CSV dataset
 
 
27
  df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
28
-
29
- # Load precomputed embeddings from Hugging Face
30
- hf_path = hf_hub_download(
31
- repo_id=HF_REPO,
32
- filename=HF_FILE,
33
- repo_type="dataset",
34
- cache_dir="/tmp"
35
- )
36
- answer_embeddings = torch.load(hf_path, map_location="cpu")
37
-
38
  print("✅ Model and embeddings loaded successfully.")
39
 
40
- # --- FastAPI app ---
41
  app = FastAPI(title="MuRIL Multilingual QA API")
42
 
43
- # --- Request/Response models ---
44
  class QueryRequest(BaseModel):
45
  question: str
46
  lang: str = None
@@ -48,16 +40,14 @@ class QueryRequest(BaseModel):
48
  class QAResponse(BaseModel):
49
  answer: str
50
 
51
- # --- Root endpoint ---
52
  @app.get("/")
53
  def root():
54
- return {"status": "✅ Running MuRIL QA API", "model_loaded": True}
55
 
56
- # --- QA endpoint ---
57
  @app.post("/get-answer", response_model=QAResponse)
58
  def get_answer_endpoint(request: QueryRequest):
59
  question_text = request.question.strip()
60
- lang_filter = request.lang or detect(question_text)
61
 
62
  filtered_df = df
63
  filtered_embeddings = answer_embeddings
@@ -67,7 +57,7 @@ def get_answer_endpoint(request: QueryRequest):
67
  filtered_embeddings = answer_embeddings[mask.values]
68
 
69
  if len(filtered_df) == 0:
70
- return {"answer": f"⚠️ No data found for language '{lang_filter}'."}
71
 
72
  question_emb = model.encode(question_text, convert_to_tensor=True)
73
  cosine_scores = util.pytorch_cos_sim(question_emb, filtered_embeddings)
@@ -75,7 +65,6 @@ def get_answer_endpoint(request: QueryRequest):
75
  answer = filtered_df.iloc[best_idx]['answer']
76
  return {"answer": answer}
77
 
78
- # --- Run app ---
79
  if __name__ == "__main__":
80
  import uvicorn
81
  uvicorn.run("main:app", host="0.0.0.0", port=8080)
 
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
  from sentence_transformers import SentenceTransformer, util
7
+ from huggingface_hub import snapshot_download
 
8
 
9
+ # --- Cache Configuration ---
10
  os.environ["HF_HOME"] = "/app/hf_cache"
11
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
12
+ os.environ["TORCH_DISABLE_CUDA"] = "1"
13
 
14
+ # --- Hugging Face Repo ---
15
+ HF_REPO = "Sp2503/Muril-Model"
 
 
 
16
 
17
+ # --- Download model & embeddings from Hugging Face Hub ---
18
+ print("📦 Downloading model & embeddings from Hugging Face Hub...")
19
+ model_dir = snapshot_download(repo_id=HF_REPO, repo_type="model")
20
+ print(f"✅ Model snapshot available at: {model_dir}")
21
 
22
+ MODEL_PATH = model_dir
23
+ CSV_PATH = os.path.join(model_dir, "muril_multilingual_dataset.csv")
24
+ EMBED_PATH = os.path.join(model_dir, "answer_embeddings.pt")
25
 
26
+ # --- Load resources ---
27
+ print("⚙️ Loading model and embeddings...")
28
+ model = SentenceTransformer(MODEL_PATH)
29
  df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
30
+ answer_embeddings = torch.load(EMBED_PATH, map_location="cpu")
 
 
 
 
 
 
 
 
 
31
  print("✅ Model and embeddings loaded successfully.")
32
 
33
+ # --- FastAPI Setup ---
34
  app = FastAPI(title="MuRIL Multilingual QA API")
35
 
 
36
  class QueryRequest(BaseModel):
37
  question: str
38
  lang: str = None
 
40
  class QAResponse(BaseModel):
41
  answer: str
42
 
 
43
  @app.get("/")
44
  def root():
45
+ return {"status": "✅ API ready", "model_loaded": True}
46
 
 
47
  @app.post("/get-answer", response_model=QAResponse)
48
  def get_answer_endpoint(request: QueryRequest):
49
  question_text = request.question.strip()
50
+ lang_filter = request.lang
51
 
52
  filtered_df = df
53
  filtered_embeddings = answer_embeddings
 
57
  filtered_embeddings = answer_embeddings[mask.values]
58
 
59
  if len(filtered_df) == 0:
60
+ return {"answer": f"No data found for language '{lang_filter}'."}
61
 
62
  question_emb = model.encode(question_text, convert_to_tensor=True)
63
  cosine_scores = util.pytorch_cos_sim(question_emb, filtered_embeddings)
 
65
  answer = filtered_df.iloc[best_idx]['answer']
66
  return {"answer": answer}
67
 
 
68
  if __name__ == "__main__":
69
  import uvicorn
70
  uvicorn.run("main:app", host="0.0.0.0", port=8080)