Muril-Model / app.py
Sp2503's picture
Rename main.py to app.py
d3dad7b verified
import os
import torch
import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from huggingface_hub import snapshot_download
# --- Cache Config ---
os.environ["HF_HOME"] = "/app/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
os.environ["TORCH_DISABLE_CUDA"] = "1"
# --- Download Model & Embeddings from Hub ---
HF_REPO = "Sp2503/Muril-Model"
print("πŸ“¦ Downloading model & embeddings from Hugging Face Hub...")
model_dir = snapshot_download(repo_id=HF_REPO, repo_type="model")
print(f"βœ… Model snapshot available at: {model_dir}")
MODEL_PATH = model_dir
CSV_PATH = os.path.join(model_dir, "muril_multilingual_dataset.csv")
EMBED_PATH = os.path.join(model_dir, "answer_embeddings.pt")
# --- Load Model ---
print("βš™οΈ Loading model and embeddings...")
model = SentenceTransformer(MODEL_PATH)
df = pd.read_csv(CSV_PATH).dropna(subset=['question', 'answer'])
answer_embeddings = torch.load(EMBED_PATH, map_location="cpu")
print("βœ… Model and embeddings loaded successfully.")
# --- FastAPI App ---
app = FastAPI(title="MuRIL Multilingual QA API")
class QueryRequest(BaseModel):
question: str
lang: str = None
class QAResponse(BaseModel):
answer: str
@app.get("/")
def root():
return {"status": "βœ… API is running", "model_loaded": True}
@app.post("/get-answer", response_model=QAResponse)
def get_answer_endpoint(request: QueryRequest):
question_text = request.question.strip()
lang_filter = request.lang
filtered_df = df
filtered_embeddings = answer_embeddings
if 'lang' in df.columns and lang_filter:
mask = df['lang'] == lang_filter
filtered_df = df[mask].reset_index(drop=True)
filtered_embeddings = answer_embeddings[mask.values]
if len(filtered_df) == 0:
return {"answer": f"No data found for language '{lang_filter}'."}
question_emb = model.encode(question_text, convert_to_tensor=True)
cosine_scores = util.pytorch_cos_sim(question_emb, filtered_embeddings)
best_idx = torch.argmax(cosine_scores).item()
answer = filtered_df.iloc[best_idx]['answer']
return {"answer": answer}