Update app.py
Browse files
app.py
CHANGED
|
@@ -3,6 +3,8 @@
|
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
import json
|
|
|
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import torch
|
| 8 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
@@ -20,17 +22,31 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 20 |
DATA_PATH = "data/1_Year_U_1.jsonl"
|
| 21 |
|
| 22 |
ENTRIES = [] # each entry is one JSON object (your schema)
|
| 23 |
-
QA_INDEX = {} # fast lookup: normalized question -> answer
|
| 24 |
|
| 25 |
|
| 26 |
def _normalize_question(q: str) -> str:
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
q = q.lower()
|
| 29 |
q = re.sub(r"[?!?!\.\,\:\;\"“”'‘’]", " ", q)
|
| 30 |
q = re.sub(r"\s+", " ", q)
|
| 31 |
return q.strip()
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
if os.path.exists(DATA_PATH):
|
| 35 |
with open(DATA_PATH, "r", encoding="utf-8") as f:
|
| 36 |
for line in f:
|
|
@@ -87,7 +103,7 @@ def retrieve_context(question: str, max_entries: int = 2) -> str:
|
|
| 87 |
for e in chosen
|
| 88 |
)
|
| 89 |
|
| 90 |
-
scored = []
|
| 91 |
|
| 92 |
for e in ENTRIES:
|
| 93 |
text = e.get("text", "")
|
|
@@ -157,13 +173,17 @@ def build_prompt(question: str) -> str:
|
|
| 157 |
|
| 158 |
|
| 159 |
def generate_answer(question: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
prompt = build_prompt(question)
|
| 161 |
inputs = tokenizer(prompt, return_tensors="pt")
|
| 162 |
with torch.no_grad():
|
| 163 |
outputs = model.generate(
|
| 164 |
**inputs,
|
| 165 |
-
max_new_tokens=160, # shorter answers = faster
|
| 166 |
-
do_sample=False, # greedy decoding → more stable &
|
| 167 |
)
|
| 168 |
|
| 169 |
# slice off the prompt part
|
|
@@ -174,37 +194,56 @@ def generate_answer(question: str) -> str:
|
|
| 174 |
|
| 175 |
def answer_from_qa(question: str) -> str | None:
|
| 176 |
"""
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
| 180 |
"""
|
| 181 |
norm_q = _normalize_question(question)
|
|
|
|
|
|
|
| 182 |
|
| 183 |
# 1) exact match first
|
| 184 |
if norm_q in QA_INDEX:
|
| 185 |
return QA_INDEX[norm_q]
|
| 186 |
|
| 187 |
-
# 2) fuzzy match
|
| 188 |
-
|
| 189 |
-
if not q_terms:
|
| 190 |
-
return None
|
| 191 |
-
|
| 192 |
-
best_score = 0
|
| 193 |
best_answer = None
|
| 194 |
|
| 195 |
for stored_q, a in QA_INDEX.items():
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
best_score = overlap
|
| 200 |
best_answer = a
|
| 201 |
|
| 202 |
-
#
|
| 203 |
-
|
|
|
|
| 204 |
return best_answer
|
| 205 |
|
| 206 |
-
|
|
|
|
|
|
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
|
| 210 |
# 3. Gradio chat function
|
|
@@ -212,7 +251,7 @@ def laos_history_bot(message: str, history: list):
|
|
| 212 |
if not message.strip():
|
| 213 |
return "ກະລຸນາພິມຄຳຖາມກ່ອນ."
|
| 214 |
|
| 215 |
-
# 1) Try to answer directly from QA pairs (instant)
|
| 216 |
direct = answer_from_qa(message)
|
| 217 |
if direct:
|
| 218 |
return direct
|
|
|
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
import json
|
| 6 |
+
from difflib import SequenceMatcher # 👈 for better fuzzy matching
|
| 7 |
+
|
| 8 |
import gradio as gr
|
| 9 |
import torch
|
| 10 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
| 22 |
DATA_PATH = "data/1_Year_U_1.jsonl"
|
| 23 |
|
| 24 |
ENTRIES = [] # each entry is one JSON object (your schema)
|
| 25 |
+
QA_INDEX: dict[str, str] = {} # fast lookup: normalized question -> answer
|
| 26 |
|
| 27 |
|
| 28 |
def _normalize_question(q: str) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Normalize Lao questions for matching:
|
| 31 |
+
- lowercase
|
| 32 |
+
- remove common punctuation
|
| 33 |
+
- collapse spaces
|
| 34 |
+
"""
|
| 35 |
q = q.lower()
|
| 36 |
q = re.sub(r"[?!?!\.\,\:\;\"“”'‘’]", " ", q)
|
| 37 |
q = re.sub(r"\s+", " ", q)
|
| 38 |
return q.strip()
|
| 39 |
|
| 40 |
|
| 41 |
+
def _similarity(a: str, b: str) -> float:
|
| 42 |
+
"""
|
| 43 |
+
Character-level similarity between two normalized strings.
|
| 44 |
+
Works OK for Lao because we’re still matching on shared sequences.
|
| 45 |
+
"""
|
| 46 |
+
return SequenceMatcher(None, a, b).ratio()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Load dataset and build QA index
|
| 50 |
if os.path.exists(DATA_PATH):
|
| 51 |
with open(DATA_PATH, "r", encoding="utf-8") as f:
|
| 52 |
for line in f:
|
|
|
|
| 103 |
for e in chosen
|
| 104 |
)
|
| 105 |
|
| 106 |
+
scored: list[tuple[int, dict]] = []
|
| 107 |
|
| 108 |
for e in ENTRIES:
|
| 109 |
text = e.get("text", "")
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
def generate_answer(question: str) -> str:
|
| 176 |
+
"""
|
| 177 |
+
Use SeaLLM + retrieved context to generate an answer.
|
| 178 |
+
Kept fairly short for speed and to avoid rambling.
|
| 179 |
+
"""
|
| 180 |
prompt = build_prompt(question)
|
| 181 |
inputs = tokenizer(prompt, return_tensors="pt")
|
| 182 |
with torch.no_grad():
|
| 183 |
outputs = model.generate(
|
| 184 |
**inputs,
|
| 185 |
+
max_new_tokens=160, # shorter answers = faster, less chance to cut mid-sentence
|
| 186 |
+
do_sample=False, # greedy decoding → more stable & deterministic
|
| 187 |
)
|
| 188 |
|
| 189 |
# slice off the prompt part
|
|
|
|
| 194 |
|
| 195 |
def answer_from_qa(question: str) -> str | None:
|
| 196 |
"""
|
| 197 |
+
Try to answer directly from:
|
| 198 |
+
1) Exact QA pairs.
|
| 199 |
+
2) Fuzzy QA question similarity.
|
| 200 |
+
3) Fuzzy match to entry summaries/titles (good for 'ຄວາມສໍາຄັນ...' type questions).
|
| 201 |
+
|
| 202 |
+
If nothing is good enough, return None so the model will answer.
|
| 203 |
"""
|
| 204 |
norm_q = _normalize_question(question)
|
| 205 |
+
if not norm_q:
|
| 206 |
+
return None
|
| 207 |
|
| 208 |
# 1) exact match first
|
| 209 |
if norm_q in QA_INDEX:
|
| 210 |
return QA_INDEX[norm_q]
|
| 211 |
|
| 212 |
+
# 2) fuzzy match over QA questions
|
| 213 |
+
best_ratio = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
best_answer = None
|
| 215 |
|
| 216 |
for stored_q, a in QA_INDEX.items():
|
| 217 |
+
r = _similarity(norm_q, stored_q)
|
| 218 |
+
if r > best_ratio:
|
| 219 |
+
best_ratio = r
|
|
|
|
| 220 |
best_answer = a
|
| 221 |
|
| 222 |
+
# threshold tuned so that very close questions (wording a bit different)
|
| 223 |
+
# still return the textbook QA answer
|
| 224 |
+
if best_ratio >= 0.55 and best_answer:
|
| 225 |
return best_answer
|
| 226 |
|
| 227 |
+
# 3) fallback: fuzzy match over entry summaries / titles / keywords
|
| 228 |
+
best_ratio = 0.0
|
| 229 |
+
best_summary = None
|
| 230 |
|
| 231 |
+
for e in ENTRIES:
|
| 232 |
+
combined = f"{e.get('title','')} {e.get('summary','')} {' '.join(e.get('keywords', []))}"
|
| 233 |
+
combined_norm = _normalize_question(combined)
|
| 234 |
+
if not combined_norm:
|
| 235 |
+
continue
|
| 236 |
+
|
| 237 |
+
r = _similarity(norm_q, combined_norm)
|
| 238 |
+
if r > best_ratio:
|
| 239 |
+
best_ratio = r
|
| 240 |
+
best_summary = e.get("summary") or e.get("text")
|
| 241 |
+
|
| 242 |
+
# lower threshold here because we’re matching against shorter summaries
|
| 243 |
+
if best_ratio >= 0.35 and best_summary:
|
| 244 |
+
return best_summary
|
| 245 |
+
|
| 246 |
+
return None
|
| 247 |
|
| 248 |
|
| 249 |
# 3. Gradio chat function
|
|
|
|
| 251 |
if not message.strip():
|
| 252 |
return "ກະລຸນາພິມຄຳຖາມກ່ອນ."
|
| 253 |
|
| 254 |
+
# 1) Try to answer directly from QA pairs or summaries (instant)
|
| 255 |
direct = answer_from_qa(message)
|
| 256 |
if direct:
|
| 257 |
return direct
|