|
|
|
|
|
|
|
|
import os |
|
|
import re |
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
torch_dtype=torch.float32, |
|
|
) |
|
|
|
|
|
|
|
|
DATA_PATH = "data/laos_history.txt" |
|
|
if os.path.exists(DATA_PATH): |
|
|
with open(DATA_PATH, "r", encoding="utf-8") as f: |
|
|
RAW_KNOWLEDGE = f.read() |
|
|
else: |
|
|
RAW_KNOWLEDGE = "ຍັງບໍ່ມີຂໍ້ມູນປະຫວັດສາດຖືກໂຫຼດ." |
|
|
|
|
|
|
|
|
PARAGRAPHS = [p.strip() for p in RAW_KNOWLEDGE.split("\n\n") if p.strip()] |
|
|
|
|
|
|
|
|
def retrieve_context(question: str, max_paragraphs: int = 5) -> str: |
|
|
""" |
|
|
VERY simple keyword-based retrieval over PARAGRAPHS. |
|
|
Good enough for first prototype; later you can replace with embeddings. |
|
|
""" |
|
|
if not PARAGRAPHS: |
|
|
return RAW_KNOWLEDGE |
|
|
|
|
|
|
|
|
terms = [w for w in re.split(r"\s+", question.lower()) if len(w) > 2] |
|
|
if not terms: |
|
|
return "\n\n".join(PARAGRAPHS[:max_paragraphs]) |
|
|
|
|
|
scored = [] |
|
|
for p in PARAGRAPHS: |
|
|
p_lower = p.lower() |
|
|
score = sum(p_lower.count(t) for t in terms) |
|
|
if score > 0: |
|
|
scored.append((score, p)) |
|
|
|
|
|
scored.sort(key=lambda x: x[0], reverse=True) |
|
|
top = [p for _, p in scored[:max_paragraphs]] |
|
|
if not top: |
|
|
top = PARAGRAPHS[:max_paragraphs] |
|
|
|
|
|
return "\n\n".join(top) |
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = ( |
|
|
"ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ. " |
|
|
"ຕອບແຕ່ພາສາລາວ, ອະທິບາຍໃຫ້ເຂົ້າໃຈງ່າຍ ແລະສັ້ນກະທັດຮັດ. " |
|
|
"ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. " |
|
|
"ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ." |
|
|
) |
|
|
|
|
|
|
|
|
def build_prompt(question: str) -> str: |
|
|
context = retrieve_context(question) |
|
|
prompt = f"""{SYSTEM_PROMPT} |
|
|
|
|
|
ຂໍ້ມູນອ້າງອີງ: |
|
|
{context} |
|
|
|
|
|
ຄຳຖາມ: {question} |
|
|
|
|
|
ຄຳຕອບດ້ວຍພາສາລາວ:""" |
|
|
return prompt |
|
|
|
|
|
|
|
|
def generate_answer(question: str) -> str: |
|
|
prompt = build_prompt(question) |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
) |
|
|
|
|
|
|
|
|
generated_ids = outputs[0][inputs["input_ids"].shape[1]:] |
|
|
answer = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
|
|
|
return answer.strip() |
|
|
|
|
|
|
|
|
|
|
|
def laos_history_bot(message: str, history: list): |
|
|
""" |
|
|
Gradio ChatInterface expects (message, history) and returns a string. |
|
|
We ignore history for now (you can later use it in the prompt). |
|
|
""" |
|
|
if not message.strip(): |
|
|
return "ກະລຸນາພິມຄຳຖາມກ່ອນ." |
|
|
|
|
|
try: |
|
|
answer = generate_answer(message) |
|
|
except Exception as e: |
|
|
|
|
|
return f"ລະບົບມີບັນຫາ: {e}" |
|
|
|
|
|
return answer |
|
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=laos_history_bot, |
|
|
title="Laos History Chatbot (Lao language)", |
|
|
description="ຖາມຂໍ້ມູນກ່ຽວກັບປະຫວັດສາດຂອງປະເທດລາວ", |
|
|
examples=[ |
|
|
"ອານາຈັກລ້ານຊ້າງເກີດຂຶ້ນໃນປີໃດ?", |
|
|
"ເມືອງຫວຽງຈັນເຄີຍເປັນນະຄອນຫຼວງຂອງອານາຈັກໃດ?", |
|
|
], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|