File size: 6,337 Bytes
1e732dd
 
 
3ca1d38
 
1e732dd
 
 
 
3ca1d38
1e732dd
 
 
3ca1d38
696f787
 
1e732dd
 
 
 
 
 
 
 
 
 
 
 
3ca1d38
 
 
 
696f787
3ca1d38
9659593
696f787
3ca1d38
 
 
 
 
 
 
 
 
696f787
3ca1d38
 
 
 
 
 
 
 
 
696f787
3ca1d38
 
 
 
 
 
 
 
 
 
696f787
3ca1d38
 
 
 
 
 
696f787
3ca1d38
 
 
696f787
3ca1d38
 
 
696f787
3ca1d38
 
696f787
3ca1d38
 
 
 
 
 
696f787
9659593
3ca1d38
1e732dd
 
 
696f787
 
 
1e732dd
 
 
 
 
 
696f787
3ca1d38
 
 
 
1e732dd
 
3ca1d38
696f787
9659593
696f787
9659593
696f787
fd5543a
9659593
1e732dd
 
 
 
 
7caf4dc
1e732dd
 
 
3ca1d38
9659593
 
696f787
9659593
fd5543a
1e732dd
 
 
696f787
1e732dd
 
 
 
fd5543a
 
 
1e732dd
 
 
 
 
 
 
 
 
 
 
 
 
 
7caf4dc
1e732dd
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
MediGuard AI — Analyze Router

Unified /analyze/natural and /analyze/structured endpoints
that delegate to the ClinicalInsightGuild workflow.
"""

from __future__ import annotations

import asyncio
import logging
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import UTC, datetime
from typing import Any

from fastapi import APIRouter, HTTPException, Request

from src.schemas.schemas import (
    AnalysisResponse,
    NaturalAnalysisRequest,
    StructuredAnalysisRequest,
)

logger = logging.getLogger(__name__)
router = APIRouter(prefix="/analyze", tags=["analysis"])

# Thread pool for running sync functions
_executor = ThreadPoolExecutor(max_workers=4)


def _score_disease_heuristic(biomarkers: dict[str, float]) -> dict[str, Any]:
    """Rule-based disease scoring (NOT ML prediction)."""
    scores = {"Diabetes": 0.0, "Anemia": 0.0, "Heart Disease": 0.0, "Thrombocytopenia": 0.0, "Thalassemia": 0.0}

    # Diabetes indicators
    glucose = biomarkers.get("Glucose")
    hba1c = biomarkers.get("HbA1c")
    if glucose is not None and glucose > 126:
        scores["Diabetes"] += 0.4
    if glucose is not None and glucose > 180:
        scores["Diabetes"] += 0.2
    if hba1c is not None and hba1c >= 6.5:
        scores["Diabetes"] += 0.5

    # Anemia indicators
    hemoglobin = biomarkers.get("Hemoglobin")
    mcv = biomarkers.get("Mean Corpuscular Volume", biomarkers.get("MCV"))
    if hemoglobin is not None and hemoglobin < 12.0:
        scores["Anemia"] += 0.6
    if hemoglobin is not None and hemoglobin < 10.0:
        scores["Anemia"] += 0.2
    if mcv is not None and mcv < 80:
        scores["Anemia"] += 0.2

    # Heart disease indicators
    cholesterol = biomarkers.get("Cholesterol")
    troponin = biomarkers.get("Troponin")
    ldl = biomarkers.get("LDL Cholesterol", biomarkers.get("LDL"))
    if cholesterol is not None and cholesterol > 240:
        scores["Heart Disease"] += 0.3
    if troponin is not None and troponin > 0.04:
        scores["Heart Disease"] += 0.6
    if ldl is not None and ldl > 190:
        scores["Heart Disease"] += 0.2

    # Thrombocytopenia indicators
    platelets = biomarkers.get("Platelets")
    if platelets is not None and platelets < 150000:
        scores["Thrombocytopenia"] += 0.6
    if platelets is not None and platelets < 50000:
        scores["Thrombocytopenia"] += 0.3

    # Thalassemia indicators
    if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0:
        scores["Thalassemia"] += 0.4

    # Find top prediction
    top_disease = max(scores, key=scores.get)
    confidence = min(scores[top_disease], 1.0)

    if confidence == 0.0:
        top_disease = "Undetermined"

    # Normalize probabilities
    total = sum(scores.values())
    if total > 0:
        probabilities = {k: v / total for k, v in scores.items()}
    else:
        probabilities = {k: 1.0 / len(scores) for k in scores}

    return {"disease": top_disease, "confidence": confidence, "probabilities": probabilities}


async def _run_guild_analysis(
    request: Request,
    biomarkers: dict[str, float],
    patient_ctx: dict[str, Any],
    extracted_biomarkers: dict[str, float] | None = None,
) -> AnalysisResponse:
    """Execute the ClinicalInsightGuild and build the response envelope."""
    request_id = f"req_{uuid.uuid4().hex[:12]}"
    t0 = time.time()

    ragbot = getattr(request.app.state, "ragbot_service", None)
    if ragbot is None:
        raise HTTPException(status_code=503, detail="Analysis service unavailable. Please wait for initialization.")

    # Generate disease prediction
    model_prediction = _score_disease_heuristic(biomarkers)

    try:
        # Run sync function in thread pool
        from src.state import PatientInput

        patient_input = PatientInput(
            biomarkers=biomarkers, patient_context=patient_ctx, model_prediction=model_prediction
        )
        loop = asyncio.get_running_loop()
        result = await loop.run_in_executor(_executor, lambda: ragbot.run(patient_input))
    except Exception as exc:
        logger.exception("Guild analysis failed: %s", exc)
        raise HTTPException(
            status_code=500,
            detail=f"Analysis pipeline error: {exc}",
        ) from exc

    elapsed = (time.time() - t0) * 1000

    # Build response from result
    prediction = result.get("model_prediction")
    analysis = result.get("final_response", {})
    # Try to extract the conversational_summary if it's there
    conversational_summary = analysis.get("conversational_summary") if isinstance(analysis, dict) else str(analysis)

    return AnalysisResponse(
        status="success",
        request_id=request_id,
        timestamp=datetime.now(UTC).isoformat(),
        extracted_biomarkers=extracted_biomarkers,
        input_biomarkers=biomarkers,
        patient_context=patient_ctx,
        processing_time_ms=round(elapsed, 1),
        prediction=prediction,
        analysis=analysis,
        conversational_summary=conversational_summary,
    )


@router.post("/natural", response_model=AnalysisResponse)
async def analyze_natural(body: NaturalAnalysisRequest, request: Request):
    """Extract biomarkers from natural language and run full analysis."""
    extraction_svc = getattr(request.app.state, "extraction_service", None)
    if extraction_svc is None:
        raise HTTPException(status_code=503, detail="Extraction service unavailable")

    try:
        extracted = await extraction_svc.extract_biomarkers(body.message)
    except Exception as exc:
        logger.exception("Biomarker extraction failed: %s", exc)
        raise HTTPException(status_code=422, detail=f"Could not extract biomarkers: {exc}") from exc

    patient_ctx = body.patient_context.model_dump(exclude_none=True) if body.patient_context else {}
    return await _run_guild_analysis(request, extracted, patient_ctx, extracted_biomarkers=extracted)


@router.post("/structured", response_model=AnalysisResponse)
async def analyze_structured(body: StructuredAnalysisRequest, request: Request):
    """Run full analysis on pre-structured biomarker data."""
    patient_ctx = body.patient_context.model_dump(exclude_none=True) if body.patient_context else {}
    return await _run_guild_analysis(request, body.biomarkers, patient_ctx)