Heng2004 commited on
Commit
8db4a34
·
verified ·
1 Parent(s): a0daaac

Create rag.py

Browse files
Files changed (1) hide show
  1. rag.py +136 -0
rag.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag.py – retrieval + model generation
2
+
3
+ import re
4
+ from typing import List, Dict
5
+
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+
9
+ from data.loader import ENTRIES, RAW_KNOWLEDGE
10
+ from data.qa_index import answer_from_qa
11
+
12
+ MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B-Chat"
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ MODEL_NAME,
17
+ torch_dtype=torch.float32, # CPU on HF free tier
18
+ )
19
+
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ model.to(device)
22
+
23
+
24
+ SYSTEM_PROMPT = (
25
+ "ທ່ານແມ່ນຜູ້ຊ່ວຍເຫຼືອດ້ານປະຫວັດສາດຂອງປະເທດລາວ "
26
+ "ສໍາລັບນັກຮຽນຊັ້ນ ມ.1. "
27
+ "ຕອບແຕ່ພາສາລາວ ແລະຕອບສັ້ນໆ 2–3 ປະໂຫຍກ ໃຫ້ເຂົ້າໃຈງ່າຍ. "
28
+ "ໃຫ້ອີງຈາກຂໍ້ມູນຂ້າງລຸ່ມນີ້ເທົ່ານັ້ນ. "
29
+ "ຖ້າຂໍ້ມູນບໍ່ພຽງພໍ ຫຼືບໍ່ຊັດເຈນ ໃຫ້ບອກວ່າບໍ່ແນ່ໃຈ."
30
+ )
31
+
32
+
33
+ def retrieve_context(question: str, max_entries: int = 2) -> str:
34
+ """
35
+ Simple keyword matching over ENTRIES (text + title + keywords).
36
+ """
37
+ if not ENTRIES:
38
+ return RAW_KNOWLEDGE
39
+
40
+ q = question.lower().strip()
41
+ terms = [t for t in re.split(r"\s+", q) if len(t) > 1]
42
+
43
+ if not terms:
44
+ chosen = ENTRIES[:max_entries]
45
+ return "\n\n".join(
46
+ f"[ຊັ້ນ {e.get('grade','')}, ບົດ {e.get('chapter','')}, "
47
+ f"ຫົວຂໍ້ {e.get('section','')} – {e.get('title','')}]\n{e['text']}"
48
+ for e in chosen
49
+ )
50
+
51
+ scored = []
52
+
53
+ for e in ENTRIES:
54
+ text = e.get("text", "")
55
+ title = e.get("title", "")
56
+ kws = e.get("keywords", [])
57
+ topic = e.get("topic", "")
58
+
59
+ base = (text + " " + title).lower()
60
+ score = 0
61
+
62
+ for t in terms:
63
+ score += base.count(t)
64
+
65
+ for kw in kws:
66
+ kw_lower = kw.lower()
67
+ for t in terms:
68
+ if t in kw_lower:
69
+ score += 2
70
+
71
+ if topic and any(t in topic for t in terms):
72
+ score += 1
73
+
74
+ if score > 0:
75
+ scored.append((score, e))
76
+
77
+ scored.sort(key=lambda x: x[0], reverse=True)
78
+ top_entries = [e for _, e in scored[:max_entries]]
79
+
80
+ if not top_entries:
81
+ top_entries = ENTRIES[:max_entries]
82
+
83
+ blocks = []
84
+ for e in top_entries:
85
+ header = (
86
+ f"[ຊັ້ນ {e.get('grade','')}, "
87
+ f"ບົດ {e.get('chapter','')}, "
88
+ f"ຫົວຂໍ້ {e.get('section','')} – {e.get('title','')}]"
89
+ )
90
+ blocks.append(f"{header}\n{e.get('text','')}")
91
+
92
+ return "\n\n".join(blocks)
93
+
94
+
95
+ def build_prompt(question: str) -> str:
96
+ context = retrieve_context(question)
97
+ return f"""{SYSTEM_PROMPT}
98
+
99
+ ຂໍ້ມູນອ້າງອີງ:
100
+ {context}
101
+
102
+ ຄໍາຖາມ: {question}
103
+
104
+ ຄໍາຕອບດ້ວຍພາສາລາວ:"""
105
+
106
+
107
+ def generate_answer_with_model(question: str) -> str:
108
+ prompt = build_prompt(question)
109
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
110
+
111
+ with torch.no_grad():
112
+ outputs = model.generate(
113
+ **inputs,
114
+ max_new_tokens=160,
115
+ do_sample=False, # greedy → stable, a bit faster
116
+ )
117
+
118
+ generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
119
+ answer = tokenizer.decode(generated_ids, skip_special_tokens=True)
120
+ return answer.strip()
121
+
122
+
123
+ def answer_question(question: str) -> str:
124
+ if not question.strip():
125
+ return "ກະລຸນາພິມຄໍາຖາມກ່ອນ."
126
+
127
+ # 1) try manual + dataset QA first (instant, no model)
128
+ direct = answer_from_qa(question)
129
+ if direct:
130
+ return direct
131
+
132
+ # 2) fall back to model + RAG
133
+ try:
134
+ return generate_answer_with_model(question)
135
+ except Exception as e:
136
+ return f"ລະບົບມີບັນຫາ: {e}"