Heng2004 commited on
Commit
a254b10
·
verified ·
1 Parent(s): c5298d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -22
app.py CHANGED
@@ -3,6 +3,8 @@
3
  import os
4
  import re
5
  import json
 
 
6
  import gradio as gr
7
  import torch
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -20,17 +22,31 @@ model = AutoModelForCausalLM.from_pretrained(
20
  DATA_PATH = "data/1_Year_U_1.jsonl"
21
 
22
  ENTRIES = [] # each entry is one JSON object (your schema)
23
- QA_INDEX = {} # fast lookup: normalized question -> answer
24
 
25
 
26
  def _normalize_question(q: str) -> str:
27
- # lowercase, remove basic punctuation, collapse spaces
 
 
 
 
 
28
  q = q.lower()
29
  q = re.sub(r"[?!?!\.\,\:\;\"“”'‘’]", " ", q)
30
  q = re.sub(r"\s+", " ", q)
31
  return q.strip()
32
 
33
 
 
 
 
 
 
 
 
 
 
34
  if os.path.exists(DATA_PATH):
35
  with open(DATA_PATH, "r", encoding="utf-8") as f:
36
  for line in f:
@@ -87,7 +103,7 @@ def retrieve_context(question: str, max_entries: int = 2) -> str:
87
  for e in chosen
88
  )
89
 
90
- scored = []
91
 
92
  for e in ENTRIES:
93
  text = e.get("text", "")
@@ -157,13 +173,17 @@ def build_prompt(question: str) -> str:
157
 
158
 
159
  def generate_answer(question: str) -> str:
 
 
 
 
160
  prompt = build_prompt(question)
161
  inputs = tokenizer(prompt, return_tensors="pt")
162
  with torch.no_grad():
163
  outputs = model.generate(
164
  **inputs,
165
- max_new_tokens=160, # shorter answers = faster
166
- do_sample=False, # greedy decoding → more stable & a bit faster
167
  )
168
 
169
  # slice off the prompt part
@@ -174,37 +194,56 @@ def generate_answer(question: str) -> str:
174
 
175
  def answer_from_qa(question: str) -> str | None:
176
  """
177
- 1) Try exact match in QA_INDEX.
178
- 2) If not found, use simple fuzzy match:
179
- pick the stored question that shares the most words.
 
 
 
180
  """
181
  norm_q = _normalize_question(question)
 
 
182
 
183
  # 1) exact match first
184
  if norm_q in QA_INDEX:
185
  return QA_INDEX[norm_q]
186
 
187
- # 2) fuzzy match
188
- q_terms = [t for t in norm_q.split(" ") if len(t) > 1]
189
- if not q_terms:
190
- return None
191
-
192
- best_score = 0
193
  best_answer = None
194
 
195
  for stored_q, a in QA_INDEX.items():
196
- stored_terms = [t for t in stored_q.split(" ") if len(t) > 1]
197
- overlap = sum(1 for t in q_terms if t in stored_terms)
198
- if overlap > best_score:
199
- best_score = overlap
200
  best_answer = a
201
 
202
- # require at least 1 overlapping word (e.g. ປະຫວັດສາດ or ຄວາມສໍາຄັນ)
203
- if best_score >= 1:
 
204
  return best_answer
205
 
206
- return None
 
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
 
210
  # 3. Gradio chat function
@@ -212,7 +251,7 @@ def laos_history_bot(message: str, history: list):
212
  if not message.strip():
213
  return "ກະລຸນາພິມຄຳຖາມກ່ອນ."
214
 
215
- # 1) Try to answer directly from QA pairs (instant)
216
  direct = answer_from_qa(message)
217
  if direct:
218
  return direct
 
3
  import os
4
  import re
5
  import json
6
+ from difflib import SequenceMatcher # 👈 for better fuzzy matching
7
+
8
  import gradio as gr
9
  import torch
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
22
  DATA_PATH = "data/1_Year_U_1.jsonl"
23
 
24
  ENTRIES = [] # each entry is one JSON object (your schema)
25
+ QA_INDEX: dict[str, str] = {} # fast lookup: normalized question -> answer
26
 
27
 
28
  def _normalize_question(q: str) -> str:
29
+ """
30
+ Normalize Lao questions for matching:
31
+ - lowercase
32
+ - remove common punctuation
33
+ - collapse spaces
34
+ """
35
  q = q.lower()
36
  q = re.sub(r"[?!?!\.\,\:\;\"“”'‘’]", " ", q)
37
  q = re.sub(r"\s+", " ", q)
38
  return q.strip()
39
 
40
 
41
+ def _similarity(a: str, b: str) -> float:
42
+ """
43
+ Character-level similarity between two normalized strings.
44
+ Works OK for Lao because we’re still matching on shared sequences.
45
+ """
46
+ return SequenceMatcher(None, a, b).ratio()
47
+
48
+
49
+ # Load dataset and build QA index
50
  if os.path.exists(DATA_PATH):
51
  with open(DATA_PATH, "r", encoding="utf-8") as f:
52
  for line in f:
 
103
  for e in chosen
104
  )
105
 
106
+ scored: list[tuple[int, dict]] = []
107
 
108
  for e in ENTRIES:
109
  text = e.get("text", "")
 
173
 
174
 
175
  def generate_answer(question: str) -> str:
176
+ """
177
+ Use SeaLLM + retrieved context to generate an answer.
178
+ Kept fairly short for speed and to avoid rambling.
179
+ """
180
  prompt = build_prompt(question)
181
  inputs = tokenizer(prompt, return_tensors="pt")
182
  with torch.no_grad():
183
  outputs = model.generate(
184
  **inputs,
185
+ max_new_tokens=160, # shorter answers = faster, less chance to cut mid-sentence
186
+ do_sample=False, # greedy decoding → more stable & deterministic
187
  )
188
 
189
  # slice off the prompt part
 
194
 
195
  def answer_from_qa(question: str) -> str | None:
196
  """
197
+ Try to answer directly from:
198
+ 1) Exact QA pairs.
199
+ 2) Fuzzy QA question similarity.
200
+ 3) Fuzzy match to entry summaries/titles (good for 'ຄວາມສໍາຄັນ...' type questions).
201
+
202
+ If nothing is good enough, return None so the model will answer.
203
  """
204
  norm_q = _normalize_question(question)
205
+ if not norm_q:
206
+ return None
207
 
208
  # 1) exact match first
209
  if norm_q in QA_INDEX:
210
  return QA_INDEX[norm_q]
211
 
212
+ # 2) fuzzy match over QA questions
213
+ best_ratio = 0.0
 
 
 
 
214
  best_answer = None
215
 
216
  for stored_q, a in QA_INDEX.items():
217
+ r = _similarity(norm_q, stored_q)
218
+ if r > best_ratio:
219
+ best_ratio = r
 
220
  best_answer = a
221
 
222
+ # threshold tuned so that very close questions (wording a bit different)
223
+ # still return the textbook QA answer
224
+ if best_ratio >= 0.55 and best_answer:
225
  return best_answer
226
 
227
+ # 3) fallback: fuzzy match over entry summaries / titles / keywords
228
+ best_ratio = 0.0
229
+ best_summary = None
230
 
231
+ for e in ENTRIES:
232
+ combined = f"{e.get('title','')} {e.get('summary','')} {' '.join(e.get('keywords', []))}"
233
+ combined_norm = _normalize_question(combined)
234
+ if not combined_norm:
235
+ continue
236
+
237
+ r = _similarity(norm_q, combined_norm)
238
+ if r > best_ratio:
239
+ best_ratio = r
240
+ best_summary = e.get("summary") or e.get("text")
241
+
242
+ # lower threshold here because we’re matching against shorter summaries
243
+ if best_ratio >= 0.35 and best_summary:
244
+ return best_summary
245
+
246
+ return None
247
 
248
 
249
  # 3. Gradio chat function
 
251
  if not message.strip():
252
  return "ກະລຸນາພິມຄຳຖາມກ່ອນ."
253
 
254
+ # 1) Try to answer directly from QA pairs or summaries (instant)
255
  direct = answer_from_qa(message)
256
  if direct:
257
  return direct