ravish5 commited on
Commit
4082cf2
·
verified ·
1 Parent(s): f604ef6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -74
app.py CHANGED
@@ -1,34 +1,55 @@
1
- import os, re, pathlib
2
  import numpy as np
3
  import pandas as pd
 
4
  import torch
5
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
6
  from sentence_transformers import SentenceTransformer
 
7
  import gradio as gr
8
 
9
- # --- Setup paths ---
10
  PROJECT_DIR = pathlib.Path(__file__).parent.resolve()
11
  DATA_DIR = PROJECT_DIR / "data"
12
  DATA_DIR.mkdir(parents=True, exist_ok=True)
13
  CSV_PATH = DATA_DIR / "sample_indic.csv"
14
 
15
- # --- Load dataset ---
16
- df = pd.read_csv(CSV_PATH, encoding="utf-8")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  _ZW = r"\u200b\u200c\u200d\ufeff"
19
  ZW_RE = re.compile(f"[{_ZW}]")
20
-
21
  def normalize_text(s: str) -> str:
22
  if not isinstance(s, str):
23
  return ""
24
- s = s.replace("\u0964", "।") # danda fix
25
  s = ZW_RE.sub("", s)
26
  s = re.sub(r"\s+", " ", s).strip()
27
  return s
28
 
 
29
  df["context_norm"] = df["context"].apply(normalize_text)
 
 
30
 
31
- # --- Embedding model ---
32
  EMB_MODEL_NAME = "intfloat/multilingual-e5-base"
33
  emb_model = SentenceTransformer(EMB_MODEL_NAME)
34
  emb_model.eval()
@@ -37,127 +58,199 @@ def encode_queries(texts):
37
  texts = [normalize_text(t) for t in texts]
38
  prefixed = [f"query: {t}" for t in texts]
39
  with torch.inference_mode():
40
- return emb_model.encode(prefixed, normalize_embeddings=True)
 
41
 
42
  def encode_passages(texts):
43
  texts = [normalize_text(t) for t in texts]
44
  prefixed = [f"passage: {t}" for t in texts]
45
  with torch.inference_mode():
46
- return emb_model.encode(prefixed, normalize_embeddings=True)
 
 
 
 
 
47
 
48
- # --- Build embeddings for whole dataset ---
49
- PASSAGE_EMBS = encode_passages(df["context_norm"].tolist())
50
 
51
- # --- Retriever ---
52
- def retrieve_top_k(query: str, lang_code: str, k: int = 3):
53
- if not query.strip():
54
  return []
55
  qv = encode_queries([query])[0]
56
  sims = np.dot(PASSAGE_EMBS, qv)
57
- mask = (df["language"] == lang_code).to_numpy()
58
- sims = np.where(mask, sims, -1e9)
59
  idxs = np.argsort(-sims)[:k]
60
  results = []
61
  for rank, i in enumerate(idxs):
62
- if sims[i] < -1e8:
63
- continue
64
- results.append({"rank": int(rank+1), "similarity": float(sims[i]), "context": df.iloc[i]["context_norm"]})
65
  return results
66
 
67
- # --- QA reader ---
68
  READER_MODEL = "deepset/xlm-roberta-large-squad2"
69
  device = 0 if torch.cuda.is_available() else -1
70
- qa = pipeline("question-answering", model=READER_MODEL, tokenizer=AutoTokenizer.from_pretrained(READER_MODEL), device=device)
71
 
72
- def answer_with_context(question, context):
73
- out = qa(question=normalize_text(question), context=normalize_text(context))
74
- return {"answer": out.get("answer","").strip(), "score": float(out.get("score",0.0))}
75
 
76
- # --- Translators (NLLB-200) ---
 
 
 
 
 
 
77
  NLLB_ID = "facebook/nllb-200-distilled-600M"
78
- nllb_tok = AutoTokenizer.from_pretrained(NLLB_ID)
79
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- def build_translator(src, tgt):
82
- return pipeline("translation", model=nllb_model, tokenizer=nllb_tok, src_lang=src, tgt_lang=tgt, device=device)
83
 
84
- trans_te_en = build_translator("tel_Telu", "eng_Latn")
85
- trans_kn_en = build_translator("kan_Knda", "eng_Latn")
86
 
87
- def te_to_en(text): return trans_te_en(text, max_length=256)[0]["translation_text"].strip() if text else ""
88
- def kn_to_en(text): return trans_kn_en(text, max_length=256)[0]["translation_text"].strip() if text else ""
89
 
90
- # --- Gradio App ---
91
  INTRO_MD = """
92
- ### ShabdaAI (Telugu + Kannada ↔ English)
93
- - **Mode 1:** Answer using provided context passage
94
- - **Mode 2:** If no passage, retrieve from small Telugu+Kannada corpus
 
 
 
 
 
95
 
96
- > Retrieval: **intfloat/multilingual-e5-base**
97
- > Reader: **deepset/xlm-roberta-large-squad2**
98
- > Translation: **NLLB-200**
99
  """
100
 
101
- def ui_answer(mode, lang_choice, translate_outputs_en, translate_inputs_en, question, user_context, top_k):
102
- if not question:
103
- return "", "", "0.000", "", "—", "—", "—"
104
 
105
- # Pick language + translator
106
  if lang_choice == "Telugu":
107
- to_en = te_to_en; lang_code = "te"
108
  else:
109
- to_en = kn_to_en; lang_code = "kn"
110
 
111
- # Input translations
112
- q_en = to_en(question) if translate_inputs_en else ""
113
- ctx_en = to_en(user_context) if translate_inputs_en and user_context else ""
114
 
115
  if mode == "With my context":
116
  res = answer_with_context(question, user_context)
117
- ans = res["answer"]; score = res["score"]
118
  ans_en = to_en(ans) if translate_outputs_en and ans else ""
119
- return ans, ans_en, f"{score:.3f}", user_context, ctx_en, q_en, "—"
 
120
  else:
121
- cands = retrieve_top_k(question, lang_code, k=top_k)
122
- best = {"answer":"", "score":-1.0, "context":""}
123
- for c in cands:
124
- out = answer_with_context(question, c["context"])
125
- if out["score"] > best["score"]:
126
- best = {"answer":out["answer"], "score":out["score"], "context":c["context"]}
127
- ans = best["answer"]; ans_en = to_en(ans) if translate_outputs_en and ans else ""
128
- tbl = "\n".join([f"{r['rank']}. (sim={r['similarity']:.3f}) {r['context']}" for r in cands]) or "—"
129
- return ans, ans_en, f"{best['score']:.3f}", best["context"], ctx_en, q_en, tbl
 
130
 
131
  with gr.Blocks() as demo:
132
  gr.Markdown(INTRO_MD)
133
 
134
  with gr.Row():
135
- mode = gr.Radio(choices=["With my context","No context (search sample data)"], value="With my context", label="Mode")
136
- lang_choice = gr.Dropdown(choices=["Telugu","Kannada"], value="Telugu", label="Language")
137
- top_k = gr.Slider(1,5,value=3,step=1,label="Top-K passages (for No-context mode)")
138
-
 
 
139
  with gr.Row():
140
- translate_outputs_en = gr.Checkbox(value=True, label="Translate Answer → English")
141
- translate_inputs_en = gr.Checkbox(value=True, label="Translate Inputs → English")
 
 
 
142
 
143
- question = gr.Textbox(label="Question", placeholder="e.g. హైదరాబాద్ ఎక్కడ ఉంది? / ಬೆಂಗಳೂರು ಯಾವ ರಾಜ್ಯದ ರಾಜಧಾನಿ?")
144
- user_context = gr.Textbox(label="Passage / Context (optional)", lines=4)
 
 
 
145
 
146
  btn = gr.Button("Answer")
147
 
148
- # Outputs
149
- answer_lang = gr.Textbox(label="Answer (Original Language)")
150
  answer_en = gr.Textbox(label="Answer (English)")
 
 
151
  score = gr.Textbox(label="Confidence score")
152
- used_ctx = gr.Textbox(label="Used context (Original)")
153
  ctx_en_box = gr.Textbox(label="Used context (English)")
154
  q_en_box = gr.Textbox(label="Question (English)")
155
- retrieved = gr.Textbox(label="Top-K retrieved passages", lines=4)
 
156
 
157
  btn.click(
158
  fn=ui_answer,
159
- inputs=[mode, lang_choice, translate_outputs_en, translate_inputs_en, question, user_context, top_k],
160
- outputs=[answer_lang, answer_en, score, used_ctx, ctx_en_box, q_en_box, retrieved]
161
  )
162
 
163
  if __name__ == "__main__":
 
1
+ import os, re, pathlib, json
2
  import numpy as np
3
  import pandas as pd
4
+
5
  import torch
6
+ from transformers import pipeline, AutoTokenizer
7
  from sentence_transformers import SentenceTransformer
8
+ from transformers import AutoModelForSeq2SeqLM
9
  import gradio as gr
10
 
11
+
12
  PROJECT_DIR = pathlib.Path(__file__).parent.resolve()
13
  DATA_DIR = PROJECT_DIR / "data"
14
  DATA_DIR.mkdir(parents=True, exist_ok=True)
15
  CSV_PATH = DATA_DIR / "sample_indic.csv"
16
 
17
+ SAMPLE_ROWS = [
18
+ {"id":"kn1","language":"kn","context":"ಬೆಂಗಳೂರು ಕರ್ನಾಟಕದ ರಾಜಧಾನಿ.","question":"ಕರ್ನಾಟಕದ ರಾಜಧಾನಿ ಯಾವುದು?","answer_text":"ಬೆಂಗಳೂರು"},
19
+ {"id":"kn2","language":"kn","context":"ಕನ್ನಡ ಒಂದು ದ್ರಾವಿಡ ಭಾಷೆ.","question":"ಕನ್ನಡ ಯಾವ ಭಾಷಾ ಕುಟುಂಬಕ್ಕೆ ಸೇರಿದೆ?","answer_text":"ದ್ರಾವಿಡ"},
20
+ {"id":"kn3","language":"kn","context":"ಮೈಸೂರು ಅರಮನೆ ಕರ್ನಾಟಕದ ಪ್ರಸಿದ್ಧ ತಾಣ.","question":"ಮೈಸೂರು ಅರಮನೆ ಎಲ್ಲಿದೆ?","answer_text":"ಕರ್ನಾಟಕ"},
21
+ {"id":"kn4","language":"kn","context":"ಟಿಪ್ಪು ಸುಲ್ತಾನ್ ಮೈಸೂರು ಸಾಮ್ರಾಜ್ಯದ ರಾಜನಾಗಿದ್ದನು.","question":"ಮೈಸೂರು ಸಾಮ್ರಾಜ್ಯದ ರಾಜ ಯಾರು?","answer_text":"ಟಿಪ್ಪು ಸುಲ್ತಾನ್"},
22
+ {"id":"kn5","language":"kn","context":"ಹಂಪಿ ಯುನೆಸ್ಕೋ ವಿಶ್ವ ಪರಂಪರೆ ತಾಣವಾಗಿದೆ.","question":"ಹಂಪಿ ಯಾವ ರೀತಿಯ ತಾಣ?","answer_text":"ವಿಶ್ವ ಪರಂಪರೆ ತಾಣ"},
23
+ {"id":"te1","language":"te","context":"తెలంగాణ రాష్ట్ర రాజధాని హైదరాబాదు. ఈ నగరం ఐటి పరిశ్రమకు ప్రసిద్ధి.","question":"తెలంగాణ రాష్ట్ర రాజధాని ఏది?","answer_text":"హైదరాబాదు"},
24
+ {"id":"te2","language":"te","context":"తెలుగు భాష ద్రావిడ భాషా కుటుంబానికి చెందినది. దాని లిపి తెలుగు లిపి.","question":"తెలుగు భాష ఏ లిపిని ఉపయోగిస్తుంది?","answer_text":"తెలుగు లిపి"},
25
+ {"id":"te3","language":"te","context":"సీతాకోక చిలుకలకు రెండు రెక్కలు ఉంటాయి. ఇవి పూల మకరందం తాగుతాయి.","question":"సీతాకోక చిలుకకు ఎన్ని రెక్కలు ఉన్నాయి?","answer_text":"రెండు"},
26
+ {"id":"te4","language":"te","context":"విశాఖపట్నం ఒక తీర నగరం. ఇది ఆంధ్రప్రదేశ్‌లోని ప్రముఖ నౌకాశ్రయం.","question":"విశాఖపట్నం ఏ రకమైన నగరం?","answer_text":"తీర నగరం"},
27
+ {"id":"te5","language":"te","context":"చార్మినార్ హైదరాబాద్ లో ఉంది. ఇది చారిత్రక స్మారక చిహ్నం.","question":"చార్మినార్ ఎక్కడ ఉంది?","answer_text":"హైదరాబాద్"},
28
+ ]
29
+
30
+ def ensure_sample_csv(path: pathlib.Path):
31
+ if not path.exists():
32
+ df = pd.DataFrame(SAMPLE_ROWS)
33
+ df.to_csv(path, index=False, encoding="utf-8")
34
+ print(f"[init] Wrote sample Kannada data to {path}")
35
+
36
+ ensure_sample_csv(CSV_PATH)
37
 
38
  _ZW = r"\u200b\u200c\u200d\ufeff"
39
  ZW_RE = re.compile(f"[{_ZW}]")
 
40
  def normalize_text(s: str) -> str:
41
  if not isinstance(s, str):
42
  return ""
43
+ s = s.replace("\u0964", "।")
44
  s = ZW_RE.sub("", s)
45
  s = re.sub(r"\s+", " ", s).strip()
46
  return s
47
 
48
+ df = pd.read_csv(CSV_PATH, encoding="utf-8")
49
  df["context_norm"] = df["context"].apply(normalize_text)
50
+ CORPUS = df["context_norm"].tolist()
51
+
52
 
 
53
  EMB_MODEL_NAME = "intfloat/multilingual-e5-base"
54
  emb_model = SentenceTransformer(EMB_MODEL_NAME)
55
  emb_model.eval()
 
58
  texts = [normalize_text(t) for t in texts]
59
  prefixed = [f"query: {t}" for t in texts]
60
  with torch.inference_mode():
61
+ vecs = emb_model.encode(prefixed, normalize_embeddings=True)
62
+ return vecs
63
 
64
  def encode_passages(texts):
65
  texts = [normalize_text(t) for t in texts]
66
  prefixed = [f"passage: {t}" for t in texts]
67
  with torch.inference_mode():
68
+ vecs = emb_model.encode(prefixed, normalize_embeddings=True)
69
+ return vecs
70
+
71
+ PASSAGE_EMBS = encode_passages(CORPUS)
72
+
73
+
74
 
75
+ def retrieve_top_k(query: str, k: int = 3):
76
+ if not query or not query.strip():
77
 
 
 
 
78
  return []
79
  qv = encode_queries([query])[0]
80
  sims = np.dot(PASSAGE_EMBS, qv)
81
+
82
+
83
  idxs = np.argsort(-sims)[:k]
84
  results = []
85
  for rank, i in enumerate(idxs):
86
+ results.append({"rank": int(rank+1), "similarity": float(sims[i]), "context": CORPUS[i]})
87
+
88
+
89
  return results
90
 
91
+
92
  READER_MODEL = "deepset/xlm-roberta-large-squad2"
93
  device = 0 if torch.cuda.is_available() else -1
 
94
 
 
 
 
95
 
96
+ tokenizer = AutoTokenizer.from_pretrained(READER_MODEL, use_fast=True)
97
+ qa = pipeline("question-answering", model=READER_MODEL, tokenizer=tokenizer, device=device)
98
+
99
+
100
+ # --- Kannada -> English translator (offline, NLLB-200) ---
101
+ # Model: facebook/nllb-200-distilled-600M
102
+ # Kannada = 'kan_Knda', English = 'eng_Latn'
103
  NLLB_ID = "facebook/nllb-200-distilled-600M"
104
+ nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_ID)
105
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_ID)
106
+ # Telugu -> English
107
+ trans_te_en = pipeline(
108
+ "translation",
109
+ model=nllb_model,
110
+ tokenizer=nllb_tokenizer,
111
+ src_lang="tel_Telu",
112
+ tgt_lang="eng_Latn",
113
+ device=device
114
+ )
115
+
116
+ def te_to_en(text: str) -> str:
117
+ text = (text or "").strip()
118
+ if not text: return ""
119
+ return trans_te_en(text, max_length=256)[0]["translation_text"].strip()
120
+
121
+ # Kannada -> English
122
+ trans_kn_en = pipeline(
123
+ "translation",
124
+ model=nllb_model,
125
+ tokenizer=nllb_tokenizer,
126
+ src_lang="kan_Knda",
127
+ tgt_lang="eng_Latn",
128
+ device=device
129
+ )
130
+
131
+ def kn_to_en(text: str) -> str:
132
+ text = (text or "").strip()
133
+ if not text: return ""
134
+ return trans_kn_en(text, max_length=256)[0]["translation_text"].strip()
135
+
136
+
137
+
138
+ def answer_with_context(question: str, context: str):
139
+ question = normalize_text(question)
140
+ context = normalize_text(context)
141
+ if not question or not context:
142
+ return {"answer": "", "score": 0.0}
143
+ out = qa(question=question, context=context)
144
+ ans = out.get("answer", "").strip()
145
+ score = float(out.get("score", 0.0))
146
+ return {"answer": ans, "score": score}
147
+
148
+ def no_context_flow(question: str, top_k: int = 3):
149
+ cands = retrieve_top_k(question, k=top_k)
150
+ if not cands:
151
+ return {"answer": "", "score": 0.0, "used_context": "", "retrieved": []}
152
+ best = {"answer": "", "score": -1.0, "used_context": ""}
153
+ for c in cands:
154
+ out = answer_with_context(question, c["context"])
155
+ if out["score"] > best["score"]:
156
+ best = {"answer": out["answer"], "score": out["score"], "used_context": c["context"]}
157
+ return {"answer": best["answer"], "score": best["score"], "used_context": best["used_context"], "retrieved": cands}
158
+
159
+
160
+
161
+
162
+
163
+
164
+
165
 
 
 
166
 
 
 
167
 
 
 
168
 
 
169
  INTRO_MD = """
170
+ ### ShabdaAI (Kannada ↔ English)
171
+ - **ಮೋಡ್ 1:** ನಾನು ನೀಡುವ ಪ್ಯಾಸೇಜ್ (context) ಆಧರಿಸಿ ಉತ್ತರಿಸು
172
+ - **ಮೋಡ್ 2:** ಪ್ಯಾಸೇಜ್ ಇಲ್ಲದಿದ್ದರೆ ಸಣ್ಣ ಕನ್ನಡ ಕಾರ್ಪಸ್‌ನಿಂದ *ಹುಡುಕು → ಓದು* ಮಾಡಿ ಉತ್ತರಿಸು
173
+ - **మోడ్ 1:** నేను ఇచ్చే ప్యాసేజ్ (context) పై సమాధానం ఇవ్వు
174
+ - **మోడ్ 2:** ప్యాసేజ్ ఇవ్వకపోతే — చిన్న తెలుగు కార్పస్‌లో *సెర్చ్ → రీడ్* చేసి సమాధానం ఇవ్వు
175
+
176
+ > Models: **intfloat/multilingual-e5-base** (retrieval) + **deepset/xlm-roberta-large-squad2** (extractive QA)
177
+
178
 
 
 
 
179
  """
180
 
181
+ def ui_answer(mode, translate_outputs_en, translate_inputs_en, question, user_context, top_k, lang_choice):
182
+ question = question or ""
183
+ user_context = user_context or ""
184
 
185
+ # Choose translator
186
  if lang_choice == "Telugu":
187
+ to_en = te_to_en
188
  else:
189
+ to_en = kn_to_en
190
 
191
+ # Optional translations
192
+ q_en = to_en(question) if translate_inputs_en and question else ""
193
+ ctx_en = to_en(user_context) if translate_inputs_en and user_context else ""
194
 
195
  if mode == "With my context":
196
  res = answer_with_context(question, user_context)
197
+ ans = res["answer"]
198
  ans_en = to_en(ans) if translate_outputs_en and ans else ""
199
+ return ans, ans_en, f"{res['score']:.3f}", user_context, ctx_en or "—", q_en or "—", "—"
200
+
201
  else:
202
+ res = no_context_flow(question, top_k=int(top_k))
203
+ ans = res["answer"]
204
+ ans_en = to_en(ans) if translate_outputs_en and ans else ""
205
+ retrieved_tbl = "\n".join(
206
+ [f"{r['rank']}. (sim={r['similarity']:.3f}) {r['context']}" for r in res.get("retrieved", [])]
207
+ ) or ""
208
+ return ans, ans_en, f"{res['score']:.3f}", res["used_context"], ctx_en or "—", q_en or "", retrieved_tbl
209
+
210
+
211
+
212
 
213
  with gr.Blocks() as demo:
214
  gr.Markdown(INTRO_MD)
215
 
216
  with gr.Row():
217
+ mode = gr.Radio(
218
+ choices=["With my context", "No context (search sample data)"],
219
+ value="With my context",
220
+ label="Mode"
221
+ )
222
+ top_k = gr.Slider(1, 5, value=3, step=1, label="Top-K passages (for No-context mode)")
223
  with gr.Row():
224
+ translate_outputs_en = gr.Checkbox(value=True, label="Translate ANSWER (Kannada → English)")
225
+ translate_inputs_en = gr.Checkbox(value=True, label="Translate INPUTS (Question/Context → English)")
226
+
227
+ question = gr.Textbox(label="ಪ್ರಶ್ನೆ (Question)", placeholder="ಉದಾ: ಬೆಂಗಳೂರು ಯಾವ ರಾಜ್ಯದ ರಾಜಧಾನಿ?")
228
+ user_context = gr.Textbox(label="ಪ್ಯಾಸೇಜ್ / ಸಂದರ್ಭ (optional)", lines=4)
229
 
230
+ lang_choice = gr.Dropdown(
231
+ choices=["Telugu", "Kannada"],
232
+ value="Kannada",
233
+ label="Language"
234
+ )
235
 
236
  btn = gr.Button("Answer")
237
 
238
+ # Answers
239
+ answer_local = gr.Textbox(label="Answer (Telugu/Kannada)")
240
  answer_en = gr.Textbox(label="Answer (English)")
241
+
242
+ # Confidence + contexts
243
  score = gr.Textbox(label="Confidence score")
244
+ used_ctx = gr.Textbox(label="Used context (Telugu/Kannada)")
245
  ctx_en_box = gr.Textbox(label="Used context (English)")
246
  q_en_box = gr.Textbox(label="Question (English)")
247
+
248
+ retrieved = gr.Textbox(label="Top-K retrieved passages (Telugu/Kannada)", lines=4)
249
 
250
  btn.click(
251
  fn=ui_answer,
252
+ inputs=[mode, translate_outputs_en, translate_inputs_en, question, user_context, top_k, lang_choice],
253
+ outputs=[answer_local, answer_en, score, used_ctx, ctx_en_box, q_en_box, retrieved]
254
  )
255
 
256
  if __name__ == "__main__":