File size: 16,944 Bytes
1149349 b59fc2c 1149349 b59fc2c 29d1146 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 29d1146 1149349 b59fc2c 1149349 b59fc2c 1149349 cf61583 1149349 29d1146 1149349 29d1146 1149349 29d1146 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 29d1146 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 b59fc2c 1149349 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 | """
FastAPI Server để tích hợp RAG vào Chatbot
Endpoints:
- GET /api/health - Health check
- GET /api/diseases - Lấy danh sách bệnh từ JSON
- POST /api/start-case - Nhận bệnh, tạo case với triệu chứng
- POST /api/evaluate - Nhận đáp án user, trả về kết quả so sánh
- Docs: http://localhost:8001/docs (Swagger UI)
"""
import sys
import io
# Fix encoding for Vietnamese characters in Windows console
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
import asyncio
import threading
from fastapi import FastAPI, HTTPException, Depends, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.api_key import APIKeyHeader
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
import json
import sys
import os
import uvicorn
# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from data_loader import DataLoader
from config import Config
from doctor_evaluator import DoctorEvaluator
from vector_store import VectorStoreManager
from rag_chain import RAGChain
from session_store import SessionStore
from disease_cache import DiseaseCache
# ── Background initialization ─────────────────────────────────────────────────
# Heavy work (model load + FAISS) runs in a background thread so uvicorn binds
# port 7860 immediately — HF Spaces sees the port up and marks the Space as
# "Running" within seconds, while initialization continues in the background.
vs_manager: VectorStoreManager = None # type: ignore[assignment]
rag: RAGChain = None # type: ignore[assignment]
evaluator: DoctorEvaluator = None # type: ignore[assignment]
session_store: SessionStore = None # type: ignore[assignment]
disease_cache: DiseaseCache = None # type: ignore[assignment]
_init_done = threading.Event() # set when initialization finishes
_init_error: Exception = None # set if initialization fails
def _background_init():
global vs_manager, rag, evaluator, session_store, disease_cache, _init_error
try:
print("[*] Initializing RAG system in background thread...")
vs_manager = VectorStoreManager()
if not vs_manager.vector_store:
raise RuntimeError("FAISS index not found — run: python src/build_faiss.py")
rag = RAGChain(vs_manager)
evaluator = DoctorEvaluator(rag)
session_store = SessionStore()
session_store.cleanup_expired()
disease_cache = DiseaseCache()
print("[OK] RAG system ready!")
except Exception as exc:
_init_error = exc
print(f"[ERROR] Background initialization failed: {exc}")
import traceback; traceback.print_exc()
finally:
_init_done.set()
# Start immediately — server is up before this finishes
threading.Thread(target=_background_init, daemon=True, name="rag-init").start()
def _require_ready():
"""FastAPI dependency: return 503 while initialization is in progress."""
if not _init_done.is_set():
raise HTTPException(status_code=503, detail="Service is initializing, please retry in a moment")
if _init_error:
raise HTTPException(status_code=500, detail=f"Initialization failed: {_init_error}")
# Configure CORS — restrict to known frontend origins via ALLOWED_ORIGINS env var.
# Default "*" so HuggingFace Spaces / fresh deploys work without manual config.
# For production hardening, set ALLOWED_ORIGINS=https://your-app.vercel.app
_allowed_origins_env = os.getenv("ALLOWED_ORIGINS", "*")
ALLOWED_ORIGINS = [o.strip() for o in _allowed_origins_env.split(",") if o.strip()]
app = FastAPI(
title="Medical RAG API",
description="RAG-based Medical Diagnosis Assistant",
version="2.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Optional API key authentication (set API_SECRET_KEY env var to enable)
_API_SECRET_KEY = os.getenv("API_SECRET_KEY", "")
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
async def verify_api_key(api_key: str = Security(_api_key_header)):
"""If API_SECRET_KEY is configured, require matching X-API-Key header."""
if _API_SECRET_KEY and api_key != _API_SECRET_KEY:
raise HTTPException(status_code=403, detail="Invalid or missing API key")
return api_key
# Pydantic models for request/response
class HealthResponse(BaseModel):
status: str
message: str
embedding_model: str
class Disease(BaseModel):
id: str
name: str
category: str
source: str
sections: List[str]
class DiseasesResponse(BaseModel):
success: bool
diseases: List[Disease]
total: int
class StartCaseRequest(BaseModel):
disease: str
sessionId: str
class StartCaseResponse(BaseModel):
success: bool
sessionId: str
case: str
symptoms: str
sources: List[Dict[str, str]]
class DiagnosisData(BaseModel):
clinical: Optional[str] = ""
paraclinical: Optional[str] = ""
definitiveDiagnosis: Optional[str] = ""
differentialDiagnosis: Optional[str] = ""
treatment: Optional[str] = ""
medication: Optional[str] = ""
class EvaluateRequest(BaseModel):
sessionId: str
diagnosis: DiagnosisData
class EvaluateResponse(BaseModel):
success: bool
case: str
standardAnswer: Dict[str, Any]
evaluation: Dict[str, Any]
sources: List[Dict[str, str]]
@app.get("/", include_in_schema=False)
async def root():
"""Root redirect to API docs"""
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/docs")
@app.get("/api/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint — always returns 200 so HF Spaces marks Space as Running."""
if not _init_done.is_set():
return HealthResponse(
status='loading',
message='RAG system is initializing, please wait...',
embedding_model=Config.EMBEDDING_MODEL
)
if _init_error:
return HealthResponse(
status='error',
message=f'Initialization failed: {_init_error}',
embedding_model=Config.EMBEDDING_MODEL
)
return HealthResponse(
status='healthy',
message='FastAPI RAG Server is running',
embedding_model=Config.EMBEDDING_MODEL
)
@app.get("/api/diseases", response_model=DiseasesResponse, dependencies=[Depends(_require_ready)])
async def get_diseases(
category: Optional[str] = None,
search: Optional[str] = None
):
"""
Lấy danh sách bệnh từ 3 file JSON (Index field)
Query params:
- category: Filter by category (procedures, pediatrics, treatment)
- search: Search in disease names
"""
try:
diseases = []
data_dir = os.path.join(os.path.dirname(__file__), 'data')
# Mapping files to categories
files = [
('BoYTe200_v3.json', 'procedures'),
('NHIKHOA2.json', 'pediatrics'),
('PHACDODIEUTRI_2016.json', 'treatment')
]
for filename, cat in files:
# Filter by category if specified
if category and category != 'all' and category != cat:
continue
filepath = os.path.join(data_dir, filename)
if not os.path.exists(filepath):
continue
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
for item in data:
disease_name = item.get('Index', '')
# Filter by search if specified
if search and search.lower() not in disease_name.lower():
continue
diseases.append(Disease(
id=f"{cat}_{item['id']}",
name=disease_name,
category=cat,
source=filename,
sections=item.get('level1_items', [])
))
return DiseasesResponse(
success=True,
diseases=diseases,
total=len(diseases)
)
except Exception as e:
print(f"[ERROR] Error in get_diseases: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/start-case", response_model=StartCaseResponse, dependencies=[Depends(verify_api_key), Depends(_require_ready)])
async def start_case(request: StartCaseRequest):
"""
1. find_symptoms() + get_detailed_standard_knowledge() run IN PARALLEL
2. generate_case() runs after symptoms are ready
3. Session persisted to SQLite
"""
try:
disease = request.disease.strip()
session_id = request.sessionId
if not disease:
raise HTTPException(status_code=400, detail="Disease name is required")
print(f"[INFO] Starting case for disease: {disease}")
print(f"[INFO] Session ID: {session_id}")
loop = asyncio.get_running_loop()
# Check disease-level cache first (0 LLM calls if HIT)
cached = disease_cache.get(disease)
if cached:
print(f"[INFO] Disease cache HIT for '{disease}' — skipping RAG queries")
symptoms = cached["symptoms"]
standard_data = cached["standard"]
all_sources_raw = cached["sources"]
else:
# Cache MISS — run symptoms + standard in parallel, then cache results
print("[INFO] Disease cache MISS — running RAG queries in parallel...")
(symptoms, symptom_sources), (standard_data, std_sources) = await asyncio.gather(
loop.run_in_executor(None, evaluator.find_symptoms, disease),
loop.run_in_executor(None, evaluator.get_detailed_standard_knowledge, disease),
)
all_sources_raw = symptom_sources + std_sources
# Cache for future requests
disease_cache.set(disease, symptoms, standard_data, [
{"file": d.metadata.get("source_file",""), "title": d.metadata.get("chunk_title",""),
"section": d.metadata.get("section_title","")} for d in all_sources_raw[:5]
])
print(f"[INFO] Symptoms (first 200 chars): {symptoms[:200]}...")
print(f"[INFO] Standard data length: {len(standard_data)} chars")
# Step 2: generate case (depends on symptoms output)
print("[INFO] Step 2: Generating patient case...")
patient_case = await loop.run_in_executor(
None, evaluator.generate_case, disease, symptoms
)
print(f"[INFO] Generated case (first 200 chars): {patient_case[:200]}...")
# Pre-format sources (Document objects or plain dicts -> plain dicts for JSON storage)
formatted_sources = []
for src in (all_sources_raw if not cached else all_sources_raw)[:5]:
if isinstance(src, dict):
formatted_sources.append(src)
else:
formatted_sources.append({
'file': src.metadata.get('source_file', ''),
'title': src.metadata.get('chunk_title', ''),
'section': src.metadata.get('section_title', ''),
})
# Persist session to SQLite
session_store.set(session_id, {
'disease': disease,
'case': patient_case,
'symptoms': symptoms,
'standard': standard_data,
'sources': formatted_sources,
})
return StartCaseResponse(
success=True,
sessionId=session_id,
case=patient_case,
symptoms=symptoms[:300] + "...",
sources=formatted_sources[:3],
)
except HTTPException:
raise
except Exception as e:
print(f"[ERROR] Error in start_case: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/evaluate", response_model=EvaluateResponse, dependencies=[Depends(verify_api_key), Depends(_require_ready)])
async def evaluate_diagnosis(request: EvaluateRequest):
"""
Nhận câu trả lời user, so sánh với đáp án chuẩn đã có trong session
"""
try:
session_id = request.sessionId
diagnosis = request.diagnosis
if not session_id:
raise HTTPException(status_code=400, detail="Session ID required")
session_data = session_store.get(session_id)
if session_data is None:
raise HTTPException(status_code=400, detail="Invalid or expired session")
disease = session_data['disease']
patient_case = session_data['case']
standard_answer = session_data['standard']
print(f"[INFO] Evaluating diagnosis for: {disease}")
print(f"[INFO] Session ID: {session_id}")
print(f"[INFO] User diagnosis: {diagnosis.dict()}")
# Format user’s answer
user_answer = f"""
CHẨN ĐOÁN:
- Lâm sàng: {diagnosis.clinical or 'Không có'}
- Cận lâm sàng: {diagnosis.paraclinical or 'Không có'}
- Chẩn đoán xác định: {diagnosis.definitiveDiagnosis or 'Không có'}
- Chẩn đoán phân biệt: {diagnosis.differentialDiagnosis or 'Không có'}
KẾ HOẠCH ĐIỀU TRỊ:
- Cách điều trị: {diagnosis.treatment or 'Không có'}
- Thuốc: {diagnosis.medication or 'Không có'}
"""
print(f"[INFO] Formatted user answer (first 300 chars): {user_answer[:300]}...")
# Run Groq evaluation (blocking I/O — executed off the event loop)
print("[INFO] Step 1: Evaluating with Groq...")
loop = asyncio.get_running_loop()
evaluation_result = await loop.run_in_executor(
None, evaluator.detailed_evaluation, user_answer, standard_answer
)
print(f"[INFO] Evaluation result (first 500 chars): {evaluation_result[:500]}...")
# Parse JSON from evaluation
print("[INFO] Step 2: Parsing JSON evaluation...")
try:
import json as _json
eval_text = evaluation_result.strip()
if eval_text.startswith('```'):
lines = eval_text.split('\n')
eval_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else eval_text
if eval_text.startswith('json'):
eval_text = eval_text[4:].strip()
evaluation_obj = _json.loads(eval_text)
print(f"[INFO] Successfully parsed JSON")
except Exception as parse_error:
print(f"[ERROR] Failed to parse JSON: {parse_error}")
evaluation_obj = {
'evaluation_text': evaluation_result,
'diem_so': 'N/A',
'diem_manh': [],
'diem_yeu': ['Không thể parse JSON từ đánh giá'],
'da_co': [],
'thieu': [],
'dien_giai': evaluation_result,
'nhan_xet_tong_quan': 'Lỗi parse JSON'
}
# Sources are already pre-formatted plain dicts (stored in session)
formatted_sources = session_data.get('sources', [])[:3]
print("[INFO] Step 3: Formatting response...")
return EvaluateResponse(
success=True,
case=patient_case,
standardAnswer={
"content": standard_answer,
"disease": disease
},
evaluation=evaluation_obj,
sources=formatted_sources
)
except HTTPException:
raise
except Exception as e:
print(f"[ERROR] Error in evaluate: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
if __name__ == '__main__':
print("[*] Starting FastAPI Server...")
print(f"[*] Server: http://localhost:8001")
print(f"[*] Docs: http://localhost:8001/docs")
api_key_status = "configured" if Config.GROQ_API_KEY_1 else "NOT SET (set GROQ_API_KEY_1 in .env)"
print(f"[*] Groq Key status: {api_key_status}")
cors_status = "restricted" if ALLOWED_ORIGINS != ["*"] else "OPEN (*)"
print(f"[*] CORS origins ({cors_status}): {ALLOWED_ORIGINS}")
auth_status = "enabled" if _API_SECRET_KEY else "disabled (set API_SECRET_KEY to enable)"
print(f"[*] API auth: {auth_status}")
uvicorn.run(
app,
host="0.0.0.0",
port=int(os.getenv("PORT", "8001")),
log_level="info",
reload=False
)
|