Spaces:
Sleeping
Sleeping
| """Reward engine for Pulse-ER with dense shaping, terminal scoring, and anti-exploitation guards.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| from pulse_physiology_env.models import PulsePhysiologyAction | |
| from pulse_physiology_env.patient_state import PatientState | |
| from pulse_physiology_env.tool_catalog import coerce_boolean_argument, normalize_contract_token | |
| from .scenarios import ScenarioDefinition | |
| class ActionRecord: | |
| """One action attempt in the reward engine history.""" | |
| tool_name: str | |
| arguments: dict[str, Any] | |
| sim_time_s: float | |
| success: bool | |
| tags: tuple[str, ...] | |
| class RewardTracker: | |
| """Per-episode reward bookkeeping.""" | |
| reward_profile: str | |
| action_budget_remaining: int | |
| last_tool_name: str | None = None | |
| same_tool_called_consecutively: int = 0 | |
| steps_since_last_diagnostic_review: int = 0 | |
| diagnostics_ordered: set[str] = field(default_factory=set) | |
| action_history: list[ActionRecord] = field(default_factory=list) | |
| time_to_stabilize_s: float | None = None | |
| class RewardBreakdown: | |
| """Structured dense and terminal reward components for one step.""" | |
| dense_total: float = 0.0 | |
| terminal_total: float = 0.0 | |
| total: float = 0.0 | |
| r_map_stability: float = 0.0 | |
| r_spo2_efficiency: float = 0.0 | |
| r_lactate_trend: float = 0.0 | |
| r_intervention_safety: float = 0.0 | |
| r_diagnostic_timeliness: float = 0.0 | |
| r_anti_exploitation: float = 0.0 | |
| r_time_pressure: float = 0.0 | |
| survival_bonus: float = 0.0 | |
| time_efficiency_bonus: float = 0.0 | |
| sequence_quality_bonus: float = 0.0 | |
| difficulty_multiplier: float = 1.0 | |
| reward_profile: str = "polytrauma" | |
| action_budget_remaining: int = 0 | |
| same_tool_called_consecutively: int = 0 | |
| steps_since_last_diagnostic_review: int = 0 | |
| terminal_applied: bool = False | |
| def as_metadata(self) -> dict[str, Any]: | |
| return { | |
| "total": self.total, | |
| "dense_total": self.dense_total, | |
| "terminal_total": self.terminal_total, | |
| "r_map_stability": self.r_map_stability, | |
| "r_spo2_efficiency": self.r_spo2_efficiency, | |
| "r_lactate_trend": self.r_lactate_trend, | |
| "r_intervention_safety": self.r_intervention_safety, | |
| "r_diagnostic_timeliness": self.r_diagnostic_timeliness, | |
| "r_anti_exploitation": self.r_anti_exploitation, | |
| "r_time_pressure": self.r_time_pressure, | |
| "survival_bonus": self.survival_bonus, | |
| "time_efficiency_bonus": self.time_efficiency_bonus, | |
| "sequence_quality_bonus": self.sequence_quality_bonus, | |
| "difficulty_multiplier": self.difficulty_multiplier, | |
| "reward_profile": self.reward_profile, | |
| "action_budget_remaining": self.action_budget_remaining, | |
| "same_tool_called_consecutively": self.same_tool_called_consecutively, | |
| "steps_since_last_diagnostic_review": self.steps_since_last_diagnostic_review, | |
| "terminal_applied": self.terminal_applied, | |
| } | |
| class RewardEngine: | |
| """Implements the Pulse-ER reward design with dense, terminal, and sequence-aware signals.""" | |
| MAP_TARGET = 65.0 | |
| HYPOXIA_THRESHOLD = 0.90 | |
| SPO2_TARGET = 0.94 | |
| DEFAULT_ACTION_BUDGET = 30 | |
| DIAGNOSTIC_ORDER_WINDOW_S = 180.0 | |
| DIAGNOSTIC_NEGLECT_WINDOW_S = 300.0 | |
| READY_DIAGNOSTIC_GRACE_STEPS = 5 | |
| DIAGNOSTIC_TOOL_ALIASES = { | |
| "get_blood_gas": "order_arterial_blood_gas", | |
| "order_arterial_blood_gas": "order_arterial_blood_gas", | |
| "get_cbc": "order_complete_blood_count", | |
| "order_complete_blood_count": "order_complete_blood_count", | |
| "get_bmp": "order_basic_metabolic_panel", | |
| "order_basic_metabolic_panel": "order_basic_metabolic_panel", | |
| } | |
| DIAGNOSTIC_TOOLS = frozenset((*DIAGNOSTIC_TOOL_ALIASES, "order_point_of_care_ultrasound")) | |
| BEDSIDE_ASSESSMENT_TOOLS = frozenset( | |
| { | |
| "auscultate_chest", | |
| "assess_consciousness_level", | |
| "check_pain_level", | |
| "measure_core_temperature", | |
| "check_end_tidal_co2", | |
| "calculate_shock_index", | |
| "assess_urine_output", | |
| "run_triage_assessment", | |
| "detect_deterioration", | |
| "check_deterioration", | |
| "get_vitals", | |
| "get_hemodynamics", | |
| "get_blood_chemistry", | |
| "get_respiratory_state", | |
| "get_respiratory_status", | |
| "get_shock_assessment", | |
| "summarize_state", | |
| } | |
| ) | |
| RESPIRATORY_ASSESSMENT_TOOLS = frozenset( | |
| { | |
| "auscultate_chest", | |
| "check_end_tidal_co2", | |
| "get_respiratory_state", | |
| "get_respiratory_status", | |
| "run_triage_assessment", | |
| "detect_deterioration", | |
| "check_deterioration", | |
| "summarize_state", | |
| } | |
| ) | |
| CARDIAC_ASSESSMENT_TOOLS = frozenset( | |
| { | |
| "get_vitals", | |
| "get_hemodynamics", | |
| "get_blood_chemistry", | |
| "get_shock_assessment", | |
| "calculate_shock_index", | |
| "assess_urine_output", | |
| "run_triage_assessment", | |
| "detect_deterioration", | |
| "check_deterioration", | |
| "summarize_state", | |
| "order_point_of_care_ultrasound", | |
| } | |
| ) | |
| BLEEDING_CONTROL_TOOLS = frozenset({"control_bleeding", "apply_tourniquet", "apply_wound_packing", "apply_direct_pressure"}) | |
| PRESSOR_TOOLS = frozenset( | |
| { | |
| "give_pressor", | |
| "start_norepinephrine_infusion", | |
| "start_dopamine_infusion", | |
| "start_phenylephrine_infusion", | |
| "adjust_infusion_rate", | |
| "stop_infusion", | |
| } | |
| ) | |
| DECOMPRESSION_TOOLS = frozenset({"needle_decompression", "perform_needle_decompression"}) | |
| PERICARDIOCENTESIS_TOOLS = frozenset({"pericardiocentesis", "perform_pericardiocentesis"}) | |
| FLUID_TOOLS = frozenset( | |
| { | |
| "give_fluids", | |
| "administer_crystalloid_bolus", | |
| "administer_blood_transfusion", | |
| "administer_plasma", | |
| "activate_massive_transfusion_protocol", | |
| } | |
| ) | |
| OXYGEN_TOOLS = frozenset( | |
| { | |
| "give_oxygen", | |
| "apply_nasal_cannula", | |
| "apply_simple_mask", | |
| "apply_nonrebreather_mask", | |
| } | |
| ) | |
| ADVANCED_AIRWAY_TOOLS = frozenset( | |
| { | |
| "airway_support", | |
| "apply_bag_valve_mask", | |
| "perform_intubation", | |
| "set_ventilator_tidal_volume", | |
| "set_ventilator_rate", | |
| "set_ventilator_fio2", | |
| } | |
| ) | |
| RSI_PREOXYGENATION_TOOLS = frozenset( | |
| { | |
| "give_oxygen", | |
| "apply_nasal_cannula", | |
| "apply_simple_mask", | |
| "apply_nonrebreather_mask", | |
| "airway_support", | |
| "apply_bag_valve_mask", | |
| } | |
| ) | |
| RSI_INDUCTION_TOOLS = frozenset( | |
| { | |
| "administer_ketamine_bolus", | |
| "administer_midazolam_bolus", | |
| "administer_lorazepam_bolus", | |
| } | |
| ) | |
| RSI_PARALYTIC_TOOLS = frozenset({"administer_succinylcholine_bolus"}) | |
| RSI_INTUBATION_TOOLS = frozenset({"perform_intubation"}) | |
| CPR_TOOLS = frozenset({"perform_cpr"}) | |
| RESTRICTED_TOOLS = frozenset({"initiate_hemorrhage", "induce_cardiac_arrest", "apply_pericardial_effusion"}) | |
| BLOOD_PRODUCTS = frozenset({"blood", "packed_rbc", "packed_rbcs", "prbc", "prbcs"}) | |
| CRYSTALLOIDS = frozenset({"saline", "crystalloid"}) | |
| DIFFICULTY_MULTIPLIER = {"easy": 1.0, "medium": 1.5, "hard": 2.0} | |
| SCENARIO_MILESTONES: dict[str, list[tuple[str, float]]] = { | |
| "tension_pneumothorax": [ | |
| ("respiratory_assessment", 0.3), | |
| ("needle_decompression", 0.4), | |
| ("crystalloid_after_decomp", 0.2), | |
| ("pressor_if_needed", 0.1), | |
| ], | |
| "hemorrhagic_shock": [ | |
| ("bleeding_control", 0.4), | |
| ("crystalloid_bolus", 0.3), | |
| ("pressor_if_needed", 0.2), | |
| ("blood_transfusion_if_severe", 0.1), | |
| ], | |
| "cardiac_tamponade": [ | |
| ("cardiac_assessment", 0.3), | |
| ("pericardiocentesis", 0.5), | |
| ("crystalloid_bolus", 0.2), | |
| ], | |
| "polytrauma": [ | |
| ("respiratory_assessment", 0.15), | |
| ("needle_decompression", 0.25), | |
| ("bleeding_control", 0.25), | |
| ("blood_transfusion_if_severe", 0.2), | |
| ("pressor_if_needed", 0.15), | |
| ], | |
| } | |
| def start_episode(self, scenario: ScenarioDefinition, initial_state: PatientState) -> RewardTracker: | |
| action_budget = max(10, int(round(scenario.max_time_s / 60.0))) | |
| tracker = RewardTracker( | |
| reward_profile=scenario.reward_profile, | |
| action_budget_remaining=action_budget, | |
| ) | |
| if self._is_stabilized(initial_state): | |
| tracker.time_to_stabilize_s = initial_state.sim_time_s | |
| return tracker | |
| def score_step( | |
| self, | |
| tracker: RewardTracker, | |
| *, | |
| scenario: ScenarioDefinition, | |
| before: PatientState, | |
| after: PatientState, | |
| action: PulsePhysiologyAction, | |
| success: bool, | |
| had_error: bool, | |
| time_pressure_multiplier: float = 1.0, | |
| ) -> RewardBreakdown: | |
| tool_name = action.tool_name.strip() | |
| arguments = dict(action.arguments) | |
| self._update_tracker(tracker, before, after, tool_name, arguments, success) | |
| breakdown = RewardBreakdown( | |
| reward_profile=tracker.reward_profile, | |
| difficulty_multiplier=self.DIFFICULTY_MULTIPLIER[scenario.difficulty], | |
| action_budget_remaining=tracker.action_budget_remaining, | |
| same_tool_called_consecutively=tracker.same_tool_called_consecutively, | |
| steps_since_last_diagnostic_review=tracker.steps_since_last_diagnostic_review, | |
| ) | |
| breakdown.r_map_stability = self._reward_map_stability(before, after) | |
| breakdown.r_spo2_efficiency = self._reward_spo2_efficiency(before, after) | |
| breakdown.r_lactate_trend = self._reward_lactate_trend(after) | |
| breakdown.r_intervention_safety = self._reward_intervention_safety(tracker, before, after, tool_name, arguments) | |
| breakdown.r_diagnostic_timeliness = self._reward_diagnostic_timeliness(tracker, before, after, tool_name, arguments) | |
| breakdown.r_anti_exploitation = self._reward_anti_exploitation(tracker, after, success, had_error) | |
| breakdown.r_time_pressure = self._reward_time_pressure(after, tool_name, time_pressure_multiplier) | |
| breakdown.dense_total = ( | |
| 0.35 * breakdown.r_map_stability | |
| + 0.25 * breakdown.r_spo2_efficiency | |
| + 0.20 * breakdown.r_lactate_trend | |
| + 0.10 * breakdown.r_intervention_safety | |
| + 0.10 * breakdown.r_diagnostic_timeliness | |
| + breakdown.r_anti_exploitation | |
| + breakdown.r_time_pressure | |
| ) | |
| if after.done: | |
| breakdown.terminal_applied = True | |
| breakdown.survival_bonus = 5.0 if self._is_alive(after) else -5.0 | |
| breakdown.time_efficiency_bonus = self._time_efficiency_bonus(tracker, scenario, after) | |
| breakdown.sequence_quality_bonus = self.evaluate_milestone_sequence( | |
| tracker.action_history, | |
| tracker.reward_profile, | |
| ) | |
| breakdown.terminal_total = ( | |
| breakdown.survival_bonus | |
| + breakdown.time_efficiency_bonus | |
| + breakdown.sequence_quality_bonus | |
| ) * breakdown.difficulty_multiplier | |
| breakdown.total = float(max(-30.0, min(30.0, breakdown.dense_total + breakdown.terminal_total))) | |
| return breakdown | |
| def evaluate_milestone_sequence( | |
| self, | |
| action_history: list[ActionRecord], | |
| reward_profile: str, | |
| ) -> float: | |
| milestones = self.SCENARIO_MILESTONES.get(reward_profile, self.SCENARIO_MILESTONES["polytrauma"]) | |
| score = 0.0 | |
| last_idx = -1 | |
| for milestone_name, weight in milestones: | |
| idx = self._find_milestone_index(action_history, milestone_name) | |
| if idx > last_idx: | |
| score += weight | |
| last_idx = idx | |
| elif idx != -1: | |
| score += weight * 0.3 | |
| return score * 2.0 | |
| def _update_tracker( | |
| self, | |
| tracker: RewardTracker, | |
| before: PatientState, | |
| after: PatientState, | |
| tool_name: str, | |
| arguments: dict[str, Any], | |
| success: bool, | |
| ) -> None: | |
| if tracker.last_tool_name == tool_name: | |
| tracker.same_tool_called_consecutively += 1 | |
| else: | |
| tracker.last_tool_name = tool_name | |
| tracker.same_tool_called_consecutively = 1 | |
| tracker.action_budget_remaining = max(0, tracker.action_budget_remaining - 1) | |
| diagnostic_key = self._canonical_diagnostic_key(tool_name, arguments) | |
| diagnostic_review = diagnostic_key is not None and diagnostic_key in before.ready_diagnostics | |
| if diagnostic_review: | |
| tracker.steps_since_last_diagnostic_review = 0 | |
| elif after.ready_diagnostics: | |
| tracker.steps_since_last_diagnostic_review += 1 | |
| else: | |
| tracker.steps_since_last_diagnostic_review = 0 | |
| if ( | |
| diagnostic_key is not None | |
| and diagnostic_key not in before.pending_diagnostics | |
| and diagnostic_key not in before.ready_diagnostics | |
| ): | |
| tracker.diagnostics_ordered.add(diagnostic_key) | |
| tags = self._extract_action_tags(tracker.action_history, tool_name, arguments, success) | |
| tracker.action_history.append( | |
| ActionRecord( | |
| tool_name=tool_name, | |
| arguments=arguments, | |
| sim_time_s=after.sim_time_s, | |
| success=success, | |
| tags=tags, | |
| ) | |
| ) | |
| if tracker.time_to_stabilize_s is None and self._is_stabilized(after): | |
| tracker.time_to_stabilize_s = after.sim_time_s | |
| def _reward_map_stability(self, before: PatientState, after: PatientState) -> float: | |
| previous_map = before.mean_arterial_pressure_mmhg or 0.0 | |
| current_map = after.mean_arterial_pressure_mmhg or 0.0 | |
| reward = self._clip((current_map - previous_map) / self.MAP_TARGET, -1.0, 1.0) | |
| if previous_map < self.MAP_TARGET <= current_map: | |
| reward += 0.3 | |
| return reward | |
| def _reward_spo2_efficiency(self, before: PatientState, after: PatientState) -> float: | |
| previous_spo2 = before.spo2 or 0.0 | |
| current_spo2 = after.spo2 or 0.0 | |
| reward = (current_spo2 - previous_spo2) * 10.0 | |
| if current_spo2 < self.HYPOXIA_THRESHOLD: | |
| reward -= 0.2 | |
| if previous_spo2 < self.SPO2_TARGET <= current_spo2: | |
| reward += 0.2 | |
| return self._clip(reward, -1.5, 1.5) | |
| def _reward_lactate_trend(after: PatientState) -> float: | |
| if after.lactate_trend == "worsening": | |
| return -1.0 | |
| if after.lactate_trend == "improving": | |
| return 0.2 | |
| return 0.0 | |
| def _reward_intervention_safety( | |
| self, | |
| tracker: RewardTracker, | |
| before: PatientState, | |
| after: PatientState, | |
| tool_name: str, | |
| arguments: dict[str, Any], | |
| ) -> float: | |
| reward = 0.0 | |
| if tool_name in self.FLUID_TOOLS: | |
| fluid_type = self._normalize_fluid_type(tool_name, arguments) | |
| if self._has_pneumothorax_signs(before) and not self._history_has_tag(tracker.action_history, "needle_decompression"): | |
| reward -= 0.8 | |
| if fluid_type in self.BLOOD_PRODUCTS and self._is_severe_bleed(before): | |
| reward += 0.1 | |
| if tool_name == "activate_massive_transfusion_protocol": | |
| reward += 0.15 if self._is_severe_bleed(before) else -0.2 | |
| if self._is_pressor_activation(tool_name, arguments) and self._is_volume_depleted(before): | |
| reward -= 0.5 | |
| if tool_name in self.DECOMPRESSION_TOOLS and not self._has_pneumothorax_signs(before): | |
| reward -= 0.4 | |
| if tool_name in self.PERICARDIOCENTESIS_TOOLS: | |
| if "possible_cardiac_tamponade" in before.active_alerts or "active_pericardial_effusion" in before.active_alerts: | |
| reward += 0.2 | |
| else: | |
| reward -= 0.6 | |
| if tool_name in self.BLEEDING_CONTROL_TOOLS and before.active_hemorrhages: | |
| reward += 0.15 | |
| if ( | |
| tool_name in self.OXYGEN_TOOLS | |
| and before.spo2 is not None | |
| and before.spo2 < 0.92 | |
| and after.spo2 is not None | |
| and after.spo2 > before.spo2 | |
| ): | |
| reward += 0.1 | |
| if tool_name in self.ADVANCED_AIRWAY_TOOLS and before.spo2 is not None and before.spo2 >= 0.97 and before.mental_status == "alert": | |
| reward -= 0.1 | |
| elif ( | |
| tool_name in self.ADVANCED_AIRWAY_TOOLS | |
| and before.spo2 is not None | |
| and before.spo2 < 0.88 | |
| and after.spo2 is not None | |
| and after.spo2 > before.spo2 | |
| ): | |
| reward += 0.15 | |
| if self._is_pressor_activation(tool_name, arguments) and after.mean_arterial_pressure_mmhg is not None and after.mean_arterial_pressure_mmhg >= self.MAP_TARGET: | |
| reward += 0.1 | |
| if tool_name == "administer_epinephrine_bolus" and "cardiac_arrest" in before.active_alerts: | |
| reward += 0.3 | |
| if tool_name in self.RSI_PARALYTIC_TOOLS: | |
| reward += self._reward_rsi_sequence(tracker, before) | |
| if tool_name in self.CPR_TOOLS: | |
| if "cardiac_arrest" in before.active_alerts: | |
| reward += 0.25 | |
| else: | |
| reward -= 0.8 | |
| if tool_name in self.RESTRICTED_TOOLS: | |
| reward -= 1.0 | |
| return self._clip(reward, -1.0, 1.0) | |
| def _reward_diagnostic_timeliness( | |
| self, | |
| tracker: RewardTracker, | |
| before: PatientState, | |
| after: PatientState, | |
| tool_name: str, | |
| arguments: dict[str, Any], | |
| ) -> float: | |
| reward = 0.0 | |
| diagnostic_key = self._canonical_diagnostic_key(tool_name, arguments) | |
| is_new_order = ( | |
| diagnostic_key is not None | |
| and diagnostic_key not in before.pending_diagnostics | |
| and diagnostic_key not in before.ready_diagnostics | |
| ) | |
| if is_new_order and before.sim_time_s < self.DIAGNOSTIC_ORDER_WINDOW_S: | |
| reward += 0.15 | |
| elif is_new_order and before.sim_time_s > self.DIAGNOSTIC_NEGLECT_WINDOW_S: | |
| reward -= 0.05 | |
| is_review = diagnostic_key is not None and diagnostic_key in before.ready_diagnostics | |
| if is_review: | |
| reward += 0.05 | |
| if ( | |
| tool_name in self.BEDSIDE_ASSESSMENT_TOOLS | |
| and before.sim_time_s < self.DIAGNOSTIC_ORDER_WINDOW_S | |
| and sum(1 for record in tracker.action_history if record.success and record.tool_name == tool_name) == 1 | |
| ): | |
| reward += 0.03 | |
| if ( | |
| after.sim_time_s > self.DIAGNOSTIC_NEGLECT_WINDOW_S | |
| and not after.pending_diagnostics | |
| and not after.ready_diagnostics | |
| and not tracker.diagnostics_ordered | |
| ): | |
| reward -= 0.1 | |
| return self._clip(reward, -0.3, 0.3) | |
| def _reward_rsi_sequence( | |
| self, | |
| tracker: RewardTracker, | |
| before: PatientState, | |
| ) -> float: | |
| preoxygenated = self._history_has_recent_tool( | |
| tracker.action_history, | |
| self.RSI_PREOXYGENATION_TOOLS, | |
| window=3, | |
| ) or before.airway_support in { | |
| "bag_valve_mask", | |
| "pressure_control_ventilation", | |
| "volume_control_ventilation", | |
| "cpap", | |
| } | |
| induction_started = self._history_has_recent_tool( | |
| tracker.action_history, | |
| self.RSI_INDUCTION_TOOLS, | |
| window=2, | |
| ) | |
| intubation_already_underway = self._history_has_recent_tool( | |
| tracker.action_history, | |
| self.RSI_INTUBATION_TOOLS, | |
| window=1, | |
| ) or before.intubated | |
| urgent_airway = ( | |
| before.mental_status in {"pain", "unresponsive"} | |
| or (before.spo2 is not None and before.spo2 < 0.9) | |
| ) | |
| if before.intubated: | |
| return -0.2 | |
| if not urgent_airway and not preoxygenated: | |
| return -0.9 | |
| if preoxygenated and (induction_started or urgent_airway or intubation_already_underway): | |
| return 0.1 | |
| if preoxygenated: | |
| return -0.2 | |
| return -1.0 | |
| def _reward_anti_exploitation( | |
| self, | |
| tracker: RewardTracker, | |
| after: PatientState, | |
| success: bool, | |
| had_error: bool, | |
| ) -> float: | |
| reward = 0.0 | |
| if tracker.same_tool_called_consecutively >= 3: | |
| reward -= 0.1 * (tracker.same_tool_called_consecutively - 2) | |
| if tracker.action_budget_remaining < 5 and not self._is_stabilized(after): | |
| reward -= 0.2 | |
| if after.ready_diagnostics and tracker.steps_since_last_diagnostic_review > self.READY_DIAGNOSTIC_GRACE_STEPS: | |
| reward -= 0.05 | |
| if had_error: | |
| reward -= 0.75 | |
| elif not success: | |
| reward -= 0.25 | |
| return reward | |
| def _reward_time_pressure( | |
| self, | |
| after: PatientState, | |
| tool_name: str, | |
| time_pressure_multiplier: float, | |
| ) -> float: | |
| if time_pressure_multiplier <= 1.0 or self._is_stabilized(after): | |
| return 0.0 | |
| penalty = -0.12 * (time_pressure_multiplier - 1.0) | |
| if tool_name == "advance_time": | |
| penalty -= 0.06 * (time_pressure_multiplier - 1.0) | |
| return self._clip(penalty, -0.6, 0.0) | |
| def _time_efficiency_bonus( | |
| self, | |
| tracker: RewardTracker, | |
| scenario: ScenarioDefinition, | |
| after: PatientState, | |
| ) -> float: | |
| time_to_stabilize = tracker.time_to_stabilize_s | |
| if time_to_stabilize is None and self._is_stabilized(after): | |
| time_to_stabilize = after.sim_time_s | |
| if time_to_stabilize is None: | |
| return 0.0 | |
| return max(0.0, (scenario.max_time_s - time_to_stabilize) / scenario.max_time_s) * 2.0 | |
| def _find_milestone_index(self, action_history: list[ActionRecord], milestone_name: str) -> int: | |
| for idx, record in enumerate(action_history): | |
| if not record.success: | |
| continue | |
| if milestone_name in record.tags: | |
| return idx | |
| return -1 | |
| def _extract_action_tags( | |
| self, | |
| action_history: list[ActionRecord], | |
| tool_name: str, | |
| arguments: dict[str, Any], | |
| success: bool, | |
| ) -> tuple[str, ...]: | |
| tags: set[str] = {tool_name} | |
| if tool_name in self.RESPIRATORY_ASSESSMENT_TOOLS: | |
| tags.add("respiratory_assessment") | |
| if tool_name in self.CARDIAC_ASSESSMENT_TOOLS: | |
| tags.add("cardiac_assessment") | |
| if tool_name in self.DECOMPRESSION_TOOLS: | |
| tags.add("needle_decompression") | |
| if tool_name in self.BLEEDING_CONTROL_TOOLS: | |
| tags.add("bleeding_control") | |
| if tool_name in self.PERICARDIOCENTESIS_TOOLS: | |
| tags.add("pericardiocentesis") | |
| if tool_name == "order_point_of_care_ultrasound": | |
| region = str(arguments.get("region") or "cardiac").strip().lower().replace("-", "_").replace(" ", "_") | |
| if region in {"cardiac", "heart", "pericardial"}: | |
| tags.add("cardiac_assessment") | |
| if region in {"chest", "lung", "thoracic"}: | |
| tags.add("respiratory_assessment") | |
| if self._is_pressor_activation(tool_name, arguments): | |
| tags.add("pressor_if_needed") | |
| pressor = normalize_contract_token(arguments.get("pressor") or arguments.get("agent") or "") | |
| if not pressor and tool_name == "start_norepinephrine_infusion": | |
| pressor = "norepinephrine" | |
| elif not pressor and tool_name == "start_phenylephrine_infusion": | |
| pressor = "phenylephrine" | |
| elif not pressor and tool_name == "start_dopamine_infusion": | |
| pressor = "dopamine" | |
| if pressor == "norepinephrine": | |
| tags.add("norepinephrine") | |
| if tool_name in self.FLUID_TOOLS: | |
| fluid_type = self._normalize_fluid_type(tool_name, arguments) | |
| if fluid_type in self.CRYSTALLOIDS: | |
| tags.add("crystalloid_bolus") | |
| if fluid_type in self.BLOOD_PRODUCTS: | |
| tags.add("blood_transfusion_if_severe") | |
| if self._history_has_tag(action_history, "needle_decompression"): | |
| tags.add("crystalloid_after_decomp") | |
| return tuple(sorted(tags)) if success else tuple(sorted({tool_name})) | |
| def _history_has_tag(self, action_history: list[ActionRecord], tag: str) -> bool: | |
| return any(record.success and tag in record.tags for record in action_history) | |
| def _history_has_recent_tool( | |
| action_history: list[ActionRecord], | |
| tool_names: frozenset[str], | |
| *, | |
| window: int, | |
| ) -> bool: | |
| if window <= 0: | |
| return False | |
| return any( | |
| record.success and record.tool_name in tool_names | |
| for record in action_history[-window:] | |
| ) | |
| def _normalize_fluid_type(cls, tool_name: str, arguments: dict[str, Any]) -> str: | |
| if tool_name == "administer_crystalloid_bolus": | |
| return "crystalloid" | |
| if tool_name == "administer_blood_transfusion": | |
| return "packed_rbc" | |
| if tool_name == "administer_plasma": | |
| return "plasma" | |
| if tool_name == "activate_massive_transfusion_protocol": | |
| return "packed_rbc" | |
| return normalize_contract_token(arguments.get("fluid_type") or arguments.get("fluid") or "saline") | |
| def _canonical_diagnostic_key(cls, tool_name: str, arguments: dict[str, Any]) -> str | None: | |
| if tool_name in cls.DIAGNOSTIC_TOOL_ALIASES: | |
| return cls.DIAGNOSTIC_TOOL_ALIASES[tool_name] | |
| if tool_name == "order_point_of_care_ultrasound": | |
| region = normalize_contract_token(arguments.get("region") or "cardiac") | |
| return f"order_point_of_care_ultrasound:{region}" | |
| return None | |
| def _is_pressor_activation(tool_name: str, arguments: dict[str, Any]) -> bool: | |
| if tool_name not in RewardEngine.PRESSOR_TOOLS or tool_name == "stop_infusion": | |
| return False | |
| if tool_name != "give_pressor": | |
| return True | |
| try: | |
| return not coerce_boolean_argument(arguments.get("stop", False)) | |
| except ValueError: | |
| return True | |
| def _has_pneumothorax_signs(state: PatientState) -> bool: | |
| return any( | |
| alert in state.active_alerts | |
| for alert in ("possible_tension_pneumothorax", "unilateral_absent_breath_sounds", "bilateral_absent_breath_sounds") | |
| ) | |
| def _is_severe_bleed(state: PatientState) -> bool: | |
| return bool(state.active_hemorrhages) and sum(state.active_hemorrhages.values()) >= 120.0 | |
| def _is_volume_depleted(state: PatientState) -> bool: | |
| if state.blood_volume_ml is not None and state.blood_volume_ml < 3000.0: | |
| return True | |
| if state.shock_index is not None and state.shock_index > 1.2 and state.active_hemorrhages: | |
| return True | |
| return False | |
| def _is_alive(cls, state: PatientState) -> bool: | |
| if "cardiac_arrest" in state.active_alerts: | |
| return False | |
| if state.mean_arterial_pressure_mmhg is not None and state.mean_arterial_pressure_mmhg <= 10.0: | |
| return False | |
| if state.heart_rate_bpm is not None and state.heart_rate_bpm <= 0.1: | |
| return False | |
| return True | |
| def _is_stabilized(cls, state: PatientState) -> bool: | |
| if not cls._is_alive(state): | |
| return False | |
| if state.mean_arterial_pressure_mmhg is None or state.mean_arterial_pressure_mmhg < cls.MAP_TARGET: | |
| return False | |
| if state.spo2 is None or state.spo2 < cls.SPO2_TARGET: | |
| return False | |
| if state.mental_status not in {"alert", "verbal"}: | |
| return False | |
| if state.lactate_trend == "worsening": | |
| return False | |
| blocking_alerts = { | |
| "active_hemorrhage", | |
| "possible_tension_pneumothorax", | |
| "possible_cardiac_tamponade", | |
| "cardiac_arrest", | |
| } | |
| return not any(alert in blocking_alerts for alert in state.active_alerts) | |
| def _clip(value: float, lower: float, upper: float) -> float: | |
| return max(lower, min(value, upper)) | |