"""Tier 3 consumer-side workflows for Pulse-ER. These workflows sit above the low-level tool layer and transform the current observation or episode trace into judge-friendly and agent-friendly outputs. They are intentionally backend-agnostic so the same logic can be used against mock and real Pulse runtimes. """ from __future__ import annotations from enum import Enum from pydantic import BaseModel, ConfigDict, Field from .episode_runner import EpisodeTrace from .models import PulsePhysiologyObservation class DeteriorationStatus(str, Enum): """Clinical status buckets used for judge-facing Tier 3 reasoning.""" STABLE = "stable" MONITORING = "monitoring" DETERIORATING = "deteriorating" CRITICAL = "critical" MINOR_ALERTS = {"tachycardia", "tachypnea"} MAJOR_ALERTS = {"hypotension", "hypoxemia", "blood_loss", "cardiovascular_collapse"} MAJOR_VITALS = ( "mean_arterial_pressure_mmhg", "spo2", "heart_rate_bpm", "respiration_rate_bpm", ) CRITICAL_ACCELERATION_VITALS = ( "mean_arterial_pressure_mmhg", "spo2", "heart_rate_bpm", ) def _mental_status_value(mental_status) -> str: return getattr(mental_status, "value", str(mental_status)) def _risk_level(observation: PulsePhysiologyObservation) -> str: alerts = set(observation.active_alerts) mental_status = _mental_status_value(observation.mental_status) if observation.done or "cardiovascular_collapse" in alerts or mental_status == "unresponsive": return "critical" if {"hypotension", "blood_loss", "hypoxemia"} & alerts or mental_status in {"pain", "verbal"}: return "high" if alerts: return "moderate" return "low" def _priority_reasons(observation: PulsePhysiologyObservation) -> list[str]: alerts = set(observation.active_alerts) reasons: list[str] = [] if "blood_loss" in alerts: reasons.append("Ongoing blood loss threatens perfusion and should be controlled early.") if "hypotension" in alerts: reasons.append("Low blood pressure suggests reduced perfusion and possible shock.") if "hypoxemia" in alerts: reasons.append("Low oxygen saturation requires respiratory support or oxygen therapy.") if "tachypnea" in alerts: reasons.append("High respiratory rate suggests respiratory stress or compensation.") if "tachycardia" in alerts: reasons.append("Persistent tachycardia may reflect compensation, stress, or ongoing instability.") if not reasons: reasons.append("No active high-priority alerts are present, so reassessment and safe monitoring are appropriate.") return reasons def _mean_arterial_pressure(observation: PulsePhysiologyObservation) -> float | None: """Return MAP directly or derive it from systolic and diastolic pressure.""" if observation.mean_arterial_pressure_mmhg is not None: return observation.mean_arterial_pressure_mmhg if observation.systolic_bp_mmhg is None or observation.diastolic_bp_mmhg is None: return None return (observation.systolic_bp_mmhg + (2 * observation.diastolic_bp_mmhg)) / 3 def _spo2_percent(observation: PulsePhysiologyObservation) -> float | None: """Return oxygen saturation as a human-readable percentage when available.""" if observation.spo2 is None: return None return round(observation.spo2 * 100, 1) def _vital_value(observation: PulsePhysiologyObservation, vital_name: str) -> float | None: """Resolve one vital from an observation, including derived MAP support.""" if vital_name == "mean_arterial_pressure_mmhg": return _mean_arterial_pressure(observation) return getattr(observation, vital_name, None) def _vital_deviation(vital_name: str, value: float | None) -> float: """Measure how far a vital has drifted from its safe range for trend scoring.""" if value is None: return 0.0 if vital_name == "mean_arterial_pressure_mmhg": return max(0.0, 65.0 - value) if vital_name == "spo2": return max(0.0, 0.94 - value) if vital_name == "heart_rate_bpm": if 60.0 <= value <= 100.0: return 0.0 return min(abs(value - 60.0), abs(value - 100.0)) if vital_name == "respiration_rate_bpm": if 12.0 <= value <= 20.0: return 0.0 return min(abs(value - 12.0), abs(value - 20.0)) return 0.0 def _trend_tolerance(vital_name: str) -> float: """Return the minimum meaningful drift for one vital trend calculation.""" if vital_name == "mean_arterial_pressure_mmhg": return 3.0 if vital_name == "spo2": return 0.02 if vital_name == "heart_rate_bpm": return 5.0 if vital_name == "respiration_rate_bpm": return 2.0 return 0.0 def _recent_observations( observation: PulsePhysiologyObservation, previous_observation: PulsePhysiologyObservation | None = None, observations: list[PulsePhysiologyObservation] | None = None, *, window: int = 3, ) -> list[PulsePhysiologyObservation]: """Assemble the most recent observation window for trend-aware Tier 3 logic.""" recent: list[PulsePhysiologyObservation] = [] if observations: recent.extend(observations) elif previous_observation is not None: recent.extend([previous_observation, observation]) else: recent.append(observation) if not recent or recent[-1] != observation: recent.append(observation) return recent[-window:] def get_trend( vital_name: str, observations: list[PulsePhysiologyObservation], window: int = 3, ) -> str: """Classify a vital as improving, stable, or worsening over recent observations.""" recent = observations[-window:] deviations = [ _vital_deviation(vital_name, _vital_value(observation, vital_name)) for observation in recent ] if len(deviations) < 2: return "stable" delta = deviations[-1] - deviations[0] tolerance = _trend_tolerance(vital_name) if delta > tolerance: return "worsening" if delta < -tolerance: return "improving" return "stable" def _is_accelerating(vital_name: str, observations: list[PulsePhysiologyObservation]) -> bool: """Detect acceleration when a vital drifts farther out of range step over step.""" recent = observations[-3:] if len(recent) < 3: return False deviations = [ _vital_deviation(vital_name, _vital_value(observation, vital_name)) for observation in recent ] first_delta = deviations[1] - deviations[0] second_delta = deviations[2] - deviations[1] tolerance = _trend_tolerance(vital_name) return ( first_delta > 0 and second_delta > first_delta and second_delta > (tolerance / 2) and deviations[-1] > (2 * tolerance) ) def _hard_threshold_reason(observation: PulsePhysiologyObservation) -> str | None: """Return the first critical threshold breach, if one is present.""" map_value = _mean_arterial_pressure(observation) if map_value is not None and map_value < 50.0: return "MAP is below 50 mmHg, indicating critical perfusion failure." if observation.spo2 is not None and observation.spo2 < 0.85: return "SpO2 is below 85%, indicating critical hypoxemia." if observation.heart_rate_bpm is not None and observation.heart_rate_bpm > 150.0: return "Heart rate is above 150 bpm, indicating critical cardiovascular stress." if observation.heart_rate_bpm is not None and observation.heart_rate_bpm < 40.0: return "Heart rate is below 40 bpm, indicating critical bradycardia." return None def _classify_deterioration_status( observation: PulsePhysiologyObservation, recent_observations: list[PulsePhysiologyObservation], ) -> tuple[DeteriorationStatus, str, dict[str, str]]: """Classify status from alert severity and recent vital trends.""" alerts = set(observation.active_alerts) trend_map = { vital_name: get_trend(vital_name, recent_observations) for vital_name in MAJOR_VITALS } hard_threshold_reason = _hard_threshold_reason(observation) if hard_threshold_reason is not None: return DeteriorationStatus.CRITICAL, hard_threshold_reason, trend_map if any(_is_accelerating(vital_name, recent_observations) for vital_name in CRITICAL_ACCELERATION_VITALS): return ( DeteriorationStatus.CRITICAL, "At least one major vital is accelerating away from the safe range.", trend_map, ) worsening_major_vitals = [ vital_name for vital_name, trend in trend_map.items() if vital_name in {"mean_arterial_pressure_mmhg", "spo2", "heart_rate_bpm"} and trend == "worsening" ] minor_alert_count = len(alerts & MINOR_ALERTS) if not alerts and all(trend != "worsening" for trend in trend_map.values()): return DeteriorationStatus.STABLE, "No active alerts and no worsening vital trends are present.", trend_map if worsening_major_vitals or minor_alert_count >= 2: return ( DeteriorationStatus.DETERIORATING, "A major vital is trending the wrong way or multiple minor alerts are accumulating.", trend_map, ) return ( DeteriorationStatus.MONITORING, "Only minor or non-progressive abnormalities are present, so close monitoring is appropriate.", trend_map, ) class NextStepRecommendation(BaseModel): """Tier 3 recommendation for the next best action.""" model_config = ConfigDict(extra="forbid") scenario_id: str risk_level: str recommended_tool: str arguments: dict = Field(default_factory=dict) rationale: str alternatives: list[str] = Field(default_factory=list) class TriageSummary(BaseModel): """Tier 3 triage framing for the current patient state.""" model_config = ConfigDict(extra="forbid") scenario_id: str acuity: str headline: str active_alerts: list[str] vitals_snapshot: dict immediate_focus: list[str] class DeteriorationExplanation(BaseModel): """Tier 3 explanation of why the patient is worsening or stable.""" model_config = ConfigDict(extra="forbid") scenario_id: str status: str cascade_risk: str primary_driver: str supporting_findings: list[str] recommended_response: str class InterventionPlan(BaseModel): """Tier 3 short intervention plan based on current state.""" model_config = ConfigDict(extra="forbid") scenario_id: str risk_level: str ordered_steps: list[dict] monitoring_targets: list[str] escalation_trigger: str class EpisodeReport(BaseModel): """Tier 3 episode-level report for demos and judge summaries.""" model_config = ConfigDict(extra="forbid") scenario_id: str policy_name: str total_reward: float outcome: str key_actions: list[str] final_alerts: list[str] summary: str def recommend_next_step(observation: PulsePhysiologyObservation) -> NextStepRecommendation: """Recommend the next best tool call from the current observation.""" alerts = set(observation.active_alerts) available_tools = set(observation.available_tools) risk_level = _risk_level(observation) mean_arterial_pressure = _mean_arterial_pressure(observation) if ( {"possible_tension_pneumothorax", "unilateral_absent_breath_sounds"} & alerts and "needle_decompression" in available_tools ): return NextStepRecommendation( scenario_id=observation.scenario_id, risk_level=risk_level, recommended_tool="needle_decompression", rationale="The current alert pattern suggests tension physiology, so decompression is the highest-yield immediate action.", alternatives=[tool for tool in ("get_respiratory_status", "airway_support", "give_oxygen") if tool in available_tools], ) if "possible_cardiac_tamponade" in alerts and "pericardiocentesis" in available_tools: return NextStepRecommendation( scenario_id=observation.scenario_id, risk_level=risk_level, recommended_tool="pericardiocentesis", rationale="Possible tamponade physiology is present, so pericardial drainage is the most direct next step.", alternatives=[tool for tool in ("give_fluids", "give_pressor", "check_deterioration") if tool in available_tools], ) if "blood_loss" in alerts and "control_bleeding" in available_tools: return NextStepRecommendation( scenario_id=observation.scenario_id, risk_level=risk_level, recommended_tool="control_bleeding", rationale="Active blood loss is present and should be controlled before deterioration progresses.", alternatives=[tool for tool in ("give_fluids", "give_oxygen", "check_deterioration") if tool in available_tools], ) if ( observation.scenario_id == "hemorrhagic_shock" and "tachycardia" in alerts and "give_fluids" in available_tools ): return NextStepRecommendation( scenario_id=observation.scenario_id, risk_level=risk_level, recommended_tool="give_fluids", arguments={"volume_ml": 250}, rationale="Persistent tachycardia after initial hemorrhage control suggests the patient may still benefit from additional perfusion support.", alternatives=[tool for tool in ("check_deterioration", "summarize_state", "advance_time") if tool in available_tools], ) if ( mean_arterial_pressure is not None and mean_arterial_pressure < 65 and observation.active_infusions and "give_pressor" in available_tools ): return NextStepRecommendation( scenario_id=observation.scenario_id, risk_level=risk_level, recommended_tool="give_pressor", rationale="Perfusion remains low despite active infusions, so vasopressor support is a reasonable escalation.", alternatives=[tool for tool in ("give_fluids", "check_deterioration", "get_blood_gas") if tool in available_tools], ) if "hypotension" in alerts and "give_fluids" in available_tools: return NextStepRecommendation( scenario_id=observation.scenario_id, risk_level=risk_level, recommended_tool="give_fluids", arguments={"volume_ml": 500}, rationale="Hypotension suggests poor perfusion and fluid resuscitation is the next most direct support.", alternatives=[tool for tool in ("control_bleeding", "position_patient", "check_deterioration") if tool in available_tools], ) if "hypoxemia" in alerts and "give_oxygen" in available_tools: return NextStepRecommendation( scenario_id=observation.scenario_id, risk_level=risk_level, recommended_tool="give_oxygen", arguments={"flow_lpm": 15}, rationale="Hypoxemia is active and oxygen support is the fastest way to improve oxygenation.", alternatives=[tool for tool in ("airway_support", "position_patient", "check_deterioration") if tool in available_tools], ) if "hypoxemia" in alerts and "get_respiratory_status" in available_tools: return NextStepRecommendation( scenario_id=observation.scenario_id, risk_level=risk_level, recommended_tool="get_respiratory_status", rationale="A focused respiratory reassessment can clarify whether the next move should be oxygen, airway support, or decompression.", alternatives=[tool for tool in ("give_oxygen", "airway_support", "position_patient") if tool in available_tools], ) if "tachypnea" in alerts and "airway_support" in available_tools: return NextStepRecommendation( scenario_id=observation.scenario_id, risk_level=risk_level, recommended_tool="airway_support", arguments={"mode": "auto"}, rationale="Respiratory effort remains elevated and airway support may prevent further deterioration.", alternatives=[tool for tool in ("give_oxygen", "position_patient", "check_deterioration") if tool in available_tools], ) for diagnostic_tool in ("get_blood_gas", "get_cbc", "get_bmp"): if diagnostic_tool in available_tools and diagnostic_tool in observation.ready_diagnostics: return NextStepRecommendation( scenario_id=observation.scenario_id, risk_level=risk_level, recommended_tool=diagnostic_tool, rationale="A diagnostic result is ready and should be reviewed before the next intervention sequence.", alternatives=[tool for tool in ("summarize_state", "check_deterioration", "get_vitals") if tool in available_tools], ) if "check_deterioration" in available_tools: return NextStepRecommendation( scenario_id=observation.scenario_id, risk_level=risk_level, recommended_tool="check_deterioration", rationale="The patient is not in obvious immediate crisis, so reassessment is the safest next step.", alternatives=[tool for tool in ("summarize_state", "advance_time") if tool in available_tools], ) return NextStepRecommendation( scenario_id=observation.scenario_id, risk_level=risk_level, recommended_tool="advance_time", arguments={"seconds": 30}, rationale="No higher-priority intervention is exposed, so advance time to generate the next signal.", alternatives=[], ) def build_triage_summary(observation: PulsePhysiologyObservation) -> TriageSummary: """Generate a compact triage summary from the current state.""" acuity = _risk_level(observation) alerts = list(observation.active_alerts) mental_status = _mental_status_value(observation.mental_status) headline = ( f"{observation.scenario_id}: {acuity.upper()} acuity with " f"HR {observation.heart_rate_bpm:.0f}, " f"BP {observation.systolic_bp_mmhg:.0f}/{observation.diastolic_bp_mmhg:.0f}, " f"SpO2 {_spo2_percent(observation):.1f}%, " f"mental status {mental_status}." ) return TriageSummary( scenario_id=observation.scenario_id, acuity=acuity, headline=headline, active_alerts=alerts, vitals_snapshot={ "heart_rate_bpm": observation.heart_rate_bpm, "systolic_bp_mmhg": observation.systolic_bp_mmhg, "diastolic_bp_mmhg": observation.diastolic_bp_mmhg, "spo2": observation.spo2, "spo2_percent": _spo2_percent(observation), "respiration_rate_bpm": observation.respiration_rate_bpm, "blood_volume_ml": observation.blood_volume_ml, }, immediate_focus=_priority_reasons(observation), ) def explain_deterioration( observation: PulsePhysiologyObservation, previous_observation: PulsePhysiologyObservation | None = None, observations: list[PulsePhysiologyObservation] | None = None, ) -> DeteriorationExplanation: """Explain the likely deterioration driver or current stability.""" recent_observations = _recent_observations( observation, previous_observation, observations, window=3, ) alerts = set(observation.active_alerts) status, status_reason, trend_map = _classify_deterioration_status(observation, recent_observations) if "blood_loss" in alerts: primary_driver = "hemorrhagic shock physiology" response = "Control bleeding and support perfusion with fluids before reassessing." elif observation.scenario_id == "hemorrhagic_shock" and "tachycardia" in alerts: primary_driver = "residual shock burden after initial resuscitation" response = "Reassess perfusion closely and consider additional volume support if the trend does not settle." elif "hypoxemia" in alerts: primary_driver = "respiratory decompensation" response = "Provide oxygen and airway or positioning support, then reassess oxygenation." elif "tachycardia" in alerts and "hypotension" in alerts: primary_driver = "compensated shock" response = "Support perfusion and reassess for ongoing blood loss or inadequate resuscitation." elif alerts: primary_driver = "ongoing physiological stress" response = "Use focused reassessment and the highest-yield intervention exposed by the current tool set." else: primary_driver = "no active deterioration signal" response = "Continue reassessment and controlled monitoring over time." supporting_findings = _priority_reasons(observation) supporting_findings.append(status_reason) if trend_map["spo2"] == "worsening": supporting_findings.append("Oxygenation is worsening over the recent observation window.") elif trend_map["spo2"] == "improving": supporting_findings.append("Oxygenation is improving over the recent observation window.") if trend_map["mean_arterial_pressure_mmhg"] == "worsening": supporting_findings.append("Perfusion is worsening based on the recent MAP trend.") elif trend_map["mean_arterial_pressure_mmhg"] == "improving": supporting_findings.append("Perfusion is improving based on the recent MAP trend.") if trend_map["heart_rate_bpm"] == "worsening": supporting_findings.append("Heart rate is drifting farther from the safe range.") if trend_map["respiration_rate_bpm"] == "worsening": supporting_findings.append("Respiratory rate is moving in the wrong direction.") cascade_risk = "low" if status == DeteriorationStatus.DETERIORATING: cascade_risk = "medium" elif status == DeteriorationStatus.CRITICAL: cascade_risk = "imminent" return DeteriorationExplanation( scenario_id=observation.scenario_id, status=status.value, cascade_risk=cascade_risk, primary_driver=primary_driver, supporting_findings=supporting_findings, recommended_response=response, ) def generate_intervention_plan(observation: PulsePhysiologyObservation) -> InterventionPlan: """Create a short ordered plan from the current observation.""" recommendation = recommend_next_step(observation) steps: list[dict] = [ { "priority": 1, "tool_name": recommendation.recommended_tool, "arguments": recommendation.arguments, "why": recommendation.rationale, } ] priority = 2 for alternative in recommendation.alternatives[:3]: steps.append( { "priority": priority, "tool_name": alternative, "arguments": {}, "why": f"Keep {alternative} ready if the primary step does not adequately stabilize the patient.", } ) priority += 1 return InterventionPlan( scenario_id=observation.scenario_id, risk_level=recommendation.risk_level, ordered_steps=steps, monitoring_targets=[ "heart_rate_bpm", "systolic_bp_mmhg", "spo2", "respiration_rate_bpm", "active_alerts", ], escalation_trigger="Escalate if alerts increase, mental status worsens, or perfusion/oxygenation declines after intervention.", ) def build_episode_report(trace: EpisodeTrace) -> EpisodeReport: """Summarize one episode into a compact Tier 3 report.""" final_alerts = list(trace.final_observation.active_alerts) if trace.final_observation.done: outcome = "critical deterioration" elif final_alerts: outcome = "partially stabilized" else: outcome = "stabilized" key_actions: list[str] = [] for step in trace.steps: if step.action.tool_name not in key_actions: key_actions.append(step.action.tool_name) summary = ( f"{trace.policy_name} completed {trace.num_steps} steps in {trace.scenario_id} " f"with total reward {trace.total_reward:.3f}; outcome: {outcome}." ) return EpisodeReport( scenario_id=trace.scenario_id, policy_name=trace.policy_name, total_reward=trace.total_reward, outcome=outcome, key_actions=key_actions, final_alerts=final_alerts, summary=summary, )