Spaces:
Sleeping
Sleeping
File size: 2,715 Bytes
6bce99b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from typing import Any, Dict
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.db.session import get_session
from app.ml.gating import gate_signal
from app.ml.inference import infer_ecg
from app.models import schemas
from app.models.ecg import ECGSample
from app.rules.engine import evaluate_ecg_rules
router = APIRouter()
@router.post("/infer", response_model=schemas.ECGInferenceResponse)
def infer_ecg_endpoint(
payload: schemas.ECGInferenceRequest,
session: Session = Depends(get_session),
) -> schemas.ECGInferenceResponse:
"""
Ingest ECG samples, store them, run ML inference, and apply rules.
"""
gated_signal, gating_meta = gate_signal(payload.signal)
try:
model_output: Dict[str, Any] = infer_ecg(
gated_signal,
original_len=len(payload.signal),
gating_meta=gating_meta,
)
except Exception as exc: # pragma: no cover - defensive, should not trip often
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Model inference failed: {exc}",
) from exc
patient_context = {
"patient_id": payload.patient_id,
"device_id": payload.device_id,
"sampling_rate": payload.sampling_rate,
"age": payload.age,
"has_prior_stroke": payload.has_prior_stroke,
}
rules_result = evaluate_ecg_rules(patient_context, model_output)
gating_expl = (
f"Gated {len(gated_signal)}/{len(payload.signal)} samples across "
f"{gating_meta.get('selected_windows', 0)}/{gating_meta.get('total_windows', 0)} windows "
f"(thr={gating_meta.get('threshold')}, temp={gating_meta.get('temperature')})."
if gating_meta
else "Gating skipped."
)
explanations = [gating_expl, *rules_result.get("explanations", [])]
sample = ECGSample(
patient_id=payload.patient_id,
signal=payload.signal,
label=model_output.get("label"),
score=model_output.get("score"),
alert_level=rules_result.get("alert_level"),
hr=model_output.get("hr"),
device_id=payload.device_id,
sampling_rate=payload.sampling_rate,
)
session.add(sample)
session.commit()
session.refresh(sample)
return schemas.ECGInferenceResponse(
patient_id=payload.patient_id,
label=model_output.get("label", "unknown"),
score=float(model_output.get("score", 0.0)),
alert_level=rules_result.get("alert_level", "none"),
hr=model_output.get("hr"),
sample_id=sample.id,
created_at=sample.created_at,
explanations=explanations,
)
|