chatbot_nihe / src /model /rag_runner.py
Auto Deploy Script
Auto deploy from local machine
f9b0dca
import os
import sys
import io
# Fix encoding issues for Windows console
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8')
if hasattr(sys.stderr, 'reconfigure'):
sys.stderr.reconfigure(encoding='utf-8')
import faiss
import pickle
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# Paths
# Paths
# Determine project root dynamically
current_dir = os.path.dirname(os.path.abspath(__file__))
# src/model -> src -> root
PROJECT_ROOT = os.path.abspath(os.path.join(current_dir, "../../"))
INDEX_DIR = os.path.join(PROJECT_ROOT, "data/index")
MODEL_CONFIG_PATH = os.path.join(PROJECT_ROOT, "config/model.yaml")
class RAGRunner:
def __init__(self, device=None, load_llm=True):
# Only enforce offline if explicitly set (Default to Online if not specified,
# but our local scripts will set it to 1)
if os.environ.get("FORCE_OFFLINE") == "1":
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
print(f"Using device: {self.device}")
self.tokenizer = None
self.model = None
self.embedder = None
self.index = None
self.index = None
self.metadata = []
self.model_load_error = None
self.load_resources(load_llm)
def load_resources(self, load_llm=True):
print("[DEBUG] Step 1: Loading embedder...", flush=True)
self.embedder = SentenceTransformer('keepitreal/vietnamese-sbert', device=self.device)
print("[DEBUG] Step 2: Embedder loaded. Loading index...", flush=True)
index_path = os.path.join(INDEX_DIR, "nihe_faiss.index")
metadata_path = os.path.join(INDEX_DIR, "metadata.pkl")
if os.path.exists(index_path):
self.index = faiss.read_index(index_path)
with open(metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
print(f"[DEBUG] Step 3: Index loaded with {self.index.ntotal} vectors.", flush=True)
else:
print("WARNING: Index not found!", flush=True)
if not load_llm:
print("[DEBUG] Skipping LLM load as requested.", flush=True)
return
print(f"[DEBUG] Step 4: Loading LLM on {self.device}...", flush=True)
model_name = "AITeamVN/Vi-Qwen2-1.5B-RAG"
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.device == "cuda":
# Use 4-bit quantization for efficiency on GPU
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quant_config,
device_map="auto",
trust_remote_code=True
)
else:
# Standard CPU load
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu",
torch_dtype=torch.float32,
trust_remote_code=True
)
except Exception as e:
import traceback
traceback.print_exc()
self.model_load_error = f"Model load failed: {str(e)}"
print(f"CRITICAL ERROR loading LLM: {e}")
self.model = None
self.tokenizer = None
def retrieve(self, query, top_k=8):
"""Retrieve relevant chunks with hybrid semantic-keyword logic."""
if not self.index: return []
# 1. Semantic Search (Broad pool)
# Increase candidate pool for larger dataset
candidate_pool_size = min(100, self.index.ntotal)
query_vec = self.embedder.encode([query]).astype('float32')
faiss.normalize_L2(query_vec)
distances, indices = self.index.search(query_vec, candidate_pool_size)
# 2. Key Term Detection (Expanded for YTDP Data)
critical_keywords = [
"chó cắn", "mèo cắn", "dại", "rắn cắn", "sơ cứu", "cấp cứu", "xử lý", # Emergency
"sốt xuất huyết", "sởi", "tay chân miệng", "cúm", "covid", "bạch hầu", # Diseases
"tiêm chủng", "vắc xin", "lịch tiêm", "giá tiêm" # Vaccination
]
found_keywords = [k for k in critical_keywords if k in query.lower()]
# 3. Build Candidate Pool (Combine Semantic + Keyword Force-Match)
candidate_indices = list(indices[0])
# Force-match: find all chunks matching critical keywords
if found_keywords:
for i, chunk in enumerate(self.metadata):
if any(k in chunk['text'].lower() for k in found_keywords):
if i not in candidate_indices:
candidate_indices.append(i)
candidates = []
for idx in candidate_indices:
if idx == -1 or idx >= len(self.metadata): continue
chunk = self.metadata[idx]
# Re-calculate semantic distance if it wasn't in original top_k search
# (In a real system we'd use index.reconstruct or a different approach,
# but for 300 chunks we can just use 0.0 as base for new ones)
base_score = 0.0
if idx in indices[0]:
matches = np.where(indices[0] == idx)[0]
if len(matches) > 0:
base_score = float(distances[0][matches[0]])
# Boost logic
boost = 0
text_lower = chunk['text'].lower()
title_lower = chunk.get('title', '').lower()
for kw in found_keywords:
if kw in text_lower: boost += 0.4 # Higher boost
if kw in title_lower: boost += 0.5 # Title match is very strong
# Bonus for "Guideline/Manual" content during emergency queries
if found_keywords and ("hướng dẫn" in text_lower or "sơ cứu" in text_lower):
boost += 0.5
candidates.append({
'text': chunk['text'],
'score': base_score + boost,
'source': chunk['url'],
'id': chunk['id']
})
# Sort by boosted score
candidates.sort(key=lambda x: x['score'], reverse=True)
# Diversity Filter: Limit max 2 chunks per URL
final_results = []
url_counts = {}
for cand in candidates:
url = cand['source']
count = url_counts.get(url, 0)
if count < 2:
final_results.append(cand)
url_counts[url] = count + 1
if len(final_results) >= top_k:
break
return final_results
def generate(self, query, context_chunks):
"""Generate answer with RAG, maintaining subject context."""
if not self.model or not self.tokenizer:
error_msg = self.model_load_error if self.model_load_error else "Lỗi: Model hoặc Tokenizer chưa được load (Unknown reason)."
return f"HỆ THỐNG ĐANG GẶP SỰ CỐ: {error_msg}"
context_text = "\n\n".join([f"TÀI LIỆU {i+1}:\n{c['text']}" for i, c in enumerate(context_chunks)])
# Refined prompt for better subject handling & Synthesis
system_prompt = (
"Bạn là chuyên gia tư vấn y tế thông thái của Viện Vệ sinh Dịch tễ Trung ương (NIHE). "
"Nhiệm vụ của bạn là tổng hợp thông tin từ các tài liệu được cung cấp để trả lời người dùng một cách chính xác và đầy đủ nhất.\n"
"QUY TẮC CỐ ĐỊNH:\n"
"1. TỔNG HỢP THÔNG TIN: Hãy đọc kỹ TẤT CẢ các tài liệu. Nếu tài liệu 1 nói về nguyên nhân, tài liệu 2 nói về triệu chứng, hãy gộp lại thành câu trả lời hoàn chỉnh.\n"
"2. CHỈ sử dụng thông tin trong tài liệu. Không bịa đặt.\n"
"3. Chú ý ĐÚNG ĐỐI TƯỢNG: Nếu người hỏi hỏi cho 'con tôi', 'người già', hãy trả lời phù hợp với đối tượng đó.\n"
"4. Nếu thông tin không có trong tài liệu, hãy hướng dẫn họ liên hệ Hotline hoặc đến cơ sở y tế gần nhất.\n"
"5. Câu trả lời cần: Ngắn gọn, Súc tích, Dễ hiểu (dùng gạch đầu dòng)."
)
user_prompt = f"""DỰA VÀO CÁC TÀI LIỆU SAU:
{context_text}
CÂU HỎI CỦA NGƯỜI DÙNG: {query}
HƯỚNG DẪN TRẢ LỜI:"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
input_text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.1, # Lower temperature for better precision
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
return response.strip()
def run(self, query):
try:
print(f"Query: {query}", flush=True)
# Search
results = self.retrieve(query)
print(f"Retrieved {len(results)} chunks.", flush=True)
# Debug: check scores and snippets
for i, res in enumerate(results):
print(f" [{i}] Score: {res['score']:.4f} | Source: {res['source']}", flush=True)
print(f" Snippet: {res['text'][:150]}...", flush=True)
# Generate
answer = self.generate(query, results)
return answer
except Exception as e:
import traceback
print(f"CRITICAL ERROR in RAGRunner.run: {e}", flush=True)
traceback.print_exc()
raise e
if __name__ == "__main__":
runner = RAGRunner()
while True:
q = input("\nBạn: ")
if q.lower() in ['exit', 'quit']: break
ans = runner.run(q)
print(f"Bot: {ans}")