|
|
|
|
|
from typing import List, Optional |
|
|
import re |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sentence_transformers.util import cos_sim |
|
|
|
|
|
import qa_store |
|
|
from loader import load_curriculum, load_manual_qa, rebuild_combined_qa, load_glossary |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat" |
|
|
EMBED_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
|
|
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
embed_model = SentenceTransformer(EMBED_MODEL_NAME) |
|
|
embed_model = embed_model.to(device) |
|
|
|
|
|
|
|
|
MAX_CONTEXT_ENTRIES = 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_entry_embeddings() -> None: |
|
|
""" |
|
|
Build embeddings for each textbook entry using chapter + section + text |
|
|
and store them in qa_store.TEXT_EMBEDDINGS. |
|
|
|
|
|
Call this after loading / reloading curriculum. |
|
|
""" |
|
|
if not getattr(qa_store, "ENTRIES", None): |
|
|
qa_store.TEXT_EMBEDDINGS = None |
|
|
return |
|
|
|
|
|
texts: List[str] = [] |
|
|
for e in qa_store.ENTRIES: |
|
|
chapter = e.get("chapter_title", "") or e.get("chapter", "") or "" |
|
|
section = e.get("section_title", "") or e.get("section", "") or "" |
|
|
text = e.get("text", "") or "" |
|
|
combined = f"{chapter}\n{section}\n{text}" |
|
|
texts.append(combined) |
|
|
|
|
|
qa_store.TEXT_EMBEDDINGS = embed_model.encode( |
|
|
texts, |
|
|
convert_to_tensor=True, |
|
|
show_progress_bar=False, |
|
|
) |
|
|
|
|
|
|
|
|
def _build_glossary_embeddings() -> None: |
|
|
"""Create embeddings for glossary terms + definitions.""" |
|
|
if not getattr(qa_store, "GLOSSARY", None): |
|
|
qa_store.GLOSSARY_EMBEDDINGS = None |
|
|
print("[INFO] No glossary terms to embed.") |
|
|
return |
|
|
|
|
|
|
|
|
texts = [ |
|
|
f"{item.get('term', '')} :: {item.get('definition', '')}" |
|
|
for item in qa_store.GLOSSARY |
|
|
] |
|
|
|
|
|
embeddings = embed_model.encode( |
|
|
texts, |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True, |
|
|
) |
|
|
qa_store.GLOSSARY_EMBEDDINGS = embeddings |
|
|
print(f"[INFO] Built glossary embeddings for {len(texts)} terms.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load_curriculum() |
|
|
load_manual_qa() |
|
|
load_glossary() |
|
|
rebuild_combined_qa() |
|
|
_build_entry_embeddings() |
|
|
_build_glossary_embeddings() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = ( |
|
|
"ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານວິທະຍາສາດທໍາມະຊາດ " |
|
|
"ສໍາລັບນັກຮຽນຊັ້ນ ມ.1-ມ.4. " |
|
|
"ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. " |
|
|
"ໃຫ້ອີງຈາກຂໍ້ມູນອ້າງອີງຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. " |
|
|
"ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _format_history(history: Optional[List]) -> str: |
|
|
""" |
|
|
Convert last few chat turns into a Lao conversation snippet |
|
|
to give the model context for follow-up questions. |
|
|
Gradio history format: [[user_msg, bot_msg], [user_msg, bot_msg], ...] |
|
|
""" |
|
|
if not history: |
|
|
return "" |
|
|
|
|
|
|
|
|
recent = history[-3:] |
|
|
|
|
|
lines: List[str] = [] |
|
|
for turn in recent: |
|
|
if not isinstance(turn, (list, tuple)) or len(turn) != 2: |
|
|
continue |
|
|
user_msg, bot_msg = turn |
|
|
lines.append(f"ນັກຮຽນ: {user_msg}") |
|
|
lines.append(f"ອາຈານ AI: {bot_msg}") |
|
|
|
|
|
if not lines: |
|
|
return "" |
|
|
|
|
|
joined = "\n".join(lines) + "\n\n" |
|
|
return joined |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_context(question: str, max_entries: int = MAX_CONTEXT_ENTRIES) -> str: |
|
|
""" |
|
|
Embedding-based retrieval over textbook entries. |
|
|
Falls back to concatenated raw knowledge if embeddings are missing. |
|
|
""" |
|
|
if not getattr(qa_store, "ENTRIES", None): |
|
|
|
|
|
return getattr(qa_store, "RAW_KNOWLEDGE", "") |
|
|
|
|
|
if qa_store.TEXT_EMBEDDINGS is None: |
|
|
top_entries = qa_store.ENTRIES[:max_entries] |
|
|
else: |
|
|
|
|
|
q_vec = embed_model.encode( |
|
|
question, |
|
|
convert_to_tensor=True, |
|
|
show_progress_bar=False, |
|
|
) |
|
|
|
|
|
|
|
|
sims = cos_sim(q_vec, qa_store.TEXT_EMBEDDINGS)[0] |
|
|
|
|
|
|
|
|
top_indices = torch.topk(sims, k=min(max_entries, sims.shape[0])).indices |
|
|
top_entries = [qa_store.ENTRIES[i] for i in top_indices.tolist()] |
|
|
|
|
|
|
|
|
context_blocks: List[str] = [] |
|
|
for e in top_entries: |
|
|
header = ( |
|
|
f"[ຊັ້ນ {e.get('grade','')}, " |
|
|
f"ໜ່ວຍ {e.get('unit','')}, " |
|
|
f"ບົດ {e.get('chapter_title','')}, " |
|
|
f"ຫົວຂໍ້ {e.get('section_title','')}]" |
|
|
) |
|
|
context_blocks.append(f"{header}\n{e.get('text','')}") |
|
|
|
|
|
return "\n\n".join(context_blocks) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def answer_from_glossary(message: str) -> Optional[str]: |
|
|
""" |
|
|
Try to answer using the glossary index. |
|
|
Priority 1: Exact string match of the Term inside the user's message. |
|
|
Priority 2: Vector embedding match (if confidence is high). |
|
|
""" |
|
|
if not getattr(qa_store, "GLOSSARY", None): |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
normalized_msg = message.lower().strip() |
|
|
|
|
|
for item in qa_store.GLOSSARY: |
|
|
term = item.get("term", "").lower().strip() |
|
|
|
|
|
if term and term in normalized_msg: |
|
|
|
|
|
if len(normalized_msg) < len(term) + 20: |
|
|
definition = item.get("definition", "").strip() |
|
|
example = item.get("example", "").strip() |
|
|
if example: |
|
|
return f"{definition} ຕົວຢ່າງ: {example}" |
|
|
return definition |
|
|
|
|
|
|
|
|
|
|
|
if qa_store.GLOSSARY_EMBEDDINGS is None: |
|
|
return None |
|
|
|
|
|
q_emb = embed_model.encode( |
|
|
[message], |
|
|
convert_to_numpy=True, |
|
|
normalize_embeddings=True, |
|
|
)[0] |
|
|
|
|
|
sims = np.dot(qa_store.GLOSSARY_EMBEDDINGS, q_emb) |
|
|
best_idx = int(np.argmax(sims)) |
|
|
best_sim = float(sims[best_idx]) |
|
|
|
|
|
|
|
|
|
|
|
if best_sim < 0.65: |
|
|
return None |
|
|
|
|
|
item = qa_store.GLOSSARY[best_idx] |
|
|
definition = item.get("definition", "").strip() |
|
|
example = item.get("example", "").strip() |
|
|
|
|
|
if example: |
|
|
return f"{definition} ຕົວຢ່າງ: {example}" |
|
|
else: |
|
|
return definition |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_prompt(question: str, history: Optional[List] = None) -> str: |
|
|
context = retrieve_context(question, max_entries=MAX_CONTEXT_ENTRIES) |
|
|
history_block = _format_history(history) |
|
|
|
|
|
return f"""{SYSTEM_PROMPT} |
|
|
|
|
|
{history_block}ຂໍ້ມູນອ້າງອີງ: |
|
|
{context} |
|
|
|
|
|
ຄຳຖາມ: {question} |
|
|
|
|
|
ຄຳຕອບດ້ວຍພາສາລາວ:""" |
|
|
|
|
|
|
|
|
def generate_answer(question: str, history: Optional[List] = None) -> str: |
|
|
prompt = build_prompt(question, history) |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=160, |
|
|
do_sample=False, |
|
|
) |
|
|
|
|
|
generated_ids = outputs[0][inputs["input_ids"].shape[1]:] |
|
|
answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
sentences = re.split(r"(?<=[\.?!…])\s+", answer) |
|
|
short_answer = " ".join(sentences[:3]).strip() |
|
|
return short_answer if short_answer else answer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def answer_from_qa(question: str) -> Optional[str]: |
|
|
""" |
|
|
1) Exact match in QA_INDEX |
|
|
2) Fuzzy match via word overlap with ALL_QA_KNOWLEDGE |
|
|
""" |
|
|
norm_q = qa_store.normalize_question(question) |
|
|
if not norm_q: |
|
|
return None |
|
|
|
|
|
|
|
|
if norm_q in qa_store.QA_INDEX: |
|
|
return qa_store.QA_INDEX[norm_q] |
|
|
|
|
|
|
|
|
q_terms = [t for t in norm_q.split(" ") if len(t) > 1] |
|
|
if not q_terms: |
|
|
return None |
|
|
|
|
|
best_score = 0 |
|
|
best_answer: Optional[str] = None |
|
|
|
|
|
for item in qa_store.ALL_QA_KNOWLEDGE: |
|
|
stored_terms = [t for t in item["norm_q"].split(" ") if len(t) > 1] |
|
|
overlap = sum(1 for t in q_terms if t in stored_terms) |
|
|
if overlap > best_score: |
|
|
best_score = overlap |
|
|
best_answer = item["a"] |
|
|
|
|
|
|
|
|
if best_score >= 2 and best_answer is not None: |
|
|
|
|
|
print(f"[FUZZY MATCH] score={best_score} -> {best_answer[:50]!r}") |
|
|
return best_answer |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def laos_science_bot(message: str, history: List) -> str: |
|
|
""" |
|
|
Main chatbot function for Student tab (Gradio ChatInterface). |
|
|
""" |
|
|
if not message.strip(): |
|
|
return "ກະລຸນາພິມຄໍາຖາມກ່ອນ." |
|
|
|
|
|
|
|
|
gloss = answer_from_glossary(message) |
|
|
if gloss: |
|
|
return gloss |
|
|
|
|
|
|
|
|
direct = answer_from_qa(message) |
|
|
if direct: |
|
|
return direct |
|
|
|
|
|
|
|
|
try: |
|
|
answer = generate_answer(message, history) |
|
|
except Exception as e: |
|
|
return f"ລະບົບມີບັນຫາ: {e}" |
|
|
|
|
|
return answer |
|
|
|