File size: 4,106 Bytes
8db4a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# rag.py – retrieval + model generation

import re
from typing import List, Dict

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from data.loader import ENTRIES, RAW_KNOWLEDGE
from data.qa_index import answer_from_qa

MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,  # CPU on HF free tier
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


SYSTEM_PROMPT = (
    "ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ "
    "ສໍາລັບນັກຮຽນຊັ້ນ ມ.1. "
    "ຕອບແຕ່ພາສາລາວ ແລະຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ໃຫ້ເຂົ້າໃຈງ່າຍ. "
    "ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
    "ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
)


def retrieve_context(question: str, max_entries: int = 2) -> str:
    """
    Simple keyword matching over ENTRIES (text + title + keywords).
    """
    if not ENTRIES:
        return RAW_KNOWLEDGE

    q = question.lower().strip()
    terms = [t for t in re.split(r"\s+", q) if len(t) > 1]

    if not terms:
        chosen = ENTRIES[:max_entries]
        return "\n\n".join(
            f"[ຊັ້ນ {e.get('grade','')}, ບົດ {e.get('chapter','')}, "
            f"ຫົວຂໍ້ {e.get('section','')}{e.get('title','')}]\n{e['text']}"
            for e in chosen
        )

    scored = []

    for e in ENTRIES:
        text = e.get("text", "")
        title = e.get("title", "")
        kws = e.get("keywords", [])
        topic = e.get("topic", "")

        base = (text + " " + title).lower()
        score = 0

        for t in terms:
            score += base.count(t)

        for kw in kws:
            kw_lower = kw.lower()
            for t in terms:
                if t in kw_lower:
                    score += 2

        if topic and any(t in topic for t in terms):
            score += 1

        if score > 0:
            scored.append((score, e))

    scored.sort(key=lambda x: x[0], reverse=True)
    top_entries = [e for _, e in scored[:max_entries]]

    if not top_entries:
        top_entries = ENTRIES[:max_entries]

    blocks = []
    for e in top_entries:
        header = (
            f"[ຊັ້ນ {e.get('grade','')}, "
            f"ບົດ {e.get('chapter','')}, "
            f"ຫົວຂໍ້ {e.get('section','')}{e.get('title','')}]"
        )
        blocks.append(f"{header}\n{e.get('text','')}")

    return "\n\n".join(blocks)


def build_prompt(question: str) -> str:
    context = retrieve_context(question)
    return f"""{SYSTEM_PROMPT}

ຂໍ້ມູນອ້າງອີງ:
{context}

ຄໍາຖາມ: {question}

ຄໍາຕອບດ້ວຍພາສາລາວ:"""


def generate_answer_with_model(question: str) -> str:
    prompt = build_prompt(question)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=160,
            do_sample=False,   # greedy → stable, a bit faster
        )

    generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
    answer = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return answer.strip()


def answer_question(question: str) -> str:
    if not question.strip():
        return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."

    # 1) try manual + dataset QA first (instant, no model)
    direct = answer_from_qa(question)
    if direct:
        return direct

    # 2) fall back to model + RAG
    try:
        return generate_answer_with_model(question)
    except Exception as e:
        return f"ລະບົບມີບັນຫາ: {e}"