ShabdaAI / app.py
ravish5's picture
Update app.py
13c9333 verified
import os, re, pathlib
import numpy as np
import pandas as pd
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
PROJECT_DIR = pathlib.Path(__file__).parent.resolve()
DATA_DIR = PROJECT_DIR / "data"
DATA_DIR.mkdir(parents=True, exist_ok=True)
CSV_PATH = DATA_DIR / "sample_indic.csv"
SAMPLE_ROWS = [
{"id":"kn1","language":"kn","context":"ಬೆಂಗಳೂರು ಕರ್ನಾಟಕದ ರಾಜಧಾನಿ.","question":"ಕರ್ನಾಟಕದ ರಾಜಧಾನಿ ಯಾವುದು?","answer_text":"ಬೆಂಗಳೂರು"},
{"id":"kn2","language":"kn","context":"ಕನ್ನಡ ಒಂದು ದ್ರಾವಿಡ ಭಾಷೆ.","question":"ಕನ್ನಡ ಯಾವ ಭಾಷಾ ಕುಟುಂಬಕ್ಕೆ ಸೇರಿದೆ?","answer_text":"ದ್ರಾವಿಡ"},
{"id":"kn3","language":"kn","context":"ಮೈಸೂರು ಅರಮನೆ ಕರ್ನಾಟಕದ ಪ್ರಸಿದ್ಧ ತಾಣ.","question":"ಮೈಸೂರು ಅರಮನೆ ಎಲ್ಲಿದೆ?","answer_text":"ಕರ್ನಾಟಕ"},
{"id":"kn4","language":"kn","context":"ಟಿಪ್ಪು ಸುಲ್ತಾನ್ ಮೈಸೂರು ಸಾಮ್ರಾಜ್ಯದ ರಾಜನಾಗಿದ್ದನು.","question":"ಮೈಸೂರು ಸಾಮ್ರಾಜ್ಯದ ರಾಜ ಯಾರು?","answer_text":"ಟಿಪ್ಪು ಸುಲ್ತಾನ್"},
{"id":"kn5","language":"kn","context":"ಹಂಪಿ ಯುನೆಸ್ಕೋ ವಿಶ್ವ ಪರಂಪರೆ ತಾಣವಾಗಿದೆ.","question":"ಹಂಪಿ ಯಾವ ರೀತಿಯ ತಾಣ?","answer_text":"ವಿಶ್ವ ಪರಂಪರೆ ತಾಣ"},
{"id":"hi1","language":"hi","context":"दिल्ली भारत की राजधानी है।","question":"भारत की राजधानी क्या है?","answer_text":"दिल्ली"},
{"id":"hi2","language":"hi","context":"हिंदी एक इंडो-आर्यन भाषा है।","question":"हिंदी किस भाषा परिवार से संबंधित है?","answer_text":"इंडो-आर्यन"},
{"id":"hi3","language":"hi","context":"ताजमहल आगरा में स्थित है।","question":"ताजमहल कहाँ स्थित है?","answer_text":"आगरा"},
{"id":"hi4","language":"hi","context":"गंगा भारत की एक प्रमुख नदी है।","question":"गंगा क्या है?","answer_text":"नदी"},
{"id":"hi5","language":"hi","context":"मुंबई भारत का एक प्रमुख शहर है।","question":"मुंबई किस देश में है?","answer_text":"भारत"},
]
def ensure_sample_csv(path):
if not path.exists():
pd.DataFrame(SAMPLE_ROWS).to_csv(path,index=False,encoding="utf-8")
ensure_sample_csv(CSV_PATH)
_ZW = r"\u200b\u200c\u200d\ufeff"
ZW_RE = re.compile(f"[{_ZW}]")
def normalize_text(s):
if not isinstance(s,str):
return ""
s = ZW_RE.sub("",s)
s = re.sub(r"\s+"," ",s).strip()
return s
def normalize_answer(text):
text = normalize_text(text)
text = re.sub(r"[^\w\s]", "", text)
text = text.lower()
return text.strip()
df = pd.read_csv(CSV_PATH)
df["context_norm"] = df["context"].apply(normalize_text)
CORPUS = df["context_norm"].tolist()
EMB_MODEL_NAME = "intfloat/multilingual-e5-base"
emb_model = SentenceTransformer(EMB_MODEL_NAME)
emb_model.eval()
def encode_queries(texts):
texts=[f"query: {normalize_text(t)}" for t in texts]
return emb_model.encode(texts,normalize_embeddings=True)
def encode_passages(texts):
texts=[f"passage: {normalize_text(t)}" for t in texts]
return emb_model.encode(texts,normalize_embeddings=True)
PASSAGE_EMBS=encode_passages(CORPUS)
def retrieve_top_k(query,k=3):
qv=encode_queries([query])[0]
sims=np.dot(PASSAGE_EMBS,qv)
idxs=np.argsort(-sims)[:k]
results=[]
for rank,i in enumerate(idxs):
results.append({
"rank":rank+1,
"similarity":float(sims[i]),
"context":CORPUS[i],
"language":df.iloc[i]["language"]
})
return results
READER_MODEL="deepset/xlm-roberta-large-squad2"
device=0 if torch.cuda.is_available() else -1
tokenizer=AutoTokenizer.from_pretrained(READER_MODEL)
qa=pipeline(
"question-answering",
model=READER_MODEL,
tokenizer=tokenizer,
device=device
)
def answer_with_context(question,context):
out=qa(question=question,context=context)
return {"answer":out["answer"],"score":float(out["score"])}
def no_context_flow(question,lang_choice,top_k=3):
cands=retrieve_top_k(question,k=top_k)
best={"answer":"","score":-1,"used_context":""}
for c in cands:
if c["language"]!=lang_choice[:2].lower():
continue
out=answer_with_context(question,c["context"])
if out["score"]>best["score"]:
best={
"answer":out["answer"],
"score":out["score"],
"used_context":c["context"]
}
return {
"answer":best["answer"],
"score":best["score"],
"used_context":best["used_context"],
"retrieved":cands
}
NLLB_ID="facebook/nllb-200-distilled-600M"
nllb_tokenizer=AutoTokenizer.from_pretrained(NLLB_ID)
nllb_model=AutoModelForSeq2SeqLM.from_pretrained(NLLB_ID)
trans_hi_en=pipeline(
"translation",
model=nllb_model,
tokenizer=nllb_tokenizer,
src_lang="hin_Deva",
tgt_lang="eng_Latn",
device=device
)
trans_kn_en=pipeline(
"translation",
model=nllb_model,
tokenizer=nllb_tokenizer,
src_lang="kan_Knda",
tgt_lang="eng_Latn",
device=device
)
def hi_to_en(text):
if not text:
return ""
return trans_hi_en(text)[0]["translation_text"]
def kn_to_en(text):
if not text:
return ""
return trans_kn_en(text)[0]["translation_text"]
def exact_match(pred,gold):
return int(normalize_answer(pred)==normalize_answer(gold))
def token_f1(pred,gold):
pred_tokens=normalize_answer(pred).split()
gold_tokens=normalize_answer(gold).split()
common=set(pred_tokens)&set(gold_tokens)
if len(common)==0:
return 0.0
precision=len(common)/len(pred_tokens)
recall=len(common)/len(gold_tokens)
return 2*precision*recall/(precision+recall)
def semantic_similarity(pred,gold):
emb=encode_queries([pred,gold])
return float(cosine_similarity([emb[0]],[emb[1]])[0][0])
def evaluate_answer(question,lang_choice):
row=df[df["question"]==question]
if row.empty:
return {}
gold=row.iloc[0]["answer_text"]
result=no_context_flow(question,lang_choice,3)
pred=result["answer"]
return {
"prediction":pred,
"gold":gold,
"em":exact_match(pred,gold),
"f1":token_f1(pred,gold),
"sim":semantic_similarity(pred,gold)
}
INTRO_MD="""
### ShabdaAI Multilingual QA
Supports
Kannada
Hindi
Models
multilingual-e5-base (retrieval)
xlm-roberta-large-squad2 (QA)
nllb-200 (translation)
"""
def ui_answer(mode,question,user_context,top_k,lang_choice):
question=question or ""
user_context=user_context or ""
if mode=="With context":
res=answer_with_context(question,user_context)
ans=res["answer"]
used=user_context
retrieved_list=[]
else:
res=no_context_flow(question,lang_choice,top_k)
ans=res["answer"]
used=res["used_context"]
retrieved_list=res.get("retrieved",[])
if lang_choice=="Hindi":
ans_en=hi_to_en(ans)
else:
ans_en=kn_to_en(ans)
ev=evaluate_answer(question,lang_choice)
retrieved="\n".join(
[f"{r['rank']}. {r['context']} ({r['similarity']:.3f})" for r in retrieved_list]
)
return (
ans,
ans_en,
float(res.get("score",0)),
used,
retrieved,
ev.get("em",""),
ev.get("f1",""),
ev.get("sim","")
)
with gr.Blocks() as demo:
gr.Markdown(INTRO_MD)
mode=gr.Radio(["With context","No context"],value="With context")
question=gr.Textbox(label="Question")
user_context=gr.Textbox(label="Context")
top_k=gr.Slider(1,5,3)
lang_choice=gr.Dropdown(["Hindi","Kannada"],value="Kannada")
btn=gr.Button("Answer")
ans_local=gr.Textbox(label="Answer")
ans_en=gr.Textbox(label="Answer English")
score=gr.Textbox(label="Confidence")
used=gr.Textbox(label="Used Context")
retrieved=gr.Textbox(label="Retrieved Contexts")
em=gr.Textbox(label="Exact Match")
f1=gr.Textbox(label="F1 Score")
sim=gr.Textbox(label="Semantic Similarity")
btn.click(
ui_answer,
inputs=[mode,question,user_context,top_k,lang_choice],
outputs=[ans_local,ans_en,score,used,retrieved,em,f1,sim]
)
if __name__=="__main__":
os.environ["HF_HUB_DISABLE_TELEMETRY"]="1"
demo.launch(server_name="0.0.0.0",server_port=7860)