Heng2004's picture
Update model_utils.py
255887a verified
raw
history blame
7.67 kB
# model_utils.py
from typing import List, Optional
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import qa_store
from loader import load_curriculum, load_manual_qa, rebuild_combined_qa
# -----------------------------
# Base chat model
# -----------------------------
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
MAX_CONTEXT_ENTRIES = 3 # how many textbook chunks to retrieve per question
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
).to(device)
model.eval()
# -----------------------------
# Embedding model for retrieval
# -----------------------------
EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embed_model = SentenceTransformer(EMBED_MODEL_NAME)
# (optional) move embedding model to same device; OK to leave on CPU if you want
embed_model = embed_model.to(device)
# NOTE: called once after load_curriculum() to precompute embeddings.
# If you ever reload curriculum at runtime, call _build_entry_embeddings() again.
def _build_entry_embeddings() -> None:
"""
Build embeddings for each textbook entry using title + summary + text
and store them in qa_store.TEXT_EMBEDDINGS.
"""
if not qa_store.ENTRIES:
qa_store.TEXT_EMBEDDINGS = None
return
texts = []
for e in qa_store.ENTRIES:
title = e.get("title", "") or ""
summary = e.get("summary", "") or ""
text = e.get("text", "") or ""
combined = f"{title}\n{summary}\n{text}"
texts.append(combined)
qa_store.TEXT_EMBEDDINGS = embed_model.encode(
texts,
convert_to_tensor=True,
show_progress_bar=False,
)
# -----------------------------
# Load data once at import time
# -----------------------------
load_curriculum()
load_manual_qa()
rebuild_combined_qa()
_build_entry_embeddings()
SYSTEM_PROMPT = (
"ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ "
"ສໍາລັບນັກຮຽນຊັ້ນ ມ.1. "
"ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. "
"ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
"ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
)
def retrieve_context(question: str, max_entries: int = MAX_CONTEXT_ENTRIES) -> str:
"""
Embedding-based retrieval over textbook entries.
Falls back to first entries if embeddings are missing.
"""
if not qa_store.ENTRIES:
return qa_store.RAW_KNOWLEDGE
if qa_store.TEXT_EMBEDDINGS is None:
top_entries = qa_store.ENTRIES[:max_entries]
else:
# 1) Encode the question
q_vec = embed_model.encode(
question,
convert_to_tensor=True,
show_progress_bar=False,
)
# 2) Cosine similarity with all entry embeddings
sims = cos_sim(q_vec, qa_store.TEXT_EMBEDDINGS)[0] # shape [N]
# 3) Pick top-k indices
k = min(max_entries, len(qa_store.ENTRIES))
_, top_indices = torch.topk(sims, k=k)
# 4) Map indices back to entries
top_entries = [qa_store.ENTRIES[i] for i in top_indices.tolist()]
# Build context string for the prompt
context_blocks: List[str] = []
for e in top_entries:
header = (
f"[ຊັ້ນ {e.get('grade','')}, "
f"ບົດ {e.get('chapter','')}, "
f"ຫົວຂໍ້ {e.get('section','')}{e.get('title','')}]"
)
context_blocks.append(f"{header}\n{e.get('text','')}")
return "\n\n".join(context_blocks)
def _format_history(history: Optional[List]) -> str:
"""
Convert last few chat turns into a Lao conversation snippet
to give the model context for follow-up questions.
Gradio history format: [[user_msg, bot_msg], [user_msg, bot_msg], ...]
"""
if not history:
return ""
# keep only the last 3 turns to avoid very long prompts
recent = history[-3:]
lines: List[str] = []
for turn in recent:
if not isinstance(turn, (list, tuple)) or len(turn) != 2:
continue
user_msg, bot_msg = turn
lines.append(f"ນັກຮຽນ: {user_msg}")
lines.append(f"ອາຈານ AI: {bot_msg}")
if not lines:
return ""
joined = "\n".join(lines)
return f"ປະຫວັດການສົນທະນາກ່ອນໜ້າ:\n{joined}\n\n"
def build_prompt(question: str, history: Optional[List] = None) -> str:
context = retrieve_context(question, max_entries=MAX_CONTEXT_ENTRIES)
history_block = _format_history(history)
return f"""{SYSTEM_PROMPT}
{history_block}ຂໍ້ມູນອ້າງອີງ:
{context}
ຄຳຖາມ: {question}
ຄຳຕອບດ້ວຍພາສາລາວ:"""
def generate_answer(question: str, history: Optional[List] = None) -> str:
prompt = build_prompt(question, history)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=160,
do_sample=False,
)
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
# Enforce 2–3 sentence answers for M.1 students
sentences = re.split(r"(?<=[\.?!…])\s+", answer)
short_answer = " ".join(sentences[:3]).strip()
return short_answer if short_answer else answer
def answer_from_qa(question: str) -> Optional[str]:
"""
1) Exact match in QA_INDEX
2) Fuzzy match via word overlap with ALL_QA_KNOWLEDGE
"""
norm_q = qa_store.normalize_question(question)
if not norm_q:
return None
# Exact match
if norm_q in qa_store.QA_INDEX:
return qa_store.QA_INDEX[norm_q]
# Fuzzy match
q_terms = [t for t in norm_q.split(" ") if len(t) > 1]
if not q_terms:
return None
best_score = 0
best_answer: Optional[str] = None
for item in qa_store.ALL_QA_KNOWLEDGE:
stored_terms = [t for t in item["norm_q"].split(" ") if len(t) > 1]
overlap = sum(1 for t in q_terms if t in stored_terms)
if overlap > best_score:
best_score = overlap
best_answer = item["a"]
# require at least 2 overlapping words to accept fuzzy match
if best_score >= 2 and best_answer is not None:
# optional: log when fuzzy match is used
print(f"[FUZZY MATCH] score={best_score} -> {best_answer[:50]!r}")
return best_answer
return None
def laos_history_bot(message: str, history: List) -> str:
"""
Main chatbot function for Student tab (Gradio ChatInterface).
"""
if not message.strip():
return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."
# 1) Try exact / fuzzy Q&A first
direct = answer_from_qa(message)
if direct:
return direct
# 2) Fall back to LLM + retrieved context
try:
answer = generate_answer(message, history)
except Exception as e: # noqa: BLE001
return f"ລະບົບມີບັນຫາ: {e}"
return answer