Spaces:
Restarting
Restarting
| # model_utils.py | |
| from typing import List, Optional | |
| import re | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers.util import cos_sim | |
| import qa_store | |
| from loader import ( | |
| load_curriculum, | |
| load_manual_qa, | |
| rebuild_combined_qa, | |
| load_glossary, | |
| sync_download_manual_qa # <--- Import it | |
| ) | |
| # ----------------------------- | |
| # Base chat model | |
| # ----------------------------- | |
| MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat" | |
| EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # Use float16 on GPU to save memory, float32 on CPU | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype) | |
| model.to(device) | |
| model.eval() | |
| embed_model = SentenceTransformer(EMBED_MODEL_NAME) | |
| embed_model = embed_model.to(device) | |
| # Number of textbook entries to include in the RAG context | |
| MAX_CONTEXT_ENTRIES = 4 | |
| # ----------------------------- | |
| # Embedding builders | |
| # ----------------------------- | |
| def _build_entry_embeddings() -> None: | |
| """ | |
| Build embeddings for each textbook entry using chapter + section + text | |
| and store them in qa_store.TEXT_EMBEDDINGS. | |
| Call this after loading / reloading curriculum. | |
| """ | |
| if not getattr(qa_store, "ENTRIES", None): | |
| qa_store.TEXT_EMBEDDINGS = None | |
| return | |
| texts: List[str] = [] | |
| for e in qa_store.ENTRIES: | |
| chapter = e.get("chapter_title", "") or e.get("chapter", "") or "" | |
| section = e.get("section_title", "") or e.get("section", "") or "" | |
| text = e.get("text", "") or "" | |
| combined = f"{chapter}\n{section}\n{text}" | |
| texts.append(combined) | |
| qa_store.TEXT_EMBEDDINGS = embed_model.encode( | |
| texts, | |
| convert_to_tensor=True, | |
| show_progress_bar=False, | |
| ) | |
| def _build_glossary_embeddings() -> None: | |
| """Create embeddings for glossary terms + definitions.""" | |
| if not getattr(qa_store, "GLOSSARY", None): | |
| qa_store.GLOSSARY_EMBEDDINGS = None | |
| print("[INFO] No glossary terms to embed.") | |
| return | |
| # Embed term + definition together | |
| texts = [ | |
| f"{item.get('term', '')} :: {item.get('definition', '')}" | |
| for item in qa_store.GLOSSARY | |
| ] | |
| embeddings = embed_model.encode( | |
| texts, | |
| convert_to_numpy=True, | |
| normalize_embeddings=True, | |
| ) | |
| qa_store.GLOSSARY_EMBEDDINGS = embeddings | |
| print(f"[INFO] Built glossary embeddings for {len(texts)} terms.") | |
| # ----------------------------- | |
| # Load data once at import time | |
| # ----------------------------- | |
| sync_download_manual_qa() | |
| load_curriculum() | |
| load_manual_qa() | |
| load_glossary() | |
| rebuild_combined_qa() | |
| _build_entry_embeddings() | |
| _build_glossary_embeddings() | |
| # ----------------------------- | |
| # System prompt (Natural Science) | |
| # ----------------------------- | |
| SYSTEM_PROMPT = ( | |
| "ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານວິທະຍາສາດທໍາມະຊາດ " | |
| "ສໍາລັບນັກຮຽນຊັ້ນ ມ.1-ມ.4. " | |
| "ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. " | |
| "ໃຫ້ອີງຈາກຂໍ້ມູນອ້າງອີງຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. " | |
| "ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ." | |
| ) | |
| # ----------------------------- | |
| # Helper: history formatting | |
| # ----------------------------- | |
| def _format_history(history: Optional[List]) -> str: | |
| """ | |
| Convert last few chat turns into a Lao conversation snippet | |
| to give the model context for follow-up questions. | |
| Gradio history format: [[user_msg, bot_msg], [user_msg, bot_msg], ...] | |
| """ | |
| if not history: | |
| return "" | |
| # keep only the last 3 turns to avoid very long prompts | |
| recent = history[-3:] | |
| lines: List[str] = [] | |
| for turn in recent: | |
| if not isinstance(turn, (list, tuple)) or len(turn) != 2: | |
| continue | |
| user_msg, bot_msg = turn | |
| lines.append(f"ນັກຮຽນ: {user_msg}") | |
| lines.append(f"ອາຈານ AI: {bot_msg}") | |
| if not lines: | |
| return "" | |
| joined = "\n".join(lines) + "\n\n" | |
| return joined | |
| # ----------------------------- | |
| # RAG: retrieve textbook context | |
| # ----------------------------- | |
| def retrieve_context(question: str, max_entries: int = MAX_CONTEXT_ENTRIES) -> str: | |
| """ | |
| Embedding-based retrieval over textbook entries. | |
| Falls back to concatenated raw knowledge if embeddings are missing. | |
| """ | |
| if not getattr(qa_store, "ENTRIES", None): | |
| # Fallback: raw knowledge (if available) or empty string | |
| return getattr(qa_store, "RAW_KNOWLEDGE", "") | |
| if qa_store.TEXT_EMBEDDINGS is None: | |
| top_entries = qa_store.ENTRIES[:max_entries] | |
| else: | |
| # 1) Encode the question | |
| q_vec = embed_model.encode( | |
| question, | |
| convert_to_tensor=True, | |
| show_progress_bar=False, | |
| ) | |
| # 2) Cosine similarity with all entry embeddings | |
| sims = cos_sim(q_vec, qa_store.TEXT_EMBEDDINGS)[0] # shape [N] | |
| # 3) Take top-k | |
| top_indices = torch.topk(sims, k=min(max_entries, sims.shape[0])).indices | |
| top_entries = [qa_store.ENTRIES[i] for i in top_indices.tolist()] | |
| # Build context string for the prompt | |
| context_blocks: List[str] = [] | |
| for e in top_entries: | |
| header = ( | |
| f"[ຊັ້ນ {e.get('grade','')}, " | |
| f"ໜ່ວຍ {e.get('unit','')}, " | |
| f"ບົດ {e.get('chapter_title','')}, " | |
| f"ຫົວຂໍ້ {e.get('section_title','')}]" | |
| ) | |
| context_blocks.append(f"{header}\n{e.get('text','')}") | |
| return "\n\n".join(context_blocks) | |
| # ----------------------------- | |
| # Glossary-based answering | |
| # ----------------------------- | |
| def answer_from_glossary(message: str) -> Optional[str]: | |
| """ | |
| Try to answer using the glossary index. | |
| Priority 1: Exact string match of the Term inside the user's message. | |
| Priority 2: Vector embedding match (if confidence is high). | |
| """ | |
| if not getattr(qa_store, "GLOSSARY", None): | |
| return None | |
| # --- FIX START: Check for EXACT term match first --- | |
| # This fixes the issue where "What is Science" matches "Pollution" | |
| # just because "Pollution" definition contains the word "Science". | |
| normalized_msg = message.lower().strip() | |
| for item in qa_store.GLOSSARY: | |
| term = item.get("term", "").lower().strip() | |
| # If the specific term appears in the message (e.g. "Science" in "What is Science?") | |
| if term and term in normalized_msg: | |
| # Optional: Check if the message is SHORT (so we don't trigger on long sentences accidentally) | |
| if len(normalized_msg) < len(term) + 20: | |
| definition = item.get("definition", "").strip() | |
| example = item.get("example", "").strip() | |
| if example: | |
| return f"{definition} ຕົວຢ່າງ: {example}" | |
| return definition | |
| # --- FIX END --- | |
| # If no exact text match, proceed to Vector Similarity (the old code) | |
| if qa_store.GLOSSARY_EMBEDDINGS is None: | |
| return None | |
| q_emb = embed_model.encode( | |
| [message], | |
| convert_to_numpy=True, | |
| normalize_embeddings=True, | |
| )[0] | |
| sims = np.dot(qa_store.GLOSSARY_EMBEDDINGS, q_emb) | |
| best_idx = int(np.argmax(sims)) | |
| best_sim = float(sims[best_idx]) | |
| # INCREASE THRESHOLD: | |
| # Raised from 0.55 to 0.65 to prevent weak matches (like Science matching Pollution) | |
| if best_sim < 0.65: | |
| return None | |
| item = qa_store.GLOSSARY[best_idx] | |
| definition = item.get("definition", "").strip() | |
| example = item.get("example", "").strip() | |
| if example: | |
| return f"{definition} ຕົວຢ່າງ: {example}" | |
| else: | |
| return definition | |
| # ----------------------------- | |
| # Prompt + LLM generation | |
| # ----------------------------- | |
| def build_prompt(question: str, history: Optional[List] = None) -> str: | |
| context = retrieve_context(question, max_entries=MAX_CONTEXT_ENTRIES) | |
| history_block = _format_history(history) | |
| return f"""{SYSTEM_PROMPT} | |
| {history_block}ຂໍ້ມູນອ້າງອີງ: | |
| {context} | |
| ຄຳຖາມ: {question} | |
| ຄຳຕອບດ້ວຍພາສາລາວ:""" | |
| def generate_answer(question: str, history: Optional[List] = None) -> str: | |
| prompt = build_prompt(question, history) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=160, | |
| do_sample=False, | |
| ) | |
| generated_ids = outputs[0][inputs["input_ids"].shape[1]:] | |
| answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| # Enforce 2–3 sentence answers for students | |
| sentences = re.split(r"(?<=[\.?!…])\s+", answer) | |
| short_answer = " ".join(sentences[:3]).strip() | |
| return short_answer if short_answer else answer | |
| # ----------------------------- | |
| # QA lookup (exact + fuzzy) | |
| # ----------------------------- | |
| def answer_from_qa(question: str) -> Optional[str]: | |
| """ | |
| 1) Exact match in QA_INDEX | |
| 2) Fuzzy match via word overlap with ALL_QA_KNOWLEDGE | |
| """ | |
| norm_q = qa_store.normalize_question(question) | |
| if not norm_q: | |
| return None | |
| # Exact match | |
| if norm_q in qa_store.QA_INDEX: | |
| return qa_store.QA_INDEX[norm_q] | |
| # Fuzzy match | |
| q_terms = [t for t in norm_q.split(" ") if len(t) > 1] | |
| if not q_terms: | |
| return None | |
| best_score = 0 | |
| best_answer: Optional[str] = None | |
| for item in qa_store.ALL_QA_KNOWLEDGE: | |
| stored_terms = [t for t in item["norm_q"].split(" ") if len(t) > 1] | |
| overlap = sum(1 for t in q_terms if t in stored_terms) | |
| if overlap > best_score: | |
| best_score = overlap | |
| best_answer = item["a"] | |
| # require at least 2 overlapping words to accept fuzzy match | |
| if best_score >= 2 and best_answer is not None: | |
| # optional: log when fuzzy match is used | |
| print(f"[FUZZY MATCH] score={best_score} -> {best_answer[:50]!r}") | |
| return best_answer | |
| return None | |
| # ----------------------------- | |
| # Main chatbot entry | |
| # ----------------------------- | |
| def laos_science_bot(message: str, history: List) -> str: | |
| """ | |
| Main chatbot function for Student tab (Gradio ChatInterface). | |
| """ | |
| if not message.strip(): | |
| return "ກະລຸນາພິມຄໍາຖາມກ່ອນ." | |
| # 0) Try glossary first for key concepts | |
| gloss = answer_from_glossary(message) | |
| if gloss: | |
| return gloss | |
| # 1) Try exact / fuzzy Q&A first | |
| direct = answer_from_qa(message) | |
| if direct: | |
| return direct | |
| # 2) Fall back to LLM + retrieved context | |
| try: | |
| answer = generate_answer(message, history) | |
| except Exception as e: # noqa: BLE001 | |
| return f"ລະບົບມີບັນຫາ: {e}" | |
| return answer | |