huydt11502 commited on
Commit ·
74b76f3
0
Parent(s):
Add RAG integration: Flask API server, disease selector, evaluation system with improved case generation
Browse files- .env.example +15 -0
- .gitignore +22 -0
- api_server.py +262 -0
- data/BoYTe200_v3.json +0 -0
- data/NHIKHOA2.json +0 -0
- data/PHACDODIEUTRI_2016.json +0 -0
- pediatric_finetune_15k_vietnamese.jsonl +0 -0
- requirements.txt +16 -0
- requirements_api.txt +2 -0
- src/__init__.py +0 -0
- src/build_faiss.py +9 -0
- src/config.py +29 -0
- src/data_loader.py +68 -0
- src/doctor_evaluator.py +166 -0
- src/embeddings.py +65 -0
- src/hybrid_retriever.py +39 -0
- src/main.py +166 -0
- src/rag_chain.py +83 -0
- src/vector_store.py +169 -0
- start_server.bat +7 -0
- test/test vector_store.py +12 -0
- test/test_chunks_by_file.py +80 -0
- test/test_data_loader.py +68 -0
- test/test_embeddings_single.py +22 -0
- test/test_faiss_single.py +28 -0
- test/test_hybrid.py +38 -0
- test/test_load_model_ft.py +31 -0
- test/test_rag_single.py +67 -0
- test/test_rq.py +4 -0
- test/test_single_file_loader.py +61 -0
- test/testapi.py +13 -0
- test_api.py +22 -0
- test_debug.py +19 -0
- test_suy_tim.py +23 -0
.env.example
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RAG Project Configuration
|
| 2 |
+
|
| 3 |
+
# Gemini API Key for RAG evaluation
|
| 4 |
+
GOOGLE_API_KEY=your_gemini_api_key_here
|
| 5 |
+
|
| 6 |
+
# Embedding Model
|
| 7 |
+
EMBEDDING_MODEL=VoVanPhuc/sup-SimCSE-VietNamese-phobert-base
|
| 8 |
+
|
| 9 |
+
# Flask API Configuration
|
| 10 |
+
FLASK_PORT=5000
|
| 11 |
+
FLASK_HOST=0.0.0.0
|
| 12 |
+
FLASK_DEBUG=True
|
| 13 |
+
|
| 14 |
+
# FAISS Configuration
|
| 15 |
+
FAISS_INDEX_PATH=faiss_cache/faiss_index
|
.gitignore
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# secrets
|
| 2 |
+
.env
|
| 3 |
+
*.env
|
| 4 |
+
|
| 5 |
+
# python
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.pyc
|
| 8 |
+
.venv/
|
| 9 |
+
venv/
|
| 10 |
+
|
| 11 |
+
# vscode
|
| 12 |
+
.vscode/
|
| 13 |
+
|
| 14 |
+
# data & cache (RAG)
|
| 15 |
+
faiss_cache/
|
| 16 |
+
store/
|
| 17 |
+
wandb/
|
| 18 |
+
|
| 19 |
+
# model files
|
| 20 |
+
*.bin
|
| 21 |
+
*.pt
|
| 22 |
+
*.ckpt
|
api_server.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Flask API Server để tích hợp RAG vào Chatbot
|
| 3 |
+
Endpoints:
|
| 4 |
+
- GET /api/diseases - Lấy danh sách bệnh từ JSON
|
| 5 |
+
- POST /api/start-case - Nhận bệnh, tạo case với triệu chứng
|
| 6 |
+
- POST /api/evaluate - Nhận đáp án user, trả về kết quả so sánh
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from flask import Flask, request, jsonify
|
| 10 |
+
from flask_cors import CORS
|
| 11 |
+
import json
|
| 12 |
+
import sys
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
# Add src to path
|
| 16 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
| 17 |
+
|
| 18 |
+
from data_loader import DataLoader
|
| 19 |
+
from config import Config
|
| 20 |
+
from doctor_evaluator import DoctorEvaluator
|
| 21 |
+
from vector_store import VectorStoreManager
|
| 22 |
+
from rag_chain import RAGChain
|
| 23 |
+
|
| 24 |
+
app = Flask(__name__)
|
| 25 |
+
CORS(app) # Enable CORS for React app
|
| 26 |
+
|
| 27 |
+
# Initialize RAG system
|
| 28 |
+
print("🚀 Initializing RAG system...")
|
| 29 |
+
vs_manager = VectorStoreManager()
|
| 30 |
+
if not vs_manager.vector_store:
|
| 31 |
+
print("❌ FAISS index not found. Run: python build_faiss.py")
|
| 32 |
+
sys.exit(1)
|
| 33 |
+
|
| 34 |
+
rag = RAGChain(vs_manager)
|
| 35 |
+
evaluator = DoctorEvaluator(rag)
|
| 36 |
+
print("✅ RAG system ready!")
|
| 37 |
+
|
| 38 |
+
# Store active sessions
|
| 39 |
+
active_sessions = {}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@app.route('/api/health', methods=['GET'])
|
| 43 |
+
def health_check():
|
| 44 |
+
"""Health check endpoint"""
|
| 45 |
+
return jsonify({
|
| 46 |
+
'status': 'healthy',
|
| 47 |
+
'message': 'RAG API Server is running',
|
| 48 |
+
'embedding_model': Config.EMBEDDING_MODEL
|
| 49 |
+
})
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@app.route('/api/diseases', methods=['GET'])
|
| 53 |
+
def get_diseases():
|
| 54 |
+
"""
|
| 55 |
+
Lấy danh sách bệnh từ 3 file JSON (Index field)
|
| 56 |
+
Returns: { diseases: [{ id, name, category, source }] }
|
| 57 |
+
"""
|
| 58 |
+
try:
|
| 59 |
+
diseases = []
|
| 60 |
+
data_dir = os.path.join(os.path.dirname(__file__), 'data')
|
| 61 |
+
|
| 62 |
+
# Mapping files to categories
|
| 63 |
+
files = [
|
| 64 |
+
('BoYTe200_v3.json', 'procedures'),
|
| 65 |
+
('NHIKHOA2.json', 'pediatrics'),
|
| 66 |
+
('PHACDODIEUTRI_2016.json', 'treatment')
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
for filename, category in files:
|
| 70 |
+
filepath = os.path.join(data_dir, filename)
|
| 71 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 72 |
+
data = json.load(f)
|
| 73 |
+
for item in data:
|
| 74 |
+
diseases.append({
|
| 75 |
+
'id': f"{category}_{item['id']}",
|
| 76 |
+
'name': item['Index'],
|
| 77 |
+
'category': category,
|
| 78 |
+
'source': filename,
|
| 79 |
+
'sections': item.get('level1_items', [])
|
| 80 |
+
})
|
| 81 |
+
|
| 82 |
+
return jsonify({
|
| 83 |
+
'success': True,
|
| 84 |
+
'diseases': diseases,
|
| 85 |
+
'total': len(diseases)
|
| 86 |
+
})
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
return jsonify({
|
| 90 |
+
'success': False,
|
| 91 |
+
'error': str(e)
|
| 92 |
+
}), 500
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@app.route('/api/start-case', methods=['POST'])
|
| 96 |
+
def start_case():
|
| 97 |
+
"""
|
| 98 |
+
Nhận tên bệnh, tìm triệu chứng và tạo case
|
| 99 |
+
Input: { disease: string, sessionId: string }
|
| 100 |
+
Output: { case: string, symptoms: string, sessionId: string }
|
| 101 |
+
"""
|
| 102 |
+
try:
|
| 103 |
+
data = request.json
|
| 104 |
+
disease = data.get('disease', '').strip()
|
| 105 |
+
session_id = data.get('sessionId')
|
| 106 |
+
|
| 107 |
+
if not disease:
|
| 108 |
+
return jsonify({
|
| 109 |
+
'success': False,
|
| 110 |
+
'error': 'Disease name is required'
|
| 111 |
+
}), 400
|
| 112 |
+
|
| 113 |
+
print(f"📋 Starting case for disease: {disease}")
|
| 114 |
+
|
| 115 |
+
# 1. RAG tìm triệu chứng
|
| 116 |
+
print("🔍 Finding symptoms...")
|
| 117 |
+
symptoms, symptom_sources = evaluator.find_symptoms(disease)
|
| 118 |
+
|
| 119 |
+
# 2. Gemini tạo case
|
| 120 |
+
print("✍️ Generating patient case...")
|
| 121 |
+
patient_case = evaluator.generate_case(disease, symptoms)
|
| 122 |
+
|
| 123 |
+
# Store session data
|
| 124 |
+
session_data = {
|
| 125 |
+
'disease': disease,
|
| 126 |
+
'symptoms': symptoms,
|
| 127 |
+
'case': patient_case,
|
| 128 |
+
'symptom_sources': [
|
| 129 |
+
{
|
| 130 |
+
'file': doc.metadata.get('source_file', ''),
|
| 131 |
+
'title': doc.metadata.get('main_title', ''),
|
| 132 |
+
'section': doc.metadata.get('sub_title', '')
|
| 133 |
+
}
|
| 134 |
+
for doc in symptom_sources[:3]
|
| 135 |
+
]
|
| 136 |
+
}
|
| 137 |
+
active_sessions[session_id] = session_data
|
| 138 |
+
|
| 139 |
+
return jsonify({
|
| 140 |
+
'success': True,
|
| 141 |
+
'sessionId': session_id,
|
| 142 |
+
'case': patient_case,
|
| 143 |
+
'symptoms': symptoms[:300] + "...", # Truncate for display
|
| 144 |
+
'sources': session_data['symptom_sources']
|
| 145 |
+
})
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"❌ Error in start_case: {str(e)}")
|
| 149 |
+
return jsonify({
|
| 150 |
+
'success': False,
|
| 151 |
+
'error': str(e)
|
| 152 |
+
}), 500
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@app.route('/api/evaluate', methods=['POST'])
|
| 156 |
+
def evaluate_diagnosis():
|
| 157 |
+
"""
|
| 158 |
+
Nhận câu trả lời user, so sánh với đáp án chuẩn
|
| 159 |
+
Input: {
|
| 160 |
+
sessionId: string,
|
| 161 |
+
diagnosis: {
|
| 162 |
+
clinical: string,
|
| 163 |
+
paraclinical: string,
|
| 164 |
+
definitiveDiagnosis: string,
|
| 165 |
+
differentialDiagnosis: string,
|
| 166 |
+
treatment: string,
|
| 167 |
+
medication: string
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
Output: {
|
| 171 |
+
standardAnswer: { ... },
|
| 172 |
+
evaluation: { ... },
|
| 173 |
+
sources: [ ... ]
|
| 174 |
+
}
|
| 175 |
+
"""
|
| 176 |
+
try:
|
| 177 |
+
data = request.json
|
| 178 |
+
session_id = data.get('sessionId')
|
| 179 |
+
diagnosis = data.get('diagnosis', {})
|
| 180 |
+
|
| 181 |
+
if not session_id or session_id not in active_sessions:
|
| 182 |
+
return jsonify({
|
| 183 |
+
'success': False,
|
| 184 |
+
'error': 'Invalid session ID'
|
| 185 |
+
}), 400
|
| 186 |
+
|
| 187 |
+
session_data = active_sessions[session_id]
|
| 188 |
+
disease = session_data['disease']
|
| 189 |
+
|
| 190 |
+
print(f"📊 Evaluating diagnosis for: {disease}")
|
| 191 |
+
|
| 192 |
+
# Format user's answer
|
| 193 |
+
user_answer = f"""
|
| 194 |
+
CHẨN ĐOÁN:
|
| 195 |
+
- Lâm sàng: {diagnosis.get('clinical', 'Không có')}
|
| 196 |
+
- Cận lâm sàng: {diagnosis.get('paraclinical', 'Không có')}
|
| 197 |
+
- Chẩn đoán xác định: {diagnosis.get('definitiveDiagnosis', 'Không có')}
|
| 198 |
+
- Chẩn đoán phân biệt: {diagnosis.get('differentialDiagnosis', 'Không có')}
|
| 199 |
+
|
| 200 |
+
KẾ HOẠCH ĐIỀU TRỊ:
|
| 201 |
+
- Cách điều trị: {diagnosis.get('treatment', 'Không có')}
|
| 202 |
+
- Thuốc: {diagnosis.get('medication', 'Không có')}
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
print("🔍 Finding standard answer...")
|
| 206 |
+
# Get standard answer from RAG
|
| 207 |
+
standard_data, all_sources = evaluator.get_detailed_standard_knowledge(disease)
|
| 208 |
+
|
| 209 |
+
print("🤖 Evaluating with Gemini...")
|
| 210 |
+
# Evaluate with Gemini
|
| 211 |
+
evaluation_json = evaluator.detailed_evaluation(user_answer, standard_data)
|
| 212 |
+
|
| 213 |
+
# Parse JSON from evaluation
|
| 214 |
+
try:
|
| 215 |
+
# Extract JSON from markdown code blocks if present
|
| 216 |
+
eval_text = evaluation_json.strip()
|
| 217 |
+
if eval_text.startswith('```'):
|
| 218 |
+
eval_text = eval_text.split('```')[1]
|
| 219 |
+
if eval_text.startswith('json'):
|
| 220 |
+
eval_text = eval_text[4:]
|
| 221 |
+
evaluation_obj = json.loads(eval_text.strip())
|
| 222 |
+
except:
|
| 223 |
+
# If parsing fails, return as text
|
| 224 |
+
evaluation_obj = {
|
| 225 |
+
'evaluation_text': evaluation_json,
|
| 226 |
+
'diem_so': 'N/A'
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
# Format sources
|
| 230 |
+
formatted_sources = [
|
| 231 |
+
{
|
| 232 |
+
'file': doc.metadata.get('source_file', ''),
|
| 233 |
+
'title': doc.metadata.get('main_title', ''),
|
| 234 |
+
'section': doc.metadata.get('sub_title', ''),
|
| 235 |
+
'content': doc.page_content[:200] + "..."
|
| 236 |
+
}
|
| 237 |
+
for doc in all_sources[:5]
|
| 238 |
+
]
|
| 239 |
+
|
| 240 |
+
return jsonify({
|
| 241 |
+
'success': True,
|
| 242 |
+
'case': session_data['case'],
|
| 243 |
+
'standardAnswer': standard_data,
|
| 244 |
+
'evaluation': evaluation_obj,
|
| 245 |
+
'sources': formatted_sources
|
| 246 |
+
})
|
| 247 |
+
|
| 248 |
+
except Exception as e:
|
| 249 |
+
print(f"❌ Error in evaluate: {str(e)}")
|
| 250 |
+
import traceback
|
| 251 |
+
traceback.print_exc()
|
| 252 |
+
return jsonify({
|
| 253 |
+
'success': False,
|
| 254 |
+
'error': str(e)
|
| 255 |
+
}), 500
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == '__main__':
|
| 259 |
+
print("🌟 Starting Flask API Server...")
|
| 260 |
+
print(f"📡 Server will run on http://localhost:5000")
|
| 261 |
+
print(f"🔑 Using API Key: {Config.GOOGLE_API_KEY[:20]}...")
|
| 262 |
+
app.run(debug=True, host='0.0.0.0', port=5000)
|
data/BoYTe200_v3.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/NHIKHOA2.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/PHACDODIEUTRI_2016.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pediatric_finetune_15k_vietnamese.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
langchain-core>=0.3.17,<0.4.0
|
| 2 |
+
langchain-community>=0.3.7,<0.4.0
|
| 3 |
+
langchain-huggingface
|
| 4 |
+
langchain-google-genai
|
| 5 |
+
|
| 6 |
+
faiss-cpu
|
| 7 |
+
sentence-transformers
|
| 8 |
+
google-generativeai
|
| 9 |
+
python-dotenv
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
unsloth
|
| 13 |
+
transformers
|
| 14 |
+
torch
|
| 15 |
+
accelerate
|
| 16 |
+
bitsandbytes
|
requirements_api.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask==3.0.0
|
| 2 |
+
flask-cors==4.0.0
|
src/__init__.py
ADDED
|
File without changes
|
src/build_faiss.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from vector_store import VectorStoreManager
|
| 2 |
+
from data_loader import DataLoader
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
print(" BUILD FAISS CACHE (2p)")
|
| 6 |
+
all_docs = DataLoader.load_all_chunks()
|
| 7 |
+
vs = VectorStoreManager()
|
| 8 |
+
vs.build_and_cache(all_docs) # Embed + SAVE
|
| 9 |
+
print(" DONE! Cache ready!")
|
src/config.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
load_dotenv()
|
| 5 |
+
|
| 6 |
+
class Config:
|
| 7 |
+
# SỬA ĐƯỜNG DẪN LẠI NHA
|
| 8 |
+
BASE_DIR = r"D:\Storage\rag_project"
|
| 9 |
+
DATA_DIR = f"{BASE_DIR}/data"
|
| 10 |
+
CHUNK_FILES = [
|
| 11 |
+
f"{DATA_DIR}/BoYTe200_v3.json",
|
| 12 |
+
f"{DATA_DIR}/NHIKHOA2.json",
|
| 13 |
+
f"{DATA_DIR}/PHACDODIEUTRI_2016.json"
|
| 14 |
+
]
|
| 15 |
+
# EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # Model embedding
|
| 16 |
+
# EMBEDDING_MODEL = "bkai-foundation-models/vietnamese-bi-encoder"
|
| 17 |
+
EMBEDDING_MODEL = "VoVanPhuc/sup-SimCSE-VietNamese-phobert-base"
|
| 18 |
+
|
| 19 |
+
# GOOGLE API KEY - Thay bằng key của bạn từ https://makersuite.google.com/app/apikey
|
| 20 |
+
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY', 'YOUR_API_KEY_HERE')
|
| 21 |
+
|
| 22 |
+
LLM_MODEL = "gemini-2.5-flash"
|
| 23 |
+
K_RETRIEVE = 3 # Số Document muốn truy
|
| 24 |
+
TEMPERATURE = 0
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
AIzaSyABvC8mPrwa0Kgy08mFFzkyeh2_N-Bb3lY
|
| 28 |
+
AIzaSyDJqr4nKDrcfmmuKOdDCHkXRvKA48htD6o
|
| 29 |
+
"""
|
src/data_loader.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List
|
| 4 |
+
from langchain_core.documents import Document
|
| 5 |
+
from config import Config
|
| 6 |
+
|
| 7 |
+
class DataLoader:
|
| 8 |
+
@staticmethod
|
| 9 |
+
def load_all_chunks() -> List[Document]:
|
| 10 |
+
BASE_DIR = Path(Config.BASE_DIR)
|
| 11 |
+
DATA_DIR = BASE_DIR / 'data'
|
| 12 |
+
all_docs = []
|
| 13 |
+
|
| 14 |
+
json_files = [Path(f) for f in Config.CHUNK_FILES if Path(f).exists()]
|
| 15 |
+
print("LOADING CHUNKS TỪ TẤT CẢ JSON:")
|
| 16 |
+
|
| 17 |
+
for json_file in json_files:
|
| 18 |
+
print(f"{json_file.name}...")
|
| 19 |
+
with open(json_file, 'r', encoding='utf-8') as f:
|
| 20 |
+
chapters = json.load(f)
|
| 21 |
+
|
| 22 |
+
# MỖI chap = 1 chunk
|
| 23 |
+
"""
|
| 24 |
+
# Gom lại theo 1 FILE JSON - CÓ 1 SAMPLE - 1 CHUNK
|
| 25 |
+
# file_chunks chứa all chunks của json đó
|
| 26 |
+
# Các chunk được tổ chức lại theo DOCUMENT (page_content + metadata)
|
| 27 |
+
# Với page_content là từng content trong contents của section
|
| 28 |
+
"""
|
| 29 |
+
file_chunks = []
|
| 30 |
+
for chap in chapters:
|
| 31 |
+
chunk = {
|
| 32 |
+
"chunk_id": chap.get("id"),
|
| 33 |
+
"title": chap.get("Index"),
|
| 34 |
+
"level1_items": chap.get("level1_items", []),
|
| 35 |
+
"contents": chap.get("contents", [])
|
| 36 |
+
}
|
| 37 |
+
# Mặc dù là content khác nhau nhưng vẫn thuộc cùng chunk_id, chunk_title
|
| 38 |
+
for i, section in enumerate(chunk["contents"]):
|
| 39 |
+
doc = Document(
|
| 40 |
+
page_content=section.get("content", ""),
|
| 41 |
+
metadata={
|
| 42 |
+
"source_file": json_file.name,
|
| 43 |
+
"chunk_id": str(chunk["chunk_id"]),
|
| 44 |
+
"chunk_title": chunk["title"],
|
| 45 |
+
"section_id": f"{chunk['chunk_id']}.{i+1}",
|
| 46 |
+
"section_title": section.get("title", "")
|
| 47 |
+
}
|
| 48 |
+
)
|
| 49 |
+
file_chunks.append(doc)
|
| 50 |
+
|
| 51 |
+
all_docs.extend(file_chunks)
|
| 52 |
+
print(f"{len(file_chunks)} docs từ {len(chapters)} chunks")
|
| 53 |
+
|
| 54 |
+
print(f"\nTỔNG {len(all_docs)} documents!")
|
| 55 |
+
return all_docs
|
| 56 |
+
|
| 57 |
+
# TEST: python data_loader.py
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
docs = DataLoader.load_all_chunks()
|
| 60 |
+
print("\n SAMPLE DOC:")
|
| 61 |
+
print("Content:", docs[0].page_content)
|
| 62 |
+
print("Metadata:", docs[0].metadata)
|
| 63 |
+
|
| 64 |
+
"""
|
| 65 |
+
OUTPUT SAMPLE
|
| 66 |
+
Content: Đặt stent khí phế quản là kỹ thuật đặt một giá đỡ vào khí, phế quản làm rộng và duy trì khẩu kính đường thở để điều trị một số trường hợp hẹp khí, phế quản bẩm sinh hoặc mắc phải. Đặt stent có thể th
|
| 67 |
+
Metadata: {'source_file': 'BoYTe200_v3.json', 'chunk_id': '1', 'chunk_title': 'NỘI SOI ĐẶT STENT KHÍ PHẾ QUẢN BẰNG ỐNG CỨNG', 'section_id': '1.1', 'section_title': 'ĐẠI CƯƠNG'}
|
| 68 |
+
"""
|
src/doctor_evaluator.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rag_chain import RAGChain
|
| 2 |
+
from vector_store import VectorStoreManager
|
| 3 |
+
from data_loader import DataLoader
|
| 4 |
+
from config import Config
|
| 5 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 6 |
+
|
| 7 |
+
class DoctorEvaluator:
|
| 8 |
+
def __init__(self, rag):
|
| 9 |
+
self.rag = rag
|
| 10 |
+
self.evaluator_llm = ChatGoogleGenerativeAI(
|
| 11 |
+
model="gemini-2.5-flash",
|
| 12 |
+
google_api_key=Config.GOOGLE_API_KEY,
|
| 13 |
+
temperature=0.1
|
| 14 |
+
)
|
| 15 |
+
print("DoctorEvaluator: Ready (Gemini + RAG)!")
|
| 16 |
+
|
| 17 |
+
def generate_case(self, disease: str, symptoms: str):
|
| 18 |
+
"""DÙNG GEMINI TẠO CASE - NHANH + ỔN ĐỊNH"""
|
| 19 |
+
prompt = f"""
|
| 20 |
+
Bạn là bác sĩ nhi khoa. Tạo một ca bệnh THỰC TẾ cho bệnh: {disease}
|
| 21 |
+
|
| 22 |
+
TRIỆU CHỨNG TỪ TÀI LIỆU:
|
| 23 |
+
{symptoms}
|
| 24 |
+
|
| 25 |
+
YÊU CẦU:
|
| 26 |
+
1. Chỉ tạo lời thoại của mẹ bệnh nhân (3-4 câu)
|
| 27 |
+
2. PHẢI MÔ TẢ các triệu chứng CỤ THỂ của bệnh {disease} từ tài liệu trên
|
| 28 |
+
3. Dùng ngôn ngữ đời thường, tự nhiên
|
| 29 |
+
4. Format: "Bé [tên] nhà chị [tên mẹ] bữa nay bị [triệu chứng cụ thể]. Chị lo lắm! [thêm chi tiết triệu chứng]."
|
| 30 |
+
|
| 31 |
+
VÍ DỤ TỐT:
|
| 32 |
+
- Bệnh Viêm phổi → "Bé An bị sốt cao 39 độ, ho có đờm, thở nhanh phì phò"
|
| 33 |
+
- Bệnh Suy tim → "Bé Minh thở nhanh, mệt lả, bú kém, chân tay lạnh"
|
| 34 |
+
|
| 35 |
+
CASE BỆNH:
|
| 36 |
+
"""
|
| 37 |
+
result = self.evaluator_llm.invoke([prompt])
|
| 38 |
+
return result.content.strip()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def evaluate_doctor(self, disease: str):
|
| 42 |
+
print(f"\n ĐÁNH GIÁ: {disease}")
|
| 43 |
+
print("=" * 80)
|
| 44 |
+
|
| 45 |
+
# 1. RAG tìm TRIỆU CHỨNG
|
| 46 |
+
print("Hệ thống đang TRUY TÌM TRIỆU CHỨNG:")
|
| 47 |
+
symptoms, symptom_sources = self.find_symptoms(disease)
|
| 48 |
+
print(f"Xác định triệu chứng: {symptoms[:100]}...")
|
| 49 |
+
|
| 50 |
+
# 2. GEMINI tạo CASE
|
| 51 |
+
print("Tiến hành tạo case...")
|
| 52 |
+
patient_case = self.generate_case(disease, symptoms)
|
| 53 |
+
print(f"Case hoàn chỉnh:\n{patient_case}")
|
| 54 |
+
|
| 55 |
+
# 3. NHẬP TRẢ LỜI BS
|
| 56 |
+
doctor_answer = input("\n NHẬP CÂU TRẢ LỜI CỦA BÁC SĨ:\n").strip()
|
| 57 |
+
|
| 58 |
+
# 4. RAG chi tiết + Đánh giá (giữ nguyên)
|
| 59 |
+
print("\n TRUY TÌM ĐÁP ÁN CHUẨN:")
|
| 60 |
+
standard_data, all_sources = self.get_detailed_standard_knowledge(disease)
|
| 61 |
+
evaluation = self.detailed_evaluation(doctor_answer, standard_data)
|
| 62 |
+
|
| 63 |
+
return {
|
| 64 |
+
'case': patient_case,
|
| 65 |
+
'standard': standard_data,
|
| 66 |
+
'evaluation': evaluation,
|
| 67 |
+
'sources': all_sources
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
def find_symptoms(self, disease: str):
|
| 71 |
+
"""RAG tìm triệu chứng bệnh - CẢI THIỆN"""
|
| 72 |
+
# Query chi tiết hơn để tìm đúng bệnh
|
| 73 |
+
queries = [
|
| 74 |
+
f"{disease} lâm sàng",
|
| 75 |
+
f"{disease} triệu chứng",
|
| 76 |
+
f"{disease} dấu hiệu"
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
all_symptoms = []
|
| 80 |
+
sources = []
|
| 81 |
+
for q in queries:
|
| 82 |
+
print(f"🔍 Query: {q}")
|
| 83 |
+
answer, src = self.rag.query(q)
|
| 84 |
+
if answer and len(answer.strip()) > 50: # Chỉ lấy answer có nội dung
|
| 85 |
+
all_symptoms.append(answer)
|
| 86 |
+
sources.extend(src)
|
| 87 |
+
|
| 88 |
+
# Gom triệu chứng đầy đủ hơn (không cắt quá ngắn)
|
| 89 |
+
if all_symptoms:
|
| 90 |
+
# Lấy 2 answer tốt nhất, mỗi cái 500 ký tự
|
| 91 |
+
symptoms_summary = "\n\n".join([s[:500] for s in all_symptoms[:2]])
|
| 92 |
+
else:
|
| 93 |
+
symptoms_summary = f"Không tìm thấy thông tin triệu chứng cho {disease}"
|
| 94 |
+
|
| 95 |
+
print(f"📋 Tìm thấy triệu chứng: {symptoms_summary[:200]}...")
|
| 96 |
+
return symptoms_summary, sources
|
| 97 |
+
|
| 98 |
+
def get_detailed_standard_knowledge(self, disease: str):
|
| 99 |
+
"""RAG CHẨN ĐOÁN CHI TIẾT + ĐIỀU TRỊ"""
|
| 100 |
+
queries = {
|
| 101 |
+
'LAM_SANG': [f"{disease} lâm sàng"],
|
| 102 |
+
'CAN_LAM_SANG': [f"{disease} cận lâm sàng"],
|
| 103 |
+
'CHAN_DOAN_XAC_DINH': [f"{disease} chẩn đoán xác định"],
|
| 104 |
+
'CHAN_DOAN_PHAN_BIET': [f"{disease} chẩn đoán phân biệt"],
|
| 105 |
+
'DIEU_TRI': [f"{disease} điều trị", f"{disease} thuốc"]
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
results = {}
|
| 109 |
+
all_sources = []
|
| 110 |
+
|
| 111 |
+
for section, qlist in queries.items():
|
| 112 |
+
print(f" {section}:")
|
| 113 |
+
section_content = []
|
| 114 |
+
for q in qlist:
|
| 115 |
+
print(f" {q}")
|
| 116 |
+
answer, sources = self.rag.query(q)
|
| 117 |
+
section_content.append(answer)
|
| 118 |
+
all_sources.extend(sources)
|
| 119 |
+
results[section] = "\n".join(section_content[:2])
|
| 120 |
+
|
| 121 |
+
# Format đẹp
|
| 122 |
+
standard_text = f"""
|
| 123 |
+
CHẨN ĐOÁN LÂM SÀNG:
|
| 124 |
+
{results['LAM_SANG']}
|
| 125 |
+
|
| 126 |
+
CHẨN ĐOÁN CẬN LÂM SÀNG:
|
| 127 |
+
{results['CAN_LAM_SANG']}
|
| 128 |
+
|
| 129 |
+
CHẨN ĐOÁN XÁC ĐỊNH:
|
| 130 |
+
{results['CHAN_DOAN_XAC_DINH']}
|
| 131 |
+
|
| 132 |
+
CHẨN ĐOÁN PHÂN BIỆT:
|
| 133 |
+
{results['CHAN_DOAN_PHAN_BIET']}
|
| 134 |
+
|
| 135 |
+
CÁCH ĐIỀU TRỊ:
|
| 136 |
+
{results['DIEU_TRI']}
|
| 137 |
+
"""
|
| 138 |
+
return standard_text, all_sources
|
| 139 |
+
|
| 140 |
+
def detailed_evaluation(self, doctor_answer: str, standard_data: str):
|
| 141 |
+
"""ĐÁNH GIÁ CHI TIẾT + DIỄN GIẢI"""
|
| 142 |
+
prompt = f"""
|
| 143 |
+
BẠN LÀ CHUYÊN GIA Y KHOA ĐÁNH GIÁ BÁC SĨ
|
| 144 |
+
|
| 145 |
+
CÂU TRẢ LỜI BÁC SĨ:
|
| 146 |
+
{doctor_answer}
|
| 147 |
+
|
| 148 |
+
KIẾN THỨC CHUẨN:
|
| 149 |
+
{standard_data}
|
| 150 |
+
|
| 151 |
+
PHÂN TÍCH CHI TIẾT (JSON):
|
| 152 |
+
{{
|
| 153 |
+
"diem_manh": ["..."],
|
| 154 |
+
"diem_yeu": ["..."],
|
| 155 |
+
"da_co": ["..."],
|
| 156 |
+
"thieu": ["..."],
|
| 157 |
+
"dien_giai": ["Giải thích vì sao đúng/thiếu..."],
|
| 158 |
+
"diem_so": "85/100",
|
| 159 |
+
"nhan_xet_tong_quan": "..."
|
| 160 |
+
}}
|
| 161 |
+
|
| 162 |
+
JSON PURE:
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
result = self.evaluator_llm.invoke([prompt])
|
| 166 |
+
return result.content
|
src/embeddings.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 2 |
+
from config import Config
|
| 3 |
+
from numpy import dot
|
| 4 |
+
from numpy.linalg import norm
|
| 5 |
+
|
| 6 |
+
class EmbeddingsManager:
|
| 7 |
+
# Khởi tạo model embedding từ cofig ngay khi gọi class
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.embeddings = HuggingFaceEmbeddings(
|
| 10 |
+
model_name=Config.EMBEDDING_MODEL
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
def get_embeddings(self):
|
| 14 |
+
return self.embeddings
|
| 15 |
+
|
| 16 |
+
# Code thêm cái này xíu hiểu hơn bản chất so sánh vector embed
|
| 17 |
+
def cosine(a, b):
|
| 18 |
+
return dot(a, b) / (norm(a) * norm(b))
|
| 19 |
+
|
| 20 |
+
# TEST: python embeddings.py
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
print("Loading embedding model...")
|
| 23 |
+
em = EmbeddingsManager()
|
| 24 |
+
embeddings = em.get_embeddings()
|
| 25 |
+
|
| 26 |
+
text = "Sốt ở trẻ em là tình trạng thân nhiệt tăng."
|
| 27 |
+
vec = embeddings.embed_query(text)
|
| 28 |
+
|
| 29 |
+
print("Vector length:", len(vec))
|
| 30 |
+
print("First 5 values:", vec[:5])
|
| 31 |
+
|
| 32 |
+
# ---- Test 2: cosine similarity ----
|
| 33 |
+
v1 = embeddings.embed_query("Sốt ở trẻ em")
|
| 34 |
+
v2 = embeddings.embed_query("Trẻ bị sốt cao")
|
| 35 |
+
v3 = embeddings.embed_query("Gãy xương tay")
|
| 36 |
+
|
| 37 |
+
print("\n COSINE SIMILARITY TEST")
|
| 38 |
+
print("v1 ↔ v2 (gần nghĩa):", cosine(v1, v2))
|
| 39 |
+
print("v1 ↔ v3 (khác nghĩa):", cosine(v1, v3))
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
OUTPUT VÀ EXPLAIN DỄ HIỂU:
|
| 43 |
+
1. OUTPUT:
|
| 44 |
+
Vector length: 384
|
| 45 |
+
First 5 values: [-0.02630099654197693, 0.01091383583843708, 0.0058159008622169495, -0.05811420455574989, -0.051191169768571854]
|
| 46 |
+
|
| 47 |
+
2. EXPLAIN
|
| 48 |
+
- Là tọa độ ngữ nghĩa của câu
|
| 49 |
+
|
| 50 |
+
- Không có ý nghĩa đơn lẻ
|
| 51 |
+
|
| 52 |
+
- Chỉ có ý nghĩa khi so sánh với vector khác
|
| 53 |
+
|
| 54 |
+
v1 = embed("Sốt ở trẻ em")
|
| 55 |
+
v2 = embed("Trẻ bị sốt cao")
|
| 56 |
+
v3 = embed("Gãy xương tay")
|
| 57 |
+
|
| 58 |
+
=> v1 ≈ v2 (gần)
|
| 59 |
+
v1 ≠ v3 (xa)
|
| 60 |
+
|
| 61 |
+
=> FAISS dùng cosine similarity / L2 distance để tìm câu gần nghĩa nhất
|
| 62 |
+
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
|
src/hybrid_retriever.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.vectorstores import FAISS
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
class HybridRetriever:
|
| 5 |
+
def __init__(self, vectorstore):
|
| 6 |
+
self.vs = vectorstore
|
| 7 |
+
|
| 8 |
+
def keyword_search(self, query, k=5):
|
| 9 |
+
"""Exact keyword matching - PRIORITY 1"""
|
| 10 |
+
keywords = re.findall(r'\b\w{3,}\b', query.lower())
|
| 11 |
+
scored_docs = []
|
| 12 |
+
|
| 13 |
+
for doc_id, doc in self.vs.docstore._dict.items():
|
| 14 |
+
content_lower = doc.page_content.lower()
|
| 15 |
+
title_lower = doc.metadata.get('chunk_title', '').lower()
|
| 16 |
+
|
| 17 |
+
# Score cao nếu match title + content
|
| 18 |
+
score = sum(2 if kw in title_lower else 1
|
| 19 |
+
for kw in keywords if kw in content_lower or kw in title_lower)
|
| 20 |
+
|
| 21 |
+
if score > 0:
|
| 22 |
+
scored_docs.append((score, doc))
|
| 23 |
+
|
| 24 |
+
scored_docs.sort(reverse=True, key=lambda x: x[0])
|
| 25 |
+
return [doc for _, doc in scored_docs[:k]]
|
| 26 |
+
|
| 27 |
+
def hybrid_search(self, query, k=3):
|
| 28 |
+
"""KEYWORD FIRST → Semantic backup"""
|
| 29 |
+
# PRIORITY 1: Keyword exact match
|
| 30 |
+
keyword_docs = self.keyword_search(query, k=k*2)
|
| 31 |
+
|
| 32 |
+
if keyword_docs:
|
| 33 |
+
print(f" KEYWORD HIT: {len(keyword_docs)} docs")
|
| 34 |
+
return keyword_docs[:k]
|
| 35 |
+
|
| 36 |
+
# PRIORITY 2: Semantic fallback
|
| 37 |
+
print(" Semantic fallback...")
|
| 38 |
+
semantic_docs = self.vs.similarity_search(query, k=k)
|
| 39 |
+
return semantic_docs
|
src/main.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from data_loader import DataLoader
|
| 2 |
+
from config import Config
|
| 3 |
+
from doctor_evaluator import DoctorEvaluator
|
| 4 |
+
from vector_store import VectorStoreManager
|
| 5 |
+
from rag_chain import RAGChain
|
| 6 |
+
|
| 7 |
+
def main():
|
| 8 |
+
# print("LOADING TẤT CẢ CHUNKS...")
|
| 9 |
+
|
| 10 |
+
# # 1. Load TẤT CẢ 3 JSON → Gộp 1 lần
|
| 11 |
+
# docs = DataLoader.load_all_chunks()
|
| 12 |
+
|
| 13 |
+
# 2. Build FAISS chung
|
| 14 |
+
vs_manager = VectorStoreManager()
|
| 15 |
+
# vs_manager.build_from_docs(docs)
|
| 16 |
+
if not vs_manager.vector_store:
|
| 17 |
+
print("Chạy: python build_faise.py trước")
|
| 18 |
+
# 3. Tạo RAG chain
|
| 19 |
+
rag = RAGChain(vs_manager)
|
| 20 |
+
|
| 21 |
+
print("\n CHATBOT Y TẾ READY!")
|
| 22 |
+
print(f"API KEY: {Config.GOOGLE_API_KEY}")
|
| 23 |
+
print(f"Embedding: {Config.EMBEDDING_MODEL}")
|
| 24 |
+
print(" Gõ câu hỏi (quit để thoát)")
|
| 25 |
+
print("-" * 50)
|
| 26 |
+
|
| 27 |
+
print("\n CHỈNH THỨC: 1=Chatbot | 2=Đánh giá bác sĩ")
|
| 28 |
+
mode = input("Chọn mode (1/2): ").strip()
|
| 29 |
+
|
| 30 |
+
if mode == "2":
|
| 31 |
+
"""
|
| 32 |
+
# evaluator = DoctorEvaluator(rag)
|
| 33 |
+
# print("\n ĐÁNH GIÁ BÁC SĨ")
|
| 34 |
+
# print("Format: Bệnh + Triệu chứng + Trả lời bác sĩ")
|
| 35 |
+
|
| 36 |
+
# disease = input("Bệnh: ").strip()
|
| 37 |
+
# symptoms = input("Triệu chứng: ").strip()
|
| 38 |
+
# doctor_answer = input("Trả lời bác sĩ: ").strip()
|
| 39 |
+
"""
|
| 40 |
+
# standard, evaluation = evaluator.evaluate_doctor(disease, symptoms, doctor_answer)
|
| 41 |
+
|
| 42 |
+
# print("\n ĐÁP ÁN CHUẨN:")
|
| 43 |
+
# print(standard)
|
| 44 |
+
# print("\n NHẬN XÉT:")
|
| 45 |
+
# print(evaluation)
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
1. Nhận case: gồm bệnh + triệu chứng
|
| 49 |
+
2. Nhận câu trả lời của bác sĩ cho case
|
| 50 |
+
2. RAG truy tìm đáp án chuẩn (CHẨN ĐOÁN + ĐIỀU TRỊ) của case đó
|
| 51 |
+
3. So sánh câu trả lời của bác sĩ vs key answer qua gemini theo các tiêu chí (điem manh, diem yeu, da co, thieu, nhan xet tong quan)
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
"""
|
| 55 |
+
CÓ THỂ THAY ĐỔI LUỒNG TRÊN THÀNH:
|
| 56 |
+
1. Nhận mỗi triệu chứng --> RAG ra được tên bệnh --> RAG tiếp tên bệnh + CHẨN ĐOÁN + ĐIỀU TRỊ (có thể thêm nếu cần thiết)
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
"""
|
| 60 |
+
NHỮNG THỨ CẦN RAG ĐƯỢC:
|
| 61 |
+
1. Nhận một câu bất kỳ nằm trong pdf --> RAG đúng mục chứa câu đó
|
| 62 |
+
VD: Liều:
|
| 63 |
+
+Trẻ ≤ 5 tuổi: 4 mg
|
| 64 |
+
+Trẻ ≥ 5 tuổi và ≤ 13 (15) tuổi: 5 mg
|
| 65 |
+
+Trẻ ≥ 13 (15) tuổi: 10 mg
|
| 66 |
+
==> SUYỄN TRẺ EM (VI. PHÒNG NGỪA)
|
| 67 |
+
2. Nhận triêu chứng --> RAG đúng bệnh có triệu chứng đó
|
| 68 |
+
3. Nhận mục --> RAG đúng mục đó trả về
|
| 69 |
+
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
# result = evaluator.evaluate_doctor(disease, symptoms, doctor_answer)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# print("\n ĐÁP ÁN CHUẨN:")
|
| 76 |
+
# print(result['standard'])
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
disease = input("\nBỆNH: ").strip()
|
| 80 |
+
|
| 81 |
+
evaluator = DoctorEvaluator(rag)
|
| 82 |
+
result = evaluator.evaluate_doctor(disease)
|
| 83 |
+
|
| 84 |
+
print("=" * 80)
|
| 85 |
+
print("\nTHU THẬP ĐẦY ĐỦ - TIẾN HÀNH LUỒNG")
|
| 86 |
+
print("=" * 80)
|
| 87 |
+
print("\nCASE BỆNH NHÂN:")
|
| 88 |
+
print(result['case'])
|
| 89 |
+
print("=" * 80)
|
| 90 |
+
print("\nĐÁP ÁN CHUẨN:")
|
| 91 |
+
print(result['standard'])
|
| 92 |
+
print("=" * 80)
|
| 93 |
+
print("\nPHÂN TÍCH CHI TIẾT:")
|
| 94 |
+
print(result['evaluation'])
|
| 95 |
+
|
| 96 |
+
# TOP 1 + TÀI LIỆU
|
| 97 |
+
print("\n TOP 1 CHẮC CHẮN:")
|
| 98 |
+
if result['sources'] and result['sources'][0]:
|
| 99 |
+
top_doc = result['sources'][0]
|
| 100 |
+
print(f"[{top_doc.metadata['source_file']}]")
|
| 101 |
+
print(f" {top_doc.metadata['chunk_title']}")
|
| 102 |
+
print(f" {top_doc.metadata['section_title']}")
|
| 103 |
+
print(f" {top_doc.page_content.strip()[:300]}")
|
| 104 |
+
print()
|
| 105 |
+
|
| 106 |
+
print(" TÀI LIỆU THAM KHẢO:")
|
| 107 |
+
if result['sources']:
|
| 108 |
+
for i, doc in enumerate(result['sources'][:3], 1):
|
| 109 |
+
file = doc.metadata.get("source_file", "N/A")
|
| 110 |
+
chunk_id = doc.metadata.get("chunk_id", "N/A")
|
| 111 |
+
chunk_title = doc.metadata.get("chunk_title", "N/A")
|
| 112 |
+
section_id = doc.metadata.get("section_id", "N/A")
|
| 113 |
+
section_title = doc.metadata.get("section_title", "N/A")
|
| 114 |
+
preview = doc.page_content[:80] + "..."
|
| 115 |
+
print(f" {i}. [{file}] {chunk_id} | {chunk_title} | {section_id} | {section_title}")
|
| 116 |
+
print(f" {preview}")
|
| 117 |
+
print()
|
| 118 |
+
|
| 119 |
+
else:
|
| 120 |
+
print("\n CHATBOT Y TẾ READY!")
|
| 121 |
+
print(f"API KEY: {Config.GOOGLE_API_KEY}")
|
| 122 |
+
print(f"Embedding: {Config.EMBEDDING_MODEL}")
|
| 123 |
+
print("Gõ câu hỏi (quit để thoát)")
|
| 124 |
+
print("-" * 50)
|
| 125 |
+
|
| 126 |
+
# 4. CHAT LOOP - CHỈ HỎI → TRẢ LỜI
|
| 127 |
+
while True:
|
| 128 |
+
query = input(" Bạn: ").strip()
|
| 129 |
+
if query.lower() in ['quit', 'exit', 'bye']:
|
| 130 |
+
print(" Tạm biệt!")
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
if not query:
|
| 134 |
+
continue
|
| 135 |
+
|
| 136 |
+
print(" Đang trả lời...")
|
| 137 |
+
answer, sources = rag.query(query) # LẤY CẢ sources
|
| 138 |
+
|
| 139 |
+
print(f"\n Trả lời: {answer}\n")
|
| 140 |
+
|
| 141 |
+
# TOP 1 CHẮC CHẮN - FULL CONTENT
|
| 142 |
+
if sources and sources[0]:
|
| 143 |
+
top_doc = sources[0]
|
| 144 |
+
print(" TOP 1 CHẮC CHẮN:")
|
| 145 |
+
print(f" [{top_doc.metadata['source_file']}]")
|
| 146 |
+
print(f" {top_doc.metadata['chunk_title']}")
|
| 147 |
+
print(f" {top_doc.metadata['section_title']}")
|
| 148 |
+
print(f" {top_doc.page_content.strip()[:300]}") # FULL content
|
| 149 |
+
print()
|
| 150 |
+
|
| 151 |
+
# IN 3 RETRIEVED DOCS
|
| 152 |
+
if sources:
|
| 153 |
+
print(" TÀI LIỆU THAM KHẢO:")
|
| 154 |
+
for i, doc in enumerate(sources[:3], 1): # Top 3
|
| 155 |
+
file = doc.metadata.get("source_file", "N/A")
|
| 156 |
+
chunk_id = doc.metadata.get("chunk_id", "N/A")
|
| 157 |
+
chunk_title = doc.metadata.get("chunk_title", "N/A")
|
| 158 |
+
section_id = doc.metadata.get("section_id", "N/A")
|
| 159 |
+
section_title = doc.metadata.get("section_title", "N/A")
|
| 160 |
+
preview = doc.page_content[:80] + "..."
|
| 161 |
+
print(f" {i}. [{file}] {chunk_id} | {chunk_title} | {section_id} | {section_title}")
|
| 162 |
+
print(f" {preview}")
|
| 163 |
+
print()
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
main()
|
src/rag_chain.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 2 |
+
from langchain_core.prompts import PromptTemplate
|
| 3 |
+
from config import Config
|
| 4 |
+
from hybrid_retriever import HybridRetriever
|
| 5 |
+
from vector_store import VectorStoreManager
|
| 6 |
+
|
| 7 |
+
class RAGChain:
|
| 8 |
+
def __init__(self, vector_store_manager: VectorStoreManager):
|
| 9 |
+
self.llm = ChatGoogleGenerativeAI(
|
| 10 |
+
model=Config.LLM_MODEL,
|
| 11 |
+
google_api_key=Config.GOOGLE_API_KEY,
|
| 12 |
+
temperature=0 # 0 để deterministic
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
self.vectorstore = vector_store_manager.vector_store
|
| 16 |
+
self.retriever = HybridRetriever(self.vectorstore) # FIX TYPO
|
| 17 |
+
|
| 18 |
+
# PROMPT MỚI: TRẢ NỘI DUNG CHUNK + TÓM TẮT
|
| 19 |
+
self.custom_prompt = PromptTemplate(
|
| 20 |
+
input_variables=["context", "question"],
|
| 21 |
+
template="""
|
| 22 |
+
Bạn là bác sĩ y khoa. Dựa vào TÀI LIỆU sau:
|
| 23 |
+
|
| 24 |
+
CONTEXT:
|
| 25 |
+
{context}
|
| 26 |
+
|
| 27 |
+
CÂU HỎI: {question}
|
| 28 |
+
|
| 29 |
+
TRẢ LỜI:
|
| 30 |
+
1. TRÍCH DẪN ĐÚNG nội dung từ CONTEXT (giữ nguyên văn bản)
|
| 31 |
+
2. Tóm tắt ngắn gọn nếu cần
|
| 32 |
+
3. Luôn ưu tiên thông tin từ chunk chính xác nhất
|
| 33 |
+
|
| 34 |
+
NỘI DUNG TÀI LIỆU:
|
| 35 |
+
"""
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def query(self, question: str):
|
| 39 |
+
"""HYBRID RETRIEVAL + FULL CHUNK CONTENT"""
|
| 40 |
+
|
| 41 |
+
# BƯỚC 1: HYBRID SEARCH - PRIORITY KEYWORD
|
| 42 |
+
sources = self.retriever.hybrid_search(question, k=4)
|
| 43 |
+
|
| 44 |
+
# BƯỚC 2: RE-RANK theo keyword match
|
| 45 |
+
ranked_sources = self.rerank_sources(sources, question)
|
| 46 |
+
|
| 47 |
+
# BƯỚC 3: Tạo context FULL CONTENT
|
| 48 |
+
context = self.build_context(ranked_sources)
|
| 49 |
+
|
| 50 |
+
# BƯỚC 4: Generate với prompt rõ ràng
|
| 51 |
+
formatted_prompt = self.custom_prompt.format(
|
| 52 |
+
context=context,
|
| 53 |
+
question=question
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
result = self.llm.invoke([formatted_prompt])
|
| 57 |
+
return result.content, ranked_sources
|
| 58 |
+
|
| 59 |
+
def rerank_sources(self, sources, question):
|
| 60 |
+
"""RE-RANK: Keyword match > Semantic"""
|
| 61 |
+
keywords = question.lower().split()
|
| 62 |
+
|
| 63 |
+
def score_doc(doc):
|
| 64 |
+
content = doc.page_content.lower()
|
| 65 |
+
title = doc.metadata.get('chunk_title', '').lower()
|
| 66 |
+
score = sum(1 for kw in keywords if kw in content or kw in title)
|
| 67 |
+
return score
|
| 68 |
+
|
| 69 |
+
return sorted(sources, key=score_doc, reverse=True)
|
| 70 |
+
|
| 71 |
+
def build_context(self, sources):
|
| 72 |
+
"""FULL CHUNK CONTENT + METADATA"""
|
| 73 |
+
context_parts = []
|
| 74 |
+
for i, doc in enumerate(sources[:3]):
|
| 75 |
+
file = doc.metadata.get('source_file', 'N/A')
|
| 76 |
+
chunk_title = doc.metadata.get('chunk_title', 'N/A')
|
| 77 |
+
section_title = doc.metadata.get('section_title', 'N/A')
|
| 78 |
+
|
| 79 |
+
context_parts.append(
|
| 80 |
+
f"[{i+1}] {file} | {chunk_title} | {section_title}\n"
|
| 81 |
+
f"NỘI DUNG:\n{doc.page_content}\n{'='*80}"
|
| 82 |
+
)
|
| 83 |
+
return "\n\n".join(context_parts)
|
src/vector_store.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.vectorstores import FAISS
|
| 2 |
+
from langchain_core.documents import Document
|
| 3 |
+
from embeddings import EmbeddingsManager
|
| 4 |
+
from typing import List
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import json
|
| 7 |
+
from pymongo import MongoClient
|
| 8 |
+
|
| 9 |
+
from data_loader import DataLoader
|
| 10 |
+
|
| 11 |
+
class VectorStoreManager:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
self.embeddings_manager = EmbeddingsManager()
|
| 14 |
+
self.embeddings = self.embeddings_manager.get_embeddings() # Chứa model embedding từ class EmbeddingsManager
|
| 15 |
+
self.vector_store = None
|
| 16 |
+
# Use relative path from current file location
|
| 17 |
+
current_dir = Path(__file__).parent.parent # Go up to rag_project folder
|
| 18 |
+
self.cache_dir = current_dir / "faiss_cache"
|
| 19 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 20 |
+
|
| 21 |
+
self.load_faiss_cache()
|
| 22 |
+
|
| 23 |
+
def load_faiss_cache(self):
|
| 24 |
+
"""LOAD FAISS 0.1s"""
|
| 25 |
+
index_path = self.cache_dir / "faiss_index"
|
| 26 |
+
if index_path.exists():
|
| 27 |
+
print(" LOADING FAISS CACHE...")
|
| 28 |
+
self.vector_store = FAISS.load_local(
|
| 29 |
+
str(index_path),
|
| 30 |
+
self.embeddings,
|
| 31 |
+
allow_dangerous_deserialization=True
|
| 32 |
+
)
|
| 33 |
+
print(f"LOADED {len(self.vector_store.docstore._dict)} chunks!")
|
| 34 |
+
return True
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
def build_and_cache(self, docs):
|
| 38 |
+
"""EMBED + SAVE"""
|
| 39 |
+
# FIX: Chỉ check 1 lần
|
| 40 |
+
if self.vector_store is not None: # Đã load/cache → skip
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
print(f" Embedding {len(docs)} chunks...")
|
| 44 |
+
self.vector_store = FAISS.from_documents(docs, self.embeddings)
|
| 45 |
+
|
| 46 |
+
# SAVE CACHE
|
| 47 |
+
index_path = self.cache_dir / "faiss_index"
|
| 48 |
+
self.vector_store.save_local(str(index_path))
|
| 49 |
+
print(f"SAVED FAISS: {index_path}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Hàm này truyền all_docs (List[Document] từ class DataLoader để embedding
|
| 53 |
+
def build_from_docs(self, docs: List[Document]):
|
| 54 |
+
"""Xây dựng FAISS từ list chunks"""
|
| 55 |
+
|
| 56 |
+
"""
|
| 57 |
+
self.vector_store CHÍNH LÀ nơi lưu trữ:
|
| 58 |
+
|
| 59 |
+
+ Vector embedding (dạng số) của từng Document
|
| 60 |
+
+ Nội dung gốc (page_content)
|
| 61 |
+
+ Metadata (source_file, chunk_id, title, …)
|
| 62 |
+
+ Chỉ mục FAISS để tìm kiếm nhanh theo độ giống (cosine / L2)
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
self.vector_store = FAISS.from_documents(docs, self.embeddings)
|
| 66 |
+
print(f"FAISS built with {len(docs)} medical chunks indexed!")
|
| 67 |
+
|
| 68 |
+
def get_retriever(self, k: int = 3):
|
| 69 |
+
"""Trả về retriever với k docs"""
|
| 70 |
+
if self.vector_store is None:
|
| 71 |
+
raise ValueError("Vector store not built. Call build_from_docs first.")
|
| 72 |
+
return self.vector_store.as_retriever(search_kwargs={"k": k})
|
| 73 |
+
|
| 74 |
+
def save_documents(self, docs):
|
| 75 |
+
|
| 76 |
+
output_dir = Path(r"D:\Storage\rag_project\store")
|
| 77 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
|
| 79 |
+
records = []
|
| 80 |
+
for i, doc in enumerate(docs):
|
| 81 |
+
records.append({
|
| 82 |
+
"doc_id": f"doc_{i}",
|
| 83 |
+
"page_content": doc.page_content,
|
| 84 |
+
"metadata": doc.metadata
|
| 85 |
+
})
|
| 86 |
+
|
| 87 |
+
# JSON
|
| 88 |
+
with open(output_dir / "documents.json", "w", encoding="utf-8") as f:
|
| 89 |
+
json.dump(records, f, ensure_ascii=False, indent=2)
|
| 90 |
+
|
| 91 |
+
# Mongo
|
| 92 |
+
client = MongoClient("mongodb://localhost:27017/")
|
| 93 |
+
col = client["medchat"]["store"]
|
| 94 |
+
col.delete_many({})
|
| 95 |
+
col.insert_many(records)
|
| 96 |
+
client.close()
|
| 97 |
+
|
| 98 |
+
print(f"Saved {len(records)} documents")
|
| 99 |
+
|
| 100 |
+
def save_embeddings(self, docs):
|
| 101 |
+
|
| 102 |
+
texts = [doc.page_content for doc in docs]
|
| 103 |
+
vectors = self.embeddings.embed_documents(texts)
|
| 104 |
+
|
| 105 |
+
output_dir = Path(r"D:\Storage\rag_project\store")
|
| 106 |
+
|
| 107 |
+
records = []
|
| 108 |
+
for i, vec in enumerate(vectors):
|
| 109 |
+
records.append({
|
| 110 |
+
"doc_id": f"doc_{i}",
|
| 111 |
+
"embedding": vec
|
| 112 |
+
})
|
| 113 |
+
|
| 114 |
+
# JSON
|
| 115 |
+
with open(output_dir / "embeddings.json", "w", encoding="utf-8") as f:
|
| 116 |
+
json.dump(records, f)
|
| 117 |
+
|
| 118 |
+
# Mongo
|
| 119 |
+
client = MongoClient("mongodb://localhost:27017/")
|
| 120 |
+
col = client["medchat"]["embeddings"]
|
| 121 |
+
col.delete_many({})
|
| 122 |
+
col.insert_many(records)
|
| 123 |
+
client.close()
|
| 124 |
+
|
| 125 |
+
print(f" Saved {len(records)} embeddings")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# TEST: python vector_store.py
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
print("TEST VECTOR STORE...")
|
| 131 |
+
|
| 132 |
+
# docs sample
|
| 133 |
+
docs = [
|
| 134 |
+
Document(
|
| 135 |
+
page_content="Sốt ở trẻ em là tình trạng thân nhiệt tăng trên 38 độ.",
|
| 136 |
+
metadata={"title": "Sốt trẻ em"}
|
| 137 |
+
),
|
| 138 |
+
Document(
|
| 139 |
+
page_content="Tiêu chảy cấp ở trẻ thường do virus hoặc vi khuẩn.",
|
| 140 |
+
metadata={"title": "Tiêu chảy"}
|
| 141 |
+
),
|
| 142 |
+
Document(
|
| 143 |
+
page_content="Gãy xương tay cần được cố định và đưa đến cơ sở y t���.",
|
| 144 |
+
metadata={"title": "Chấn thương"}
|
| 145 |
+
),
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
vs = VectorStoreManager()
|
| 149 |
+
vs.build_from_docs(docs)
|
| 150 |
+
print(f"Đã build xong docs")
|
| 151 |
+
|
| 152 |
+
retriever = vs.get_retriever(k=2)
|
| 153 |
+
results = retriever.get_relevant_documents("vi khuẩn")
|
| 154 |
+
|
| 155 |
+
print("\n KẾT QUẢ TRUY XUẤT:")
|
| 156 |
+
for i, doc in enumerate(results):
|
| 157 |
+
print(f"{i+1}. {doc.metadata['title']} | {doc.page_content[:60]}...")
|
| 158 |
+
|
| 159 |
+
print("\n VECTOR STORE OK!")
|
| 160 |
+
|
| 161 |
+
all_docs = DataLoader.load_all_chunks()
|
| 162 |
+
|
| 163 |
+
vs.build_from_docs(docs)
|
| 164 |
+
|
| 165 |
+
print(f"Đã build xong all_docs")
|
| 166 |
+
vs.save_documents(all_docs)
|
| 167 |
+
vs.save_embeddings(all_docs)
|
| 168 |
+
|
| 169 |
+
print("\n HOÀN THÀNH LƯU DOCUMENT VÀ EMBEDING HOÀN TẤT")
|
start_server.bat
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
cd /d "%~dp0"
|
| 3 |
+
echo Starting RAG API Server...
|
| 4 |
+
echo This will take 30-60 seconds to load models...
|
| 5 |
+
echo.
|
| 6 |
+
python api_server.py
|
| 7 |
+
pause
|
test/test vector_store.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from src.data_loader import DataLoader
|
| 3 |
+
from src.embeddings import EmbeddingsManager
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
docs = DataLoader.load_all_chunks()
|
| 7 |
+
|
| 8 |
+
em = EmbeddingsManager()
|
| 9 |
+
embeddings_model = em.get_embeddings()
|
| 10 |
+
|
| 11 |
+
save_documents(docs)
|
| 12 |
+
save_embeddings(docs, embeddings_model)
|
test/test_chunks_by_file.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
BoYTe200_v3.json: 157 chunks
|
| 6 |
+
NHIKHOA2.json: 47 chunks
|
| 7 |
+
PHACDODIEUTRI_2016.json: 156 chunks
|
| 8 |
+
TỔNG: 360 chunks logic
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
chunk = {
|
| 12 |
+
"chunk_id": chap["id"],
|
| 13 |
+
"title": chap["Index"],
|
| 14 |
+
"level1_items": chap["level1_items"],
|
| 15 |
+
"contents": [ # Mảng sections
|
| 16 |
+
{"title": "Section 1", "content": "..."},
|
| 17 |
+
{"title": "Section 2", "content": "..."},
|
| 18 |
+
...
|
| 19 |
+
]
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
BASE_DIR = Path(r"D:\Storage\rag_project") # sửa cho đúng đường dẫn
|
| 25 |
+
DATA_DIR = BASE_DIR / "data"
|
| 26 |
+
|
| 27 |
+
def load_chapters(json_path: Path):
|
| 28 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 29 |
+
chapters = json.load(f)
|
| 30 |
+
return chapters
|
| 31 |
+
|
| 32 |
+
def main():
|
| 33 |
+
print(" BASE_DIR:", BASE_DIR)
|
| 34 |
+
print(" DATA_DIR:", DATA_DIR, "\n")
|
| 35 |
+
|
| 36 |
+
json_files = list(DATA_DIR.glob("*.json"))
|
| 37 |
+
if not json_files:
|
| 38 |
+
print(" Không tìm thấy file .json nào trong data/")
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
total_chunks = 0
|
| 42 |
+
|
| 43 |
+
for json_file in json_files:
|
| 44 |
+
print(f"==============================")
|
| 45 |
+
print(f" FILE: {json_file.name}")
|
| 46 |
+
chapters = load_chapters(json_file)
|
| 47 |
+
|
| 48 |
+
all_chunks = []
|
| 49 |
+
for chap in chapters:
|
| 50 |
+
all_chunks.append({
|
| 51 |
+
"chunk_id": chap.get("id"),
|
| 52 |
+
"title": chap.get("Index"),
|
| 53 |
+
"level1_items": chap.get("level1_items", []),
|
| 54 |
+
"contents": chap.get("contents", []),
|
| 55 |
+
})
|
| 56 |
+
|
| 57 |
+
num_chunks = len(all_chunks)
|
| 58 |
+
total_chunks += num_chunks
|
| 59 |
+
print(f" Số chunk (theo id) trong file này: {num_chunks}")
|
| 60 |
+
|
| 61 |
+
# In MẪU 1 chunk đầu tiên của file
|
| 62 |
+
if num_chunks > 0:
|
| 63 |
+
sample = all_chunks[0]
|
| 64 |
+
print("\n MẪU CHUNK ĐẦU TIÊN:")
|
| 65 |
+
print(" chunk_id:", sample["chunk_id"])
|
| 66 |
+
print(" title :", sample["title"])
|
| 67 |
+
print(" level1_items:", sample["level1_items"])
|
| 68 |
+
print(" Số sections trong contents:", len(sample["contents"]))
|
| 69 |
+
if sample["contents"]:
|
| 70 |
+
sec0 = sample["contents"][0]
|
| 71 |
+
print(" ➜ Section 1 title :", sec0.get("title"))
|
| 72 |
+
print(" ➜ Section 1 content preview:",
|
| 73 |
+
(sec0.get("content") or "")[:150], "...")
|
| 74 |
+
print()
|
| 75 |
+
|
| 76 |
+
print("====================================")
|
| 77 |
+
print(" TỔNG SỐ CHUNK (theo id) TỪ TẤT CẢ FILE:", total_chunks)
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
main()
|
test/test_data_loader.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from langchain_core.documents import Document
|
| 6 |
+
|
| 7 |
+
# PATH ĐÚNG
|
| 8 |
+
BASE_DIR = Path(r'D:\Storage\rag_project')
|
| 9 |
+
sys.path.insert(0, str(BASE_DIR / 'src'))
|
| 10 |
+
|
| 11 |
+
print(" Đường dẫn Python search:")
|
| 12 |
+
print(f" - BASE_DIR: {BASE_DIR}")
|
| 13 |
+
|
| 14 |
+
# LOAD TẤT CẢ JSON TRONG DATA/
|
| 15 |
+
DATA_DIR = BASE_DIR / 'data'
|
| 16 |
+
all_docs = []
|
| 17 |
+
|
| 18 |
+
print("\n TẤT CẢ JSON TRONG DATA:")
|
| 19 |
+
json_files = list(DATA_DIR.glob("*.json"))
|
| 20 |
+
for json_file in json_files:
|
| 21 |
+
print(f" {json_file.name}")
|
| 22 |
+
|
| 23 |
+
# Load TẤT CẢ JSON files
|
| 24 |
+
total_chunks = 0
|
| 25 |
+
for json_file in json_files:
|
| 26 |
+
print(f"\n Đang load {json_file.name}...")
|
| 27 |
+
|
| 28 |
+
with open(json_file, 'r', encoding='utf-8') as f:
|
| 29 |
+
chapters = json.load(f)
|
| 30 |
+
|
| 31 |
+
# Tạo chunks như notebook gốc
|
| 32 |
+
file_chunks = []
|
| 33 |
+
for chap in chapters:
|
| 34 |
+
all_chunks = {
|
| 35 |
+
"chunkid": chap.get("id", "unknown"),
|
| 36 |
+
"title": chap.get("index", "unknown"),
|
| 37 |
+
"level1items": chap.get("level1items", []),
|
| 38 |
+
"contents": chap.get("contents", [])
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# Tạo Documents
|
| 42 |
+
for i, section in enumerate(all_chunks["contents"]):
|
| 43 |
+
doc = Document(
|
| 44 |
+
page_content=section["content"],
|
| 45 |
+
metadata={
|
| 46 |
+
"source_file": json_file.name,
|
| 47 |
+
"chunkid": all_chunks["chunkid"],
|
| 48 |
+
"sectionid": f"{all_chunks['chunkid']}.{i+1}",
|
| 49 |
+
"title": all_chunks["title"],
|
| 50 |
+
"sectiontitle": section["title"]
|
| 51 |
+
}
|
| 52 |
+
)
|
| 53 |
+
file_chunks.append(doc)
|
| 54 |
+
|
| 55 |
+
all_docs.extend(file_chunks)
|
| 56 |
+
total_chunks += len(file_chunks)
|
| 57 |
+
print(f" {json_file.name}: {len(file_chunks)} chunks")
|
| 58 |
+
|
| 59 |
+
print(f"\n TỔNG KẾT:")
|
| 60 |
+
print(f" Tổng chunks từ {len(json_files)} files: {total_chunks}")
|
| 61 |
+
print(f" Chunk mẫu 1:")
|
| 62 |
+
if all_docs:
|
| 63 |
+
doc = all_docs[0]
|
| 64 |
+
print(f" File: {doc.metadata['source_file']}")
|
| 65 |
+
print(f" Content: {doc.page_content[:150]}...")
|
| 66 |
+
print(f" Title: {doc.metadata['title']}")
|
| 67 |
+
|
| 68 |
+
print("\n READY CHO RAG - FAISS + LLM!")
|
test/test_embeddings_single.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append(r'D:\Storage\rag_project\src')
|
| 3 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 4 |
+
from test_single_file_loader import test_single_file
|
| 5 |
+
|
| 6 |
+
def test_embed_single(filename):
|
| 7 |
+
print(f"\n EMBED TEST: {filename}")
|
| 8 |
+
docs = test_single_file(filename)
|
| 9 |
+
|
| 10 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 11 |
+
|
| 12 |
+
# Embed 1 doc mẫu
|
| 13 |
+
sample_text = docs[0].page_content[:500]
|
| 14 |
+
vector = embeddings.embed_query(sample_text)
|
| 15 |
+
|
| 16 |
+
print(f" Embedding shape: {len(vector)}")
|
| 17 |
+
print(f" Vector preview: {vector[:5]}...")
|
| 18 |
+
print(f" READY cho FAISS!")
|
| 19 |
+
|
| 20 |
+
if __name__ == "__main__":
|
| 21 |
+
test_embed_single("NHIKHOA2.json")
|
| 22 |
+
test_embed_single("PHACDODIEUTRI_2016.json")
|
test/test_faiss_single.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append(r'D:\Storage\rag_project\src')
|
| 3 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 4 |
+
from langchain_community.vectorstores import FAISS
|
| 5 |
+
from test_single_file_loader import test_single_file
|
| 6 |
+
|
| 7 |
+
def test_faiss_single(filename):
|
| 8 |
+
print(f"\n FAISS TEST: {filename}")
|
| 9 |
+
docs = test_single_file(filename)
|
| 10 |
+
|
| 11 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 12 |
+
vectorstore = FAISS.from_documents(docs, embeddings)
|
| 13 |
+
|
| 14 |
+
print(f" FAISS index created: {len(docs)} vectors")
|
| 15 |
+
|
| 16 |
+
# Test retrieve
|
| 17 |
+
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
|
| 18 |
+
query = "tuần hoàn" if "NHIKHOA" in filename else "đột tử"
|
| 19 |
+
results = retriever.get_relevant_documents(query)
|
| 20 |
+
|
| 21 |
+
print(f" Query '{query}' → Found {len(results)} docs:")
|
| 22 |
+
for i, doc in enumerate(results):
|
| 23 |
+
print(f" {i+1}. {doc.metadata['chunk_title']}")
|
| 24 |
+
print(" FAISS OK!")
|
| 25 |
+
|
| 26 |
+
if __name__ == "__main__":
|
| 27 |
+
test_faiss_single("NHIKHOA2.json")
|
| 28 |
+
test_faiss_single("PHACDODIEUTRI_2016.json")
|
test/test_hybrid.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# ✅ FIX PATH - QUAN TRỌNG!
|
| 5 |
+
sys.path.insert(0, r'D:\Storage\rag_project\src') # Thêm src vào đầu path
|
| 6 |
+
|
| 7 |
+
from data_loader import DataLoader
|
| 8 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 9 |
+
from langchain_community.vectorstores import FAISS
|
| 10 |
+
|
| 11 |
+
docs = DataLoader.load_all_chunks()
|
| 12 |
+
print(f"✅ Loaded {len(docs)} docs")
|
| 13 |
+
|
| 14 |
+
embeddings = HuggingFaceEmbeddings(model_name="bkai-foundation-models/vietnamese-bi-encoder")
|
| 15 |
+
vs = FAISS.from_documents(docs, embeddings)
|
| 16 |
+
|
| 17 |
+
query = "X-quang phổi 60-83% bình thường 3 tháng"
|
| 18 |
+
|
| 19 |
+
# OLD: FAISS thuần
|
| 20 |
+
old_docs = vs.similarity_search(query, k=3)
|
| 21 |
+
print("\n❌ OLD FAISS:")
|
| 22 |
+
for doc in old_docs:
|
| 23 |
+
print(f" {doc.metadata['chunk_title'][:40]} | {doc.metadata['source_file']}")
|
| 24 |
+
|
| 25 |
+
# NEW: Keyword boost
|
| 26 |
+
keywords = ["X-quang", "phổi", "bình thường", "3 tháng"]
|
| 27 |
+
boosted = []
|
| 28 |
+
for doc_id, doc in vs.docstore._dict.items():
|
| 29 |
+
score = sum(1 for kw in keywords if kw in doc.page_content.lower())
|
| 30 |
+
if score > 0:
|
| 31 |
+
boosted.append(doc)
|
| 32 |
+
print(f"\n✅ KEYWORD BOOST HIT: score={score}")
|
| 33 |
+
print(f" {doc.metadata['chunk_title']}")
|
| 34 |
+
print(f" Preview: {doc.page_content[:100]}...")
|
| 35 |
+
break
|
| 36 |
+
|
| 37 |
+
if not boosted:
|
| 38 |
+
print("\n❌ KHÔNG TÌM THẤY KEYWORD NÀO!")
|
test/test_load_model_ft.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# test_load_model_fixed.py
|
| 2 |
+
from sentence_transformers import SentenceTransformer
|
| 3 |
+
from transformers import AutoTokenizer, AutoModel
|
| 4 |
+
import torch
|
| 5 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
texts = ["Sốt ở trẻ em", "Trẻ bị sốt cao", "Gãy xương tay"]
|
| 9 |
+
|
| 10 |
+
print("=== sup-SimCSE-VietNamese-phobert-base ===")
|
| 11 |
+
simcse_model = SentenceTransformer("VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
|
| 12 |
+
simcse_embeds = simcse_model.encode(texts)
|
| 13 |
+
print("v1 ↔ v2:", cosine_similarity([simcse_embeds[0]], [simcse_embeds[1]])[0][0])
|
| 14 |
+
print("v1 ↔ v3:", cosine_similarity([simcse_embeds[0]], [simcse_embeds[2]])[0][0])
|
| 15 |
+
|
| 16 |
+
print("\n=== PhoBERT gốc ===")
|
| 17 |
+
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
|
| 18 |
+
phobert_model = AutoModel.from_pretrained("vinai/phobert-base")
|
| 19 |
+
|
| 20 |
+
def phobert_embed(text):
|
| 21 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
outputs = phobert_model(**inputs)
|
| 24 |
+
attention_mask = inputs['attention_mask']
|
| 25 |
+
embeddings = outputs.last_hidden_state
|
| 26 |
+
mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
|
| 27 |
+
return torch.sum(embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
|
| 28 |
+
|
| 29 |
+
phobert_embeds = [phobert_embed(t).squeeze().numpy() for t in texts]
|
| 30 |
+
print("v1 ↔ v2:", cosine_similarity([phobert_embeds[0]], [phobert_embeds[1]])[0][0])
|
| 31 |
+
print("v1 ↔ v3:", cosine_similarity([phobert_embeds[0]], [phobert_embeds[2]])[0][0])
|
test/test_rag_single.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append(r'D:\Storage\rag_project\src')
|
| 4 |
+
|
| 5 |
+
# os.environ["GOOGLE_API_KEY"] = "AIzaSyABvC8mPrwa0Kgy08mFFzkyeh2_N-Bb3lY" # Thay key thật
|
| 6 |
+
|
| 7 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 8 |
+
from langchain.chains import RetrievalQA
|
| 9 |
+
from langchain_core.prompts import PromptTemplate
|
| 10 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 11 |
+
from langchain_community.vectorstores import FAISS
|
| 12 |
+
from test_single_file_loader import test_single_file
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_rag_single(filename):
|
| 16 |
+
print(f"\n FULL RAG TEST: {filename}")
|
| 17 |
+
docs = test_single_file(filename)
|
| 18 |
+
|
| 19 |
+
# Build FAISS
|
| 20 |
+
print(" Building FAISS...")
|
| 21 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 22 |
+
vectorstore = FAISS.from_documents(docs, embeddings)
|
| 23 |
+
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
|
| 24 |
+
|
| 25 |
+
# LLM + Prompt
|
| 26 |
+
print(" Init Gemini...")
|
| 27 |
+
llm = ChatGoogleGenerativeAI(
|
| 28 |
+
model="gemini-2.5-flash",
|
| 29 |
+
temperature=0,
|
| 30 |
+
google_api_key="AIzaSyBKUfFRLphY4AgTY-j5sr-6s0SFWW0ATyg" # API KEY Ở ĐÂY
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
custom_prompt = PromptTemplate(
|
| 34 |
+
input_variables=["context", "question"],
|
| 35 |
+
template="""Bạn là bác sĩ nhi khoa. Dựa vào TÀI LIỆU Y KHOA sau:
|
| 36 |
+
|
| 37 |
+
CONTEXT: {context}
|
| 38 |
+
|
| 39 |
+
CÂU HỎI: {question}
|
| 40 |
+
|
| 41 |
+
TRẢ LỜI chính xác dựa trên CONTEXT, ngắn gọn, chuyên nghiệp."""
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
qa_chain = RetrievalQA.from_chain_type(
|
| 45 |
+
llm=llm,
|
| 46 |
+
chain_type="stuff",
|
| 47 |
+
retriever=retriever,
|
| 48 |
+
chain_type_kwargs={"prompt": custom_prompt},
|
| 49 |
+
return_source_documents=True
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# CHỈ NHIKHOA2.json - 1 query đúng
|
| 53 |
+
query = "điều trị suy hô hấp"
|
| 54 |
+
|
| 55 |
+
print(f"\n Query: {query}")
|
| 56 |
+
result = qa_chain.invoke({"query": query})
|
| 57 |
+
print(f" Answer: {result['result'][:400]}...")
|
| 58 |
+
# FIX: Kiểm tra key tồn tại
|
| 59 |
+
if 'source_documents' in result:
|
| 60 |
+
print(f" Sources: {len(result['source_documents'])} docs")
|
| 61 |
+
else:
|
| 62 |
+
print(" Sources: Không có source_documents (Gemini 2.5 format)")
|
| 63 |
+
|
| 64 |
+
print("\n RAG SINGLE FILE OK!")
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
test_rag_single("NHIKHOA2.json")
|
test/test_rq.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.documents import Document
|
| 2 |
+
from langchain_community.vectorstores import FAISS
|
| 3 |
+
|
| 4 |
+
print("LangChain OK")
|
test/test_single_file_loader.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
|
| 5 |
+
BASE_DIR = Path(r"D:\Storage\rag_project")
|
| 6 |
+
DATA_DIR = BASE_DIR / "data"
|
| 7 |
+
|
| 8 |
+
def test_single_file(filename):
|
| 9 |
+
json_path = DATA_DIR / filename
|
| 10 |
+
print(f"\n{'='*60}")
|
| 11 |
+
print(f" TEST FILE: {filename}")
|
| 12 |
+
|
| 13 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 14 |
+
chapters = json.load(f)
|
| 15 |
+
|
| 16 |
+
print(f" Chapters: {len(chapters)}")
|
| 17 |
+
|
| 18 |
+
# Tạo chunks ĐÚNG Colab
|
| 19 |
+
all_chunks = []
|
| 20 |
+
# Tạo từng chunk dựa vào id mỗi sample
|
| 21 |
+
for chap in chapters:
|
| 22 |
+
chunk = {
|
| 23 |
+
"chunk_id": chap.get("id"),
|
| 24 |
+
"title": chap.get("Index"),
|
| 25 |
+
"level1_items": chap.get("level1_items", []),
|
| 26 |
+
"contents": chap.get("contents", [])
|
| 27 |
+
}
|
| 28 |
+
all_chunks.append(chunk)
|
| 29 |
+
|
| 30 |
+
print(f" Chunks: {len(all_chunks)}")
|
| 31 |
+
|
| 32 |
+
# Tạo Documents
|
| 33 |
+
docs = []
|
| 34 |
+
# Tạo documents từ all_chunk
|
| 35 |
+
|
| 36 |
+
for chunk in all_chunks:
|
| 37 |
+
for i, section in enumerate(chunk["contents"]):
|
| 38 |
+
doc = Document(
|
| 39 |
+
page_content=section.get("content", ""),
|
| 40 |
+
metadata={
|
| 41 |
+
"source_file": filename,
|
| 42 |
+
"chunk_id": chunk["chunk_id"],
|
| 43 |
+
"chunk_title": chunk["title"],
|
| 44 |
+
"section_id": f"{chunk['chunk_id']}.{i+1}",
|
| 45 |
+
"section_title": section.get("title", "")
|
| 46 |
+
}
|
| 47 |
+
)
|
| 48 |
+
docs.append(doc)
|
| 49 |
+
|
| 50 |
+
print(f" Documents: {len(docs)}")
|
| 51 |
+
print(f" Mẫu doc 0:")
|
| 52 |
+
print(f" Title: {docs[0].metadata['chunk_title']}")
|
| 53 |
+
print(f" Section: {docs[0].metadata["section_id"]}")
|
| 54 |
+
print(f" Content: {docs[0].page_content[:100]}...")
|
| 55 |
+
return docs
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
# Test từng file
|
| 59 |
+
test_single_file("NHIKHOA2.json")
|
| 60 |
+
test_single_file("BoYTe200_v3.json")
|
| 61 |
+
test_single_file("PHACDODIEUTRI_2016.json")
|
test/testapi.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 2 |
+
|
| 3 |
+
API_KEY = "AIzaSyBhsaHDerIO-IAf7gEPvjs5Mb1hYOjqOMk"
|
| 4 |
+
|
| 5 |
+
llm = ChatGoogleGenerativeAI(
|
| 6 |
+
model="gemini-2.5-flash",
|
| 7 |
+
google_api_key=API_KEY,
|
| 8 |
+
temperature=0
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
print(f"API KEY: {API_KEY} READY")
|
| 12 |
+
resp = llm.invoke("helo")
|
| 13 |
+
print(resp.content)
|
test_api.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test API endpoint để debug lỗi 500"""
|
| 2 |
+
import requests
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
url = "http://localhost:5000/api/start-case"
|
| 6 |
+
data = {
|
| 7 |
+
"disease": "VIÊM PHỔI",
|
| 8 |
+
"sessionId": "test_123"
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
print("Testing /api/start-case endpoint...")
|
| 12 |
+
print(f"URL: {url}")
|
| 13 |
+
print(f"Data: {data}")
|
| 14 |
+
print("-" * 80)
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
response = requests.post(url, json=data)
|
| 18 |
+
print(f"Status Code: {response.status_code}")
|
| 19 |
+
print(f"Response:\n{json.dumps(response.json(), indent=2, ensure_ascii=False)}")
|
| 20 |
+
except Exception as e:
|
| 21 |
+
print(f"Error: {e}")
|
| 22 |
+
print(f"Response text: {response.text if 'response' in locals() else 'No response'}")
|
test_debug.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test with detailed error logging"""
|
| 2 |
+
import requests
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
url = "http://localhost:5000/api/start-case"
|
| 6 |
+
data = {
|
| 7 |
+
"disease": "SUY TIM",
|
| 8 |
+
"sessionId": "debug_test"
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
response = requests.post(url, json=data, timeout=60)
|
| 13 |
+
print(f"Status: {response.status_code}")
|
| 14 |
+
print(f"Response:\n{json.dumps(response.json(), indent=2, ensure_ascii=False)}")
|
| 15 |
+
except requests.exceptions.Timeout:
|
| 16 |
+
print("⏱️ Request timeout - server taking too long")
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"❌ Error: {e}")
|
| 19 |
+
print(f"Type: {type(e)}")
|
test_suy_tim.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test API với bệnh SUY TIM"""
|
| 2 |
+
import requests
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
url = "http://localhost:5000/api/start-case"
|
| 6 |
+
data = {
|
| 7 |
+
"disease": "SUY TIM",
|
| 8 |
+
"sessionId": "test_suy_tim_456"
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
print("Testing SUY TIM case generation...")
|
| 12 |
+
print(f"URL: {url}")
|
| 13 |
+
print(f"Data: {data}")
|
| 14 |
+
print("-" * 80)
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
response = requests.post(url, json=data)
|
| 18 |
+
print(f"Status Code: {response.status_code}")
|
| 19 |
+
result = response.json()
|
| 20 |
+
print(f"\n✅ Case được tạo:\n{result.get('case', 'N/A')}")
|
| 21 |
+
print(f"\n📋 Triệu chứng tìm được (500 ký tự đầu):\n{result.get('symptoms', 'N/A')[:500]}")
|
| 22 |
+
except Exception as e:
|
| 23 |
+
print(f"❌ Error: {e}")
|