File size: 11,244 Bytes
f9b0dca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import os
import sys
import io

# Fix encoding issues for Windows console
if hasattr(sys.stdout, 'reconfigure'):
    sys.stdout.reconfigure(encoding='utf-8')
if hasattr(sys.stderr, 'reconfigure'):
    sys.stderr.reconfigure(encoding='utf-8')
import faiss
import pickle
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Paths
# Paths
# Determine project root dynamically
current_dir = os.path.dirname(os.path.abspath(__file__))
# src/model -> src -> root
PROJECT_ROOT = os.path.abspath(os.path.join(current_dir, "../../"))
INDEX_DIR = os.path.join(PROJECT_ROOT, "data/index")
MODEL_CONFIG_PATH = os.path.join(PROJECT_ROOT, "config/model.yaml")

class RAGRunner:
    def __init__(self, device=None, load_llm=True):
        # Only enforce offline if explicitly set (Default to Online if not specified, 
        # but our local scripts will set it to 1)
        if os.environ.get("FORCE_OFFLINE") == "1":
            os.environ["HF_HUB_OFFLINE"] = "1"
            os.environ["TRANSFORMERS_OFFLINE"] = "1"
        
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
            
        print(f"Using device: {self.device}")
        self.tokenizer = None
        self.model = None
        self.embedder = None
        self.index = None
        self.index = None
        self.metadata = []
        self.model_load_error = None
        
        self.load_resources(load_llm)
        
    def load_resources(self, load_llm=True):
        print("[DEBUG] Step 1: Loading embedder...", flush=True)
        self.embedder = SentenceTransformer('keepitreal/vietnamese-sbert', device=self.device)
        print("[DEBUG] Step 2: Embedder loaded. Loading index...", flush=True)
        
        index_path = os.path.join(INDEX_DIR, "nihe_faiss.index")
        metadata_path = os.path.join(INDEX_DIR, "metadata.pkl")
        
        if os.path.exists(index_path):
            self.index = faiss.read_index(index_path)
            with open(metadata_path, 'rb') as f:
                self.metadata = pickle.load(f)
            print(f"[DEBUG] Step 3: Index loaded with {self.index.ntotal} vectors.", flush=True)
        else:
            print("WARNING: Index not found!", flush=True)
            
        if not load_llm:
            print("[DEBUG] Skipping LLM load as requested.", flush=True)
            return

        print(f"[DEBUG] Step 4: Loading LLM on {self.device}...", flush=True)
        model_name = "AITeamVN/Vi-Qwen2-1.5B-RAG"
        
        try:
             self.tokenizer = AutoTokenizer.from_pretrained(model_name)
             
             if self.device == "cuda":
                 # Use 4-bit quantization for efficiency on GPU
                 quant_config = BitsAndBytesConfig(
                     load_in_4bit=True,
                     bnb_4bit_compute_dtype=torch.float16,
                     bnb_4bit_quant_type="nf4",
                     bnb_4bit_use_double_quant=True
                 )
                 self.model = AutoModelForCausalLM.from_pretrained(
                     model_name,
                     quantization_config=quant_config,
                     device_map="auto",
                     trust_remote_code=True
                 )
             else:
                 # Standard CPU load
                 self.model = AutoModelForCausalLM.from_pretrained(
                     model_name,
                     device_map="cpu",
                     torch_dtype=torch.float32,
                     trust_remote_code=True
                 )
        except Exception as e:
            import traceback
            traceback.print_exc()
            self.model_load_error = f"Model load failed: {str(e)}"
            print(f"CRITICAL ERROR loading LLM: {e}")
            self.model = None
            self.tokenizer = None
            
    def retrieve(self, query, top_k=8):
        """Retrieve relevant chunks with hybrid semantic-keyword logic."""
        if not self.index: return []
        
        # 1. Semantic Search (Broad pool)
        # Increase candidate pool for larger dataset
        candidate_pool_size = min(100, self.index.ntotal)
        query_vec = self.embedder.encode([query]).astype('float32')
        faiss.normalize_L2(query_vec)
        distances, indices = self.index.search(query_vec, candidate_pool_size)
        
        # 2. Key Term Detection (Expanded for YTDP Data)
        critical_keywords = [
            "chó cắn", "mèo cắn", "dại", "rắn cắn", "sơ cứu", "cấp cứu", "xử lý", # Emergency
            "sốt xuất huyết", "sởi", "tay chân miệng", "cúm", "covid", "bạch hầu", # Diseases
            "tiêm chủng", "vắc xin", "lịch tiêm", "giá tiêm" # Vaccination
        ]
        found_keywords = [k for k in critical_keywords if k in query.lower()]
        
        # 3. Build Candidate Pool (Combine Semantic + Keyword Force-Match)
        candidate_indices = list(indices[0])
        
        # Force-match: find all chunks matching critical keywords
        if found_keywords:
            for i, chunk in enumerate(self.metadata):
                if any(k in chunk['text'].lower() for k in found_keywords):
                    if i not in candidate_indices:
                        candidate_indices.append(i)

        candidates = []
        for idx in candidate_indices:
            if idx == -1 or idx >= len(self.metadata): continue
            chunk = self.metadata[idx]
            
            # Re-calculate semantic distance if it wasn't in original top_k search
            # (In a real system we'd use index.reconstruct or a different approach,
            # but for 300 chunks we can just use 0.0 as base for new ones)
            base_score = 0.0
            if idx in indices[0]:
                matches = np.where(indices[0] == idx)[0]
                if len(matches) > 0:
                    base_score = float(distances[0][matches[0]])
            
            # Boost logic
            boost = 0
            text_lower = chunk['text'].lower()
            title_lower = chunk.get('title', '').lower()
            
            for kw in found_keywords:
                if kw in text_lower: boost += 0.4 # Higher boost
                if kw in title_lower: boost += 0.5 # Title match is very strong
            
            # Bonus for "Guideline/Manual" content during emergency queries
            if found_keywords and ("hướng dẫn" in text_lower or "sơ cứu" in text_lower):
                boost += 0.5

            candidates.append({
                'text': chunk['text'],
                'score': base_score + boost,
                'source': chunk['url'],
                'id': chunk['id']
            })

        # Sort by boosted score
        candidates.sort(key=lambda x: x['score'], reverse=True)
        
        # Diversity Filter: Limit max 2 chunks per URL
        final_results = []
        url_counts = {}
        
        for cand in candidates:
            url = cand['source']
            count = url_counts.get(url, 0)
            if count < 2:
                final_results.append(cand)
                url_counts[url] = count + 1
            
            if len(final_results) >= top_k:
                break
                
        return final_results

    def generate(self, query, context_chunks):
        """Generate answer with RAG, maintaining subject context."""
        if not self.model or not self.tokenizer:
            error_msg = self.model_load_error if self.model_load_error else "Lỗi: Model hoặc Tokenizer chưa được load (Unknown reason)."
            return f"HỆ THỐNG ĐANG GẶP SỰ CỐ: {error_msg}"

        context_text = "\n\n".join([f"TÀI LIỆU {i+1}:\n{c['text']}" for i, c in enumerate(context_chunks)])
        
        # Refined prompt for better subject handling & Synthesis
        system_prompt = (
            "Bạn là chuyên gia tư vấn y tế thông thái của Viện Vệ sinh Dịch tễ Trung ương (NIHE). "
            "Nhiệm vụ của bạn là tổng hợp thông tin từ các tài liệu được cung cấp để trả lời người dùng một cách chính xác và đầy đủ nhất.\n"
            "QUY TẮC CỐ ĐỊNH:\n"
            "1. TỔNG HỢP THÔNG TIN: Hãy đọc kỹ TẤT CẢ các tài liệu. Nếu tài liệu 1 nói về nguyên nhân, tài liệu 2 nói về triệu chứng, hãy gộp lại thành câu trả lời hoàn chỉnh.\n"
            "2. CHỈ sử dụng thông tin trong tài liệu. Không bịa đặt.\n"
            "3. Chú ý ĐÚNG ĐỐI TƯỢNG: Nếu người hỏi hỏi cho 'con tôi', 'người già', hãy trả lời phù hợp với đối tượng đó.\n"
            "4. Nếu thông tin không có trong tài liệu, hãy hướng dẫn họ liên hệ Hotline hoặc đến cơ sở y tế gần nhất.\n"
            "5. Câu trả lời cần: Ngắn gọn, Súc tích, Dễ hiểu (dùng gạch đầu dòng)."
        )
        
        user_prompt = f"""DỰA VÀO CÁC TÀI LIỆU SAU:
{context_text}

CÂU HỎI CỦA NGƯỜI DÙNG: {query}

HƯỚNG DẪN TRẢ LỜI:"""

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
        
        input_text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=1024,
                temperature=0.1, # Lower temperature for better precision
                do_sample=True,
                top_p=0.9,
                repetition_penalty=1.1,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
        return response.strip()

    def run(self, query):
        try:
            print(f"Query: {query}", flush=True)
            # Search
            results = self.retrieve(query)
            print(f"Retrieved {len(results)} chunks.", flush=True)
            
            # Debug: check scores and snippets
            for i, res in enumerate(results):
                 print(f"  [{i}] Score: {res['score']:.4f} | Source: {res['source']}", flush=True)
                 print(f"      Snippet: {res['text'][:150]}...", flush=True)
            
            # Generate
            answer = self.generate(query, results)
            return answer
        except Exception as e:
            import traceback
            print(f"CRITICAL ERROR in RAGRunner.run: {e}", flush=True)
            traceback.print_exc()
            raise e

if __name__ == "__main__":
    runner = RAGRunner()
    while True:
        q = input("\nBạn: ")
        if q.lower() in ['exit', 'quit']: break
        ans = runner.run(q)
        print(f"Bot: {ans}")