""" FastAPI Server để tích hợp RAG vào Chatbot Endpoints: - GET /api/health - Health check - GET /api/diseases - Lấy danh sách bệnh từ JSON - POST /api/start-case - Nhận bệnh, tạo case với triệu chứng - POST /api/evaluate - Nhận đáp án user, trả về kết quả so sánh - Docs: http://localhost:8001/docs (Swagger UI) """ import sys import io # Fix encoding for Vietnamese characters in Windows console sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace') sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace') import asyncio import threading from fastapi import FastAPI, HTTPException, Depends, Security from fastapi.middleware.cors import CORSMiddleware from fastapi.security.api_key import APIKeyHeader from pydantic import BaseModel from typing import Optional, List, Dict, Any import json import sys import os import uvicorn # Add src to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) from data_loader import DataLoader from config import Config from doctor_evaluator import DoctorEvaluator from vector_store import VectorStoreManager from rag_chain import RAGChain from session_store import SessionStore from disease_cache import DiseaseCache # ── Background initialization ───────────────────────────────────────────────── # Heavy work (model load + FAISS) runs in a background thread so uvicorn binds # port 7860 immediately — HF Spaces sees the port up and marks the Space as # "Running" within seconds, while initialization continues in the background. vs_manager: VectorStoreManager = None # type: ignore[assignment] rag: RAGChain = None # type: ignore[assignment] evaluator: DoctorEvaluator = None # type: ignore[assignment] session_store: SessionStore = None # type: ignore[assignment] disease_cache: DiseaseCache = None # type: ignore[assignment] _init_done = threading.Event() # set when initialization finishes _init_error: Exception = None # set if initialization fails def _background_init(): global vs_manager, rag, evaluator, session_store, disease_cache, _init_error try: print("[*] Initializing RAG system in background thread...") vs_manager = VectorStoreManager() if not vs_manager.vector_store: raise RuntimeError("FAISS index not found — run: python src/build_faiss.py") rag = RAGChain(vs_manager) evaluator = DoctorEvaluator(rag) session_store = SessionStore() session_store.cleanup_expired() disease_cache = DiseaseCache() print("[OK] RAG system ready!") except Exception as exc: _init_error = exc print(f"[ERROR] Background initialization failed: {exc}") import traceback; traceback.print_exc() finally: _init_done.set() # Start immediately — server is up before this finishes threading.Thread(target=_background_init, daemon=True, name="rag-init").start() def _require_ready(): """FastAPI dependency: return 503 while initialization is in progress.""" if not _init_done.is_set(): raise HTTPException(status_code=503, detail="Service is initializing, please retry in a moment") if _init_error: raise HTTPException(status_code=500, detail=f"Initialization failed: {_init_error}") # Configure CORS — restrict to known frontend origins via ALLOWED_ORIGINS env var. # Default "*" so HuggingFace Spaces / fresh deploys work without manual config. # For production hardening, set ALLOWED_ORIGINS=https://your-app.vercel.app _allowed_origins_env = os.getenv("ALLOWED_ORIGINS", "*") ALLOWED_ORIGINS = [o.strip() for o in _allowed_origins_env.split(",") if o.strip()] app = FastAPI( title="Medical RAG API", description="RAG-based Medical Diagnosis Assistant", version="2.0.0", docs_url="/docs", redoc_url="/redoc" ) app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Optional API key authentication (set API_SECRET_KEY env var to enable) _API_SECRET_KEY = os.getenv("API_SECRET_KEY", "") _api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) async def verify_api_key(api_key: str = Security(_api_key_header)): """If API_SECRET_KEY is configured, require matching X-API-Key header.""" if _API_SECRET_KEY and api_key != _API_SECRET_KEY: raise HTTPException(status_code=403, detail="Invalid or missing API key") return api_key # Pydantic models for request/response class HealthResponse(BaseModel): status: str message: str embedding_model: str class Disease(BaseModel): id: str name: str category: str source: str sections: List[str] class DiseasesResponse(BaseModel): success: bool diseases: List[Disease] total: int class StartCaseRequest(BaseModel): disease: str sessionId: str class StartCaseResponse(BaseModel): success: bool sessionId: str case: str symptoms: str sources: List[Dict[str, str]] class DiagnosisData(BaseModel): clinical: Optional[str] = "" paraclinical: Optional[str] = "" definitiveDiagnosis: Optional[str] = "" differentialDiagnosis: Optional[str] = "" treatment: Optional[str] = "" medication: Optional[str] = "" class EvaluateRequest(BaseModel): sessionId: str diagnosis: DiagnosisData class EvaluateResponse(BaseModel): success: bool case: str standardAnswer: Dict[str, Any] evaluation: Dict[str, Any] sources: List[Dict[str, str]] @app.get("/", include_in_schema=False) async def root(): """Root redirect to API docs""" from fastapi.responses import RedirectResponse return RedirectResponse(url="/docs") @app.get("/api/health", response_model=HealthResponse) async def health_check(): """Health check endpoint — always returns 200 so HF Spaces marks Space as Running.""" if not _init_done.is_set(): return HealthResponse( status='loading', message='RAG system is initializing, please wait...', embedding_model=Config.EMBEDDING_MODEL ) if _init_error: return HealthResponse( status='error', message=f'Initialization failed: {_init_error}', embedding_model=Config.EMBEDDING_MODEL ) return HealthResponse( status='healthy', message='FastAPI RAG Server is running', embedding_model=Config.EMBEDDING_MODEL ) @app.get("/api/diseases", response_model=DiseasesResponse, dependencies=[Depends(_require_ready)]) async def get_diseases( category: Optional[str] = None, search: Optional[str] = None ): """ Lấy danh sách bệnh từ 3 file JSON (Index field) Query params: - category: Filter by category (procedures, pediatrics, treatment) - search: Search in disease names """ try: diseases = [] data_dir = os.path.join(os.path.dirname(__file__), 'data') # Mapping files to categories files = [ ('BoYTe200_v3.json', 'procedures'), ('NHIKHOA2.json', 'pediatrics'), ('PHACDODIEUTRI_2016.json', 'treatment') ] for filename, cat in files: # Filter by category if specified if category and category != 'all' and category != cat: continue filepath = os.path.join(data_dir, filename) if not os.path.exists(filepath): continue with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f) for item in data: disease_name = item.get('Index', '') # Filter by search if specified if search and search.lower() not in disease_name.lower(): continue diseases.append(Disease( id=f"{cat}_{item['id']}", name=disease_name, category=cat, source=filename, sections=item.get('level1_items', []) )) return DiseasesResponse( success=True, diseases=diseases, total=len(diseases) ) except Exception as e: print(f"[ERROR] Error in get_diseases: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/start-case", response_model=StartCaseResponse, dependencies=[Depends(verify_api_key), Depends(_require_ready)]) async def start_case(request: StartCaseRequest): """ 1. find_symptoms() + get_detailed_standard_knowledge() run IN PARALLEL 2. generate_case() runs after symptoms are ready 3. Session persisted to SQLite """ try: disease = request.disease.strip() session_id = request.sessionId if not disease: raise HTTPException(status_code=400, detail="Disease name is required") print(f"[INFO] Starting case for disease: {disease}") print(f"[INFO] Session ID: {session_id}") loop = asyncio.get_running_loop() # Check disease-level cache first (0 LLM calls if HIT) cached = disease_cache.get(disease) if cached: print(f"[INFO] Disease cache HIT for '{disease}' — skipping RAG queries") symptoms = cached["symptoms"] standard_data = cached["standard"] all_sources_raw = cached["sources"] else: # Cache MISS — run symptoms + standard in parallel, then cache results print("[INFO] Disease cache MISS — running RAG queries in parallel...") (symptoms, symptom_sources), (standard_data, std_sources) = await asyncio.gather( loop.run_in_executor(None, evaluator.find_symptoms, disease), loop.run_in_executor(None, evaluator.get_detailed_standard_knowledge, disease), ) all_sources_raw = symptom_sources + std_sources # Cache for future requests disease_cache.set(disease, symptoms, standard_data, [ {"file": d.metadata.get("source_file",""), "title": d.metadata.get("chunk_title",""), "section": d.metadata.get("section_title","")} for d in all_sources_raw[:5] ]) print(f"[INFO] Symptoms (first 200 chars): {symptoms[:200]}...") print(f"[INFO] Standard data length: {len(standard_data)} chars") # Step 2: generate case (depends on symptoms output) print("[INFO] Step 2: Generating patient case...") patient_case = await loop.run_in_executor( None, evaluator.generate_case, disease, symptoms ) print(f"[INFO] Generated case (first 200 chars): {patient_case[:200]}...") # Pre-format sources (Document objects or plain dicts -> plain dicts for JSON storage) formatted_sources = [] for src in (all_sources_raw if not cached else all_sources_raw)[:5]: if isinstance(src, dict): formatted_sources.append(src) else: formatted_sources.append({ 'file': src.metadata.get('source_file', ''), 'title': src.metadata.get('chunk_title', ''), 'section': src.metadata.get('section_title', ''), }) # Persist session to SQLite session_store.set(session_id, { 'disease': disease, 'case': patient_case, 'symptoms': symptoms, 'standard': standard_data, 'sources': formatted_sources, }) return StartCaseResponse( success=True, sessionId=session_id, case=patient_case, symptoms=symptoms[:300] + "...", sources=formatted_sources[:3], ) except HTTPException: raise except Exception as e: print(f"[ERROR] Error in start_case: {str(e)}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/evaluate", response_model=EvaluateResponse, dependencies=[Depends(verify_api_key), Depends(_require_ready)]) async def evaluate_diagnosis(request: EvaluateRequest): """ Nhận câu trả lời user, so sánh với đáp án chuẩn đã có trong session """ try: session_id = request.sessionId diagnosis = request.diagnosis if not session_id: raise HTTPException(status_code=400, detail="Session ID required") session_data = session_store.get(session_id) if session_data is None: raise HTTPException(status_code=400, detail="Invalid or expired session") disease = session_data['disease'] patient_case = session_data['case'] standard_answer = session_data['standard'] print(f"[INFO] Evaluating diagnosis for: {disease}") print(f"[INFO] Session ID: {session_id}") print(f"[INFO] User diagnosis: {diagnosis.dict()}") # Format user’s answer user_answer = f""" CHẨN ĐOÁN: - Lâm sàng: {diagnosis.clinical or 'Không có'} - Cận lâm sàng: {diagnosis.paraclinical or 'Không có'} - Chẩn đoán xác định: {diagnosis.definitiveDiagnosis or 'Không có'} - Chẩn đoán phân biệt: {diagnosis.differentialDiagnosis or 'Không có'} KẾ HOẠCH ĐIỀU TRỊ: - Cách điều trị: {diagnosis.treatment or 'Không có'} - Thuốc: {diagnosis.medication or 'Không có'} """ print(f"[INFO] Formatted user answer (first 300 chars): {user_answer[:300]}...") # Run Groq evaluation (blocking I/O — executed off the event loop) print("[INFO] Step 1: Evaluating with Groq...") loop = asyncio.get_running_loop() evaluation_result = await loop.run_in_executor( None, evaluator.detailed_evaluation, user_answer, standard_answer ) print(f"[INFO] Evaluation result (first 500 chars): {evaluation_result[:500]}...") # Parse JSON from evaluation print("[INFO] Step 2: Parsing JSON evaluation...") try: import json as _json eval_text = evaluation_result.strip() if eval_text.startswith('```'): lines = eval_text.split('\n') eval_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else eval_text if eval_text.startswith('json'): eval_text = eval_text[4:].strip() evaluation_obj = _json.loads(eval_text) print(f"[INFO] Successfully parsed JSON") except Exception as parse_error: print(f"[ERROR] Failed to parse JSON: {parse_error}") evaluation_obj = { 'evaluation_text': evaluation_result, 'diem_so': 'N/A', 'diem_manh': [], 'diem_yeu': ['Không thể parse JSON từ đánh giá'], 'da_co': [], 'thieu': [], 'dien_giai': evaluation_result, 'nhan_xet_tong_quan': 'Lỗi parse JSON' } # Sources are already pre-formatted plain dicts (stored in session) formatted_sources = session_data.get('sources', [])[:3] print("[INFO] Step 3: Formatting response...") return EvaluateResponse( success=True, case=patient_case, standardAnswer={ "content": standard_answer, "disease": disease }, evaluation=evaluation_obj, sources=formatted_sources ) except HTTPException: raise except Exception as e: print(f"[ERROR] Error in evaluate: {str(e)}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) if __name__ == '__main__': print("[*] Starting FastAPI Server...") print(f"[*] Server: http://localhost:8001") print(f"[*] Docs: http://localhost:8001/docs") api_key_status = "configured" if Config.GROQ_API_KEY_1 else "NOT SET (set GROQ_API_KEY_1 in .env)" print(f"[*] Groq Key status: {api_key_status}") cors_status = "restricted" if ALLOWED_ORIGINS != ["*"] else "OPEN (*)" print(f"[*] CORS origins ({cors_status}): {ALLOWED_ORIGINS}") auth_status = "enabled" if _API_SECRET_KEY else "disabled (set API_SECRET_KEY to enable)" print(f"[*] API auth: {auth_status}") uvicorn.run( app, host="0.0.0.0", port=int(os.getenv("PORT", "8001")), log_level="info", reload=False )