import os import sys import io # Fix encoding issues for Windows console if hasattr(sys.stdout, 'reconfigure'): sys.stdout.reconfigure(encoding='utf-8') if hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(encoding='utf-8') import faiss import pickle import torch import numpy as np from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # Paths # Paths # Determine project root dynamically current_dir = os.path.dirname(os.path.abspath(__file__)) # src/model -> src -> root PROJECT_ROOT = os.path.abspath(os.path.join(current_dir, "../../")) INDEX_DIR = os.path.join(PROJECT_ROOT, "data/index") MODEL_CONFIG_PATH = os.path.join(PROJECT_ROOT, "config/model.yaml") class RAGRunner: def __init__(self, device=None, load_llm=True): # Only enforce offline if explicitly set (Default to Online if not specified, # but our local scripts will set it to 1) if os.environ.get("FORCE_OFFLINE") == "1": os.environ["HF_HUB_OFFLINE"] = "1" os.environ["TRANSFORMERS_OFFLINE"] = "1" if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device print(f"Using device: {self.device}") self.tokenizer = None self.model = None self.embedder = None self.index = None self.index = None self.metadata = [] self.model_load_error = None self.load_resources(load_llm) def load_resources(self, load_llm=True): print("[DEBUG] Step 1: Loading embedder...", flush=True) self.embedder = SentenceTransformer('keepitreal/vietnamese-sbert', device=self.device) print("[DEBUG] Step 2: Embedder loaded. Loading index...", flush=True) index_path = os.path.join(INDEX_DIR, "nihe_faiss.index") metadata_path = os.path.join(INDEX_DIR, "metadata.pkl") if os.path.exists(index_path): self.index = faiss.read_index(index_path) with open(metadata_path, 'rb') as f: self.metadata = pickle.load(f) print(f"[DEBUG] Step 3: Index loaded with {self.index.ntotal} vectors.", flush=True) else: print("WARNING: Index not found!", flush=True) if not load_llm: print("[DEBUG] Skipping LLM load as requested.", flush=True) return print(f"[DEBUG] Step 4: Loading LLM on {self.device}...", flush=True) model_name = "AITeamVN/Vi-Qwen2-1.5B-RAG" try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) if self.device == "cuda": # Use 4-bit quantization for efficiency on GPU quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) self.model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=quant_config, device_map="auto", trust_remote_code=True ) else: # Standard CPU load self.model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", torch_dtype=torch.float32, trust_remote_code=True ) except Exception as e: import traceback traceback.print_exc() self.model_load_error = f"Model load failed: {str(e)}" print(f"CRITICAL ERROR loading LLM: {e}") self.model = None self.tokenizer = None def retrieve(self, query, top_k=8): """Retrieve relevant chunks with hybrid semantic-keyword logic.""" if not self.index: return [] # 1. Semantic Search (Broad pool) # Increase candidate pool for larger dataset candidate_pool_size = min(100, self.index.ntotal) query_vec = self.embedder.encode([query]).astype('float32') faiss.normalize_L2(query_vec) distances, indices = self.index.search(query_vec, candidate_pool_size) # 2. Key Term Detection (Expanded for YTDP Data) critical_keywords = [ "chó cắn", "mèo cắn", "dại", "rắn cắn", "sơ cứu", "cấp cứu", "xử lý", # Emergency "sốt xuất huyết", "sởi", "tay chân miệng", "cúm", "covid", "bạch hầu", # Diseases "tiêm chủng", "vắc xin", "lịch tiêm", "giá tiêm" # Vaccination ] found_keywords = [k for k in critical_keywords if k in query.lower()] # 3. Build Candidate Pool (Combine Semantic + Keyword Force-Match) candidate_indices = list(indices[0]) # Force-match: find all chunks matching critical keywords if found_keywords: for i, chunk in enumerate(self.metadata): if any(k in chunk['text'].lower() for k in found_keywords): if i not in candidate_indices: candidate_indices.append(i) candidates = [] for idx in candidate_indices: if idx == -1 or idx >= len(self.metadata): continue chunk = self.metadata[idx] # Re-calculate semantic distance if it wasn't in original top_k search # (In a real system we'd use index.reconstruct or a different approach, # but for 300 chunks we can just use 0.0 as base for new ones) base_score = 0.0 if idx in indices[0]: matches = np.where(indices[0] == idx)[0] if len(matches) > 0: base_score = float(distances[0][matches[0]]) # Boost logic boost = 0 text_lower = chunk['text'].lower() title_lower = chunk.get('title', '').lower() for kw in found_keywords: if kw in text_lower: boost += 0.4 # Higher boost if kw in title_lower: boost += 0.5 # Title match is very strong # Bonus for "Guideline/Manual" content during emergency queries if found_keywords and ("hướng dẫn" in text_lower or "sơ cứu" in text_lower): boost += 0.5 candidates.append({ 'text': chunk['text'], 'score': base_score + boost, 'source': chunk['url'], 'id': chunk['id'] }) # Sort by boosted score candidates.sort(key=lambda x: x['score'], reverse=True) # Diversity Filter: Limit max 2 chunks per URL final_results = [] url_counts = {} for cand in candidates: url = cand['source'] count = url_counts.get(url, 0) if count < 2: final_results.append(cand) url_counts[url] = count + 1 if len(final_results) >= top_k: break return final_results def generate(self, query, context_chunks): """Generate answer with RAG, maintaining subject context.""" if not self.model or not self.tokenizer: error_msg = self.model_load_error if self.model_load_error else "Lỗi: Model hoặc Tokenizer chưa được load (Unknown reason)." return f"HỆ THỐNG ĐANG GẶP SỰ CỐ: {error_msg}" context_text = "\n\n".join([f"TÀI LIỆU {i+1}:\n{c['text']}" for i, c in enumerate(context_chunks)]) # Refined prompt for better subject handling & Synthesis system_prompt = ( "Bạn là chuyên gia tư vấn y tế thông thái của Viện Vệ sinh Dịch tễ Trung ương (NIHE). " "Nhiệm vụ của bạn là tổng hợp thông tin từ các tài liệu được cung cấp để trả lời người dùng một cách chính xác và đầy đủ nhất.\n" "QUY TẮC CỐ ĐỊNH:\n" "1. TỔNG HỢP THÔNG TIN: Hãy đọc kỹ TẤT CẢ các tài liệu. Nếu tài liệu 1 nói về nguyên nhân, tài liệu 2 nói về triệu chứng, hãy gộp lại thành câu trả lời hoàn chỉnh.\n" "2. CHỈ sử dụng thông tin trong tài liệu. Không bịa đặt.\n" "3. Chú ý ĐÚNG ĐỐI TƯỢNG: Nếu người hỏi hỏi cho 'con tôi', 'người già', hãy trả lời phù hợp với đối tượng đó.\n" "4. Nếu thông tin không có trong tài liệu, hãy hướng dẫn họ liên hệ Hotline hoặc đến cơ sở y tế gần nhất.\n" "5. Câu trả lời cần: Ngắn gọn, Súc tích, Dễ hiểu (dùng gạch đầu dòng)." ) user_prompt = f"""DỰA VÀO CÁC TÀI LIỆU SAU: {context_text} CÂU HỎI CỦA NGƯỜI DÙNG: {query} HƯỚNG DẪN TRẢ LỜI:""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] input_text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=1024, temperature=0.1, # Lower temperature for better precision do_sample=True, top_p=0.9, repetition_penalty=1.1, pad_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True) return response.strip() def run(self, query): try: print(f"Query: {query}", flush=True) # Search results = self.retrieve(query) print(f"Retrieved {len(results)} chunks.", flush=True) # Debug: check scores and snippets for i, res in enumerate(results): print(f" [{i}] Score: {res['score']:.4f} | Source: {res['source']}", flush=True) print(f" Snippet: {res['text'][:150]}...", flush=True) # Generate answer = self.generate(query, results) return answer except Exception as e: import traceback print(f"CRITICAL ERROR in RAGRunner.run: {e}", flush=True) traceback.print_exc() raise e if __name__ == "__main__": runner = RAGRunner() while True: q = input("\nBạn: ") if q.lower() in ['exit', 'quit']: break ans = runner.run(q) print(f"Bot: {ans}")