Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import io | |
| # Fix encoding issues for Windows console | |
| if hasattr(sys.stdout, 'reconfigure'): | |
| sys.stdout.reconfigure(encoding='utf-8') | |
| if hasattr(sys.stderr, 'reconfigure'): | |
| sys.stderr.reconfigure(encoding='utf-8') | |
| import faiss | |
| import pickle | |
| import torch | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| # Paths | |
| # Paths | |
| # Determine project root dynamically | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| # src/model -> src -> root | |
| PROJECT_ROOT = os.path.abspath(os.path.join(current_dir, "../../")) | |
| INDEX_DIR = os.path.join(PROJECT_ROOT, "data/index") | |
| MODEL_CONFIG_PATH = os.path.join(PROJECT_ROOT, "config/model.yaml") | |
| class RAGRunner: | |
| def __init__(self, device=None, load_llm=True): | |
| # Only enforce offline if explicitly set (Default to Online if not specified, | |
| # but our local scripts will set it to 1) | |
| if os.environ.get("FORCE_OFFLINE") == "1": | |
| os.environ["HF_HUB_OFFLINE"] = "1" | |
| os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
| if device is None: | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| else: | |
| self.device = device | |
| print(f"Using device: {self.device}") | |
| self.tokenizer = None | |
| self.model = None | |
| self.embedder = None | |
| self.index = None | |
| self.index = None | |
| self.metadata = [] | |
| self.model_load_error = None | |
| self.load_resources(load_llm) | |
| def load_resources(self, load_llm=True): | |
| print("[DEBUG] Step 1: Loading embedder...", flush=True) | |
| self.embedder = SentenceTransformer('keepitreal/vietnamese-sbert', device=self.device) | |
| print("[DEBUG] Step 2: Embedder loaded. Loading index...", flush=True) | |
| index_path = os.path.join(INDEX_DIR, "nihe_faiss.index") | |
| metadata_path = os.path.join(INDEX_DIR, "metadata.pkl") | |
| if os.path.exists(index_path): | |
| self.index = faiss.read_index(index_path) | |
| with open(metadata_path, 'rb') as f: | |
| self.metadata = pickle.load(f) | |
| print(f"[DEBUG] Step 3: Index loaded with {self.index.ntotal} vectors.", flush=True) | |
| else: | |
| print("WARNING: Index not found!", flush=True) | |
| if not load_llm: | |
| print("[DEBUG] Skipping LLM load as requested.", flush=True) | |
| return | |
| print(f"[DEBUG] Step 4: Loading LLM on {self.device}...", flush=True) | |
| model_name = "AITeamVN/Vi-Qwen2-1.5B-RAG" | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if self.device == "cuda": | |
| # Use 4-bit quantization for efficiency on GPU | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=quant_config, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| else: | |
| # Standard CPU load | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="cpu", | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| self.model_load_error = f"Model load failed: {str(e)}" | |
| print(f"CRITICAL ERROR loading LLM: {e}") | |
| self.model = None | |
| self.tokenizer = None | |
| def retrieve(self, query, top_k=8): | |
| """Retrieve relevant chunks with hybrid semantic-keyword logic.""" | |
| if not self.index: return [] | |
| # 1. Semantic Search (Broad pool) | |
| # Increase candidate pool for larger dataset | |
| candidate_pool_size = min(100, self.index.ntotal) | |
| query_vec = self.embedder.encode([query]).astype('float32') | |
| faiss.normalize_L2(query_vec) | |
| distances, indices = self.index.search(query_vec, candidate_pool_size) | |
| # 2. Key Term Detection (Expanded for YTDP Data) | |
| critical_keywords = [ | |
| "chó cắn", "mèo cắn", "dại", "rắn cắn", "sơ cứu", "cấp cứu", "xử lý", # Emergency | |
| "sốt xuất huyết", "sởi", "tay chân miệng", "cúm", "covid", "bạch hầu", # Diseases | |
| "tiêm chủng", "vắc xin", "lịch tiêm", "giá tiêm" # Vaccination | |
| ] | |
| found_keywords = [k for k in critical_keywords if k in query.lower()] | |
| # 3. Build Candidate Pool (Combine Semantic + Keyword Force-Match) | |
| candidate_indices = list(indices[0]) | |
| # Force-match: find all chunks matching critical keywords | |
| if found_keywords: | |
| for i, chunk in enumerate(self.metadata): | |
| if any(k in chunk['text'].lower() for k in found_keywords): | |
| if i not in candidate_indices: | |
| candidate_indices.append(i) | |
| candidates = [] | |
| for idx in candidate_indices: | |
| if idx == -1 or idx >= len(self.metadata): continue | |
| chunk = self.metadata[idx] | |
| # Re-calculate semantic distance if it wasn't in original top_k search | |
| # (In a real system we'd use index.reconstruct or a different approach, | |
| # but for 300 chunks we can just use 0.0 as base for new ones) | |
| base_score = 0.0 | |
| if idx in indices[0]: | |
| matches = np.where(indices[0] == idx)[0] | |
| if len(matches) > 0: | |
| base_score = float(distances[0][matches[0]]) | |
| # Boost logic | |
| boost = 0 | |
| text_lower = chunk['text'].lower() | |
| title_lower = chunk.get('title', '').lower() | |
| for kw in found_keywords: | |
| if kw in text_lower: boost += 0.4 # Higher boost | |
| if kw in title_lower: boost += 0.5 # Title match is very strong | |
| # Bonus for "Guideline/Manual" content during emergency queries | |
| if found_keywords and ("hướng dẫn" in text_lower or "sơ cứu" in text_lower): | |
| boost += 0.5 | |
| candidates.append({ | |
| 'text': chunk['text'], | |
| 'score': base_score + boost, | |
| 'source': chunk['url'], | |
| 'id': chunk['id'] | |
| }) | |
| # Sort by boosted score | |
| candidates.sort(key=lambda x: x['score'], reverse=True) | |
| # Diversity Filter: Limit max 2 chunks per URL | |
| final_results = [] | |
| url_counts = {} | |
| for cand in candidates: | |
| url = cand['source'] | |
| count = url_counts.get(url, 0) | |
| if count < 2: | |
| final_results.append(cand) | |
| url_counts[url] = count + 1 | |
| if len(final_results) >= top_k: | |
| break | |
| return final_results | |
| def generate(self, query, context_chunks): | |
| """Generate answer with RAG, maintaining subject context.""" | |
| if not self.model or not self.tokenizer: | |
| error_msg = self.model_load_error if self.model_load_error else "Lỗi: Model hoặc Tokenizer chưa được load (Unknown reason)." | |
| return f"HỆ THỐNG ĐANG GẶP SỰ CỐ: {error_msg}" | |
| context_text = "\n\n".join([f"TÀI LIỆU {i+1}:\n{c['text']}" for i, c in enumerate(context_chunks)]) | |
| # Refined prompt for better subject handling & Synthesis | |
| system_prompt = ( | |
| "Bạn là chuyên gia tư vấn y tế thông thái của Viện Vệ sinh Dịch tễ Trung ương (NIHE). " | |
| "Nhiệm vụ của bạn là tổng hợp thông tin từ các tài liệu được cung cấp để trả lời người dùng một cách chính xác và đầy đủ nhất.\n" | |
| "QUY TẮC CỐ ĐỊNH:\n" | |
| "1. TỔNG HỢP THÔNG TIN: Hãy đọc kỹ TẤT CẢ các tài liệu. Nếu tài liệu 1 nói về nguyên nhân, tài liệu 2 nói về triệu chứng, hãy gộp lại thành câu trả lời hoàn chỉnh.\n" | |
| "2. CHỈ sử dụng thông tin trong tài liệu. Không bịa đặt.\n" | |
| "3. Chú ý ĐÚNG ĐỐI TƯỢNG: Nếu người hỏi hỏi cho 'con tôi', 'người già', hãy trả lời phù hợp với đối tượng đó.\n" | |
| "4. Nếu thông tin không có trong tài liệu, hãy hướng dẫn họ liên hệ Hotline hoặc đến cơ sở y tế gần nhất.\n" | |
| "5. Câu trả lời cần: Ngắn gọn, Súc tích, Dễ hiểu (dùng gạch đầu dòng)." | |
| ) | |
| user_prompt = f"""DỰA VÀO CÁC TÀI LIỆU SAU: | |
| {context_text} | |
| CÂU HỎI CỦA NGƯỜI DÙNG: {query} | |
| HƯỚNG DẪN TRẢ LỜI:""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| input_text = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| temperature=0.1, # Lower temperature for better precision | |
| do_sample=True, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True) | |
| return response.strip() | |
| def run(self, query): | |
| try: | |
| print(f"Query: {query}", flush=True) | |
| # Search | |
| results = self.retrieve(query) | |
| print(f"Retrieved {len(results)} chunks.", flush=True) | |
| # Debug: check scores and snippets | |
| for i, res in enumerate(results): | |
| print(f" [{i}] Score: {res['score']:.4f} | Source: {res['source']}", flush=True) | |
| print(f" Snippet: {res['text'][:150]}...", flush=True) | |
| # Generate | |
| answer = self.generate(query, results) | |
| return answer | |
| except Exception as e: | |
| import traceback | |
| print(f"CRITICAL ERROR in RAGRunner.run: {e}", flush=True) | |
| traceback.print_exc() | |
| raise e | |
| if __name__ == "__main__": | |
| runner = RAGRunner() | |
| while True: | |
| q = input("\nBạn: ") | |
| if q.lower() in ['exit', 'quit']: break | |
| ans = runner.run(q) | |
| print(f"Bot: {ans}") | |