| import os, re, pathlib |
| import numpy as np |
| import pandas as pd |
|
|
| import torch |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
| from sentence_transformers import SentenceTransformer |
| import gradio as gr |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
| PROJECT_DIR = pathlib.Path(__file__).parent.resolve() |
| DATA_DIR = PROJECT_DIR / "data" |
| DATA_DIR.mkdir(parents=True, exist_ok=True) |
| CSV_PATH = DATA_DIR / "sample_indic.csv" |
|
|
|
|
| SAMPLE_ROWS = [ |
| {"id":"kn1","language":"kn","context":"ಬೆಂಗಳೂರು ಕರ್ನಾಟಕದ ರಾಜಧಾನಿ.","question":"ಕರ್ನಾಟಕದ ರಾಜಧಾನಿ ಯಾವುದು?","answer_text":"ಬೆಂಗಳೂರು"}, |
| {"id":"kn2","language":"kn","context":"ಕನ್ನಡ ಒಂದು ದ್ರಾವಿಡ ಭಾಷೆ.","question":"ಕನ್ನಡ ಯಾವ ಭಾಷಾ ಕುಟುಂಬಕ್ಕೆ ಸೇರಿದೆ?","answer_text":"ದ್ರಾವಿಡ"}, |
| {"id":"kn3","language":"kn","context":"ಮೈಸೂರು ಅರಮನೆ ಕರ್ನಾಟಕದ ಪ್ರಸಿದ್ಧ ತಾಣ.","question":"ಮೈಸೂರು ಅರಮನೆ ಎಲ್ಲಿದೆ?","answer_text":"ಕರ್ನಾಟಕ"}, |
| {"id":"kn4","language":"kn","context":"ಟಿಪ್ಪು ಸುಲ್ತಾನ್ ಮೈಸೂರು ಸಾಮ್ರಾಜ್ಯದ ರಾಜನಾಗಿದ್ದನು.","question":"ಮೈಸೂರು ಸಾಮ್ರಾಜ್ಯದ ರಾಜ ಯಾರು?","answer_text":"ಟಿಪ್ಪು ಸುಲ್ತಾನ್"}, |
| {"id":"kn5","language":"kn","context":"ಹಂಪಿ ಯುನೆಸ್ಕೋ ವಿಶ್ವ ಪರಂಪರೆ ತಾಣವಾಗಿದೆ.","question":"ಹಂಪಿ ಯಾವ ರೀತಿಯ ತಾಣ?","answer_text":"ವಿಶ್ವ ಪರಂಪರೆ ತಾಣ"}, |
|
|
| {"id":"hi1","language":"hi","context":"दिल्ली भारत की राजधानी है।","question":"भारत की राजधानी क्या है?","answer_text":"दिल्ली"}, |
| {"id":"hi2","language":"hi","context":"हिंदी एक इंडो-आर्यन भाषा है।","question":"हिंदी किस भाषा परिवार से संबंधित है?","answer_text":"इंडो-आर्यन"}, |
| {"id":"hi3","language":"hi","context":"ताजमहल आगरा में स्थित है।","question":"ताजमहल कहाँ स्थित है?","answer_text":"आगरा"}, |
| {"id":"hi4","language":"hi","context":"गंगा भारत की एक प्रमुख नदी है।","question":"गंगा क्या है?","answer_text":"नदी"}, |
| {"id":"hi5","language":"hi","context":"मुंबई भारत का एक प्रमुख शहर है।","question":"मुंबई किस देश में है?","answer_text":"भारत"}, |
| ] |
|
|
|
|
| def ensure_sample_csv(path): |
| if not path.exists(): |
| pd.DataFrame(SAMPLE_ROWS).to_csv(path,index=False,encoding="utf-8") |
|
|
| ensure_sample_csv(CSV_PATH) |
|
|
|
|
| _ZW = r"\u200b\u200c\u200d\ufeff" |
| ZW_RE = re.compile(f"[{_ZW}]") |
|
|
| def normalize_text(s): |
| if not isinstance(s,str): |
| return "" |
| s = ZW_RE.sub("",s) |
| s = re.sub(r"\s+"," ",s).strip() |
| return s |
|
|
|
|
| def normalize_answer(text): |
| text = normalize_text(text) |
| text = re.sub(r"[^\w\s]", "", text) |
| text = text.lower() |
| return text.strip() |
|
|
|
|
| df = pd.read_csv(CSV_PATH) |
| df["context_norm"] = df["context"].apply(normalize_text) |
|
|
| CORPUS = df["context_norm"].tolist() |
|
|
|
|
| EMB_MODEL_NAME = "intfloat/multilingual-e5-base" |
| emb_model = SentenceTransformer(EMB_MODEL_NAME) |
| emb_model.eval() |
|
|
|
|
| def encode_queries(texts): |
| texts=[f"query: {normalize_text(t)}" for t in texts] |
| return emb_model.encode(texts,normalize_embeddings=True) |
|
|
|
|
| def encode_passages(texts): |
| texts=[f"passage: {normalize_text(t)}" for t in texts] |
| return emb_model.encode(texts,normalize_embeddings=True) |
|
|
|
|
| PASSAGE_EMBS=encode_passages(CORPUS) |
|
|
|
|
| def retrieve_top_k(query,k=3): |
|
|
| qv=encode_queries([query])[0] |
| sims=np.dot(PASSAGE_EMBS,qv) |
|
|
| idxs=np.argsort(-sims)[:k] |
|
|
| results=[] |
|
|
| for rank,i in enumerate(idxs): |
|
|
| results.append({ |
| "rank":rank+1, |
| "similarity":float(sims[i]), |
| "context":CORPUS[i], |
| "language":df.iloc[i]["language"] |
| }) |
|
|
| return results |
|
|
|
|
| READER_MODEL="deepset/xlm-roberta-large-squad2" |
|
|
| device=0 if torch.cuda.is_available() else -1 |
|
|
| tokenizer=AutoTokenizer.from_pretrained(READER_MODEL) |
|
|
| qa=pipeline( |
| "question-answering", |
| model=READER_MODEL, |
| tokenizer=tokenizer, |
| device=device |
| ) |
|
|
|
|
| def answer_with_context(question,context): |
|
|
| out=qa(question=question,context=context) |
|
|
| return {"answer":out["answer"],"score":float(out["score"])} |
|
|
|
|
| def no_context_flow(question,lang_choice,top_k=3): |
|
|
| cands=retrieve_top_k(question,k=top_k) |
|
|
| best={"answer":"","score":-1,"used_context":""} |
|
|
| for c in cands: |
|
|
| if c["language"]!=lang_choice[:2].lower(): |
| continue |
|
|
| out=answer_with_context(question,c["context"]) |
|
|
| if out["score"]>best["score"]: |
|
|
| best={ |
| "answer":out["answer"], |
| "score":out["score"], |
| "used_context":c["context"] |
| } |
|
|
| return { |
| "answer":best["answer"], |
| "score":best["score"], |
| "used_context":best["used_context"], |
| "retrieved":cands |
| } |
|
|
|
|
| NLLB_ID="facebook/nllb-200-distilled-600M" |
|
|
| nllb_tokenizer=AutoTokenizer.from_pretrained(NLLB_ID) |
| nllb_model=AutoModelForSeq2SeqLM.from_pretrained(NLLB_ID) |
|
|
|
|
| trans_hi_en=pipeline( |
| "translation", |
| model=nllb_model, |
| tokenizer=nllb_tokenizer, |
| src_lang="hin_Deva", |
| tgt_lang="eng_Latn", |
| device=device |
| ) |
|
|
| trans_kn_en=pipeline( |
| "translation", |
| model=nllb_model, |
| tokenizer=nllb_tokenizer, |
| src_lang="kan_Knda", |
| tgt_lang="eng_Latn", |
| device=device |
| ) |
|
|
|
|
| def hi_to_en(text): |
| if not text: |
| return "" |
| return trans_hi_en(text)[0]["translation_text"] |
|
|
|
|
| def kn_to_en(text): |
| if not text: |
| return "" |
| return trans_kn_en(text)[0]["translation_text"] |
|
|
|
|
| def exact_match(pred,gold): |
| return int(normalize_answer(pred)==normalize_answer(gold)) |
|
|
|
|
| def token_f1(pred,gold): |
|
|
| pred_tokens=normalize_answer(pred).split() |
| gold_tokens=normalize_answer(gold).split() |
|
|
| common=set(pred_tokens)&set(gold_tokens) |
|
|
| if len(common)==0: |
| return 0.0 |
|
|
| precision=len(common)/len(pred_tokens) |
| recall=len(common)/len(gold_tokens) |
|
|
| return 2*precision*recall/(precision+recall) |
|
|
|
|
| def semantic_similarity(pred,gold): |
|
|
| emb=encode_queries([pred,gold]) |
|
|
| return float(cosine_similarity([emb[0]],[emb[1]])[0][0]) |
|
|
|
|
| def evaluate_answer(question,lang_choice): |
|
|
| row=df[df["question"]==question] |
|
|
| if row.empty: |
| return {} |
|
|
| gold=row.iloc[0]["answer_text"] |
|
|
| result=no_context_flow(question,lang_choice,3) |
|
|
| pred=result["answer"] |
|
|
| return { |
| "prediction":pred, |
| "gold":gold, |
| "em":exact_match(pred,gold), |
| "f1":token_f1(pred,gold), |
| "sim":semantic_similarity(pred,gold) |
| } |
|
|
|
|
| INTRO_MD=""" |
| ### ShabdaAI Multilingual QA |
| |
| Supports |
| |
| Kannada |
| Hindi |
| |
| Models |
| |
| multilingual-e5-base (retrieval) |
| xlm-roberta-large-squad2 (QA) |
| nllb-200 (translation) |
| """ |
|
|
|
|
| def ui_answer(mode,question,user_context,top_k,lang_choice): |
|
|
| question=question or "" |
| user_context=user_context or "" |
|
|
| if mode=="With context": |
|
|
| res=answer_with_context(question,user_context) |
|
|
| ans=res["answer"] |
| used=user_context |
| retrieved_list=[] |
|
|
| else: |
|
|
| res=no_context_flow(question,lang_choice,top_k) |
|
|
| ans=res["answer"] |
| used=res["used_context"] |
| retrieved_list=res.get("retrieved",[]) |
|
|
| if lang_choice=="Hindi": |
| ans_en=hi_to_en(ans) |
| else: |
| ans_en=kn_to_en(ans) |
|
|
| ev=evaluate_answer(question,lang_choice) |
|
|
| retrieved="\n".join( |
| [f"{r['rank']}. {r['context']} ({r['similarity']:.3f})" for r in retrieved_list] |
| ) |
|
|
| return ( |
| ans, |
| ans_en, |
| float(res.get("score",0)), |
| used, |
| retrieved, |
| ev.get("em",""), |
| ev.get("f1",""), |
| ev.get("sim","") |
| ) |
|
|
|
|
| with gr.Blocks() as demo: |
|
|
| gr.Markdown(INTRO_MD) |
|
|
| mode=gr.Radio(["With context","No context"],value="With context") |
|
|
| question=gr.Textbox(label="Question") |
|
|
| user_context=gr.Textbox(label="Context") |
|
|
| top_k=gr.Slider(1,5,3) |
|
|
| lang_choice=gr.Dropdown(["Hindi","Kannada"],value="Kannada") |
|
|
| btn=gr.Button("Answer") |
|
|
| ans_local=gr.Textbox(label="Answer") |
|
|
| ans_en=gr.Textbox(label="Answer English") |
|
|
| score=gr.Textbox(label="Confidence") |
|
|
| used=gr.Textbox(label="Used Context") |
|
|
| retrieved=gr.Textbox(label="Retrieved Contexts") |
|
|
| em=gr.Textbox(label="Exact Match") |
|
|
| f1=gr.Textbox(label="F1 Score") |
|
|
| sim=gr.Textbox(label="Semantic Similarity") |
|
|
|
|
| btn.click( |
| ui_answer, |
| inputs=[mode,question,user_context,top_k,lang_choice], |
| outputs=[ans_local,ans_en,score,used,retrieved,em,f1,sim] |
| ) |
|
|
|
|
| if __name__=="__main__": |
|
|
| os.environ["HF_HUB_DISABLE_TELEMETRY"]="1" |
|
|
| demo.launch(server_name="0.0.0.0",server_port=7860) |