Spaces:
Sleeping
Sleeping
File size: 11,244 Bytes
f9b0dca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 | import os
import sys
import io
# Fix encoding issues for Windows console
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8')
if hasattr(sys.stderr, 'reconfigure'):
sys.stderr.reconfigure(encoding='utf-8')
import faiss
import pickle
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# Paths
# Paths
# Determine project root dynamically
current_dir = os.path.dirname(os.path.abspath(__file__))
# src/model -> src -> root
PROJECT_ROOT = os.path.abspath(os.path.join(current_dir, "../../"))
INDEX_DIR = os.path.join(PROJECT_ROOT, "data/index")
MODEL_CONFIG_PATH = os.path.join(PROJECT_ROOT, "config/model.yaml")
class RAGRunner:
def __init__(self, device=None, load_llm=True):
# Only enforce offline if explicitly set (Default to Online if not specified,
# but our local scripts will set it to 1)
if os.environ.get("FORCE_OFFLINE") == "1":
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
print(f"Using device: {self.device}")
self.tokenizer = None
self.model = None
self.embedder = None
self.index = None
self.index = None
self.metadata = []
self.model_load_error = None
self.load_resources(load_llm)
def load_resources(self, load_llm=True):
print("[DEBUG] Step 1: Loading embedder...", flush=True)
self.embedder = SentenceTransformer('keepitreal/vietnamese-sbert', device=self.device)
print("[DEBUG] Step 2: Embedder loaded. Loading index...", flush=True)
index_path = os.path.join(INDEX_DIR, "nihe_faiss.index")
metadata_path = os.path.join(INDEX_DIR, "metadata.pkl")
if os.path.exists(index_path):
self.index = faiss.read_index(index_path)
with open(metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
print(f"[DEBUG] Step 3: Index loaded with {self.index.ntotal} vectors.", flush=True)
else:
print("WARNING: Index not found!", flush=True)
if not load_llm:
print("[DEBUG] Skipping LLM load as requested.", flush=True)
return
print(f"[DEBUG] Step 4: Loading LLM on {self.device}...", flush=True)
model_name = "AITeamVN/Vi-Qwen2-1.5B-RAG"
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.device == "cuda":
# Use 4-bit quantization for efficiency on GPU
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quant_config,
device_map="auto",
trust_remote_code=True
)
else:
# Standard CPU load
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu",
torch_dtype=torch.float32,
trust_remote_code=True
)
except Exception as e:
import traceback
traceback.print_exc()
self.model_load_error = f"Model load failed: {str(e)}"
print(f"CRITICAL ERROR loading LLM: {e}")
self.model = None
self.tokenizer = None
def retrieve(self, query, top_k=8):
"""Retrieve relevant chunks with hybrid semantic-keyword logic."""
if not self.index: return []
# 1. Semantic Search (Broad pool)
# Increase candidate pool for larger dataset
candidate_pool_size = min(100, self.index.ntotal)
query_vec = self.embedder.encode([query]).astype('float32')
faiss.normalize_L2(query_vec)
distances, indices = self.index.search(query_vec, candidate_pool_size)
# 2. Key Term Detection (Expanded for YTDP Data)
critical_keywords = [
"chó cắn", "mèo cắn", "dại", "rắn cắn", "sơ cứu", "cấp cứu", "xử lý", # Emergency
"sốt xuất huyết", "sởi", "tay chân miệng", "cúm", "covid", "bạch hầu", # Diseases
"tiêm chủng", "vắc xin", "lịch tiêm", "giá tiêm" # Vaccination
]
found_keywords = [k for k in critical_keywords if k in query.lower()]
# 3. Build Candidate Pool (Combine Semantic + Keyword Force-Match)
candidate_indices = list(indices[0])
# Force-match: find all chunks matching critical keywords
if found_keywords:
for i, chunk in enumerate(self.metadata):
if any(k in chunk['text'].lower() for k in found_keywords):
if i not in candidate_indices:
candidate_indices.append(i)
candidates = []
for idx in candidate_indices:
if idx == -1 or idx >= len(self.metadata): continue
chunk = self.metadata[idx]
# Re-calculate semantic distance if it wasn't in original top_k search
# (In a real system we'd use index.reconstruct or a different approach,
# but for 300 chunks we can just use 0.0 as base for new ones)
base_score = 0.0
if idx in indices[0]:
matches = np.where(indices[0] == idx)[0]
if len(matches) > 0:
base_score = float(distances[0][matches[0]])
# Boost logic
boost = 0
text_lower = chunk['text'].lower()
title_lower = chunk.get('title', '').lower()
for kw in found_keywords:
if kw in text_lower: boost += 0.4 # Higher boost
if kw in title_lower: boost += 0.5 # Title match is very strong
# Bonus for "Guideline/Manual" content during emergency queries
if found_keywords and ("hướng dẫn" in text_lower or "sơ cứu" in text_lower):
boost += 0.5
candidates.append({
'text': chunk['text'],
'score': base_score + boost,
'source': chunk['url'],
'id': chunk['id']
})
# Sort by boosted score
candidates.sort(key=lambda x: x['score'], reverse=True)
# Diversity Filter: Limit max 2 chunks per URL
final_results = []
url_counts = {}
for cand in candidates:
url = cand['source']
count = url_counts.get(url, 0)
if count < 2:
final_results.append(cand)
url_counts[url] = count + 1
if len(final_results) >= top_k:
break
return final_results
def generate(self, query, context_chunks):
"""Generate answer with RAG, maintaining subject context."""
if not self.model or not self.tokenizer:
error_msg = self.model_load_error if self.model_load_error else "Lỗi: Model hoặc Tokenizer chưa được load (Unknown reason)."
return f"HỆ THỐNG ĐANG GẶP SỰ CỐ: {error_msg}"
context_text = "\n\n".join([f"TÀI LIỆU {i+1}:\n{c['text']}" for i, c in enumerate(context_chunks)])
# Refined prompt for better subject handling & Synthesis
system_prompt = (
"Bạn là chuyên gia tư vấn y tế thông thái của Viện Vệ sinh Dịch tễ Trung ương (NIHE). "
"Nhiệm vụ của bạn là tổng hợp thông tin từ các tài liệu được cung cấp để trả lời người dùng một cách chính xác và đầy đủ nhất.\n"
"QUY TẮC CỐ ĐỊNH:\n"
"1. TỔNG HỢP THÔNG TIN: Hãy đọc kỹ TẤT CẢ các tài liệu. Nếu tài liệu 1 nói về nguyên nhân, tài liệu 2 nói về triệu chứng, hãy gộp lại thành câu trả lời hoàn chỉnh.\n"
"2. CHỈ sử dụng thông tin trong tài liệu. Không bịa đặt.\n"
"3. Chú ý ĐÚNG ĐỐI TƯỢNG: Nếu người hỏi hỏi cho 'con tôi', 'người già', hãy trả lời phù hợp với đối tượng đó.\n"
"4. Nếu thông tin không có trong tài liệu, hãy hướng dẫn họ liên hệ Hotline hoặc đến cơ sở y tế gần nhất.\n"
"5. Câu trả lời cần: Ngắn gọn, Súc tích, Dễ hiểu (dùng gạch đầu dòng)."
)
user_prompt = f"""DỰA VÀO CÁC TÀI LIỆU SAU:
{context_text}
CÂU HỎI CỦA NGƯỜI DÙNG: {query}
HƯỚNG DẪN TRẢ LỜI:"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
input_text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.1, # Lower temperature for better precision
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
return response.strip()
def run(self, query):
try:
print(f"Query: {query}", flush=True)
# Search
results = self.retrieve(query)
print(f"Retrieved {len(results)} chunks.", flush=True)
# Debug: check scores and snippets
for i, res in enumerate(results):
print(f" [{i}] Score: {res['score']:.4f} | Source: {res['source']}", flush=True)
print(f" Snippet: {res['text'][:150]}...", flush=True)
# Generate
answer = self.generate(query, results)
return answer
except Exception as e:
import traceback
print(f"CRITICAL ERROR in RAGRunner.run: {e}", flush=True)
traceback.print_exc()
raise e
if __name__ == "__main__":
runner = RAGRunner()
while True:
q = input("\nBạn: ")
if q.lower() in ['exit', 'quit']: break
ans = runner.run(q)
print(f"Bot: {ans}")
|