|
|
|
|
|
from typing import List, Optional |
|
|
import re |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
import qa_store |
|
|
from loader import load_curriculum, load_manual_qa, rebuild_combined_qa |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
torch_dtype=torch.float32, |
|
|
) |
|
|
|
|
|
|
|
|
load_curriculum() |
|
|
load_manual_qa() |
|
|
rebuild_combined_qa() |
|
|
|
|
|
SYSTEM_PROMPT = ( |
|
|
"ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ " |
|
|
"ສໍາລັບນັກຮຽນຊັ້ນ ມ.1. " |
|
|
"ຕອບແຕ່ພາສາລາວ ໃຫ້ຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ແລະເຂົ້າໃຈງ່າຍ. " |
|
|
"ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. " |
|
|
"ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ." |
|
|
) |
|
|
|
|
|
|
|
|
def retrieve_context(question: str, max_entries: int = 2) -> str: |
|
|
""" |
|
|
Simple keyword retrieval over textbook entries. |
|
|
""" |
|
|
if not qa_store.ENTRIES: |
|
|
return qa_store.RAW_KNOWLEDGE |
|
|
|
|
|
q = question.lower().strip() |
|
|
terms = [t for t in re.split(r"\s+", q) if len(t) > 1] |
|
|
|
|
|
if not terms: |
|
|
chosen = qa_store.ENTRIES[:max_entries] |
|
|
return "\n\n".join( |
|
|
f"[ຊັ້ນ {e.get('grade','')}, ບົດ {e.get('chapter','')}, " |
|
|
f"ຫົວຂໍ້ {e.get('section','')} – {e.get('title','')}]\n{e['text']}" |
|
|
for e in chosen |
|
|
) |
|
|
|
|
|
scored = [] |
|
|
|
|
|
for e in qa_store.ENTRIES: |
|
|
text = e.get("text", "") |
|
|
title = e.get("title", "") |
|
|
kws = e.get("keywords", []) |
|
|
topic = e.get("topic", "") |
|
|
|
|
|
base = (text + " " + title).lower() |
|
|
score = 0 |
|
|
|
|
|
for t in terms: |
|
|
score += base.count(t) |
|
|
|
|
|
for kw in kws: |
|
|
kw_lower = kw.lower() |
|
|
for t in terms: |
|
|
if t in kw_lower: |
|
|
score += 2 |
|
|
|
|
|
if topic and any(t in topic for t in terms): |
|
|
score += 1 |
|
|
|
|
|
if score > 0: |
|
|
scored.append((score, e)) |
|
|
|
|
|
scored.sort(key=lambda x: x[0], reverse=True) |
|
|
top_entries = [e for _, e in scored[:max_entries]] |
|
|
|
|
|
if not top_entries: |
|
|
top_entries = qa_store.ENTRIES[:max_entries] |
|
|
|
|
|
context_blocks = [] |
|
|
for e in top_entries: |
|
|
header = ( |
|
|
f"[ຊັ້ນ {e.get('grade','')}, " |
|
|
f"ບົດ {e.get('chapter','')}, " |
|
|
f"ຫົວຂໍ້ {e.get('section','')} – {e.get('title','')}]" |
|
|
) |
|
|
context_blocks.append(f"{header}\n{e.get('text','')}") |
|
|
|
|
|
return "\n\n".join(context_blocks) |
|
|
|
|
|
|
|
|
def build_prompt(question: str) -> str: |
|
|
context = retrieve_context(question) |
|
|
return f"""{SYSTEM_PROMPT} |
|
|
|
|
|
ຂໍ້ມູນອ້າງອີງ: |
|
|
{context} |
|
|
|
|
|
ຄຳຖາມ: {question} |
|
|
|
|
|
ຄຳຕອບດ້ວຍພາສາລາວ:""" |
|
|
|
|
|
|
|
|
def generate_answer(question: str) -> str: |
|
|
prompt = build_prompt(question) |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=160, |
|
|
do_sample=False, |
|
|
) |
|
|
|
|
|
generated_ids = outputs[0][inputs["input_ids"].shape[1]:] |
|
|
answer = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
return answer.strip() |
|
|
|
|
|
|
|
|
def answer_from_qa(question: str) -> Optional[str]: |
|
|
""" |
|
|
1) exact match in QA_INDEX |
|
|
2) fuzzy match via word overlap with ALL_QA_KNOWLEDGE |
|
|
""" |
|
|
norm_q = qa_store.normalize_question(question) |
|
|
if not norm_q: |
|
|
return None |
|
|
|
|
|
if norm_q in qa_store.QA_INDEX: |
|
|
return qa_store.QA_INDEX[norm_q] |
|
|
|
|
|
q_terms = [t for t in norm_q.split(" ") if len(t) > 1] |
|
|
if not q_terms: |
|
|
return None |
|
|
|
|
|
best_score = 0 |
|
|
best_answer: Optional[str] = None |
|
|
|
|
|
for item in qa_store.ALL_QA_KNOWLEDGE: |
|
|
stored_terms = [t for t in item["norm_q"].split(" ") if len(t) > 1] |
|
|
overlap = sum(1 for t in q_terms if t in stored_terms) |
|
|
if overlap > best_score: |
|
|
best_score = overlap |
|
|
best_answer = item["a"] |
|
|
|
|
|
if best_score >= 1: |
|
|
return best_answer |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def laos_history_bot(message: str, history: List) -> str: |
|
|
""" |
|
|
Main chatbot function for Student tab. |
|
|
""" |
|
|
if not message.strip(): |
|
|
return "ກະລຸນາພິມຄໍາຖາມກ່ອນ." |
|
|
|
|
|
direct = answer_from_qa(message) |
|
|
if direct: |
|
|
return direct |
|
|
|
|
|
try: |
|
|
answer = generate_answer(message) |
|
|
except Exception as e: |
|
|
return f"ລະບົບມີບັນຫາ: {e}" |
|
|
|
|
|
return answer |
|
|
|