from typing import Any, Dict from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from app.db.session import get_session from app.ml.gating import gate_signal from app.ml.inference import infer_ecg from app.models import schemas from app.models.ecg import ECGSample from app.rules.engine import evaluate_ecg_rules router = APIRouter() @router.post("/infer", response_model=schemas.ECGInferenceResponse) def infer_ecg_endpoint( payload: schemas.ECGInferenceRequest, session: Session = Depends(get_session), ) -> schemas.ECGInferenceResponse: """ Ingest ECG samples, store them, run ML inference, and apply rules. """ gated_signal, gating_meta = gate_signal(payload.signal) try: model_output: Dict[str, Any] = infer_ecg( gated_signal, original_len=len(payload.signal), gating_meta=gating_meta, ) except Exception as exc: # pragma: no cover - defensive, should not trip often raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Model inference failed: {exc}", ) from exc patient_context = { "patient_id": payload.patient_id, "device_id": payload.device_id, "sampling_rate": payload.sampling_rate, "age": payload.age, "has_prior_stroke": payload.has_prior_stroke, } rules_result = evaluate_ecg_rules(patient_context, model_output) gating_expl = ( f"Gated {len(gated_signal)}/{len(payload.signal)} samples across " f"{gating_meta.get('selected_windows', 0)}/{gating_meta.get('total_windows', 0)} windows " f"(thr={gating_meta.get('threshold')}, temp={gating_meta.get('temperature')})." if gating_meta else "Gating skipped." ) explanations = [gating_expl, *rules_result.get("explanations", [])] sample = ECGSample( patient_id=payload.patient_id, signal=payload.signal, label=model_output.get("label"), score=model_output.get("score"), alert_level=rules_result.get("alert_level"), hr=model_output.get("hr"), device_id=payload.device_id, sampling_rate=payload.sampling_rate, ) session.add(sample) session.commit() session.refresh(sample) return schemas.ECGInferenceResponse( patient_id=payload.patient_id, label=model_output.get("label", "unknown"), score=float(model_output.get("score", 0.0)), alert_level=rules_result.get("alert_level", "none"), hr=model_output.get("hr"), sample_id=sample.id, created_at=sample.created_at, explanations=explanations, )