Sp2503 commited on
Commit
e0eaa41
·
verified ·
1 Parent(s): cdbc6f5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +27 -12
main.py CHANGED
@@ -1,15 +1,18 @@
1
- # main.py
2
  import os
3
  import torch
4
  import pandas as pd
5
  from fastapi import FastAPI
 
6
  from sentence_transformers import SentenceTransformer, util
 
7
  from huggingface_hub import hf_hub_download
8
 
 
9
  os.environ["HF_HOME"] = "/app/hf_cache"
10
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
11
- os.environ["TORCH_DISABLE_CUDA"] = "1"
12
 
 
13
  MODEL_PATH = './muril_combined_multilingual_model'
14
  CSV_PATH = './muril_multilingual_dataset.csv'
15
  HF_REPO = "Sp2503/muril-dataset"
@@ -17,22 +20,27 @@ HF_FILE = "answer_embeddings.pt"
17
 
18
  print("⚙️ Loading model and embeddings...")
19
 
20
- # Load model
21
  model = SentenceTransformer(MODEL_PATH)
 
 
22
  df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
23
 
24
- # Load embeddings from HF
25
- hf_path = hf_hub_download(repo_id=HF_REPO, filename=HF_FILE, repo_type="dataset", cache_dir="/tmp")
 
 
 
 
 
26
  answer_embeddings = torch.load(hf_path, map_location="cpu")
27
 
28
- print("✅ Model and embeddings loaded.")
29
-
30
- from fastapi import FastAPI
31
- from pydantic import BaseModel
32
- from langdetect import detect
33
 
34
- app = FastAPI(title="MuRIL QA API")
 
35
 
 
36
  class QueryRequest(BaseModel):
37
  question: str
38
  lang: str = None
@@ -40,10 +48,12 @@ class QueryRequest(BaseModel):
40
  class QAResponse(BaseModel):
41
  answer: str
42
 
 
43
  @app.get("/")
44
  def root():
45
- return {"status": "✅ Running", "model_loaded": True}
46
 
 
47
  @app.post("/get-answer", response_model=QAResponse)
48
  def get_answer_endpoint(request: QueryRequest):
49
  question_text = request.question.strip()
@@ -64,3 +74,8 @@ def get_answer_endpoint(request: QueryRequest):
64
  best_idx = torch.argmax(cosine_scores).item()
65
  answer = filtered_df.iloc[best_idx]['answer']
66
  return {"answer": answer}
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import pandas as pd
4
  from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
  from sentence_transformers import SentenceTransformer, util
7
+ from langdetect import detect
8
  from huggingface_hub import hf_hub_download
9
 
10
+ # --- Cache configuration ---
11
  os.environ["HF_HOME"] = "/app/hf_cache"
12
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
13
+ os.environ["TORCH_DISABLE_CUDA"] = "1" # CPU only
14
 
15
+ # --- Paths ---
16
  MODEL_PATH = './muril_combined_multilingual_model'
17
  CSV_PATH = './muril_multilingual_dataset.csv'
18
  HF_REPO = "Sp2503/muril-dataset"
 
20
 
21
  print("⚙️ Loading model and embeddings...")
22
 
23
+ # Load MuRIL model
24
  model = SentenceTransformer(MODEL_PATH)
25
+
26
+ # Load CSV dataset
27
  df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
28
 
29
+ # Load precomputed embeddings from Hugging Face
30
+ hf_path = hf_hub_download(
31
+ repo_id=HF_REPO,
32
+ filename=HF_FILE,
33
+ repo_type="dataset",
34
+ cache_dir="/tmp"
35
+ )
36
  answer_embeddings = torch.load(hf_path, map_location="cpu")
37
 
38
+ print("✅ Model and embeddings loaded successfully.")
 
 
 
 
39
 
40
+ # --- FastAPI app ---
41
+ app = FastAPI(title="MuRIL Multilingual QA API")
42
 
43
+ # --- Request/Response models ---
44
  class QueryRequest(BaseModel):
45
  question: str
46
  lang: str = None
 
48
  class QAResponse(BaseModel):
49
  answer: str
50
 
51
+ # --- Root endpoint ---
52
  @app.get("/")
53
  def root():
54
+ return {"status": "✅ Running MuRIL QA API", "model_loaded": True}
55
 
56
+ # --- QA endpoint ---
57
  @app.post("/get-answer", response_model=QAResponse)
58
  def get_answer_endpoint(request: QueryRequest):
59
  question_text = request.question.strip()
 
74
  best_idx = torch.argmax(cosine_scores).item()
75
  answer = filtered_df.iloc[best_idx]['answer']
76
  return {"answer": answer}
77
+
78
+ # --- Run app ---
79
+ if __name__ == "__main__":
80
+ import uvicorn
81
+ uvicorn.run("main:app", host="0.0.0.0", port=8080)