File size: 16,944 Bytes
1149349
 
 
 
 
 
 
b59fc2c
1149349
 
 
 
 
 
 
 
b59fc2c
29d1146
b59fc2c
 
1149349
b59fc2c
1149349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b59fc2c
 
1149349
29d1146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1149349
 
 
 
 
 
 
 
 
 
b59fc2c
1149349
 
 
 
 
b59fc2c
 
 
 
 
 
 
 
 
 
 
1149349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf61583
 
 
 
 
 
 
1149349
 
29d1146
 
 
 
 
 
 
 
 
 
 
 
 
1149349
 
 
 
 
 
 
29d1146
1149349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29d1146
1149349
 
b59fc2c
 
 
1149349
 
 
 
b59fc2c
1149349
 
b59fc2c
1149349
 
b59fc2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1149349
b59fc2c
 
 
1149349
b59fc2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1149349
 
 
 
b59fc2c
 
 
1149349
 
 
 
 
b59fc2c
1149349
b59fc2c
1149349
 
 
 
 
 
 
 
 
29d1146
1149349
 
 
 
 
 
 
b59fc2c
 
 
 
 
 
1149349
b59fc2c
1149349
 
 
b59fc2c
1149349
 
 
b59fc2c
 
1149349
 
 
 
 
 
 
 
 
 
 
 
b59fc2c
 
 
 
 
 
 
 
 
1149349
b59fc2c
1149349
b59fc2c
1149349
 
 
 
 
 
b59fc2c
 
1149349
 
 
 
 
 
 
 
 
 
 
 
b59fc2c
 
 
 
 
1149349
 
 
 
 
 
 
 
 
 
b59fc2c
1149349
 
 
 
 
 
 
 
 
 
 
 
b59fc2c
 
 
 
 
 
 
1149349
 
 
 
b59fc2c
1149349
b59fc2c
1149349
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
"""
FastAPI Server để tích hợp RAG vào Chatbot
Endpoints:
- GET /api/health - Health check
- GET /api/diseases - Lấy danh sách bệnh từ JSON
- POST /api/start-case - Nhận bệnh, tạo case với triệu chứng
- POST /api/evaluate - Nhận đáp án user, trả về kết quả so sánh
- Docs: http://localhost:8001/docs (Swagger UI)
"""
import sys
import io

# Fix encoding for Vietnamese characters in Windows console
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')

import asyncio
import threading

from fastapi import FastAPI, HTTPException, Depends, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.api_key import APIKeyHeader
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
import json
import sys
import os
import uvicorn

# Add src to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))

from data_loader import DataLoader
from config import Config
from doctor_evaluator import DoctorEvaluator
from vector_store import VectorStoreManager
from rag_chain import RAGChain
from session_store import SessionStore
from disease_cache import DiseaseCache

# ── Background initialization ─────────────────────────────────────────────────
# Heavy work (model load + FAISS) runs in a background thread so uvicorn binds
# port 7860 immediately — HF Spaces sees the port up and marks the Space as
# "Running" within seconds, while initialization continues in the background.

vs_manager: VectorStoreManager = None   # type: ignore[assignment]
rag: RAGChain = None                    # type: ignore[assignment]
evaluator: DoctorEvaluator = None       # type: ignore[assignment]
session_store: SessionStore = None      # type: ignore[assignment]
disease_cache: DiseaseCache = None      # type: ignore[assignment]

_init_done = threading.Event()   # set when initialization finishes
_init_error: Exception = None    # set if initialization fails


def _background_init():
    global vs_manager, rag, evaluator, session_store, disease_cache, _init_error
    try:
        print("[*] Initializing RAG system in background thread...")
        vs_manager = VectorStoreManager()
        if not vs_manager.vector_store:
            raise RuntimeError("FAISS index not found — run: python src/build_faiss.py")
        rag = RAGChain(vs_manager)
        evaluator = DoctorEvaluator(rag)
        session_store = SessionStore()
        session_store.cleanup_expired()
        disease_cache = DiseaseCache()
        print("[OK] RAG system ready!")
    except Exception as exc:
        _init_error = exc
        print(f"[ERROR] Background initialization failed: {exc}")
        import traceback; traceback.print_exc()
    finally:
        _init_done.set()


# Start immediately — server is up before this finishes
threading.Thread(target=_background_init, daemon=True, name="rag-init").start()


def _require_ready():
    """FastAPI dependency: return 503 while initialization is in progress."""
    if not _init_done.is_set():
        raise HTTPException(status_code=503, detail="Service is initializing, please retry in a moment")
    if _init_error:
        raise HTTPException(status_code=500, detail=f"Initialization failed: {_init_error}")


# Configure CORS — restrict to known frontend origins via ALLOWED_ORIGINS env var.
# Default "*" so HuggingFace Spaces / fresh deploys work without manual config.
# For production hardening, set ALLOWED_ORIGINS=https://your-app.vercel.app
_allowed_origins_env = os.getenv("ALLOWED_ORIGINS", "*")
ALLOWED_ORIGINS = [o.strip() for o in _allowed_origins_env.split(",") if o.strip()]

app = FastAPI(
    title="Medical RAG API",
    description="RAG-based Medical Diagnosis Assistant",
    version="2.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=ALLOWED_ORIGINS,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Optional API key authentication (set API_SECRET_KEY env var to enable)
_API_SECRET_KEY = os.getenv("API_SECRET_KEY", "")
_api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)


async def verify_api_key(api_key: str = Security(_api_key_header)):
    """If API_SECRET_KEY is configured, require matching X-API-Key header."""
    if _API_SECRET_KEY and api_key != _API_SECRET_KEY:
        raise HTTPException(status_code=403, detail="Invalid or missing API key")
    return api_key


# Pydantic models for request/response
class HealthResponse(BaseModel):
    status: str
    message: str
    embedding_model: str


class Disease(BaseModel):
    id: str
    name: str
    category: str
    source: str
    sections: List[str]


class DiseasesResponse(BaseModel):
    success: bool
    diseases: List[Disease]
    total: int


class StartCaseRequest(BaseModel):
    disease: str
    sessionId: str


class StartCaseResponse(BaseModel):
    success: bool
    sessionId: str
    case: str
    symptoms: str
    sources: List[Dict[str, str]]


class DiagnosisData(BaseModel):
    clinical: Optional[str] = ""
    paraclinical: Optional[str] = ""
    definitiveDiagnosis: Optional[str] = ""
    differentialDiagnosis: Optional[str] = ""
    treatment: Optional[str] = ""
    medication: Optional[str] = ""


class EvaluateRequest(BaseModel):
    sessionId: str
    diagnosis: DiagnosisData


class EvaluateResponse(BaseModel):
    success: bool
    case: str
    standardAnswer: Dict[str, Any]
    evaluation: Dict[str, Any]
    sources: List[Dict[str, str]]


@app.get("/", include_in_schema=False)
async def root():
    """Root redirect to API docs"""
    from fastapi.responses import RedirectResponse
    return RedirectResponse(url="/docs")


@app.get("/api/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint — always returns 200 so HF Spaces marks Space as Running."""
    if not _init_done.is_set():
        return HealthResponse(
            status='loading',
            message='RAG system is initializing, please wait...',
            embedding_model=Config.EMBEDDING_MODEL
        )
    if _init_error:
        return HealthResponse(
            status='error',
            message=f'Initialization failed: {_init_error}',
            embedding_model=Config.EMBEDDING_MODEL
        )
    return HealthResponse(
        status='healthy',
        message='FastAPI RAG Server is running',
        embedding_model=Config.EMBEDDING_MODEL
    )


@app.get("/api/diseases", response_model=DiseasesResponse, dependencies=[Depends(_require_ready)])
async def get_diseases(
    category: Optional[str] = None,
    search: Optional[str] = None
):
    """
    Lấy danh sách bệnh từ 3 file JSON (Index field)
    Query params:
    - category: Filter by category (procedures, pediatrics, treatment)
    - search: Search in disease names
    """
    try:
        diseases = []
        data_dir = os.path.join(os.path.dirname(__file__), 'data')
        
        # Mapping files to categories
        files = [
            ('BoYTe200_v3.json', 'procedures'),
            ('NHIKHOA2.json', 'pediatrics'),
            ('PHACDODIEUTRI_2016.json', 'treatment')
        ]
        
        for filename, cat in files:
            # Filter by category if specified
            if category and category != 'all' and category != cat:
                continue
                
            filepath = os.path.join(data_dir, filename)
            if not os.path.exists(filepath):
                continue
                
            with open(filepath, 'r', encoding='utf-8') as f:
                data = json.load(f)
                for item in data:
                    disease_name = item.get('Index', '')
                    
                    # Filter by search if specified
                    if search and search.lower() not in disease_name.lower():
                        continue
                    
                    diseases.append(Disease(
                        id=f"{cat}_{item['id']}",
                        name=disease_name,
                        category=cat,
                        source=filename,
                        sections=item.get('level1_items', [])
                    ))
        
        return DiseasesResponse(
            success=True,
            diseases=diseases,
            total=len(diseases)
        )
    
    except Exception as e:
        print(f"[ERROR] Error in get_diseases: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/api/start-case", response_model=StartCaseResponse, dependencies=[Depends(verify_api_key), Depends(_require_ready)])
async def start_case(request: StartCaseRequest):
    """
    1. find_symptoms() + get_detailed_standard_knowledge() run IN PARALLEL
    2. generate_case() runs after symptoms are ready
    3. Session persisted to SQLite
    """
    try:
        disease = request.disease.strip()
        session_id = request.sessionId

        if not disease:
            raise HTTPException(status_code=400, detail="Disease name is required")

        print(f"[INFO] Starting case for disease: {disease}")
        print(f"[INFO] Session ID: {session_id}")

        loop = asyncio.get_running_loop()

        # Check disease-level cache first (0 LLM calls if HIT)
        cached = disease_cache.get(disease)
        if cached:
            print(f"[INFO] Disease cache HIT for '{disease}' — skipping RAG queries")
            symptoms = cached["symptoms"]
            standard_data = cached["standard"]
            all_sources_raw = cached["sources"]
        else:
            # Cache MISS — run symptoms + standard in parallel, then cache results
            print("[INFO] Disease cache MISS — running RAG queries in parallel...")
            (symptoms, symptom_sources), (standard_data, std_sources) = await asyncio.gather(
                loop.run_in_executor(None, evaluator.find_symptoms, disease),
                loop.run_in_executor(None, evaluator.get_detailed_standard_knowledge, disease),
            )
            all_sources_raw = symptom_sources + std_sources
            # Cache for future requests
            disease_cache.set(disease, symptoms, standard_data, [
                {"file": d.metadata.get("source_file",""), "title": d.metadata.get("chunk_title",""),
                 "section": d.metadata.get("section_title","")} for d in all_sources_raw[:5]
            ])

        print(f"[INFO] Symptoms (first 200 chars): {symptoms[:200]}...")
        print(f"[INFO] Standard data length: {len(standard_data)} chars")

        # Step 2: generate case (depends on symptoms output)
        print("[INFO] Step 2: Generating patient case...")
        patient_case = await loop.run_in_executor(
            None, evaluator.generate_case, disease, symptoms
        )
        print(f"[INFO] Generated case (first 200 chars): {patient_case[:200]}...")

        # Pre-format sources (Document objects or plain dicts -> plain dicts for JSON storage)
        formatted_sources = []
        for src in (all_sources_raw if not cached else all_sources_raw)[:5]:
            if isinstance(src, dict):
                formatted_sources.append(src)
            else:
                formatted_sources.append({
                    'file':    src.metadata.get('source_file', ''),
                    'title':   src.metadata.get('chunk_title', ''),
                    'section': src.metadata.get('section_title', ''),
                })

        # Persist session to SQLite
        session_store.set(session_id, {
            'disease': disease,
            'case': patient_case,
            'symptoms': symptoms,
            'standard': standard_data,
            'sources': formatted_sources,
        })

        return StartCaseResponse(
            success=True,
            sessionId=session_id,
            case=patient_case,
            symptoms=symptoms[:300] + "...",
            sources=formatted_sources[:3],
        )

    except HTTPException:
        raise
    except Exception as e:
        print(f"[ERROR] Error in start_case: {str(e)}")
        import traceback
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/api/evaluate", response_model=EvaluateResponse, dependencies=[Depends(verify_api_key), Depends(_require_ready)])
async def evaluate_diagnosis(request: EvaluateRequest):
    """
    Nhận câu trả lời user, so sánh với đáp án chuẩn đã có trong session
    """
    try:
        session_id = request.sessionId
        diagnosis = request.diagnosis

        if not session_id:
            raise HTTPException(status_code=400, detail="Session ID required")

        session_data = session_store.get(session_id)
        if session_data is None:
            raise HTTPException(status_code=400, detail="Invalid or expired session")

        disease = session_data['disease']
        patient_case = session_data['case']
        standard_answer = session_data['standard']

        print(f"[INFO] Evaluating diagnosis for: {disease}")
        print(f"[INFO] Session ID: {session_id}")
        print(f"[INFO] User diagnosis: {diagnosis.dict()}")

        # Format user’s answer
        user_answer = f"""
CHẨN ĐOÁN:
- Lâm sàng: {diagnosis.clinical or 'Không có'}
- Cận lâm sàng: {diagnosis.paraclinical or 'Không có'}
- Chẩn đoán xác định: {diagnosis.definitiveDiagnosis or 'Không có'}
- Chẩn đoán phân biệt: {diagnosis.differentialDiagnosis or 'Không có'}

KẾ HOẠCH ĐIỀU TRỊ:
- Cách điều trị: {diagnosis.treatment or 'Không có'}
- Thuốc: {diagnosis.medication or 'Không có'}
"""
        print(f"[INFO] Formatted user answer (first 300 chars): {user_answer[:300]}...")

        # Run Groq evaluation (blocking I/O — executed off the event loop)
        print("[INFO] Step 1: Evaluating with Groq...")
        loop = asyncio.get_running_loop()
        evaluation_result = await loop.run_in_executor(
            None, evaluator.detailed_evaluation, user_answer, standard_answer
        )
        print(f"[INFO] Evaluation result (first 500 chars): {evaluation_result[:500]}...")

        # Parse JSON from evaluation
        print("[INFO] Step 2: Parsing JSON evaluation...")
        try:
            import json as _json
            eval_text = evaluation_result.strip()
            if eval_text.startswith('```'):
                lines = eval_text.split('\n')
                eval_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else eval_text
                if eval_text.startswith('json'):
                    eval_text = eval_text[4:].strip()
            evaluation_obj = _json.loads(eval_text)
            print(f"[INFO] Successfully parsed JSON")
        except Exception as parse_error:
            print(f"[ERROR] Failed to parse JSON: {parse_error}")
            evaluation_obj = {
                'evaluation_text': evaluation_result,
                'diem_so': 'N/A',
                'diem_manh': [],
                'diem_yeu': ['Không thể parse JSON từ đánh giá'],
                'da_co': [],
                'thieu': [],
                'dien_giai': evaluation_result,
                'nhan_xet_tong_quan': 'Lỗi parse JSON'
            }

        # Sources are already pre-formatted plain dicts (stored in session)
        formatted_sources = session_data.get('sources', [])[:3]

        print("[INFO] Step 3: Formatting response...")
        return EvaluateResponse(
            success=True,
            case=patient_case,
            standardAnswer={
                "content": standard_answer,
                "disease": disease
            },
            evaluation=evaluation_obj,
            sources=formatted_sources
        )

    except HTTPException:
        raise
    except Exception as e:
        print(f"[ERROR] Error in evaluate: {str(e)}")
        import traceback
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == '__main__':
    print("[*] Starting FastAPI Server...")
    print(f"[*] Server: http://localhost:8001")
    print(f"[*] Docs:   http://localhost:8001/docs")
    api_key_status = "configured" if Config.GROQ_API_KEY_1 else "NOT SET (set GROQ_API_KEY_1 in .env)"
    print(f"[*] Groq Key status: {api_key_status}")
    cors_status = "restricted" if ALLOWED_ORIGINS != ["*"] else "OPEN (*)"
    print(f"[*] CORS origins ({cors_status}): {ALLOWED_ORIGINS}")
    auth_status = "enabled" if _API_SECRET_KEY else "disabled (set API_SECRET_KEY to enable)"  
    print(f"[*] API auth: {auth_status}")
    
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=int(os.getenv("PORT", "8001")),
        log_level="info",
        reload=False
    )