ravish5 commited on
Commit
0712854
·
verified ·
1 Parent(s): 293c269

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, pathlib, json
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ import torch
6
+ from transformers import pipeline, AutoTokenizer
7
+ from sentence_transformers import SentenceTransformer
8
+ from transformers import AutoModelForSeq2SeqLM
9
+ import gradio as gr
10
+
11
+ PROJECT_DIR = pathlib.Path(__file__).parent.resolve()
12
+ DATA_DIR = PROJECT_DIR / "data"
13
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
14
+ CSV_PATH = DATA_DIR / "sample_telugu.csv"
15
+
16
+ SAMPLE_ROWS = [
17
+ {"id":"te1","language":"te","context":"తెలంగాణ రాష్ట్ర రాజధాని హైదరాబాదు. ఈ నగరం ఐటి పరిశ్రమకు ప్రసిద్ధి.","question":"తెలంగాణ రాష్ట్ర రాజధాని ఏది?","answer_text":"హైదరాబాదు"},
18
+ {"id":"te2","language":"te","context":"తెలుగు భాష ద్రావిడ భాషా కుటుంబానికి చెందినది. దాని లిపి తెలుగు లిపి.","question":"తెలుగు భాష ఏ లిపిని ఉపయోగిస్తుంది?","answer_text":"తెలుగు లిపి"},
19
+ {"id":"te3","language":"te","context":"సీతాకోక చిలుకలకు రెండు రెక్కలు ఉంటాయి. ఇవి పూల మకరందం తాగుతాయి.","question":"సీతాకోక చిలుకకు ఎన్ని రెక్కలు ఉన్నాయి?","answer_text":"రెండు"},
20
+ {"id":"te4","language":"te","context":"విశాఖపట్నం ఒక తీర నగరం. ఇది ఆంధ్రప్రదేశ్‌లోని ప్రముఖ నౌకాశ్రయం.","question":"విశాఖపట్నం ఏ రకమైన నగరం?","answer_text":"తీర నగరం"},
21
+ {"id":"te5","language":"te","context":"చార్మినార్ హైదరాబాద్ లో ఉంది. ఇది చారిత్రక స్మారక చిహ్నం.","question":"చార్మినార్ ఎక్కడ ఉంది?","answer_text":"హైదరాబాద్"},
22
+ ]
23
+
24
+ def ensure_sample_csv(path: pathlib.Path):
25
+ if not path.exists():
26
+ df = pd.DataFrame(SAMPLE_ROWS)
27
+ df.to_csv(path, index=False, encoding="utf-8")
28
+ print(f"[init] Wrote sample Telugu data to {path}")
29
+
30
+ ensure_sample_csv(CSV_PATH)
31
+
32
+ _ZW = r"\u200b\u200c\u200d\ufeff"
33
+ ZW_RE = re.compile(f"[{_ZW}]")
34
+
35
+ def normalize_text(s: str) -> str:
36
+ if not isinstance(s, str):
37
+ return ""
38
+ s = s.replace("\u0964", "।")
39
+ s = ZW_RE.sub("", s)
40
+ s = re.sub(r"\s+", " ", s).strip()
41
+ return s
42
+
43
+ df = pd.read_csv(CSV_PATH, encoding="utf-8")
44
+ df["context_norm"] = df["context"].apply(normalize_text)
45
+ CORPUS = df["context_norm"].tolist()
46
+
47
+ EMB_MODEL_NAME = "intfloat/multilingual-e5-base"
48
+ emb_model = SentenceTransformer(EMB_MODEL_NAME)
49
+ emb_model.eval()
50
+
51
+ def encode_queries(texts):
52
+ texts = [normalize_text(t) for t in texts]
53
+ prefixed = [f"query: {t}" for t in texts]
54
+ with torch.inference_mode():
55
+ vecs = emb_model.encode(prefixed, normalize_embeddings=True)
56
+ return vecs
57
+
58
+ def encode_passages(texts):
59
+ texts = [normalize_text(t) for t in texts]
60
+ prefixed = [f"passage: {t}" for t in texts]
61
+ with torch.inference_mode():
62
+ vecs = emb_model.encode(prefixed, normalize_embeddings=True)
63
+ return vecs
64
+
65
+ PASSAGE_EMBS = encode_passages(CORPUS)
66
+
67
+ def retrieve_top_k(query: str, k: int = 3):
68
+ if not query or not query.strip():
69
+ return []
70
+ qv = encode_queries([query])[0]
71
+ sims = np.dot(PASSAGE_EMBS, qv)
72
+ idxs = np.argsort(-sims)[:k]
73
+ results = []
74
+ for rank, i in enumerate(idxs):
75
+ results.append({"rank": int(rank+1), "similarity": float(sims[i]), "context": CORPUS[i]})
76
+ return results
77
+
78
+ READER_MODEL = "deepset/xlm-roberta-large-squad2"
79
+ device = 0 if torch.cuda.is_available() else -1
80
+
81
+ tokenizer = AutoTokenizer.from_pretrained(READER_MODEL, use_fast=True)
82
+ qa = pipeline("question-answering", model=READER_MODEL, tokenizer=tokenizer, device=device)
83
+
84
+ # --- Telugu -> English translator (offline, NLLB-200) ---
85
+ # Model: facebook/nllb-200-distilled-600M
86
+ # Language codes: Telugu = 'tel_Telu', English = 'eng_Latn'
87
+ NLLB_ID = "facebook/nllb-200-distilled-600M"
88
+ nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_ID)
89
+ nllb_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_ID)
90
+ trans_te_en = pipeline(
91
+ "translation",
92
+ model=nllb_model,
93
+ tokenizer=nllb_tokenizer,
94
+ src_lang="tel_Telu",
95
+ tgt_lang="eng_Latn",
96
+ device=device
97
+ )
98
+
99
+ def te_to_en(text: str) -> str:
100
+ text = (text or "").strip()
101
+ if not text:
102
+ return ""
103
+ out = trans_te_en(text, max_length=256)
104
+ return out[0]["translation_text"].strip()
105
+
106
+
107
+ def answer_with_context(question: str, context: str):
108
+ question = normalize_text(question)
109
+ context = normalize_text(context)
110
+ if not question or not context:
111
+ return {"answer": "", "score": 0.0}
112
+ out = qa(question=question, context=context)
113
+ ans = out.get("answer", "").strip()
114
+ score = float(out.get("score", 0.0))
115
+ return {"answer": ans, "score": score}
116
+
117
+ def no_context_flow(question: str, top_k: int = 3):
118
+ cands = retrieve_top_k(question, k=top_k)
119
+ if not cands:
120
+ return {"answer": "", "score": 0.0, "used_context": "", "retrieved": []}
121
+ best = {"answer": "", "score": -1.0, "used_context": ""}
122
+ for c in cands:
123
+ out = answer_with_context(question, c["context"])
124
+ if out["score"] > best["score"]:
125
+ best = {"answer": out["answer"], "score": out["score"], "used_context": c["context"]}
126
+ return {"answer": best["answer"], "score": best["score"], "used_context": best["used_context"], "retrieved": cands}
127
+
128
+ INTRO_MD = """
129
+ ### ShabdaAI
130
+ - **మోడ్ 1:** నేను ఇచ్చే ప్యాసేజ్ (context) పై సమాధానం ఇవ్వు
131
+ - **మోడ్ 2:** ప్యాసేజ్ ఇవ్వకపోతే — చిన్న తెలుగు కార్పస్‌లో *సెర్చ్ → రీడ్* చేసి సమాధానం ఇవ్వు
132
+
133
+ > Models: **intfloat/multilingual-e5-base** (retrieval) + **deepset/xlm-roberta-large-squad2** (extractive QA)
134
+ """
135
+
136
+ def ui_answer(mode, translate_outputs_en, translate_inputs_en, question, user_context, top_k):
137
+ question = question or ""
138
+ user_context = user_context or ""
139
+
140
+ # Optional English translations of inputs
141
+ q_en = te_to_en(question) if translate_inputs_en and question else ""
142
+ ctx_en = te_to_en(user_context) if translate_inputs_en and user_context else ""
143
+
144
+ if mode == "With my context":
145
+ res = answer_with_context(question, user_context)
146
+ ans_te = res["answer"]
147
+ ans_en = te_to_en(ans_te) if translate_outputs_en and ans_te else ""
148
+ return ans_te, ans_en, f"{res['score']:.3f}", user_context, ctx_en or "—", q_en or "—", "—"
149
+
150
+ else:
151
+ res = no_context_flow(question, top_k=int(top_k))
152
+ ans_te = res["answer"]
153
+ ans_en = te_to_en(ans_te) if translate_outputs_en and ans_te else ""
154
+ retrieved_tbl = "\n".join(
155
+ [f"{r['rank']}. (sim={r['similarity']:.3f}) {r['context']}" for r in res.get("retrieved", [])]
156
+ ) or "—"
157
+ return ans_te, ans_en, f"{res['score']:.3f}", res["used_context"], ctx_en or "—", q_en or "—", retrieved_tbl
158
+
159
+
160
+ with gr.Blocks() as demo:
161
+ gr.Markdown(INTRO_MD)
162
+
163
+ with gr.Row():
164
+ mode = gr.Radio(
165
+ choices=["With my context", "No context (search sample data)"],
166
+ value="With my context",
167
+ label="Mode"
168
+ )
169
+ top_k = gr.Slider(1, 5, value=3, step=1, label="Top-K passages (for No-context mode)")
170
+ with gr.Row():
171
+ translate_outputs_en = gr.Checkbox(value=True, label="Translate ANSWER (Telugu → English)")
172
+ translate_inputs_en = gr.Checkbox(value=True, label="Translate INPUTS (Question/Context → English)")
173
+
174
+ question = gr.Textbox(label="ప్రశ్న (Question)", placeholder="ఉదా: చార్మినార్ ఎక్కడ ఉంది?")
175
+ user_context = gr.Textbox(label="ప్యాసేజ్ / కాంటెక్స్ట్ (optional)", lines=4)
176
+
177
+ btn = gr.Button("Answer")
178
+
179
+ # Answers
180
+ answer_te = gr.Textbox(label="Answer (Telugu)")
181
+ answer_en = gr.Textbox(label="Answer (English)")
182
+
183
+ # Confidence + contexts
184
+ score = gr.Textbox(label="Confidence score")
185
+ used_ctx = gr.Textbox(label="Used context (Telugu)")
186
+ ctx_en_box = gr.Textbox(label="Used context (English)")
187
+ q_en_box = gr.Textbox(label="Question (English)")
188
+
189
+ retrieved = gr.Textbox(label="Top-K retrieved passages (Telugu)", lines=4)
190
+
191
+ btn.click(
192
+ fn=ui_answer,
193
+ inputs=[mode, translate_outputs_en, translate_inputs_en, question, user_context, top_k],
194
+ outputs=[answer_te, answer_en, score, used_ctx, ctx_en_box, q_en_box, retrieved]
195
+ )
196
+
197
+ if __name__ == "__main__":
198
+ os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
199
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)