mgbam commited on
Commit
6bce99b
·
verified ·
1 Parent(s): 171ebe4

Upload ecg.py

Browse files
Files changed (1) hide show
  1. app/ecg.py +79 -0
app/ecg.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ from fastapi import APIRouter, Depends, HTTPException, status
4
+ from sqlalchemy.orm import Session
5
+
6
+ from app.db.session import get_session
7
+ from app.ml.gating import gate_signal
8
+ from app.ml.inference import infer_ecg
9
+ from app.models import schemas
10
+ from app.models.ecg import ECGSample
11
+ from app.rules.engine import evaluate_ecg_rules
12
+
13
+ router = APIRouter()
14
+
15
+
16
+ @router.post("/infer", response_model=schemas.ECGInferenceResponse)
17
+ def infer_ecg_endpoint(
18
+ payload: schemas.ECGInferenceRequest,
19
+ session: Session = Depends(get_session),
20
+ ) -> schemas.ECGInferenceResponse:
21
+ """
22
+ Ingest ECG samples, store them, run ML inference, and apply rules.
23
+ """
24
+ gated_signal, gating_meta = gate_signal(payload.signal)
25
+
26
+ try:
27
+ model_output: Dict[str, Any] = infer_ecg(
28
+ gated_signal,
29
+ original_len=len(payload.signal),
30
+ gating_meta=gating_meta,
31
+ )
32
+ except Exception as exc: # pragma: no cover - defensive, should not trip often
33
+ raise HTTPException(
34
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
35
+ detail=f"Model inference failed: {exc}",
36
+ ) from exc
37
+
38
+ patient_context = {
39
+ "patient_id": payload.patient_id,
40
+ "device_id": payload.device_id,
41
+ "sampling_rate": payload.sampling_rate,
42
+ "age": payload.age,
43
+ "has_prior_stroke": payload.has_prior_stroke,
44
+ }
45
+ rules_result = evaluate_ecg_rules(patient_context, model_output)
46
+ gating_expl = (
47
+ f"Gated {len(gated_signal)}/{len(payload.signal)} samples across "
48
+ f"{gating_meta.get('selected_windows', 0)}/{gating_meta.get('total_windows', 0)} windows "
49
+ f"(thr={gating_meta.get('threshold')}, temp={gating_meta.get('temperature')})."
50
+ if gating_meta
51
+ else "Gating skipped."
52
+ )
53
+ explanations = [gating_expl, *rules_result.get("explanations", [])]
54
+
55
+ sample = ECGSample(
56
+ patient_id=payload.patient_id,
57
+ signal=payload.signal,
58
+ label=model_output.get("label"),
59
+ score=model_output.get("score"),
60
+ alert_level=rules_result.get("alert_level"),
61
+ hr=model_output.get("hr"),
62
+ device_id=payload.device_id,
63
+ sampling_rate=payload.sampling_rate,
64
+ )
65
+
66
+ session.add(sample)
67
+ session.commit()
68
+ session.refresh(sample)
69
+
70
+ return schemas.ECGInferenceResponse(
71
+ patient_id=payload.patient_id,
72
+ label=model_output.get("label", "unknown"),
73
+ score=float(model_output.get("score", 0.0)),
74
+ alert_level=rules_result.get("alert_level", "none"),
75
+ hr=model_output.get("hr"),
76
+ sample_id=sample.id,
77
+ created_at=sample.created_at,
78
+ explanations=explanations,
79
+ )