Heng2004's picture
Create rag.py
8db4a34 verified
raw
history blame
4.11 kB
# rag.py – retrieval + model generation
import re
from typing import List, Dict
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from data.loader import ENTRIES, RAW_KNOWLEDGE
from data.qa_index import answer_from_qa
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32, # CPU on HF free tier
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
SYSTEM_PROMPT = (
"ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ "
"ສໍາລັບນັກຮຽນຊັ້ນ ມ.1. "
"ຕອບແຕ່ພາສາລາວ ແລະຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ໃຫ້ເຂົ້າໃຈງ່າຍ. "
"ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
"ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
)
def retrieve_context(question: str, max_entries: int = 2) -> str:
"""
Simple keyword matching over ENTRIES (text + title + keywords).
"""
if not ENTRIES:
return RAW_KNOWLEDGE
q = question.lower().strip()
terms = [t for t in re.split(r"\s+", q) if len(t) > 1]
if not terms:
chosen = ENTRIES[:max_entries]
return "\n\n".join(
f"[ຊັ້ນ {e.get('grade','')}, ບົດ {e.get('chapter','')}, "
f"ຫົວຂໍ້ {e.get('section','')}{e.get('title','')}]\n{e['text']}"
for e in chosen
)
scored = []
for e in ENTRIES:
text = e.get("text", "")
title = e.get("title", "")
kws = e.get("keywords", [])
topic = e.get("topic", "")
base = (text + " " + title).lower()
score = 0
for t in terms:
score += base.count(t)
for kw in kws:
kw_lower = kw.lower()
for t in terms:
if t in kw_lower:
score += 2
if topic and any(t in topic for t in terms):
score += 1
if score > 0:
scored.append((score, e))
scored.sort(key=lambda x: x[0], reverse=True)
top_entries = [e for _, e in scored[:max_entries]]
if not top_entries:
top_entries = ENTRIES[:max_entries]
blocks = []
for e in top_entries:
header = (
f"[ຊັ້ນ {e.get('grade','')}, "
f"ບົດ {e.get('chapter','')}, "
f"ຫົວຂໍ້ {e.get('section','')}{e.get('title','')}]"
)
blocks.append(f"{header}\n{e.get('text','')}")
return "\n\n".join(blocks)
def build_prompt(question: str) -> str:
context = retrieve_context(question)
return f"""{SYSTEM_PROMPT}
ຂໍ້ມູນອ້າງອີງ:
{context}
ຄໍາຖາມ: {question}
ຄໍາຕອບດ້ວຍພາສາລາວ:"""
def generate_answer_with_model(question: str) -> str:
prompt = build_prompt(question)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=160,
do_sample=False, # greedy → stable, a bit faster
)
generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
answer = tokenizer.decode(generated_ids, skip_special_tokens=True)
return answer.strip()
def answer_question(question: str) -> str:
if not question.strip():
return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."
# 1) try manual + dataset QA first (instant, no model)
direct = answer_from_qa(question)
if direct:
return direct
# 2) fall back to model + RAG
try:
return generate_answer_with_model(question)
except Exception as e:
return f"ລະບົບມີບັນຫາ: {e}"