Update app.py
Browse files
app.py
CHANGED
|
@@ -24,8 +24,11 @@ QA_INDEX = {} # fast lookup: normalized question -> answer
|
|
| 24 |
|
| 25 |
|
| 26 |
def _normalize_question(q: str) -> str:
|
| 27 |
-
#
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
if os.path.exists(DATA_PATH):
|
|
@@ -159,7 +162,7 @@ def generate_answer(question: str) -> str:
|
|
| 159 |
with torch.no_grad():
|
| 160 |
outputs = model.generate(
|
| 161 |
**inputs,
|
| 162 |
-
max_new_tokens=
|
| 163 |
do_sample=False, # greedy decoding → more stable & a bit faster
|
| 164 |
)
|
| 165 |
|
|
@@ -171,11 +174,37 @@ def generate_answer(question: str) -> str:
|
|
| 171 |
|
| 172 |
def answer_from_qa(question: str) -> str | None:
|
| 173 |
"""
|
| 174 |
-
|
| 175 |
-
|
|
|
|
| 176 |
"""
|
| 177 |
norm_q = _normalize_question(question)
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
|
| 181 |
# 3. Gradio chat function
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def _normalize_question(q: str) -> str:
|
| 27 |
+
# lowercase, remove basic punctuation, collapse spaces
|
| 28 |
+
q = q.lower()
|
| 29 |
+
q = re.sub(r"[?!?!\.\,\:\;\"“”'‘’]", " ", q)
|
| 30 |
+
q = re.sub(r"\s+", " ", q)
|
| 31 |
+
return q.strip()
|
| 32 |
|
| 33 |
|
| 34 |
if os.path.exists(DATA_PATH):
|
|
|
|
| 162 |
with torch.no_grad():
|
| 163 |
outputs = model.generate(
|
| 164 |
**inputs,
|
| 165 |
+
max_new_tokens=160, # shorter answers = faster
|
| 166 |
do_sample=False, # greedy decoding → more stable & a bit faster
|
| 167 |
)
|
| 168 |
|
|
|
|
| 174 |
|
| 175 |
def answer_from_qa(question: str) -> str | None:
|
| 176 |
"""
|
| 177 |
+
1) Try exact match in QA_INDEX.
|
| 178 |
+
2) If not found, use simple fuzzy match:
|
| 179 |
+
pick the stored question that shares the most words.
|
| 180 |
"""
|
| 181 |
norm_q = _normalize_question(question)
|
| 182 |
+
|
| 183 |
+
# 1) exact match first
|
| 184 |
+
if norm_q in QA_INDEX:
|
| 185 |
+
return QA_INDEX[norm_q]
|
| 186 |
+
|
| 187 |
+
# 2) fuzzy match
|
| 188 |
+
q_terms = [t for t in norm_q.split(" ") if len(t) > 1]
|
| 189 |
+
if not q_terms:
|
| 190 |
+
return None
|
| 191 |
+
|
| 192 |
+
best_score = 0
|
| 193 |
+
best_answer = None
|
| 194 |
+
|
| 195 |
+
for stored_q, a in QA_INDEX.items():
|
| 196 |
+
stored_terms = [t for t in stored_q.split(" ") if len(t) > 1]
|
| 197 |
+
overlap = sum(1 for t in q_terms if t in stored_terms)
|
| 198 |
+
if overlap > best_score:
|
| 199 |
+
best_score = overlap
|
| 200 |
+
best_answer = a
|
| 201 |
+
|
| 202 |
+
# require at least 1 overlapping word (e.g. ປະຫວັດສາດ or ຄວາມສໍາຄັນ)
|
| 203 |
+
if best_score >= 1:
|
| 204 |
+
return best_answer
|
| 205 |
+
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
|
| 209 |
|
| 210 |
# 3. Gradio chat function
|