Heng2004's picture
Update model_utils.py
cc59a28 verified
raw
history blame
10.1 kB
# model_utils.py
from typing import List, Optional
import re
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import qa_store
from loader import load_curriculum, load_manual_qa, rebuild_combined_qa, load_glossary
# -----------------------------
# Base chat model
# -----------------------------
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Use float16 on GPU to save memory, float32 on CPU
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype)
model.to(device)
model.eval()
embed_model = SentenceTransformer(EMBED_MODEL_NAME)
embed_model = embed_model.to(device)
# Number of textbook entries to include in the RAG context
MAX_CONTEXT_ENTRIES = 4
# -----------------------------
# Embedding builders
# -----------------------------
def _build_entry_embeddings() -> None:
"""
Build embeddings for each textbook entry using chapter + section + text
and store them in qa_store.TEXT_EMBEDDINGS.
Call this after loading / reloading curriculum.
"""
if not getattr(qa_store, "ENTRIES", None):
qa_store.TEXT_EMBEDDINGS = None
return
texts: List[str] = []
for e in qa_store.ENTRIES:
chapter = e.get("chapter_title", "") or e.get("chapter", "") or ""
section = e.get("section_title", "") or e.get("section", "") or ""
text = e.get("text", "") or ""
combined = f"{chapter}\n{section}\n{text}"
texts.append(combined)
qa_store.TEXT_EMBEDDINGS = embed_model.encode(
texts,
convert_to_tensor=True,
show_progress_bar=False,
)
def _build_glossary_embeddings() -> None:
"""Create embeddings for glossary terms + definitions."""
if not getattr(qa_store, "GLOSSARY", None):
qa_store.GLOSSARY_EMBEDDINGS = None
print("[INFO] No glossary terms to embed.")
return
# Embed term + definition together
texts = [
f"{item.get('term', '')} :: {item.get('definition', '')}"
for item in qa_store.GLOSSARY
]
embeddings = embed_model.encode(
texts,
convert_to_numpy=True,
normalize_embeddings=True,
)
qa_store.GLOSSARY_EMBEDDINGS = embeddings
print(f"[INFO] Built glossary embeddings for {len(texts)} terms.")
# -----------------------------
# Load data once at import time
# -----------------------------
load_curriculum()
load_manual_qa()
load_glossary()
rebuild_combined_qa()
_build_entry_embeddings()
_build_glossary_embeddings()
# -----------------------------
# System prompt (Natural Science)
# -----------------------------
SYSTEM_PROMPT = (
"ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານວິທະຍາສາດທໍາມະຊາດ "
"ສໍາລັບນັກຮຽນຊັ້ນ ມ.1-ມ.4. "
"ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. "
"ໃຫ້ອີງຈາກຂໍ້ມູນອ້າງອີງຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
"ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
)
# -----------------------------
# Helper: history formatting
# -----------------------------
def _format_history(history: Optional[List]) -> str:
"""
Convert last few chat turns into a Lao conversation snippet
to give the model context for follow-up questions.
Gradio history format: [[user_msg, bot_msg], [user_msg, bot_msg], ...]
"""
if not history:
return ""
# keep only the last 3 turns to avoid very long prompts
recent = history[-3:]
lines: List[str] = []
for turn in recent:
if not isinstance(turn, (list, tuple)) or len(turn) != 2:
continue
user_msg, bot_msg = turn
lines.append(f"ນັກຮຽນ: {user_msg}")
lines.append(f"ອາຈານ AI: {bot_msg}")
if not lines:
return ""
joined = "\n".join(lines) + "\n\n"
return joined
# -----------------------------
# RAG: retrieve textbook context
# -----------------------------
def retrieve_context(question: str, max_entries: int = MAX_CONTEXT_ENTRIES) -> str:
"""
Embedding-based retrieval over textbook entries.
Falls back to concatenated raw knowledge if embeddings are missing.
"""
if not getattr(qa_store, "ENTRIES", None):
# Fallback: raw knowledge (if available) or empty string
return getattr(qa_store, "RAW_KNOWLEDGE", "")
if qa_store.TEXT_EMBEDDINGS is None:
top_entries = qa_store.ENTRIES[:max_entries]
else:
# 1) Encode the question
q_vec = embed_model.encode(
question,
convert_to_tensor=True,
show_progress_bar=False,
)
# 2) Cosine similarity with all entry embeddings
sims = cos_sim(q_vec, qa_store.TEXT_EMBEDDINGS)[0] # shape [N]
# 3) Take top-k
top_indices = torch.topk(sims, k=min(max_entries, sims.shape[0])).indices
top_entries = [qa_store.ENTRIES[i] for i in top_indices.tolist()]
# Build context string for the prompt
context_blocks: List[str] = []
for e in top_entries:
header = (
f"[ຊັ້ນ {e.get('grade','')}, "
f"ໜ່ວຍ {e.get('unit','')}, "
f"ບົດ {e.get('chapter_title','')}, "
f"ຫົວຂໍ້ {e.get('section_title','')}]"
)
context_blocks.append(f"{header}\n{e.get('text','')}")
return "\n\n".join(context_blocks)
# -----------------------------
# Glossary-based answering
# -----------------------------
def answer_from_glossary(message: str) -> Optional[str]:
"""
Try to answer using the glossary index.
Returns Lao answer string or None if not confident.
"""
if not getattr(qa_store, "GLOSSARY", None) or qa_store.GLOSSARY_EMBEDDINGS is None:
return None
# Encode question
q_emb = embed_model.encode(
[message],
convert_to_numpy=True,
normalize_embeddings=True,
)[0]
sims = np.dot(qa_store.GLOSSARY_EMBEDDINGS, q_emb)
best_idx = int(np.argmax(sims))
best_sim = float(sims[best_idx])
# tune this threshold later if needed
if best_sim < 0.55:
return None
item = qa_store.GLOSSARY[best_idx]
definition = item.get("definition", "").strip()
example = item.get("example", "").strip()
if example:
return f"{definition} ຕົວຢ່າງ: {example}"
else:
return definition
# -----------------------------
# Prompt + LLM generation
# -----------------------------
def build_prompt(question: str, history: Optional[List] = None) -> str:
context = retrieve_context(question, max_entries=MAX_CONTEXT_ENTRIES)
history_block = _format_history(history)
return f"""{SYSTEM_PROMPT}
{history_block}ຂໍ້ມູນອ້າງອີງ:
{context}
ຄຳຖາມ: {question}
ຄຳຕອບດ້ວຍພາສາລາວ:"""
def generate_answer(question: str, history: Optional[List] = None) -> str:
prompt = build_prompt(question, history)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=160,
do_sample=False,
)
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
# Enforce 2–3 sentence answers for students
sentences = re.split(r"(?<=[\.?!…])\s+", answer)
short_answer = " ".join(sentences[:3]).strip()
return short_answer if short_answer else answer
# -----------------------------
# QA lookup (exact + fuzzy)
# -----------------------------
def answer_from_qa(question: str) -> Optional[str]:
"""
1) Exact match in QA_INDEX
2) Fuzzy match via word overlap with ALL_QA_KNOWLEDGE
"""
norm_q = qa_store.normalize_question(question)
if not norm_q:
return None
# Exact match
if norm_q in qa_store.QA_INDEX:
return qa_store.QA_INDEX[norm_q]
# Fuzzy match
q_terms = [t for t in norm_q.split(" ") if len(t) > 1]
if not q_terms:
return None
best_score = 0
best_answer: Optional[str] = None
for item in qa_store.ALL_QA_KNOWLEDGE:
stored_terms = [t for t in item["norm_q"].split(" ") if len(t) > 1]
overlap = sum(1 for t in q_terms if t in stored_terms)
if overlap > best_score:
best_score = overlap
best_answer = item["a"]
# require at least 2 overlapping words to accept fuzzy match
if best_score >= 2 and best_answer is not None:
# optional: log when fuzzy match is used
print(f"[FUZZY MATCH] score={best_score} -> {best_answer[:50]!r}")
return best_answer
return None
# -----------------------------
# Main chatbot entry
# -----------------------------
def laos_science_bot(message: str, history: List) -> str:
"""
Main chatbot function for Student tab (Gradio ChatInterface).
"""
if not message.strip():
return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."
# 0) Try glossary first for key concepts
gloss = answer_from_glossary(message)
if gloss:
return gloss
# 1) Try exact / fuzzy Q&A first
direct = answer_from_qa(message)
if direct:
return direct
# 2) Fall back to LLM + retrieved context
try:
answer = generate_answer(message, history)
except Exception as e: # noqa: BLE001
return f"ລະບົບມີບັນຫາ: {e}"
return answer