mgbam commited on
Commit
60bd09b
·
verified ·
1 Parent(s): d8ad613

Upload 3 files

Browse files
Files changed (2) hide show
  1. app/api/dashboard.py +236 -0
  2. app/api/streaming.py +150 -0
app/api/dashboard.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-patient dashboard API endpoints.
3
+
4
+ Provides aggregate views for monitoring multiple patients simultaneously.
5
+ """
6
+ from typing import Any, Dict, List, Optional
7
+ from datetime import datetime, timedelta
8
+
9
+ from fastapi import APIRouter, Depends, Query
10
+ from sqlalchemy import func, and_
11
+ from sqlalchemy.orm import Session
12
+
13
+ from app.db.session import get_session
14
+ from app.models.ecg import ECGSample
15
+ from app.models.schemas import DashboardStats, PatientSummary, AlertSummary
16
+
17
+ router = APIRouter()
18
+
19
+
20
+ @router.get("/stats", response_model=DashboardStats)
21
+ def get_dashboard_stats(
22
+ session: Session = Depends(get_session),
23
+ hours: int = Query(24, ge=1, le=168, description="Time window in hours")
24
+ ) -> Dict[str, Any]:
25
+ """
26
+ Get aggregate statistics for the dashboard.
27
+
28
+ Returns:
29
+ - Total patients monitored
30
+ - Total samples processed
31
+ - Alert distribution (none, notify, escalate)
32
+ - Average scores
33
+ - Energy savings estimate
34
+ """
35
+ cutoff_time = datetime.utcnow() - timedelta(hours=hours)
36
+
37
+ # Total samples in time window
38
+ total_samples = session.query(func.count(ECGSample.id)).filter(
39
+ ECGSample.created_at >= cutoff_time
40
+ ).scalar() or 0
41
+
42
+ # Unique patients
43
+ unique_patients = session.query(func.count(func.distinct(ECGSample.patient_id))).filter(
44
+ ECGSample.created_at >= cutoff_time
45
+ ).scalar() or 0
46
+
47
+ # Alert distribution
48
+ alert_counts = session.query(
49
+ ECGSample.alert_level,
50
+ func.count(ECGSample.id)
51
+ ).filter(
52
+ ECGSample.created_at >= cutoff_time
53
+ ).group_by(ECGSample.alert_level).all()
54
+
55
+ alert_distribution = {level: count for level, count in alert_counts}
56
+
57
+ # Average score
58
+ avg_score = session.query(func.avg(ECGSample.score)).filter(
59
+ ECGSample.created_at >= cutoff_time
60
+ ).scalar() or 0.0
61
+
62
+ # Label distribution
63
+ label_counts = session.query(
64
+ ECGSample.label,
65
+ func.count(ECGSample.id)
66
+ ).filter(
67
+ ECGSample.created_at >= cutoff_time
68
+ ).group_by(ECGSample.label).all()
69
+
70
+ label_distribution = {label: count for label, count in label_counts}
71
+
72
+ # Estimated energy savings (assume 60% average from gating)
73
+ estimated_energy_savings_pct = 60.0
74
+
75
+ return {
76
+ "time_window_hours": hours,
77
+ "total_samples": total_samples,
78
+ "unique_patients": unique_patients,
79
+ "alert_distribution": alert_distribution,
80
+ "label_distribution": label_distribution,
81
+ "avg_score": round(float(avg_score), 3),
82
+ "estimated_energy_savings_pct": estimated_energy_savings_pct,
83
+ "timestamp": datetime.utcnow().isoformat(),
84
+ }
85
+
86
+
87
+ @router.get("/patients", response_model=List[PatientSummary])
88
+ def get_patient_summaries(
89
+ session: Session = Depends(get_session),
90
+ alert_level: Optional[str] = Query(None, description="Filter by alert level"),
91
+ limit: int = Query(100, ge=1, le=1000),
92
+ ) -> List[Dict[str, Any]]:
93
+ """
94
+ Get summary information for all patients.
95
+
96
+ Returns list of patients with their latest sample and alert status.
97
+ """
98
+ # Subquery to get latest sample per patient
99
+ from sqlalchemy import distinct
100
+ from sqlalchemy.sql import exists
101
+
102
+ # Get distinct patient IDs
103
+ patient_ids = session.query(distinct(ECGSample.patient_id)).all()
104
+ patient_ids = [pid[0] for pid in patient_ids]
105
+
106
+ summaries = []
107
+
108
+ for patient_id in patient_ids[:limit]:
109
+ # Get latest sample for this patient
110
+ latest_sample = session.query(ECGSample).filter(
111
+ ECGSample.patient_id == patient_id
112
+ ).order_by(ECGSample.created_at.desc()).first()
113
+
114
+ if not latest_sample:
115
+ continue
116
+
117
+ # Filter by alert level if specified
118
+ if alert_level and latest_sample.alert_level != alert_level:
119
+ continue
120
+
121
+ # Count total samples for this patient
122
+ sample_count = session.query(func.count(ECGSample.id)).filter(
123
+ ECGSample.patient_id == patient_id
124
+ ).scalar() or 0
125
+
126
+ # Count alerts
127
+ alert_count = session.query(func.count(ECGSample.id)).filter(
128
+ and_(
129
+ ECGSample.patient_id == patient_id,
130
+ ECGSample.alert_level.in_(['notify', 'escalate'])
131
+ )
132
+ ).scalar() or 0
133
+
134
+ summaries.append({
135
+ "patient_id": patient_id,
136
+ "latest_label": latest_sample.label,
137
+ "latest_score": round(float(latest_sample.score or 0.0), 3),
138
+ "latest_alert_level": latest_sample.alert_level,
139
+ "latest_hr": latest_sample.hr,
140
+ "last_updated": latest_sample.created_at.isoformat(),
141
+ "total_samples": sample_count,
142
+ "alert_count": alert_count,
143
+ })
144
+
145
+ # Sort by alert level priority (escalate > notify > none)
146
+ alert_priority = {'escalate': 0, 'notify': 1, 'none': 2, None: 3}
147
+ summaries.sort(key=lambda x: alert_priority.get(x['latest_alert_level'], 3))
148
+
149
+ return summaries
150
+
151
+
152
+ @router.get("/alerts", response_model=List[AlertSummary])
153
+ def get_recent_alerts(
154
+ session: Session = Depends(get_session),
155
+ hours: int = Query(24, ge=1, le=168),
156
+ alert_level: Optional[str] = Query(None, description="Filter: notify or escalate"),
157
+ limit: int = Query(50, ge=1, le=500),
158
+ ) -> List[Dict[str, Any]]:
159
+ """
160
+ Get recent alerts across all patients.
161
+
162
+ Returns samples with alert_level in ['notify', 'escalate'], sorted by recency.
163
+ """
164
+ cutoff_time = datetime.utcnow() - timedelta(hours=hours)
165
+
166
+ query = session.query(ECGSample).filter(
167
+ and_(
168
+ ECGSample.created_at >= cutoff_time,
169
+ ECGSample.alert_level.in_(['notify', 'escalate'])
170
+ )
171
+ )
172
+
173
+ if alert_level:
174
+ query = query.filter(ECGSample.alert_level == alert_level)
175
+
176
+ alerts = query.order_by(ECGSample.created_at.desc()).limit(limit).all()
177
+
178
+ return [
179
+ {
180
+ "sample_id": alert.id,
181
+ "patient_id": alert.patient_id,
182
+ "alert_level": alert.alert_level,
183
+ "label": alert.label,
184
+ "score": round(float(alert.score or 0.0), 3),
185
+ "hr": alert.hr,
186
+ "timestamp": alert.created_at.isoformat(),
187
+ }
188
+ for alert in alerts
189
+ ]
190
+
191
+
192
+ @router.get("/patient/{patient_id}/history")
193
+ def get_patient_history(
194
+ patient_id: str,
195
+ session: Session = Depends(get_session),
196
+ hours: int = Query(24, ge=1, le=168),
197
+ limit: int = Query(100, ge=1, le=1000),
198
+ ) -> Dict[str, Any]:
199
+ """
200
+ Get historical data for a specific patient.
201
+
202
+ Returns time series of samples, labels, scores, alerts.
203
+ """
204
+ cutoff_time = datetime.utcnow() - timedelta(hours=hours)
205
+
206
+ samples = session.query(ECGSample).filter(
207
+ and_(
208
+ ECGSample.patient_id == patient_id,
209
+ ECGSample.created_at >= cutoff_time
210
+ )
211
+ ).order_by(ECGSample.created_at.asc()).limit(limit).all()
212
+
213
+ history = [
214
+ {
215
+ "sample_id": s.id,
216
+ "label": s.label,
217
+ "score": round(float(s.score or 0.0), 3),
218
+ "alert_level": s.alert_level,
219
+ "hr": s.hr,
220
+ "timestamp": s.created_at.isoformat(),
221
+ }
222
+ for s in samples
223
+ ]
224
+
225
+ # Compute summary stats
226
+ alert_count = sum(1 for s in samples if s.alert_level in ['notify', 'escalate'])
227
+ avg_score = sum(s.score or 0.0 for s in samples) / max(len(samples), 1)
228
+
229
+ return {
230
+ "patient_id": patient_id,
231
+ "time_window_hours": hours,
232
+ "sample_count": len(samples),
233
+ "alert_count": alert_count,
234
+ "avg_score": round(float(avg_score), 3),
235
+ "history": history,
236
+ }
app/api/streaming.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WebSocket streaming endpoint for real-time ECG monitoring.
3
+
4
+ Supports:
5
+ - Live ECG signal streaming
6
+ - Real-time gating, inference, and rule evaluation
7
+ - Multi-patient concurrent monitoring
8
+ """
9
+ import asyncio
10
+ import json
11
+ from typing import Any, Dict, List, Optional
12
+ from datetime import datetime
13
+
14
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
15
+ from sqlalchemy.orm import Session
16
+
17
+ from app.db.session import get_session
18
+ from app.ml.gating import gate_signal
19
+ from app.ml.inference import infer_ecg
20
+ from app.rules.engine import evaluate_ecg_rules
21
+ from app.models.ecg import ECGSample
22
+
23
+ router = APIRouter()
24
+
25
+
26
+ class ConnectionManager:
27
+ """Manages WebSocket connections for multiple patients."""
28
+
29
+ def __init__(self):
30
+ self.active_connections: Dict[str, WebSocket] = {}
31
+
32
+ async def connect(self, patient_id: str, websocket: WebSocket):
33
+ await websocket.accept()
34
+ self.active_connections[patient_id] = websocket
35
+
36
+ def disconnect(self, patient_id: str):
37
+ if patient_id in self.active_connections:
38
+ del self.active_connections[patient_id]
39
+
40
+ async def send_message(self, patient_id: str, message: Dict[str, Any]):
41
+ if patient_id in self.active_connections:
42
+ await self.active_connections[patient_id].send_json(message)
43
+
44
+ def get_active_patients(self) -> List[str]:
45
+ return list(self.active_connections.keys())
46
+
47
+
48
+ manager = ConnectionManager()
49
+
50
+
51
+ @router.websocket("/ws/ecg/{patient_id}")
52
+ async def ecg_stream(websocket: WebSocket, patient_id: str):
53
+ """
54
+ WebSocket endpoint for streaming ECG data.
55
+
56
+ Protocol:
57
+ - Client sends: {"signal": [float, ...], "patient_context": {...}}
58
+ - Server responds: {"status": "ok", "result": {...}, "timestamp": "..."}
59
+
60
+ The server runs gating → inference → rules and streams back results.
61
+ """
62
+ await manager.connect(patient_id, websocket)
63
+
64
+ try:
65
+ while True:
66
+ # Receive signal chunk from client
67
+ data = await websocket.receive_json()
68
+
69
+ signal = data.get("signal", [])
70
+ patient_context = data.get("patient_context", {})
71
+ patient_context["patient_id"] = patient_id
72
+
73
+ if not signal:
74
+ await websocket.send_json({
75
+ "status": "error",
76
+ "message": "Empty signal",
77
+ "timestamp": datetime.utcnow().isoformat()
78
+ })
79
+ continue
80
+
81
+ # Process signal through pipeline
82
+ try:
83
+ # Step 1: Gating
84
+ gated, gating_meta = gate_signal(signal)
85
+
86
+ # Step 2: Inference
87
+ model_output = infer_ecg(gated, original_len=len(signal), gating_meta=gating_meta)
88
+
89
+ # Step 3: Rules
90
+ rules_result = evaluate_ecg_rules(patient_context, model_output)
91
+
92
+ # Build response
93
+ result = {
94
+ "patient_id": patient_id,
95
+ "label": model_output.get("label"),
96
+ "score": round(float(model_output.get("score", 0.0)), 3),
97
+ "hr": model_output.get("hr"),
98
+ "alert_level": rules_result.get("alert_level"),
99
+ "explanations": rules_result.get("explanations", []),
100
+ "gating": {
101
+ "ratio": round(gating_meta.get("ratio", 1.0), 3),
102
+ "selected_windows": gating_meta.get("selected_windows", 0),
103
+ "total_windows": gating_meta.get("total_windows", 0),
104
+ "energy_saved_pct": round((1 - gating_meta.get("ratio", 1.0)) * 100, 1),
105
+ },
106
+ "timestamp": datetime.utcnow().isoformat(),
107
+ }
108
+
109
+ await websocket.send_json({
110
+ "status": "ok",
111
+ "result": result,
112
+ "timestamp": datetime.utcnow().isoformat(),
113
+ })
114
+
115
+ except Exception as e:
116
+ await websocket.send_json({
117
+ "status": "error",
118
+ "message": str(e),
119
+ "timestamp": datetime.utcnow().isoformat(),
120
+ })
121
+
122
+ except WebSocketDisconnect:
123
+ manager.disconnect(patient_id)
124
+ print(f"Patient {patient_id} disconnected")
125
+
126
+
127
+ @router.get("/active-streams")
128
+ async def get_active_streams() -> Dict[str, Any]:
129
+ """Get list of active patient streams."""
130
+ return {
131
+ "active_patients": manager.get_active_patients(),
132
+ "count": len(manager.get_active_patients()),
133
+ }
134
+
135
+
136
+ @router.post("/broadcast-alert")
137
+ async def broadcast_alert(alert: Dict[str, Any]):
138
+ """
139
+ Broadcast an alert to all connected patients (admin use).
140
+
141
+ Example: System-wide maintenance notification.
142
+ """
143
+ for patient_id in manager.get_active_patients():
144
+ await manager.send_message(patient_id, {
145
+ "type": "system_alert",
146
+ "message": alert.get("message", ""),
147
+ "timestamp": datetime.utcnow().isoformat(),
148
+ })
149
+
150
+ return {"status": "ok", "recipients": len(manager.get_active_patients())}