Heng2004 commited on
Commit
c5298d8
·
verified ·
1 Parent(s): b2e6f17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -6
app.py CHANGED
@@ -24,8 +24,11 @@ QA_INDEX = {} # fast lookup: normalized question -> answer
24
 
25
 
26
  def _normalize_question(q: str) -> str:
27
- # simple normalization: collapse spaces and strip
28
- return re.sub(r"\s+", " ", q).strip()
 
 
 
29
 
30
 
31
  if os.path.exists(DATA_PATH):
@@ -159,7 +162,7 @@ def generate_answer(question: str) -> str:
159
  with torch.no_grad():
160
  outputs = model.generate(
161
  **inputs,
162
- max_new_tokens=120, # shorter answers = faster
163
  do_sample=False, # greedy decoding → more stable & a bit faster
164
  )
165
 
@@ -171,11 +174,37 @@ def generate_answer(question: str) -> str:
171
 
172
  def answer_from_qa(question: str) -> str | None:
173
  """
174
- Fast path: if the question exactly matches a QA pair from the dataset,
175
- return that answer immediately (no model call).
 
176
  """
177
  norm_q = _normalize_question(question)
178
- return QA_INDEX.get(norm_q)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
 
181
  # 3. Gradio chat function
 
24
 
25
 
26
  def _normalize_question(q: str) -> str:
27
+ # lowercase, remove basic punctuation, collapse spaces
28
+ q = q.lower()
29
+ q = re.sub(r"[?!?!\.\,\:\;\"“”'‘’]", " ", q)
30
+ q = re.sub(r"\s+", " ", q)
31
+ return q.strip()
32
 
33
 
34
  if os.path.exists(DATA_PATH):
 
162
  with torch.no_grad():
163
  outputs = model.generate(
164
  **inputs,
165
+ max_new_tokens=160, # shorter answers = faster
166
  do_sample=False, # greedy decoding → more stable & a bit faster
167
  )
168
 
 
174
 
175
  def answer_from_qa(question: str) -> str | None:
176
  """
177
+ 1) Try exact match in QA_INDEX.
178
+ 2) If not found, use simple fuzzy match:
179
+ pick the stored question that shares the most words.
180
  """
181
  norm_q = _normalize_question(question)
182
+
183
+ # 1) exact match first
184
+ if norm_q in QA_INDEX:
185
+ return QA_INDEX[norm_q]
186
+
187
+ # 2) fuzzy match
188
+ q_terms = [t for t in norm_q.split(" ") if len(t) > 1]
189
+ if not q_terms:
190
+ return None
191
+
192
+ best_score = 0
193
+ best_answer = None
194
+
195
+ for stored_q, a in QA_INDEX.items():
196
+ stored_terms = [t for t in stored_q.split(" ") if len(t) > 1]
197
+ overlap = sum(1 for t in q_terms if t in stored_terms)
198
+ if overlap > best_score:
199
+ best_score = overlap
200
+ best_answer = a
201
+
202
+ # require at least 1 overlapping word (e.g. ປະຫວັດສາດ or ຄວາມສໍາຄັນ)
203
+ if best_score >= 1:
204
+ return best_answer
205
+
206
+ return None
207
+
208
 
209
 
210
  # 3. Gradio chat function