Heng2004 commited on
Commit
c053894
·
verified ·
1 Parent(s): 1b5b80c

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +137 -54
model_utils.py CHANGED
@@ -2,56 +2,57 @@
2
  from typing import List, Optional
3
  import re
4
 
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from sentence_transformers import SentenceTransformer
8
  from sentence_transformers.util import cos_sim
9
 
10
  import qa_store
11
- from loader import load_curriculum, load_manual_qa, rebuild_combined_qa
12
 
13
  # -----------------------------
14
  # Base chat model
15
  # -----------------------------
16
  MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
17
- MAX_CONTEXT_ENTRIES = 3 # how many textbook chunks to retrieve per question
18
-
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
- model = AutoModelForCausalLM.from_pretrained(
24
- MODEL_NAME,
25
- torch_dtype=torch.float32,
26
- ).to(device)
 
27
  model.eval()
28
 
29
- # -----------------------------
30
- # Embedding model for retrieval
31
- # -----------------------------
32
- EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
33
  embed_model = SentenceTransformer(EMBED_MODEL_NAME)
34
- # (optional) move embedding model to same device; OK to leave on CPU if you want
35
  embed_model = embed_model.to(device)
36
 
 
 
 
37
 
38
- # NOTE: called once after load_curriculum() to precompute embeddings.
39
- # If you ever reload curriculum at runtime, call _build_entry_embeddings() again.
 
40
  def _build_entry_embeddings() -> None:
41
  """
42
- Build embeddings for each textbook entry using title + summary + text
43
  and store them in qa_store.TEXT_EMBEDDINGS.
 
 
44
  """
45
- if not qa_store.ENTRIES:
46
  qa_store.TEXT_EMBEDDINGS = None
47
  return
48
 
49
- texts = []
50
  for e in qa_store.ENTRIES:
51
- title = e.get("title", "") or ""
52
- summary = e.get("summary", "") or ""
53
  text = e.get("text", "") or ""
54
- combined = f"{title}\n{summary}\n{text}"
55
  texts.append(combined)
56
 
57
  qa_store.TEXT_EMBEDDINGS = embed_model.encode(
@@ -61,30 +62,91 @@ def _build_entry_embeddings() -> None:
61
  )
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # -----------------------------
65
  # Load data once at import time
66
  # -----------------------------
67
  load_curriculum()
68
  load_manual_qa()
 
69
  rebuild_combined_qa()
70
  _build_entry_embeddings()
 
71
 
 
 
 
72
  SYSTEM_PROMPT = (
73
- "ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ "
74
- "ສໍາລັບນັກຮຽນຊັ້ນ ມ.1. "
75
  "ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. "
76
- "ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
77
  "ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
78
  )
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def retrieve_context(question: str, max_entries: int = MAX_CONTEXT_ENTRIES) -> str:
82
  """
83
  Embedding-based retrieval over textbook entries.
84
- Falls back to first entries if embeddings are missing.
85
  """
86
- if not qa_store.ENTRIES:
87
- return qa_store.RAW_KNOWLEDGE
 
88
 
89
  if qa_store.TEXT_EMBEDDINGS is None:
90
  top_entries = qa_store.ENTRIES[:max_entries]
@@ -99,11 +161,8 @@ def retrieve_context(question: str, max_entries: int = MAX_CONTEXT_ENTRIES) -> s
99
  # 2) Cosine similarity with all entry embeddings
100
  sims = cos_sim(q_vec, qa_store.TEXT_EMBEDDINGS)[0] # shape [N]
101
 
102
- # 3) Pick top-k indices
103
- k = min(max_entries, len(qa_store.ENTRIES))
104
- _, top_indices = torch.topk(sims, k=k)
105
-
106
- # 4) Map indices back to entries
107
  top_entries = [qa_store.ENTRIES[i] for i in top_indices.tolist()]
108
 
109
  # Build context string for the prompt
@@ -111,41 +170,54 @@ def retrieve_context(question: str, max_entries: int = MAX_CONTEXT_ENTRIES) -> s
111
  for e in top_entries:
112
  header = (
113
  f"[ຊັ້ນ {e.get('grade','')}, "
114
- f"ບົດ {e.get('chapter','')}, "
115
- f"ຫົວຂໍ້ {e.get('section','')} – {e.get('title','')}]"
 
116
  )
117
  context_blocks.append(f"{header}\n{e.get('text','')}")
118
 
119
  return "\n\n".join(context_blocks)
120
 
121
 
122
- def _format_history(history: Optional[List]) -> str:
 
 
 
123
  """
124
- Convert last few chat turns into a Lao conversation snippet
125
- to give the model context for follow-up questions.
126
- Gradio history format: [[user_msg, bot_msg], [user_msg, bot_msg], ...]
127
  """
128
- if not history:
129
- return ""
130
 
131
- # keep only the last 3 turns to avoid very long prompts
132
- recent = history[-3:]
 
 
 
 
133
 
134
- lines: List[str] = []
135
- for turn in recent:
136
- if not isinstance(turn, (list, tuple)) or len(turn) != 2:
137
- continue
138
- user_msg, bot_msg = turn
139
- lines.append(f"ນັກຮຽນ: {user_msg}")
140
- lines.append(f"ອາຈານ AI: {bot_msg}")
141
 
142
- if not lines:
143
- return ""
 
 
 
 
 
144
 
145
- joined = "\n".join(lines)
146
- return f"ປະຫວັດການສົນທະນາກ່ອນໜ້າ:\n{joined}\n\n"
 
 
147
 
148
 
 
 
 
149
  def build_prompt(question: str, history: Optional[List] = None) -> str:
150
  context = retrieve_context(question, max_entries=MAX_CONTEXT_ENTRIES)
151
  history_block = _format_history(history)
@@ -174,12 +246,15 @@ def generate_answer(question: str, history: Optional[List] = None) -> str:
174
  generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
175
  answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
176
 
177
- # Enforce 2–3 sentence answers for M.1 students
178
  sentences = re.split(r"(?<=[\.?!…])\s+", answer)
179
  short_answer = " ".join(sentences[:3]).strip()
180
  return short_answer if short_answer else answer
181
 
182
 
 
 
 
183
  def answer_from_qa(question: str) -> Optional[str]:
184
  """
185
  1) Exact match in QA_INDEX
@@ -217,6 +292,9 @@ def answer_from_qa(question: str) -> Optional[str]:
217
  return None
218
 
219
 
 
 
 
220
  def laos_history_bot(message: str, history: List) -> str:
221
  """
222
  Main chatbot function for Student tab (Gradio ChatInterface).
@@ -224,6 +302,11 @@ def laos_history_bot(message: str, history: List) -> str:
224
  if not message.strip():
225
  return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."
226
 
 
 
 
 
 
227
  # 1) Try exact / fuzzy Q&A first
228
  direct = answer_from_qa(message)
229
  if direct:
 
2
  from typing import List, Optional
3
  import re
4
 
5
+ import numpy as np
6
  import torch
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
  from sentence_transformers import SentenceTransformer
9
  from sentence_transformers.util import cos_sim
10
 
11
  import qa_store
12
+ from loader import load_curriculum, load_manual_qa, rebuild_combined_qa, load_glossary
13
 
14
  # -----------------------------
15
  # Base chat model
16
  # -----------------------------
17
  MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
18
+ EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
 
 
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
+ # Use float16 on GPU to save memory, float32 on CPU
24
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
25
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype)
26
+ model.to(device)
27
  model.eval()
28
 
 
 
 
 
29
  embed_model = SentenceTransformer(EMBED_MODEL_NAME)
 
30
  embed_model = embed_model.to(device)
31
 
32
+ # Number of textbook entries to include in the RAG context
33
+ MAX_CONTEXT_ENTRIES = 4
34
+
35
 
36
+ # -----------------------------
37
+ # Embedding builders
38
+ # -----------------------------
39
  def _build_entry_embeddings() -> None:
40
  """
41
+ Build embeddings for each textbook entry using chapter + section + text
42
  and store them in qa_store.TEXT_EMBEDDINGS.
43
+
44
+ Call this after loading / reloading curriculum.
45
  """
46
+ if not getattr(qa_store, "ENTRIES", None):
47
  qa_store.TEXT_EMBEDDINGS = None
48
  return
49
 
50
+ texts: List[str] = []
51
  for e in qa_store.ENTRIES:
52
+ chapter = e.get("chapter_title", "") or e.get("chapter", "") or ""
53
+ section = e.get("section_title", "") or e.get("section", "") or ""
54
  text = e.get("text", "") or ""
55
+ combined = f"{chapter}\n{section}\n{text}"
56
  texts.append(combined)
57
 
58
  qa_store.TEXT_EMBEDDINGS = embed_model.encode(
 
62
  )
63
 
64
 
65
+ def _build_glossary_embeddings() -> None:
66
+ """Create embeddings for glossary terms + definitions."""
67
+ if not getattr(qa_store, "GLOSSARY", None):
68
+ qa_store.GLOSSARY_EMBEDDINGS = None
69
+ print("[INFO] No glossary terms to embed.")
70
+ return
71
+
72
+ # Embed term + definition together
73
+ texts = [
74
+ f"{item.get('term', '')} :: {item.get('definition', '')}"
75
+ for item in qa_store.GLOSSARY
76
+ ]
77
+
78
+ embeddings = embed_model.encode(
79
+ texts,
80
+ convert_to_numpy=True,
81
+ normalize_embeddings=True,
82
+ )
83
+ qa_store.GLOSSARY_EMBEDDINGS = embeddings
84
+ print(f"[INFO] Built glossary embeddings for {len(texts)} terms.")
85
+
86
+
87
  # -----------------------------
88
  # Load data once at import time
89
  # -----------------------------
90
  load_curriculum()
91
  load_manual_qa()
92
+ load_glossary()
93
  rebuild_combined_qa()
94
  _build_entry_embeddings()
95
+ _build_glossary_embeddings()
96
 
97
+ # -----------------------------
98
+ # System prompt (Natural Science)
99
+ # -----------------------------
100
  SYSTEM_PROMPT = (
101
+ "ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານວິທະຍາສາດທໍາມະຊາດ "
102
+ "ສໍາລັບນັກຮຽນຊັ້ນ ມ.1-ມ.4. "
103
  "ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. "
104
+ "ໃຫ້ອີງຈາກຂໍ້ມູນອ້າງອີງຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
105
  "ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
106
  )
107
 
108
 
109
+ # -----------------------------
110
+ # Helper: history formatting
111
+ # -----------------------------
112
+ def _format_history(history: Optional[List]) -> str:
113
+ """
114
+ Convert last few chat turns into a Lao conversation snippet
115
+ to give the model context for follow-up questions.
116
+ Gradio history format: [[user_msg, bot_msg], [user_msg, bot_msg], ...]
117
+ """
118
+ if not history:
119
+ return ""
120
+
121
+ # keep only the last 3 turns to avoid very long prompts
122
+ recent = history[-3:]
123
+
124
+ lines: List[str] = []
125
+ for turn in recent:
126
+ if not isinstance(turn, (list, tuple)) or len(turn) != 2:
127
+ continue
128
+ user_msg, bot_msg = turn
129
+ lines.append(f"ນັກຮຽນ: {user_msg}")
130
+ lines.append(f"ອາຈານ AI: {bot_msg}")
131
+
132
+ if not lines:
133
+ return ""
134
+
135
+ joined = "\n".join(lines) + "\n\n"
136
+ return joined
137
+
138
+
139
+ # -----------------------------
140
+ # RAG: retrieve textbook context
141
+ # -----------------------------
142
  def retrieve_context(question: str, max_entries: int = MAX_CONTEXT_ENTRIES) -> str:
143
  """
144
  Embedding-based retrieval over textbook entries.
145
+ Falls back to concatenated raw knowledge if embeddings are missing.
146
  """
147
+ if not getattr(qa_store, "ENTRIES", None):
148
+ # Fallback: raw knowledge (if available) or empty string
149
+ return getattr(qa_store, "RAW_KNOWLEDGE", "")
150
 
151
  if qa_store.TEXT_EMBEDDINGS is None:
152
  top_entries = qa_store.ENTRIES[:max_entries]
 
161
  # 2) Cosine similarity with all entry embeddings
162
  sims = cos_sim(q_vec, qa_store.TEXT_EMBEDDINGS)[0] # shape [N]
163
 
164
+ # 3) Take top-k
165
+ top_indices = torch.topk(sims, k=min(max_entries, sims.shape[0])).indices
 
 
 
166
  top_entries = [qa_store.ENTRIES[i] for i in top_indices.tolist()]
167
 
168
  # Build context string for the prompt
 
170
  for e in top_entries:
171
  header = (
172
  f"[ຊັ້ນ {e.get('grade','')}, "
173
+ f"ໜ່ວຍ {e.get('unit','')}, "
174
+ f"ບົດ {e.get('chapter_title','')}, "
175
+ f"ຫົວຂໍ້ {e.get('section_title','')}]"
176
  )
177
  context_blocks.append(f"{header}\n{e.get('text','')}")
178
 
179
  return "\n\n".join(context_blocks)
180
 
181
 
182
+ # -----------------------------
183
+ # Glossary-based answering
184
+ # -----------------------------
185
+ def answer_from_glossary(message: str) -> Optional[str]:
186
  """
187
+ Try to answer using the glossary index.
188
+ Returns Lao answer string or None if not confident.
 
189
  """
190
+ if not getattr(qa_store, "GLOSSARY", None) or qa_store.GLOSSARY_EMBEDDINGS is None:
191
+ return None
192
 
193
+ # Encode question
194
+ q_emb = embed_model.encode(
195
+ [message],
196
+ convert_to_numpy=True,
197
+ normalize_embeddings=True,
198
+ )[0]
199
 
200
+ sims = np.dot(qa_store.GLOSSARY_EMBEDDINGS, q_emb)
201
+ best_idx = int(np.argmax(sims))
202
+ best_sim = float(sims[best_idx])
 
 
 
 
203
 
204
+ # tune this threshold later if needed
205
+ if best_sim < 0.55:
206
+ return None
207
+
208
+ item = qa_store.GLOSSARY[best_idx]
209
+ definition = item.get("definition", "").strip()
210
+ example = item.get("example", "").strip()
211
 
212
+ if example:
213
+ return f"{definition} ຕົວຢ່າງ: {example}"
214
+ else:
215
+ return definition
216
 
217
 
218
+ # -----------------------------
219
+ # Prompt + LLM generation
220
+ # -----------------------------
221
  def build_prompt(question: str, history: Optional[List] = None) -> str:
222
  context = retrieve_context(question, max_entries=MAX_CONTEXT_ENTRIES)
223
  history_block = _format_history(history)
 
246
  generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
247
  answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
248
 
249
+ # Enforce 2–3 sentence answers for students
250
  sentences = re.split(r"(?<=[\.?!…])\s+", answer)
251
  short_answer = " ".join(sentences[:3]).strip()
252
  return short_answer if short_answer else answer
253
 
254
 
255
+ # -----------------------------
256
+ # QA lookup (exact + fuzzy)
257
+ # -----------------------------
258
  def answer_from_qa(question: str) -> Optional[str]:
259
  """
260
  1) Exact match in QA_INDEX
 
292
  return None
293
 
294
 
295
+ # -----------------------------
296
+ # Main chatbot entry
297
+ # -----------------------------
298
  def laos_history_bot(message: str, history: List) -> str:
299
  """
300
  Main chatbot function for Student tab (Gradio ChatInterface).
 
302
  if not message.strip():
303
  return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."
304
 
305
+ # 0) Try glossary first for key concepts
306
+ gloss = answer_from_glossary(message)
307
+ if gloss:
308
+ return gloss
309
+
310
  # 1) Try exact / fuzzy Q&A first
311
  direct = answer_from_qa(message)
312
  if direct: