Spaces:
Sleeping
Sleeping
| """Backend monitor payload builder for the hackathon demo dashboard.""" | |
| from __future__ import annotations | |
| from dataclasses import asdict, dataclass | |
| from typing import Any, Literal, Sequence | |
| from pulse_physiology_env.patient_state import PatientState | |
| from .reward_engine import ActionRecord | |
| MonitorStatus = Literal["stable", "warning", "critical", "improving", "info"] | |
| class MonitorTile: | |
| """One monitor tile for the eventual UI.""" | |
| key: str | |
| label: str | |
| value: str | |
| status: MonitorStatus | |
| trend_delta: float | None = None | |
| class MonitorTrendPoint: | |
| """One sampled point for the patient monitor charts.""" | |
| sim_time_s: float | |
| mean_arterial_pressure_mmhg: float | None | |
| spo2: float | None | |
| heart_rate_bpm: float | None | |
| respiration_rate_bpm: float | None | |
| etco2_mmhg: float | None | |
| class MonitorEvent: | |
| """Recent event for the intervention and milestone feed.""" | |
| sim_time_s: float | |
| label: str | |
| detail: str | |
| tone: MonitorStatus | |
| class PatientMonitorPayload: | |
| """Structured monitor payload that a frontend can render directly.""" | |
| headline: str | |
| status_banner: str | |
| scenario_2_moment: str | None | |
| tiles: list[MonitorTile] | |
| alerts: list[str] | |
| active_interventions: list[str] | |
| diagnostic_status: dict[str, Any] | |
| recent_events: list[MonitorEvent] | |
| trends: list[MonitorTrendPoint] | |
| def as_dict(self) -> dict[str, Any]: | |
| """Return a JSON-serializable payload.""" | |
| return asdict(self) | |
| class PatientMonitorVisualization: | |
| """Transforms environment state into a demo-friendly monitor payload.""" | |
| MAX_TREND_POINTS = 24 | |
| MAX_RECENT_EVENTS = 6 | |
| def build( | |
| self, | |
| *, | |
| history: Sequence[PatientState], | |
| action_history: Sequence[ActionRecord], | |
| current_state: PatientState, | |
| ) -> PatientMonitorPayload: | |
| previous_state = history[-2] if len(history) >= 2 else None | |
| return PatientMonitorPayload( | |
| headline=self._build_headline(previous_state, current_state, action_history), | |
| status_banner=self._build_status_banner(current_state), | |
| scenario_2_moment=self._build_scenario_2_moment(previous_state, current_state, action_history), | |
| tiles=self._build_tiles(previous_state, current_state), | |
| alerts=list(current_state.active_alerts), | |
| active_interventions=self._build_active_interventions(current_state), | |
| diagnostic_status={ | |
| "pending": dict(current_state.pending_diagnostics), | |
| "ready": list(current_state.ready_diagnostics), | |
| }, | |
| recent_events=self._build_recent_events(action_history), | |
| trends=self._build_trends(history), | |
| ) | |
| def _build_headline( | |
| self, | |
| previous_state: PatientState | None, | |
| current_state: PatientState, | |
| action_history: Sequence[ActionRecord], | |
| ) -> str: | |
| if previous_state is not None: | |
| if current_state.spo2 is not None and previous_state.spo2 is not None: | |
| spo2_jump = current_state.spo2 - previous_state.spo2 | |
| if spo2_jump >= 0.08 and self._last_action_has_tag(action_history, "needle_decompression"): | |
| return ( | |
| "Respiratory rescue visible: SpO2 improved from " | |
| f"{previous_state.spo2:.2f} to {current_state.spo2:.2f} after decompression." | |
| ) | |
| if ( | |
| current_state.mean_arterial_pressure_mmhg is not None | |
| and previous_state.mean_arterial_pressure_mmhg is not None | |
| and previous_state.mean_arterial_pressure_mmhg < 65.0 <= current_state.mean_arterial_pressure_mmhg | |
| ): | |
| return "Perfusion restored: MAP crossed the 65 mmHg resuscitation threshold." | |
| if current_state.spo2 is not None and current_state.spo2 < 0.88: | |
| return "Critical respiratory compromise on monitor; airway and chest interventions should be obvious." | |
| if current_state.mean_arterial_pressure_mmhg is not None and current_state.mean_arterial_pressure_mmhg < 55.0: | |
| return "Critical perfusion deficit on monitor; immediate resuscitation sequence is needed." | |
| return "Live trauma monitor ready for the decompression-first demo moment." | |
| def _build_status_banner(current_state: PatientState) -> str: | |
| if "cardiac_arrest" in current_state.active_alerts: | |
| return "Cardiac arrest physiology active." | |
| if current_state.spo2 is not None and current_state.spo2 < 0.88: | |
| return "Oxygenation is critical." | |
| if current_state.mean_arterial_pressure_mmhg is not None and current_state.mean_arterial_pressure_mmhg < 55.0: | |
| return "Perfusion is critical." | |
| if current_state.active_alerts: | |
| return "Active trauma alerts require attention." | |
| return "Physiology is currently stable enough for continued observation." | |
| def _build_scenario_2_moment( | |
| self, | |
| previous_state: PatientState | None, | |
| current_state: PatientState, | |
| action_history: Sequence[ActionRecord], | |
| ) -> str | None: | |
| if previous_state is None: | |
| return None | |
| if not self._last_action_has_tag(action_history, "needle_decompression"): | |
| return None | |
| if previous_state.spo2 is None or current_state.spo2 is None: | |
| return None | |
| improvement = current_state.spo2 - previous_state.spo2 | |
| if improvement < 0.05: | |
| return None | |
| return ( | |
| "Scenario 2 moment: decompression produced a visible SpO2 jump " | |
| f"from {previous_state.spo2:.2f} to {current_state.spo2:.2f}." | |
| ) | |
| def _build_tiles( | |
| self, | |
| previous_state: PatientState | None, | |
| current_state: PatientState, | |
| ) -> list[MonitorTile]: | |
| return [ | |
| self._tile( | |
| key="heart_rate_bpm", | |
| label="Heart Rate", | |
| value=self._format_value(current_state.heart_rate_bpm, "bpm"), | |
| status=self._status_for_heart_rate(current_state.heart_rate_bpm), | |
| trend_delta=self._delta(previous_state, current_state, "heart_rate_bpm"), | |
| ), | |
| self._tile( | |
| key="mean_arterial_pressure_mmhg", | |
| label="MAP", | |
| value=self._format_value(current_state.mean_arterial_pressure_mmhg, "mmHg"), | |
| status=self._status_for_map(current_state.mean_arterial_pressure_mmhg), | |
| trend_delta=self._delta(previous_state, current_state, "mean_arterial_pressure_mmhg"), | |
| ), | |
| self._tile( | |
| key="spo2", | |
| label="SpO2", | |
| value=self._format_fraction(current_state.spo2), | |
| status=self._status_for_spo2(current_state.spo2), | |
| trend_delta=self._delta(previous_state, current_state, "spo2"), | |
| ), | |
| self._tile( | |
| key="respiration_rate_bpm", | |
| label="Respiratory Rate", | |
| value=self._format_value(current_state.respiration_rate_bpm, "bpm"), | |
| status=self._status_for_respiration(current_state.respiration_rate_bpm), | |
| trend_delta=self._delta(previous_state, current_state, "respiration_rate_bpm"), | |
| ), | |
| self._tile( | |
| key="etco2_mmhg", | |
| label="EtCO2", | |
| value=self._format_value(current_state.etco2_mmhg, "mmHg"), | |
| status=self._status_for_etco2(current_state.etco2_mmhg), | |
| trend_delta=self._delta(previous_state, current_state, "etco2_mmhg"), | |
| ), | |
| self._tile( | |
| key="shock_index", | |
| label="Shock Index", | |
| value=self._format_number(current_state.shock_index, precision=2), | |
| status=self._status_for_shock_index(current_state.shock_index), | |
| trend_delta=self._delta(previous_state, current_state, "shock_index"), | |
| ), | |
| ] | |
| def _build_active_interventions(self, current_state: PatientState) -> list[str]: | |
| interventions: list[str] = [] | |
| if current_state.oxygen_device is not None: | |
| flow = self._format_number(current_state.oxygen_flow_lpm) | |
| interventions.append(f"Oxygen: {current_state.oxygen_device} at {flow} L/min") | |
| if current_state.airway_support is not None: | |
| interventions.append(f"Airway support: {current_state.airway_support}") | |
| if current_state.intubated: | |
| interventions.append("Airway secured") | |
| for name, rate in sorted(current_state.active_infusions.items()): | |
| interventions.append(f"Infusion: {name} at {self._format_number(rate)}") | |
| for site, flow in sorted(current_state.active_hemorrhages.items()): | |
| interventions.append(f"Active hemorrhage: {site} at {self._format_number(flow)} mL/min") | |
| return interventions | |
| def _build_recent_events(self, action_history: Sequence[ActionRecord]) -> list[MonitorEvent]: | |
| events: list[MonitorEvent] = [] | |
| for record in action_history[-self.MAX_RECENT_EVENTS :]: | |
| tone: MonitorStatus = "info" | |
| if "needle_decompression" in record.tags or "pericardiocentesis" in record.tags: | |
| tone = "improving" | |
| elif record.tool_name in {"perform_cpr", "induce_cardiac_arrest"}: | |
| tone = "critical" | |
| elif record.tool_name in {"apply_tourniquet", "control_bleeding", "apply_direct_pressure"}: | |
| tone = "warning" | |
| events.append( | |
| MonitorEvent( | |
| sim_time_s=record.sim_time_s, | |
| label=record.tool_name, | |
| detail=f"successful={record.success}", | |
| tone=tone, | |
| ) | |
| ) | |
| return events | |
| def _build_trends(self, history: Sequence[PatientState]) -> list[MonitorTrendPoint]: | |
| trend_states = history[-self.MAX_TREND_POINTS :] | |
| return [ | |
| MonitorTrendPoint( | |
| sim_time_s=state.sim_time_s, | |
| mean_arterial_pressure_mmhg=state.mean_arterial_pressure_mmhg, | |
| spo2=state.spo2, | |
| heart_rate_bpm=state.heart_rate_bpm, | |
| respiration_rate_bpm=state.respiration_rate_bpm, | |
| etco2_mmhg=state.etco2_mmhg, | |
| ) | |
| for state in trend_states | |
| ] | |
| def _tile( | |
| *, | |
| key: str, | |
| label: str, | |
| value: str, | |
| status: MonitorStatus, | |
| trend_delta: float | None, | |
| ) -> MonitorTile: | |
| return MonitorTile( | |
| key=key, | |
| label=label, | |
| value=value, | |
| status=status, | |
| trend_delta=None if trend_delta is None else round(trend_delta, 3), | |
| ) | |
| def _delta( | |
| previous_state: PatientState | None, | |
| current_state: PatientState, | |
| field_name: str, | |
| ) -> float | None: | |
| if previous_state is None: | |
| return None | |
| previous_value = getattr(previous_state, field_name) | |
| current_value = getattr(current_state, field_name) | |
| if previous_value is None or current_value is None: | |
| return None | |
| return float(current_value) - float(previous_value) | |
| def _format_value(value: float | None, unit: str) -> str: | |
| if value is None: | |
| return "n/a" | |
| return f"{value:.1f} {unit}" | |
| def _format_fraction(value: float | None) -> str: | |
| if value is None: | |
| return "n/a" | |
| return f"{value:.2f}" | |
| def _format_number(value: float | None, *, precision: int = 1) -> str: | |
| if value is None: | |
| return "n/a" | |
| return f"{value:.{precision}f}" | |
| def _last_action_has_tag(action_history: Sequence[ActionRecord], tag: str) -> bool: | |
| if not action_history: | |
| return False | |
| last_record = action_history[-1] | |
| return last_record.success and tag in last_record.tags | |
| def _status_for_heart_rate(value: float | None) -> MonitorStatus: | |
| if value is None: | |
| return "info" | |
| if value < 40 or value > 140: | |
| return "critical" | |
| if value < 55 or value > 120: | |
| return "warning" | |
| return "stable" | |
| def _status_for_map(value: float | None) -> MonitorStatus: | |
| if value is None: | |
| return "info" | |
| if value < 55: | |
| return "critical" | |
| if value < 65: | |
| return "warning" | |
| return "stable" | |
| def _status_for_spo2(value: float | None) -> MonitorStatus: | |
| if value is None: | |
| return "info" | |
| if value < 0.88: | |
| return "critical" | |
| if value < 0.94: | |
| return "warning" | |
| return "stable" | |
| def _status_for_respiration(value: float | None) -> MonitorStatus: | |
| if value is None: | |
| return "info" | |
| if value < 8 or value > 35: | |
| return "critical" | |
| if value < 10 or value > 28: | |
| return "warning" | |
| return "stable" | |
| def _status_for_etco2(value: float | None) -> MonitorStatus: | |
| if value is None: | |
| return "info" | |
| if value < 20 or value > 60: | |
| return "critical" | |
| if value < 25 or value > 50: | |
| return "warning" | |
| return "stable" | |
| def _status_for_shock_index(value: float | None) -> MonitorStatus: | |
| if value is None: | |
| return "info" | |
| if value > 1.3: | |
| return "critical" | |
| if value > 0.9: | |
| return "warning" | |
| return "stable" | |