Spaces:
Running
Running
| from typing import Any, Dict | |
| from fastapi import APIRouter, Depends, HTTPException, status | |
| from sqlalchemy.orm import Session | |
| from app.db.session import get_session | |
| from app.ml.gating import gate_signal | |
| from app.ml.inference import infer_ecg | |
| from app.models import schemas | |
| from app.models.ecg import ECGSample | |
| from app.rules.engine import evaluate_ecg_rules | |
| router = APIRouter() | |
| def infer_ecg_endpoint( | |
| payload: schemas.ECGInferenceRequest, | |
| session: Session = Depends(get_session), | |
| ) -> schemas.ECGInferenceResponse: | |
| """ | |
| Ingest ECG samples, store them, run ML inference, and apply rules. | |
| """ | |
| gated_signal, gating_meta = gate_signal(payload.signal) | |
| try: | |
| model_output: Dict[str, Any] = infer_ecg( | |
| gated_signal, | |
| original_len=len(payload.signal), | |
| gating_meta=gating_meta, | |
| ) | |
| except Exception as exc: # pragma: no cover - defensive, should not trip often | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Model inference failed: {exc}", | |
| ) from exc | |
| patient_context = { | |
| "patient_id": payload.patient_id, | |
| "device_id": payload.device_id, | |
| "sampling_rate": payload.sampling_rate, | |
| "age": payload.age, | |
| "has_prior_stroke": payload.has_prior_stroke, | |
| } | |
| rules_result = evaluate_ecg_rules(patient_context, model_output) | |
| gating_expl = ( | |
| f"Gated {len(gated_signal)}/{len(payload.signal)} samples across " | |
| f"{gating_meta.get('selected_windows', 0)}/{gating_meta.get('total_windows', 0)} windows " | |
| f"(thr={gating_meta.get('threshold')}, temp={gating_meta.get('temperature')})." | |
| if gating_meta | |
| else "Gating skipped." | |
| ) | |
| explanations = [gating_expl, *rules_result.get("explanations", [])] | |
| sample = ECGSample( | |
| patient_id=payload.patient_id, | |
| signal=payload.signal, | |
| label=model_output.get("label"), | |
| score=model_output.get("score"), | |
| alert_level=rules_result.get("alert_level"), | |
| hr=model_output.get("hr"), | |
| device_id=payload.device_id, | |
| sampling_rate=payload.sampling_rate, | |
| ) | |
| session.add(sample) | |
| session.commit() | |
| session.refresh(sample) | |
| return schemas.ECGInferenceResponse( | |
| patient_id=payload.patient_id, | |
| label=model_output.get("label", "unknown"), | |
| score=float(model_output.get("score", 0.0)), | |
| alert_level=rules_result.get("alert_level", "none"), | |
| hr=model_output.get("hr"), | |
| sample_id=sample.id, | |
| created_at=sample.created_at, | |
| explanations=explanations, | |
| ) | |