File size: 4,106 Bytes
8db4a34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# rag.py – retrieval + model generation
import re
from typing import List, Dict
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from data.loader import ENTRIES, RAW_KNOWLEDGE
from data.qa_index import answer_from_qa
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32, # CPU on HF free tier
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
SYSTEM_PROMPT = (
"ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ "
"ສໍາລັບນັກຮຽນຊັ້ນ ມ.1. "
"ຕອບແຕ່ພາສາລາວ ແລະຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ໃຫ້ເຂົ້າໃຈງ່າຍ. "
"ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
"ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
)
def retrieve_context(question: str, max_entries: int = 2) -> str:
"""
Simple keyword matching over ENTRIES (text + title + keywords).
"""
if not ENTRIES:
return RAW_KNOWLEDGE
q = question.lower().strip()
terms = [t for t in re.split(r"\s+", q) if len(t) > 1]
if not terms:
chosen = ENTRIES[:max_entries]
return "\n\n".join(
f"[ຊັ້ນ {e.get('grade','')}, ບົດ {e.get('chapter','')}, "
f"ຫົວຂໍ້ {e.get('section','')} – {e.get('title','')}]\n{e['text']}"
for e in chosen
)
scored = []
for e in ENTRIES:
text = e.get("text", "")
title = e.get("title", "")
kws = e.get("keywords", [])
topic = e.get("topic", "")
base = (text + " " + title).lower()
score = 0
for t in terms:
score += base.count(t)
for kw in kws:
kw_lower = kw.lower()
for t in terms:
if t in kw_lower:
score += 2
if topic and any(t in topic for t in terms):
score += 1
if score > 0:
scored.append((score, e))
scored.sort(key=lambda x: x[0], reverse=True)
top_entries = [e for _, e in scored[:max_entries]]
if not top_entries:
top_entries = ENTRIES[:max_entries]
blocks = []
for e in top_entries:
header = (
f"[ຊັ້ນ {e.get('grade','')}, "
f"ບົດ {e.get('chapter','')}, "
f"ຫົວຂໍ້ {e.get('section','')} – {e.get('title','')}]"
)
blocks.append(f"{header}\n{e.get('text','')}")
return "\n\n".join(blocks)
def build_prompt(question: str) -> str:
context = retrieve_context(question)
return f"""{SYSTEM_PROMPT}
ຂໍ້ມູນອ້າງອີງ:
{context}
ຄໍາຖາມ: {question}
ຄໍາຕອບດ້ວຍພາສາລາວ:"""
def generate_answer_with_model(question: str) -> str:
prompt = build_prompt(question)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=160,
do_sample=False, # greedy → stable, a bit faster
)
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
answer = tokenizer.decode(generated_ids, skip_special_tokens=True)
return answer.strip()
def answer_question(question: str) -> str:
if not question.strip():
return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."
# 1) try manual + dataset QA first (instant, no model)
direct = answer_from_qa(question)
if direct:
return direct
# 2) fall back to model + RAG
try:
return generate_answer_with_model(question)
except Exception as e:
return f"ລະບົບມີບັນຫາ: {e}"
|