Update model_utils.py
Browse files- model_utils.py +39 -17
model_utils.py
CHANGED
|
@@ -101,11 +101,40 @@ def retrieve_context(question: str, max_entries: int = 2) -> str:
|
|
| 101 |
return "\n\n".join(context_blocks)
|
| 102 |
|
| 103 |
|
| 104 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
context = retrieve_context(question)
|
|
|
|
|
|
|
| 106 |
return f"""{SYSTEM_PROMPT}
|
| 107 |
|
| 108 |
-
ຂໍ້ມູນອ້າງອີງ:
|
| 109 |
{context}
|
| 110 |
|
| 111 |
ຄຳຖາມ: {question}
|
|
@@ -113,8 +142,8 @@ def build_prompt(question: str) -> str:
|
|
| 113 |
ຄຳຕອບດ້ວຍພາສາລາວ:"""
|
| 114 |
|
| 115 |
|
| 116 |
-
def generate_answer(question: str) -> str:
|
| 117 |
-
prompt = build_prompt(question)
|
| 118 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 119 |
with torch.no_grad():
|
| 120 |
outputs = model.generate(
|
|
@@ -126,14 +155,11 @@ def generate_answer(question: str) -> str:
|
|
| 126 |
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
|
| 127 |
answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 128 |
|
| 129 |
-
#
|
| 130 |
-
# `re` is already imported at the top of this file
|
| 131 |
sentences = re.split(r"(?<=[\.?!…])\s+", answer)
|
| 132 |
short_answer = " ".join(sentences[:3]).strip()
|
| 133 |
-
|
| 134 |
return short_answer if short_answer else answer
|
| 135 |
-
|
| 136 |
-
|
| 137 |
|
| 138 |
def answer_from_qa(question: str) -> Optional[str]:
|
| 139 |
"""
|
|
@@ -179,16 +205,12 @@ def laos_history_bot(message: str, history: List) -> str:
|
|
| 179 |
|
| 180 |
direct = answer_from_qa(message)
|
| 181 |
if direct:
|
| 182 |
-
|
| 183 |
-
meta = "[ຊັ້ນ M1, ບົດ 1]"
|
| 184 |
-
return f"{meta} {direct}"
|
| 185 |
-
|
| 186 |
|
| 187 |
try:
|
| 188 |
-
|
|
|
|
| 189 |
except Exception as e: # noqa: BLE001
|
| 190 |
return f"ລະບົບມີບັນຫາ: {e}"
|
| 191 |
|
| 192 |
-
|
| 193 |
-
return f"{meta} {answer}"
|
| 194 |
-
|
|
|
|
| 101 |
return "\n\n".join(context_blocks)
|
| 102 |
|
| 103 |
|
| 104 |
+
def _format_history(history: Optional[List]) -> str:
|
| 105 |
+
"""
|
| 106 |
+
Convert last few chat turns into a Lao conversation snippet
|
| 107 |
+
to give the model context for follow-up questions.
|
| 108 |
+
Gradio history format: [[user_msg, bot_msg], [user_msg, bot_msg], ...]
|
| 109 |
+
"""
|
| 110 |
+
if not history:
|
| 111 |
+
return ""
|
| 112 |
+
|
| 113 |
+
# keep only the last 3 turns to avoid very long prompts
|
| 114 |
+
recent = history[-3:]
|
| 115 |
+
|
| 116 |
+
lines = []
|
| 117 |
+
for turn in recent:
|
| 118 |
+
if not isinstance(turn, (list, tuple)) or len(turn) != 2:
|
| 119 |
+
continue
|
| 120 |
+
user_msg, bot_msg = turn
|
| 121 |
+
lines.append(f"ນັກຮຽນ: {user_msg}")
|
| 122 |
+
lines.append(f"ອາຈານ AI: {bot_msg}")
|
| 123 |
+
|
| 124 |
+
if not lines:
|
| 125 |
+
return ""
|
| 126 |
+
|
| 127 |
+
joined = "\n".join(lines)
|
| 128 |
+
return f"ປະຫວັດການສົນທະນາກ່ອນໜ້າ:\n{joined}\n\n"
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def build_prompt(question: str, history: Optional[List] = None) -> str:
|
| 132 |
context = retrieve_context(question)
|
| 133 |
+
history_block = _format_history(history)
|
| 134 |
+
|
| 135 |
return f"""{SYSTEM_PROMPT}
|
| 136 |
|
| 137 |
+
{history_block}ຂໍ້ມູນອ້າງອີງ:
|
| 138 |
{context}
|
| 139 |
|
| 140 |
ຄຳຖາມ: {question}
|
|
|
|
| 142 |
ຄຳຕອບດ້ວຍພາສາລາວ:"""
|
| 143 |
|
| 144 |
|
| 145 |
+
def generate_answer(question: str, history: Optional[List] = None) -> str:
|
| 146 |
+
prompt = build_prompt(question, history)
|
| 147 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 148 |
with torch.no_grad():
|
| 149 |
outputs = model.generate(
|
|
|
|
| 155 |
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
|
| 156 |
answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 157 |
|
| 158 |
+
# (your 2–3 sentence enforcement can stay here)
|
|
|
|
| 159 |
sentences = re.split(r"(?<=[\.?!…])\s+", answer)
|
| 160 |
short_answer = " ".join(sentences[:3]).strip()
|
|
|
|
| 161 |
return short_answer if short_answer else answer
|
| 162 |
+
|
|
|
|
| 163 |
|
| 164 |
def answer_from_qa(question: str) -> Optional[str]:
|
| 165 |
"""
|
|
|
|
| 205 |
|
| 206 |
direct = answer_from_qa(message)
|
| 207 |
if direct:
|
| 208 |
+
return direct
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
try:
|
| 211 |
+
# ✅ pass history to let LLM understand follow-up questions
|
| 212 |
+
answer = generate_answer(message, history)
|
| 213 |
except Exception as e: # noqa: BLE001
|
| 214 |
return f"ລະບົບມີບັນຫາ: {e}"
|
| 215 |
|
| 216 |
+
return answer
|
|
|
|
|
|