Heng2004 commited on
Commit
42ccf80
·
verified ·
1 Parent(s): 921357f

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +39 -17
model_utils.py CHANGED
@@ -101,11 +101,40 @@ def retrieve_context(question: str, max_entries: int = 2) -> str:
101
  return "\n\n".join(context_blocks)
102
 
103
 
104
- def build_prompt(question: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  context = retrieve_context(question)
 
 
106
  return f"""{SYSTEM_PROMPT}
107
 
108
- ຂໍ້ມູນອ້າງອີງ:
109
  {context}
110
 
111
  ຄຳຖາມ: {question}
@@ -113,8 +142,8 @@ def build_prompt(question: str) -> str:
113
  ຄຳຕອບດ້ວຍພາສາລາວ:"""
114
 
115
 
116
- def generate_answer(question: str) -> str:
117
- prompt = build_prompt(question)
118
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
119
  with torch.no_grad():
120
  outputs = model.generate(
@@ -126,14 +155,11 @@ def generate_answer(question: str) -> str:
126
  generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
127
  answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
128
 
129
- # Enforce 2–3 short sentences
130
- # `re` is already imported at the top of this file
131
  sentences = re.split(r"(?<=[\.?!…])\s+", answer)
132
  short_answer = " ".join(sentences[:3]).strip()
133
-
134
  return short_answer if short_answer else answer
135
-
136
-
137
 
138
  def answer_from_qa(question: str) -> Optional[str]:
139
  """
@@ -179,16 +205,12 @@ def laos_history_bot(message: str, history: List) -> str:
179
 
180
  direct = answer_from_qa(message)
181
  if direct:
182
- # later you can make this dynamic from the dataset
183
- meta = "[ຊັ້ນ M1, ບົດ 1]"
184
- return f"{meta} {direct}"
185
-
186
 
187
  try:
188
- answer = generate_answer(message)
 
189
  except Exception as e: # noqa: BLE001
190
  return f"ລະບົບມີບັນຫາ: {e}"
191
 
192
- meta = "[ຊັ້ນ M1, ບົດ 1]" # placeholder, later make dynamic
193
- return f"{meta} {answer}"
194
-
 
101
  return "\n\n".join(context_blocks)
102
 
103
 
104
+ def _format_history(history: Optional[List]) -> str:
105
+ """
106
+ Convert last few chat turns into a Lao conversation snippet
107
+ to give the model context for follow-up questions.
108
+ Gradio history format: [[user_msg, bot_msg], [user_msg, bot_msg], ...]
109
+ """
110
+ if not history:
111
+ return ""
112
+
113
+ # keep only the last 3 turns to avoid very long prompts
114
+ recent = history[-3:]
115
+
116
+ lines = []
117
+ for turn in recent:
118
+ if not isinstance(turn, (list, tuple)) or len(turn) != 2:
119
+ continue
120
+ user_msg, bot_msg = turn
121
+ lines.append(f"ນັກຮຽນ: {user_msg}")
122
+ lines.append(f"ອາຈານ AI: {bot_msg}")
123
+
124
+ if not lines:
125
+ return ""
126
+
127
+ joined = "\n".join(lines)
128
+ return f"ປະຫວັດການສົນທະນາກ່ອນໜ້າ:\n{joined}\n\n"
129
+
130
+
131
+ def build_prompt(question: str, history: Optional[List] = None) -> str:
132
  context = retrieve_context(question)
133
+ history_block = _format_history(history)
134
+
135
  return f"""{SYSTEM_PROMPT}
136
 
137
+ {history_block}ຂໍ້ມູນອ້າງອີງ:
138
  {context}
139
 
140
  ຄຳຖາມ: {question}
 
142
  ຄຳຕອບດ້ວຍພາສາລາວ:"""
143
 
144
 
145
+ def generate_answer(question: str, history: Optional[List] = None) -> str:
146
+ prompt = build_prompt(question, history)
147
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
148
  with torch.no_grad():
149
  outputs = model.generate(
 
155
  generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
156
  answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
157
 
158
+ # (your 2–3 sentence enforcement can stay here)
 
159
  sentences = re.split(r"(?<=[\.?!…])\s+", answer)
160
  short_answer = " ".join(sentences[:3]).strip()
 
161
  return short_answer if short_answer else answer
162
+
 
163
 
164
  def answer_from_qa(question: str) -> Optional[str]:
165
  """
 
205
 
206
  direct = answer_from_qa(message)
207
  if direct:
208
+ return direct
 
 
 
209
 
210
  try:
211
+ # pass history to let LLM understand follow-up questions
212
+ answer = generate_answer(message, history)
213
  except Exception as e: # noqa: BLE001
214
  return f"ລະບົບມີບັນຫາ: {e}"
215
 
216
+ return answer