Sp2503 commited on
Commit
7bf985b
Β·
verified Β·
1 Parent(s): 2256365

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +17 -64
main.py CHANGED
@@ -1,77 +1,38 @@
 
1
  import os
2
  import torch
3
  import pandas as pd
4
  from fastapi import FastAPI
5
- from pydantic import BaseModel
6
  from sentence_transformers import SentenceTransformer, util
7
- from langdetect import detect
8
  from huggingface_hub import hf_hub_download
9
- import threading
10
 
11
- # --- Cache Configuration ---
12
  os.environ["HF_HOME"] = "/app/hf_cache"
13
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
14
  os.environ["TORCH_DISABLE_CUDA"] = "1"
15
 
16
- # --- Paths ---
17
  MODEL_PATH = './muril_combined_multilingual_model'
18
  CSV_PATH = './muril_multilingual_dataset.csv'
19
  HF_REPO = "Sp2503/muril-dataset"
20
  HF_FILE = "answer_embeddings.pt"
21
 
22
- # --- FastAPI Setup ---
23
- app = FastAPI(title="MuRIL Multilingual QA API")
24
 
25
- # Global variables
26
- model = None
27
- df = None
28
- answer_embeddings = None
29
- is_model_ready = False
30
- loading_lock = threading.Lock()
31
 
32
- # --- Helper: Load embeddings from Hugging Face ---
33
- def load_embeddings():
34
- print("πŸ“₯ Downloading embeddings from Hugging Face...")
35
- hf_path = hf_hub_download(
36
- repo_id=HF_REPO,
37
- filename=HF_FILE,
38
- repo_type="dataset",
39
- cache_dir="/tmp"
40
- )
41
- print(f"βœ… Embeddings available at {hf_path}")
42
- return torch.load(hf_path, map_location="cpu")
43
 
44
- # --- Resource Loader ---
45
- def load_resources():
46
- global model, df, answer_embeddings, is_model_ready
47
- with loading_lock:
48
- if is_model_ready:
49
- return
50
- try:
51
- print("βš™οΈ Loading model and dataset...")
52
- model = SentenceTransformer(MODEL_PATH)
53
- df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
54
- answer_embeddings = load_embeddings()
55
- is_model_ready = True
56
- print("βœ… Model and embeddings ready.")
57
- except Exception as e:
58
- print(f"❌ Error loading resources: {e}")
59
 
60
- # --- Health Check ---
61
- @app.get("/healthz")
62
- def health_check():
63
- # Always return 200 for Cloud Run health checks
64
- return {"status": "ok", "model_loaded": is_model_ready}
65
 
66
- # --- Root Endpoint ---
67
- @app.get("/")
68
- def root():
69
- return {
70
- "status": "βœ… MuRIL QA API running",
71
- "model_loaded": is_model_ready
72
- }
73
 
74
- # --- Request Models ---
75
  class QueryRequest(BaseModel):
76
  question: str
77
  lang: str = None
@@ -79,15 +40,12 @@ class QueryRequest(BaseModel):
79
  class QAResponse(BaseModel):
80
  answer: str
81
 
82
- # --- Question Answer Endpoint ---
 
 
 
83
  @app.post("/get-answer", response_model=QAResponse)
84
  def get_answer_endpoint(request: QueryRequest):
85
- if not is_model_ready:
86
- # Lazy-load the model if first request
87
- load_resources()
88
- if not is_model_ready:
89
- return {"answer": "⏳ Model still loading, please try again shortly."}
90
-
91
  question_text = request.question.strip()
92
  lang_filter = request.lang or detect(question_text)
93
 
@@ -106,8 +64,3 @@ def get_answer_endpoint(request: QueryRequest):
106
  best_idx = torch.argmax(cosine_scores).item()
107
  answer = filtered_df.iloc[best_idx]['answer']
108
  return {"answer": answer}
109
-
110
- # --- Run app directly ---
111
- if __name__ == "__main__":
112
- import uvicorn
113
- uvicorn.run("main:app", host="0.0.0.0", port=8080)
 
1
+ # main.py
2
  import os
3
  import torch
4
  import pandas as pd
5
  from fastapi import FastAPI
 
6
  from sentence_transformers import SentenceTransformer, util
 
7
  from huggingface_hub import hf_hub_download
 
8
 
 
9
  os.environ["HF_HOME"] = "/app/hf_cache"
10
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
11
  os.environ["TORCH_DISABLE_CUDA"] = "1"
12
 
 
13
  MODEL_PATH = './muril_combined_multilingual_model'
14
  CSV_PATH = './muril_multilingual_dataset.csv'
15
  HF_REPO = "Sp2503/muril-dataset"
16
  HF_FILE = "answer_embeddings.pt"
17
 
18
+ print("βš™οΈ Loading model and embeddings...")
 
19
 
20
+ # Load model
21
+ model = SentenceTransformer(MODEL_PATH)
22
+ df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
 
 
 
23
 
24
+ # Load embeddings from HF
25
+ hf_path = hf_hub_download(repo_id=HF_REPO, filename=HF_FILE, repo_type="dataset", cache_dir="/tmp")
26
+ answer_embeddings = torch.load(hf_path, map_location="cpu")
 
 
 
 
 
 
 
 
27
 
28
+ print("βœ… Model and embeddings loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ from fastapi import FastAPI
31
+ from pydantic import BaseModel
32
+ from langdetect import detect
 
 
33
 
34
+ app = FastAPI(title="MuRIL QA API")
 
 
 
 
 
 
35
 
 
36
  class QueryRequest(BaseModel):
37
  question: str
38
  lang: str = None
 
40
  class QAResponse(BaseModel):
41
  answer: str
42
 
43
+ @app.get("/")
44
+ def root():
45
+ return {"status": "βœ… Running", "model_loaded": True}
46
+
47
  @app.post("/get-answer", response_model=QAResponse)
48
  def get_answer_endpoint(request: QueryRequest):
 
 
 
 
 
 
49
  question_text = request.question.strip()
50
  lang_filter = request.lang or detect(question_text)
51
 
 
64
  best_idx = torch.argmax(cosine_scores).item()
65
  answer = filtered_df.iloc[best_idx]['answer']
66
  return {"answer": answer}