Spaces:
Running
Running
| """ | |
| WebSocket streaming endpoint for real-time ECG monitoring. | |
| Supports: | |
| - Live ECG signal streaming | |
| - Real-time gating, inference, and rule evaluation | |
| - Multi-patient concurrent monitoring | |
| """ | |
| import asyncio | |
| import json | |
| from typing import Any, Dict, List, Optional | |
| from datetime import datetime | |
| from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends | |
| from sqlalchemy.orm import Session | |
| from app.db.session import get_session | |
| from app.ml.gating import gate_signal | |
| from app.ml.inference import infer_ecg | |
| from app.rules.engine import evaluate_ecg_rules | |
| from app.models.ecg import ECGSample | |
| router = APIRouter() | |
| class ConnectionManager: | |
| """Manages WebSocket connections for multiple patients.""" | |
| def __init__(self): | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| async def connect(self, patient_id: str, websocket: WebSocket): | |
| await websocket.accept() | |
| self.active_connections[patient_id] = websocket | |
| def disconnect(self, patient_id: str): | |
| if patient_id in self.active_connections: | |
| del self.active_connections[patient_id] | |
| async def send_message(self, patient_id: str, message: Dict[str, Any]): | |
| if patient_id in self.active_connections: | |
| await self.active_connections[patient_id].send_json(message) | |
| def get_active_patients(self) -> List[str]: | |
| return list(self.active_connections.keys()) | |
| manager = ConnectionManager() | |
| async def ecg_stream(websocket: WebSocket, patient_id: str): | |
| """ | |
| WebSocket endpoint for streaming ECG data. | |
| Protocol: | |
| - Client sends: {"signal": [float, ...], "patient_context": {...}} | |
| - Server responds: {"status": "ok", "result": {...}, "timestamp": "..."} | |
| The server runs gating β inference β rules and streams back results. | |
| """ | |
| await manager.connect(patient_id, websocket) | |
| try: | |
| while True: | |
| # Receive signal chunk from client | |
| data = await websocket.receive_json() | |
| signal = data.get("signal", []) | |
| patient_context = data.get("patient_context", {}) | |
| patient_context["patient_id"] = patient_id | |
| if not signal: | |
| await websocket.send_json({ | |
| "status": "error", | |
| "message": "Empty signal", | |
| "timestamp": datetime.utcnow().isoformat() | |
| }) | |
| continue | |
| # Process signal through pipeline | |
| try: | |
| # Step 1: Gating | |
| gated, gating_meta = gate_signal(signal) | |
| # Step 2: Inference | |
| model_output = infer_ecg(gated, original_len=len(signal), gating_meta=gating_meta) | |
| # Step 3: Rules | |
| rules_result = evaluate_ecg_rules(patient_context, model_output) | |
| # Build response | |
| result = { | |
| "patient_id": patient_id, | |
| "label": model_output.get("label"), | |
| "score": round(float(model_output.get("score", 0.0)), 3), | |
| "hr": model_output.get("hr"), | |
| "alert_level": rules_result.get("alert_level"), | |
| "explanations": rules_result.get("explanations", []), | |
| "gating": { | |
| "ratio": round(gating_meta.get("ratio", 1.0), 3), | |
| "selected_windows": gating_meta.get("selected_windows", 0), | |
| "total_windows": gating_meta.get("total_windows", 0), | |
| "energy_saved_pct": round((1 - gating_meta.get("ratio", 1.0)) * 100, 1), | |
| }, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| } | |
| await websocket.send_json({ | |
| "status": "ok", | |
| "result": result, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| }) | |
| except Exception as e: | |
| await websocket.send_json({ | |
| "status": "error", | |
| "message": str(e), | |
| "timestamp": datetime.utcnow().isoformat(), | |
| }) | |
| except WebSocketDisconnect: | |
| manager.disconnect(patient_id) | |
| print(f"Patient {patient_id} disconnected") | |
| async def get_active_streams() -> Dict[str, Any]: | |
| """Get list of active patient streams.""" | |
| return { | |
| "active_patients": manager.get_active_patients(), | |
| "count": len(manager.get_active_patients()), | |
| } | |
| async def broadcast_alert(alert: Dict[str, Any]): | |
| """ | |
| Broadcast an alert to all connected patients (admin use). | |
| Example: System-wide maintenance notification. | |
| """ | |
| for patient_id in manager.get_active_patients(): | |
| await manager.send_message(patient_id, { | |
| "type": "system_alert", | |
| "message": alert.get("message", ""), | |
| "timestamp": datetime.utcnow().isoformat(), | |
| }) | |
| return {"status": "ok", "recipients": len(manager.get_active_patients())} | |