# model_utils.py from typing import List, Optional import re import torch from transformers import AutoTokenizer, AutoModelForCausalLM from sentence_transformers import SentenceTransformer from sentence_transformers.util import cos_sim import qa_store from loader import load_curriculum, load_manual_qa, rebuild_combined_qa # ----------------------------- # Base chat model # ----------------------------- MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat" MAX_CONTEXT_ENTRIES = 3 # how many textbook chunks to retrieve per question tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, ).to(device) model.eval() # ----------------------------- # Embedding model for retrieval # ----------------------------- EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" embed_model = SentenceTransformer(EMBED_MODEL_NAME) # (optional) move embedding model to same device; OK to leave on CPU if you want embed_model = embed_model.to(device) # NOTE: called once after load_curriculum() to precompute embeddings. # If you ever reload curriculum at runtime, call _build_entry_embeddings() again. def _build_entry_embeddings() -> None: """ Build embeddings for each textbook entry using title + summary + text and store them in qa_store.TEXT_EMBEDDINGS. """ if not qa_store.ENTRIES: qa_store.TEXT_EMBEDDINGS = None return texts = [] for e in qa_store.ENTRIES: title = e.get("title", "") or "" summary = e.get("summary", "") or "" text = e.get("text", "") or "" combined = f"{title}\n{summary}\n{text}" texts.append(combined) qa_store.TEXT_EMBEDDINGS = embed_model.encode( texts, convert_to_tensor=True, show_progress_bar=False, ) # ----------------------------- # Load data once at import time # ----------------------------- load_curriculum() load_manual_qa() rebuild_combined_qa() _build_entry_embeddings() SYSTEM_PROMPT = ( "ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ " "ສໍາລັບນັກຮຽນຊັ້ນ ມ.1. " "ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. " "ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. " "ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ." ) def retrieve_context(question: str, max_entries: int = MAX_CONTEXT_ENTRIES) -> str: """ Embedding-based retrieval over textbook entries. Falls back to first entries if embeddings are missing. """ if not qa_store.ENTRIES: return qa_store.RAW_KNOWLEDGE if qa_store.TEXT_EMBEDDINGS is None: top_entries = qa_store.ENTRIES[:max_entries] else: # 1) Encode the question q_vec = embed_model.encode( question, convert_to_tensor=True, show_progress_bar=False, ) # 2) Cosine similarity with all entry embeddings sims = cos_sim(q_vec, qa_store.TEXT_EMBEDDINGS)[0] # shape [N] # 3) Pick top-k indices k = min(max_entries, len(qa_store.ENTRIES)) _, top_indices = torch.topk(sims, k=k) # 4) Map indices back to entries top_entries = [qa_store.ENTRIES[i] for i in top_indices.tolist()] # Build context string for the prompt context_blocks: List[str] = [] for e in top_entries: header = ( f"[ຊັ້ນ {e.get('grade','')}, " f"ບົດ {e.get('chapter','')}, " f"ຫົວຂໍ້ {e.get('section','')} – {e.get('title','')}]" ) context_blocks.append(f"{header}\n{e.get('text','')}") return "\n\n".join(context_blocks) def _format_history(history: Optional[List]) -> str: """ Convert last few chat turns into a Lao conversation snippet to give the model context for follow-up questions. Gradio history format: [[user_msg, bot_msg], [user_msg, bot_msg], ...] """ if not history: return "" # keep only the last 3 turns to avoid very long prompts recent = history[-3:] lines: List[str] = [] for turn in recent: if not isinstance(turn, (list, tuple)) or len(turn) != 2: continue user_msg, bot_msg = turn lines.append(f"ນັກຮຽນ: {user_msg}") lines.append(f"ອາຈານ AI: {bot_msg}") if not lines: return "" joined = "\n".join(lines) return f"ປະຫວັດການສົນທະນາກ່ອນໜ້າ:\n{joined}\n\n" def build_prompt(question: str, history: Optional[List] = None) -> str: context = retrieve_context(question, max_entries=MAX_CONTEXT_ENTRIES) history_block = _format_history(history) return f"""{SYSTEM_PROMPT} {history_block}ຂໍ້ມູນອ້າງອີງ: {context} ຄຳຖາມ: {question} ຄຳຕອບດ້ວຍພາສາລາວ:""" def generate_answer(question: str, history: Optional[List] = None) -> str: prompt = build_prompt(question, history) inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=160, do_sample=False, ) generated_ids = outputs[0][inputs["input_ids"].shape[1]:] answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() # Enforce 2–3 sentence answers for M.1 students sentences = re.split(r"(?<=[\.?!…])\s+", answer) short_answer = " ".join(sentences[:3]).strip() return short_answer if short_answer else answer def answer_from_qa(question: str) -> Optional[str]: """ 1) Exact match in QA_INDEX 2) Fuzzy match via word overlap with ALL_QA_KNOWLEDGE """ norm_q = qa_store.normalize_question(question) if not norm_q: return None # Exact match if norm_q in qa_store.QA_INDEX: return qa_store.QA_INDEX[norm_q] # Fuzzy match q_terms = [t for t in norm_q.split(" ") if len(t) > 1] if not q_terms: return None best_score = 0 best_answer: Optional[str] = None for item in qa_store.ALL_QA_KNOWLEDGE: stored_terms = [t for t in item["norm_q"].split(" ") if len(t) > 1] overlap = sum(1 for t in q_terms if t in stored_terms) if overlap > best_score: best_score = overlap best_answer = item["a"] # require at least 2 overlapping words to accept fuzzy match if best_score >= 2 and best_answer is not None: # optional: log when fuzzy match is used print(f"[FUZZY MATCH] score={best_score} -> {best_answer[:50]!r}") return best_answer return None def laos_history_bot(message: str, history: List) -> str: """ Main chatbot function for Student tab (Gradio ChatInterface). """ if not message.strip(): return "ກະລຸນາພິມຄໍາຖາມກ່ອນ." # 1) Try exact / fuzzy Q&A first direct = answer_from_qa(message) if direct: return direct # 2) Fall back to LLM + retrieved context try: answer = generate_answer(message, history) except Exception as e: # noqa: BLE001 return f"ລະບົບມີບັນຫາ: {e}" return answer