Heng2004 commited on
Commit
bcbd5e8
·
verified ·
1 Parent(s): 076631b

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +75 -55
model_utils.py CHANGED
@@ -4,14 +4,17 @@ import re
4
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
7
 
8
  import qa_store
9
  from loader import load_curriculum, load_manual_qa, rebuild_combined_qa
10
 
11
  # -----------------------------
12
- # Model
13
  # -----------------------------
14
  MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
 
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
 
@@ -21,74 +24,87 @@ model = AutoModelForCausalLM.from_pretrained(
21
  MODEL_NAME,
22
  torch_dtype=torch.float32,
23
  ).to(device)
24
-
25
  model.eval()
26
 
 
 
 
 
 
 
 
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # Load data once at import time
 
29
  load_curriculum()
30
  load_manual_qa()
31
  rebuild_combined_qa()
 
32
 
33
  SYSTEM_PROMPT = (
34
  "ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ "
35
  "ສໍາລັບນັກຮຽນຊັ້ນ ມ.1. "
36
  "ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. "
37
  "ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
38
- "ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
39
  )
40
 
41
 
42
- def retrieve_context(question: str, max_entries: int = 2) -> str:
43
  """
44
- Simple keyword retrieval over textbook entries.
 
45
  """
46
  if not qa_store.ENTRIES:
47
  return qa_store.RAW_KNOWLEDGE
48
 
49
- q = question.lower().strip()
50
- terms = [t for t in re.split(r"\s+", q) if len(t) > 1]
51
-
52
- if not terms:
53
- chosen = qa_store.ENTRIES[:max_entries]
54
- return "\n\n".join(
55
- f"[ຊັ້ນ {e.get('grade','')}, ບົດ {e.get('chapter','')}, "
56
- f"ຫົວຂໍ້ {e.get('section','')} – {e.get('title','')}]\n{e['text']}"
57
- for e in chosen
58
  )
59
 
60
- scored = []
61
-
62
- for e in qa_store.ENTRIES:
63
- text = e.get("text", "")
64
- title = e.get("title", "")
65
- kws = e.get("keywords", [])
66
- topic = e.get("topic", "")
67
-
68
- base = (text + " " + title).lower()
69
- score = 0
70
-
71
- for t in terms:
72
- score += base.count(t)
73
 
74
- for kw in kws:
75
- kw_lower = kw.lower()
76
- for t in terms:
77
- if t in kw_lower:
78
- score += 2
79
 
80
- if topic and any(t in topic for t in terms):
81
- score += 1
82
-
83
- if score > 0:
84
- scored.append((score, e))
85
-
86
- scored.sort(key=lambda x: x[0], reverse=True)
87
- top_entries = [e for _, e in scored[:max_entries]]
88
-
89
- if not top_entries:
90
- top_entries = qa_store.ENTRIES[:max_entries]
91
 
 
92
  context_blocks = []
93
  for e in top_entries:
94
  header = (
@@ -113,7 +129,7 @@ def _format_history(history: Optional[List]) -> str:
113
  # keep only the last 3 turns to avoid very long prompts
114
  recent = history[-3:]
115
 
116
- lines = []
117
  for turn in recent:
118
  if not isinstance(turn, (list, tuple)) or len(turn) != 2:
119
  continue
@@ -129,7 +145,7 @@ def _format_history(history: Optional[List]) -> str:
129
 
130
 
131
  def build_prompt(question: str, history: Optional[List] = None) -> str:
132
- context = retrieve_context(question)
133
  history_block = _format_history(history)
134
 
135
  return f"""{SYSTEM_PROMPT}
@@ -144,7 +160,8 @@ def build_prompt(question: str, history: Optional[List] = None) -> str:
144
 
145
  def generate_answer(question: str, history: Optional[List] = None) -> str:
146
  prompt = build_prompt(question, history)
147
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
148
  with torch.no_grad():
149
  outputs = model.generate(
150
  **inputs,
@@ -155,24 +172,26 @@ def generate_answer(question: str, history: Optional[List] = None) -> str:
155
  generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
156
  answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
157
 
158
- # (your 2–3 sentence enforcement can stay here)
159
  sentences = re.split(r"(?<=[\.?!…])\s+", answer)
160
  short_answer = " ".join(sentences[:3]).strip()
161
  return short_answer if short_answer else answer
162
-
163
 
164
  def answer_from_qa(question: str) -> Optional[str]:
165
  """
166
- 1) exact match in QA_INDEX
167
- 2) fuzzy match via word overlap with ALL_QA_KNOWLEDGE
168
  """
169
  norm_q = qa_store.normalize_question(question)
170
  if not norm_q:
171
  return None
172
 
 
173
  if norm_q in qa_store.QA_INDEX:
174
  return qa_store.QA_INDEX[norm_q]
175
 
 
176
  q_terms = [t for t in norm_q.split(" ") if len(t) > 1]
177
  if not q_terms:
178
  return None
@@ -181,14 +200,14 @@ def answer_from_qa(question: str) -> Optional[str]:
181
  best_answer: Optional[str] = None
182
 
183
  for item in qa_store.ALL_QA_KNOWLEDGE:
184
- stored_terms = [t for t in item["norm_q"].split(" ") if len(t) > 1]
185
- overlap = sum(1 for t in q_terms if t in stored_terms)
186
  if overlap > best_score:
187
  best_score = overlap
188
  best_answer = item["a"]
189
 
190
  # require at least 2 overlapping words to accept fuzzy match
191
- if best_score >= 2:
192
  # optional: log when fuzzy match is used
193
  print(f"[FUZZY MATCH] score={best_score} -> {best_answer[:50]!r}")
194
  return best_answer
@@ -198,17 +217,18 @@ def answer_from_qa(question: str) -> Optional[str]:
198
 
199
  def laos_history_bot(message: str, history: List) -> str:
200
  """
201
- Main chatbot function for Student tab.
202
  """
203
  if not message.strip():
204
  return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."
205
 
 
206
  direct = answer_from_qa(message)
207
  if direct:
208
  return direct
209
 
 
210
  try:
211
- # ✅ pass history to let LLM understand follow-up questions
212
  answer = generate_answer(message, history)
213
  except Exception as e: # noqa: BLE001
214
  return f"ລະບົບມີບັນຫາ: {e}"
 
4
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from sentence_transformers import SentenceTransformer
8
+ from sentence_transformers.util import cos_sim
9
 
10
  import qa_store
11
  from loader import load_curriculum, load_manual_qa, rebuild_combined_qa
12
 
13
  # -----------------------------
14
+ # Base chat model
15
  # -----------------------------
16
  MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
17
+ MAX_CONTEXT_ENTRIES = 3 # how many textbook chunks to retrieve per question
18
 
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
 
 
24
  MODEL_NAME,
25
  torch_dtype=torch.float32,
26
  ).to(device)
 
27
  model.eval()
28
 
29
+ # -----------------------------
30
+ # Embedding model for retrieval
31
+ # -----------------------------
32
+ EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
33
+ embed_model = SentenceTransformer(EMBED_MODEL_NAME)
34
+ # move embedding model to same device (optional but faster on GPU)
35
+ embed_model = embed_model.to(device)
36
+
37
 
38
+ def _build_entry_embeddings() -> None:
39
+ """
40
+ Build embeddings for each textbook entry using title + summary + text
41
+ and store them in qa_store.TEXT_EMBEDDINGS.
42
+ """
43
+ if not qa_store.ENTRIES:
44
+ qa_store.TEXT_EMBEDDINGS = None
45
+ return
46
+
47
+ texts = []
48
+ for e in qa_store.ENTRIES:
49
+ title = e.get("title", "") or ""
50
+ summary = e.get("summary", "") or ""
51
+ text = e.get("text", "") or ""
52
+ combined = f"{title}\n{summary}\n{text}"
53
+ texts.append(combined)
54
+
55
+ qa_store.TEXT_EMBEDDINGS = embed_model.encode(
56
+ texts,
57
+ convert_to_tensor=True,
58
+ show_progress_bar=False,
59
+ )
60
+
61
+
62
+ # -----------------------------
63
  # Load data once at import time
64
+ # -----------------------------
65
  load_curriculum()
66
  load_manual_qa()
67
  rebuild_combined_qa()
68
+ _build_entry_embeddings()
69
 
70
  SYSTEM_PROMPT = (
71
  "ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ "
72
  "ສໍາລັບນັກຮຽນຊັ້ນ ມ.1. "
73
  "ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. "
74
  "ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
75
+ "ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບຍັງບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
76
  )
77
 
78
 
79
+ def retrieve_context(question: str, max_entries: int = MAX_CONTEXT_ENTRIES) -> str:
80
  """
81
+ Embedding-based retrieval over textbook entries.
82
+ Falls back to first entries if embeddings are missing.
83
  """
84
  if not qa_store.ENTRIES:
85
  return qa_store.RAW_KNOWLEDGE
86
 
87
+ if qa_store.TEXT_EMBEDDINGS is None:
88
+ top_entries = qa_store.ENTRIES[:max_entries]
89
+ else:
90
+ # 1) Encode the question
91
+ q_vec = embed_model.encode(
92
+ question,
93
+ convert_to_tensor=True,
94
+ show_progress_bar=False,
 
95
  )
96
 
97
+ # 2) Cosine similarity with all entry embeddings
98
+ sims = cos_sim(q_vec, qa_store.TEXT_EMBEDDINGS)[0] # shape [N]
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # 3) Pick top-k indices
101
+ k = min(max_entries, len(qa_store.ENTRIES))
102
+ _, top_indices = torch.topk(sims, k=k)
 
 
103
 
104
+ # 4) Map indices back to entries
105
+ top_entries = [qa_store.ENTRIES[i] for i in top_indices.tolist()]
 
 
 
 
 
 
 
 
 
106
 
107
+ # Build context string for the prompt
108
  context_blocks = []
109
  for e in top_entries:
110
  header = (
 
129
  # keep only the last 3 turns to avoid very long prompts
130
  recent = history[-3:]
131
 
132
+ lines: List[str] = []
133
  for turn in recent:
134
  if not isinstance(turn, (list, tuple)) or len(turn) != 2:
135
  continue
 
145
 
146
 
147
  def build_prompt(question: str, history: Optional[List] = None) -> str:
148
+ context = retrieve_context(question, max_entries=MAX_CONTEXT_ENTRIES)
149
  history_block = _format_history(history)
150
 
151
  return f"""{SYSTEM_PROMPT}
 
160
 
161
  def generate_answer(question: str, history: Optional[List] = None) -> str:
162
  prompt = build_prompt(question, history)
163
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
164
+
165
  with torch.no_grad():
166
  outputs = model.generate(
167
  **inputs,
 
172
  generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
173
  answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
174
 
175
+ # Enforce 2–3 sentence answers for M.1 students
176
  sentences = re.split(r"(?<=[\.?!…])\s+", answer)
177
  short_answer = " ".join(sentences[:3]).strip()
178
  return short_answer if short_answer else answer
179
+
180
 
181
  def answer_from_qa(question: str) -> Optional[str]:
182
  """
183
+ 1) Exact match in QA_INDEX
184
+ 2) Fuzzy match via word overlap with ALL_QA_KNOWLEDGE
185
  """
186
  norm_q = qa_store.normalize_question(question)
187
  if not norm_q:
188
  return None
189
 
190
+ # Exact match
191
  if norm_q in qa_store.QA_INDEX:
192
  return qa_store.QA_INDEX[norm_q]
193
 
194
+ # Fuzzy match
195
  q_terms = [t for t in norm_q.split(" ") if len(t) > 1]
196
  if not q_terms:
197
  return None
 
200
  best_answer: Optional[str] = None
201
 
202
  for item in qa_store.ALL_QA_KNOWLEDGE:
203
+ stored_terms = [t for t in item["norm_q"].split(" ") if len(t) > 1]:
204
+ overlap = sum(1 for t in q_terms if t in stored_terms)
205
  if overlap > best_score:
206
  best_score = overlap
207
  best_answer = item["a"]
208
 
209
  # require at least 2 overlapping words to accept fuzzy match
210
+ if best_score >= 2 and best_answer is not None:
211
  # optional: log when fuzzy match is used
212
  print(f"[FUZZY MATCH] score={best_score} -> {best_answer[:50]!r}")
213
  return best_answer
 
217
 
218
  def laos_history_bot(message: str, history: List) -> str:
219
  """
220
+ Main chatbot function for Student tab (Gradio ChatInterface).
221
  """
222
  if not message.strip():
223
  return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."
224
 
225
+ # 1) Try exact / fuzzy Q&A first
226
  direct = answer_from_qa(message)
227
  if direct:
228
  return direct
229
 
230
+ # 2) Fall back to LLM + retrieved context
231
  try:
 
232
  answer = generate_answer(message, history)
233
  except Exception as e: # noqa: BLE001
234
  return f"ລະບົບມີບັນຫາ: {e}"