Spaces:
Sleeping
Sleeping
| """ | |
| Multi-patient dashboard API endpoints. | |
| Provides aggregate views for monitoring multiple patients simultaneously. | |
| """ | |
| from typing import Any, Dict, List, Optional | |
| from datetime import datetime, timedelta | |
| from fastapi import APIRouter, Depends, Query | |
| from sqlalchemy import func, and_ | |
| from sqlalchemy.orm import Session | |
| from app.db.session import get_session | |
| from app.models.ecg import ECGSample | |
| from app.models.schemas import DashboardStats, PatientSummary, AlertSummary | |
| router = APIRouter() | |
| def get_dashboard_stats( | |
| session: Session = Depends(get_session), | |
| hours: int = Query(24, ge=1, le=168, description="Time window in hours") | |
| ) -> Dict[str, Any]: | |
| """ | |
| Get aggregate statistics for the dashboard. | |
| Returns: | |
| - Total patients monitored | |
| - Total samples processed | |
| - Alert distribution (none, notify, escalate) | |
| - Average scores | |
| - Energy savings estimate | |
| """ | |
| cutoff_time = datetime.utcnow() - timedelta(hours=hours) | |
| # Total samples in time window | |
| total_samples = session.query(func.count(ECGSample.id)).filter( | |
| ECGSample.created_at >= cutoff_time | |
| ).scalar() or 0 | |
| # Unique patients | |
| unique_patients = session.query(func.count(func.distinct(ECGSample.patient_id))).filter( | |
| ECGSample.created_at >= cutoff_time | |
| ).scalar() or 0 | |
| # Alert distribution | |
| alert_counts = session.query( | |
| ECGSample.alert_level, | |
| func.count(ECGSample.id) | |
| ).filter( | |
| ECGSample.created_at >= cutoff_time | |
| ).group_by(ECGSample.alert_level).all() | |
| alert_distribution = {level: count for level, count in alert_counts} | |
| # Average score | |
| avg_score = session.query(func.avg(ECGSample.score)).filter( | |
| ECGSample.created_at >= cutoff_time | |
| ).scalar() or 0.0 | |
| # Label distribution | |
| label_counts = session.query( | |
| ECGSample.label, | |
| func.count(ECGSample.id) | |
| ).filter( | |
| ECGSample.created_at >= cutoff_time | |
| ).group_by(ECGSample.label).all() | |
| label_distribution = {label: count for label, count in label_counts} | |
| # Estimated energy savings (assume 60% average from gating) | |
| estimated_energy_savings_pct = 60.0 | |
| return { | |
| "time_window_hours": hours, | |
| "total_samples": total_samples, | |
| "unique_patients": unique_patients, | |
| "alert_distribution": alert_distribution, | |
| "label_distribution": label_distribution, | |
| "avg_score": round(float(avg_score), 3), | |
| "estimated_energy_savings_pct": estimated_energy_savings_pct, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| } | |
| def get_patient_summaries( | |
| session: Session = Depends(get_session), | |
| alert_level: Optional[str] = Query(None, description="Filter by alert level"), | |
| limit: int = Query(100, ge=1, le=1000), | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Get summary information for all patients. | |
| Returns list of patients with their latest sample and alert status. | |
| """ | |
| # Subquery to get latest sample per patient | |
| from sqlalchemy import distinct | |
| from sqlalchemy.sql import exists | |
| # Get distinct patient IDs | |
| patient_ids = session.query(distinct(ECGSample.patient_id)).all() | |
| patient_ids = [pid[0] for pid in patient_ids] | |
| summaries = [] | |
| for patient_id in patient_ids[:limit]: | |
| # Get latest sample for this patient | |
| latest_sample = session.query(ECGSample).filter( | |
| ECGSample.patient_id == patient_id | |
| ).order_by(ECGSample.created_at.desc()).first() | |
| if not latest_sample: | |
| continue | |
| # Filter by alert level if specified | |
| if alert_level and latest_sample.alert_level != alert_level: | |
| continue | |
| # Count total samples for this patient | |
| sample_count = session.query(func.count(ECGSample.id)).filter( | |
| ECGSample.patient_id == patient_id | |
| ).scalar() or 0 | |
| # Count alerts | |
| alert_count = session.query(func.count(ECGSample.id)).filter( | |
| and_( | |
| ECGSample.patient_id == patient_id, | |
| ECGSample.alert_level.in_(['notify', 'escalate']) | |
| ) | |
| ).scalar() or 0 | |
| summaries.append({ | |
| "patient_id": patient_id, | |
| "latest_label": latest_sample.label, | |
| "latest_score": round(float(latest_sample.score or 0.0), 3), | |
| "latest_alert_level": latest_sample.alert_level, | |
| "latest_hr": latest_sample.hr, | |
| "last_updated": latest_sample.created_at.isoformat(), | |
| "total_samples": sample_count, | |
| "alert_count": alert_count, | |
| }) | |
| # Sort by alert level priority (escalate > notify > none) | |
| alert_priority = {'escalate': 0, 'notify': 1, 'none': 2, None: 3} | |
| summaries.sort(key=lambda x: alert_priority.get(x['latest_alert_level'], 3)) | |
| return summaries | |
| def get_recent_alerts( | |
| session: Session = Depends(get_session), | |
| hours: int = Query(24, ge=1, le=168), | |
| alert_level: Optional[str] = Query(None, description="Filter: notify or escalate"), | |
| limit: int = Query(50, ge=1, le=500), | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Get recent alerts across all patients. | |
| Returns samples with alert_level in ['notify', 'escalate'], sorted by recency. | |
| """ | |
| cutoff_time = datetime.utcnow() - timedelta(hours=hours) | |
| query = session.query(ECGSample).filter( | |
| and_( | |
| ECGSample.created_at >= cutoff_time, | |
| ECGSample.alert_level.in_(['notify', 'escalate']) | |
| ) | |
| ) | |
| if alert_level: | |
| query = query.filter(ECGSample.alert_level == alert_level) | |
| alerts = query.order_by(ECGSample.created_at.desc()).limit(limit).all() | |
| return [ | |
| { | |
| "sample_id": alert.id, | |
| "patient_id": alert.patient_id, | |
| "alert_level": alert.alert_level, | |
| "label": alert.label, | |
| "score": round(float(alert.score or 0.0), 3), | |
| "hr": alert.hr, | |
| "timestamp": alert.created_at.isoformat(), | |
| } | |
| for alert in alerts | |
| ] | |
| def get_patient_history( | |
| patient_id: str, | |
| session: Session = Depends(get_session), | |
| hours: int = Query(24, ge=1, le=168), | |
| limit: int = Query(100, ge=1, le=1000), | |
| ) -> Dict[str, Any]: | |
| """ | |
| Get historical data for a specific patient. | |
| Returns time series of samples, labels, scores, alerts. | |
| """ | |
| cutoff_time = datetime.utcnow() - timedelta(hours=hours) | |
| samples = session.query(ECGSample).filter( | |
| and_( | |
| ECGSample.patient_id == patient_id, | |
| ECGSample.created_at >= cutoff_time | |
| ) | |
| ).order_by(ECGSample.created_at.asc()).limit(limit).all() | |
| history = [ | |
| { | |
| "sample_id": s.id, | |
| "label": s.label, | |
| "score": round(float(s.score or 0.0), 3), | |
| "alert_level": s.alert_level, | |
| "hr": s.hr, | |
| "timestamp": s.created_at.isoformat(), | |
| } | |
| for s in samples | |
| ] | |
| # Compute summary stats | |
| alert_count = sum(1 for s in samples if s.alert_level in ['notify', 'escalate']) | |
| avg_score = sum(s.score or 0.0 for s in samples) / max(len(samples), 1) | |
| return { | |
| "patient_id": patient_id, | |
| "time_window_hours": hours, | |
| "sample_count": len(samples), | |
| "alert_count": alert_count, | |
| "avg_score": round(float(avg_score), 3), | |
| "history": history, | |
| } | |