Sp2503 commited on
Commit
0a41dbe
Β·
verified Β·
1 Parent(s): 10947b0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +56 -97
main.py CHANGED
@@ -1,107 +1,66 @@
1
- import os
2
- import torch
3
- import pandas as pd
4
  from fastapi import FastAPI
5
- from pydantic import BaseModel
6
- from sentence_transformers import SentenceTransformer, util
7
- from langdetect import detect
8
- from huggingface_hub import hf_hub_download
9
- import threading
10
- import time
11
-
12
- # --- Cache Configuration ---
13
- os.environ["HF_HOME"] = "/app/hf_cache"
14
- os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
15
- os.environ["TORCH_DISABLE_CUDA"] = "1"
16
-
17
- # --- Paths ---
18
- MODEL_PATH = './muril_combined_multilingual_model'
19
- CSV_PATH = './muril_multilingual_dataset.csv'
20
- HF_REPO = "Sp2503/muril-dataset"
21
- HF_FILE = "answer_embeddings.pt"
22
-
23
- # --- FastAPI Setup ---
24
- app = FastAPI(title="MuRIL Multilingual QA API")
25
-
26
- # Global variables
27
  model = None
28
- df = None
29
  answer_embeddings = None
30
 
31
- # --- Helper: Load embeddings from Hugging Face ---
32
- def load_embeddings():
33
- print("πŸ“₯ Downloading embeddings from Hugging Face...")
34
- hf_path = hf_hub_download(
35
- repo_id=HF_REPO,
36
- filename=HF_FILE,
37
- repo_type="dataset",
38
- cache_dir="/tmp"
39
- )
40
- print(f"βœ… Embeddings available at {hf_path}")
41
- return torch.load(hf_path, map_location="cpu")
42
-
43
- # --- Resource Loader ---
44
- def load_resources():
45
- global model, df, answer_embeddings
46
- try:
47
- print("βš™οΈ Loading model and dataset...")
48
- model = SentenceTransformer(MODEL_PATH)
49
- df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
50
- answer_embeddings = load_embeddings()
51
- print("βœ… Model and embeddings ready.")
52
- except Exception as e:
53
- print(f"❌ Error loading resources: {e}")
54
-
55
- # --- Background Loader Thread ---
56
- @app.on_event("startup")
57
- def startup_event():
58
- print("πŸš€ Starting background model loader thread...")
59
- thread = threading.Thread(target=load_resources)
60
- thread.daemon = True
61
- thread.start()
62
 
63
  @app.get("/")
64
- def root():
65
- ready = model is not None and df is not None and answer_embeddings is not None
66
- return {"status": "βœ… Running MuRIL QA API", "model_loaded": ready}
67
-
68
- class QueryRequest(BaseModel):
69
- question: str
70
- lang: str = None
71
-
72
- class QAResponse(BaseModel):
73
- answer: str
74
-
75
- @app.post("/get-answer", response_model=QAResponse)
76
- def get_answer_endpoint(request: QueryRequest):
77
- if model is None or df is None or answer_embeddings is None:
78
- return {"answer": "⏳ Model still loading, please try again shortly."}
79
-
80
- question_text = request.question.strip()
81
- lang_filter = request.lang or detect(question_text)
82
-
83
- filtered_df = df
84
- filtered_embeddings = answer_embeddings
85
- if 'lang' in df.columns and lang_filter:
86
- mask = df['lang'] == lang_filter
87
- filtered_df = df[mask].reset_index(drop=True)
88
- filtered_embeddings = answer_embeddings[mask.values]
89
-
90
- if len(filtered_df) == 0:
91
- return {"answer": f"⚠️ No data found for language '{lang_filter}'."}
92
-
93
- question_emb = model.encode(question_text, convert_to_tensor=True)
94
- cosine_scores = util.pytorch_cos_sim(question_emb, filtered_embeddings)
95
- best_idx = torch.argmax(cosine_scores).item()
96
- answer = filtered_df.iloc[best_idx]['answer']
97
- return {"answer": answer}
98
-
99
- # --- Keep-alive thread for Spaces ---
100
- def keep_alive():
101
- while True:
102
- time.sleep(60)
103
 
104
  if __name__ == "__main__":
105
  import uvicorn
106
- threading.Thread(target=keep_alive, daemon=True).start()
107
  uvicorn.run("main:app", host="0.0.0.0", port=8080)
 
 
 
 
1
  from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import os
6
+
7
+ app = FastAPI(title="MuRIL QA Demo")
8
+
9
+ # Allow cross-origin requests
10
+ app.add_middleware(
11
+ CORSMiddleware,
12
+ allow_origins=["*"],
13
+ allow_credentials=True,
14
+ allow_methods=["*"],
15
+ allow_headers=["*"],
16
+ )
17
+
18
+ MODEL_NAME = "google/muril-base-cased"
19
+ EMBED_PATH = "/tmp/datasets--Sp2503--muril-dataset/snapshots/b768e5a3a401589f25b723c20f9674e88717db1b/answer_embeddings.pt"
20
+
 
 
 
21
  model = None
22
+ tokenizer = None
23
  answer_embeddings = None
24
 
25
+ def load_model():
26
+ global model, tokenizer, answer_embeddings
27
+
28
+ print("βš™οΈ Loading model and dataset...")
29
+
30
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
31
+ model = AutoModel.from_pretrained(MODEL_NAME)
32
+
33
+ if os.path.exists(EMBED_PATH):
34
+ answer_embeddings = torch.load(EMBED_PATH, map_location="cpu")
35
+ print(f"βœ… Embeddings loaded from {EMBED_PATH}")
36
+ else:
37
+ print("⚠️ Embeddings not found! Please check dataset path.")
38
+
39
+ print("βœ… Model and embeddings ready.")
40
+
41
+ # πŸš€ Load everything before starting FastAPI
42
+ print("πŸš€ Starting app...")
43
+ load_model()
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  @app.get("/")
46
+ def health_check():
47
+ return {"status": "ok"}
48
+
49
+ @app.get("/ask")
50
+ def ask(question: str):
51
+ if model is None or tokenizer is None or answer_embeddings is None:
52
+ return {"error": "Model not loaded yet"}
53
+
54
+ inputs = tokenizer(question, return_tensors="pt")
55
+ with torch.no_grad():
56
+ q_emb = model(**inputs).last_hidden_state.mean(dim=1)
57
+
58
+ similarities = torch.nn.functional.cosine_similarity(q_emb, answer_embeddings)
59
+ top_idx = torch.argmax(similarities).item()
60
+
61
+ return {"question": question, "answer_id": top_idx, "score": similarities[top_idx].item()}
62
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  if __name__ == "__main__":
65
  import uvicorn
 
66
  uvicorn.run("main:app", host="0.0.0.0", port=8080)