Heng2004's picture
Create model_utils.py
fe6264c verified
raw
history blame
4.98 kB
# model_utils.py
from typing import List, Optional
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import qa_store
from loader import load_curriculum, load_manual_qa, rebuild_combined_qa
# -----------------------------
# Model
# -----------------------------
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
)
# Load data once at import time
load_curriculum()
load_manual_qa()
rebuild_combined_qa()
SYSTEM_PROMPT = (
"ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ "
"ສໍາລັບນັກຮຽນຊັ້ນ ມ.1. "
"ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. "
"ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
"ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
)
def retrieve_context(question: str, max_entries: int = 2) -> str:
"""
Simple keyword retrieval over textbook entries.
"""
if not qa_store.ENTRIES:
return qa_store.RAW_KNOWLEDGE
q = question.lower().strip()
terms = [t for t in re.split(r"\s+", q) if len(t) > 1]
if not terms:
chosen = qa_store.ENTRIES[:max_entries]
return "\n\n".join(
f"[ຊັ້ນ {e.get('grade','')}, ບົດ {e.get('chapter','')}, "
f"ຫົວຂໍ້ {e.get('section','')}{e.get('title','')}]\n{e['text']}"
for e in chosen
)
scored = []
for e in qa_store.ENTRIES:
text = e.get("text", "")
title = e.get("title", "")
kws = e.get("keywords", [])
topic = e.get("topic", "")
base = (text + " " + title).lower()
score = 0
for t in terms:
score += base.count(t)
for kw in kws:
kw_lower = kw.lower()
for t in terms:
if t in kw_lower:
score += 2
if topic and any(t in topic for t in terms):
score += 1
if score > 0:
scored.append((score, e))
scored.sort(key=lambda x: x[0], reverse=True)
top_entries = [e for _, e in scored[:max_entries]]
if not top_entries:
top_entries = qa_store.ENTRIES[:max_entries]
context_blocks = []
for e in top_entries:
header = (
f"[ຊັ້ນ {e.get('grade','')}, "
f"ບົດ {e.get('chapter','')}, "
f"ຫົວຂໍ້ {e.get('section','')}{e.get('title','')}]"
)
context_blocks.append(f"{header}\n{e.get('text','')}")
return "\n\n".join(context_blocks)
def build_prompt(question: str) -> str:
context = retrieve_context(question)
return f"""{SYSTEM_PROMPT}
ຂໍ້ມູນອ້າງອີງ:
{context}
ຄຳຖາມ: {question}
ຄຳຕອບດ້ວຍພາສາລາວ:"""
def generate_answer(question: str) -> str:
prompt = build_prompt(question)
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=160,
do_sample=False,
)
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
answer = tokenizer.decode(generated_ids, skip_special_tokens=True)
return answer.strip()
def answer_from_qa(question: str) -> Optional[str]:
"""
1) exact match in QA_INDEX
2) fuzzy match via word overlap with ALL_QA_KNOWLEDGE
"""
norm_q = qa_store.normalize_question(question)
if not norm_q:
return None
if norm_q in qa_store.QA_INDEX:
return qa_store.QA_INDEX[norm_q]
q_terms = [t for t in norm_q.split(" ") if len(t) > 1]
if not q_terms:
return None
best_score = 0
best_answer: Optional[str] = None
for item in qa_store.ALL_QA_KNOWLEDGE:
stored_terms = [t for t in item["norm_q"].split(" ") if len(t) > 1]
overlap = sum(1 for t in q_terms if t in stored_terms)
if overlap > best_score:
best_score = overlap
best_answer = item["a"]
if best_score >= 1:
return best_answer
return None
def laos_history_bot(message: str, history: List) -> str:
"""
Main chatbot function for Student tab.
"""
if not message.strip():
return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."
direct = answer_from_qa(message)
if direct:
return direct
try:
answer = generate_answer(message)
except Exception as e: # noqa: BLE001
return f"ລະບົບມີບັນຫາ: {e}"
return answer