File size: 2,715 Bytes
6bce99b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import Any, Dict

from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session

from app.db.session import get_session
from app.ml.gating import gate_signal
from app.ml.inference import infer_ecg
from app.models import schemas
from app.models.ecg import ECGSample
from app.rules.engine import evaluate_ecg_rules

router = APIRouter()


@router.post("/infer", response_model=schemas.ECGInferenceResponse)
def infer_ecg_endpoint(
    payload: schemas.ECGInferenceRequest,
    session: Session = Depends(get_session),
) -> schemas.ECGInferenceResponse:
    """
    Ingest ECG samples, store them, run ML inference, and apply rules.
    """
    gated_signal, gating_meta = gate_signal(payload.signal)

    try:
        model_output: Dict[str, Any] = infer_ecg(
            gated_signal,
            original_len=len(payload.signal),
            gating_meta=gating_meta,
        )
    except Exception as exc:  # pragma: no cover - defensive, should not trip often
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Model inference failed: {exc}",
        ) from exc

    patient_context = {
        "patient_id": payload.patient_id,
        "device_id": payload.device_id,
        "sampling_rate": payload.sampling_rate,
        "age": payload.age,
        "has_prior_stroke": payload.has_prior_stroke,
    }
    rules_result = evaluate_ecg_rules(patient_context, model_output)
    gating_expl = (
        f"Gated {len(gated_signal)}/{len(payload.signal)} samples across "
        f"{gating_meta.get('selected_windows', 0)}/{gating_meta.get('total_windows', 0)} windows "
        f"(thr={gating_meta.get('threshold')}, temp={gating_meta.get('temperature')})."
        if gating_meta
        else "Gating skipped."
    )
    explanations = [gating_expl, *rules_result.get("explanations", [])]

    sample = ECGSample(
        patient_id=payload.patient_id,
        signal=payload.signal,
        label=model_output.get("label"),
        score=model_output.get("score"),
        alert_level=rules_result.get("alert_level"),
        hr=model_output.get("hr"),
        device_id=payload.device_id,
        sampling_rate=payload.sampling_rate,
    )

    session.add(sample)
    session.commit()
    session.refresh(sample)

    return schemas.ECGInferenceResponse(
        patient_id=payload.patient_id,
        label=model_output.get("label", "unknown"),
        score=float(model_output.get("score", 0.0)),
        alert_level=rules_result.get("alert_level", "none"),
        hr=model_output.get("hr"),
        sample_id=sample.id,
        created_at=sample.created_at,
        explanations=explanations,
    )