Spaces:
Sleeping
Sleeping
| """Backend adapters for swapping mock and real Pulse runtimes.""" | |
| from __future__ import annotations | |
| from abc import ABC, abstractmethod | |
| import random | |
| from ..models import ( | |
| INITIAL_TOOL_NAMES, | |
| EnvironmentResponse, | |
| ObservationMetadata, | |
| PatientState, | |
| PulsePhysiologyObservation, | |
| ToolAction, | |
| ToolError, | |
| ToolResult, | |
| ) | |
| from ..runtime_effects import NoisyObservation, TimePressureMechanic | |
| from ..rewards import compute_reward | |
| from ..tool_catalog import ( | |
| KNOWN_TOOL_NAMES, | |
| ToolValidationError, | |
| canonicalize_tool_name, | |
| coerce_boolean_argument, | |
| validate_tool_arguments, | |
| ) | |
| from .mock_scenarios import DEFAULT_MOCK_SCENARIO_ID, MOCK_SCENARIOS, MockScenarioDefinition | |
| MOCK_READ_ONLY_TOOLS = { | |
| "get_vitals", | |
| "summarize_state", | |
| "check_deterioration", | |
| "recommend_next_step", | |
| "get_respiratory_status", | |
| "get_blood_gas", | |
| "get_cbc", | |
| "get_bmp", | |
| } | |
| MOCK_DIAGNOSTIC_DELAYS = { | |
| "get_blood_gas": 120, | |
| "get_cbc": 240, | |
| "get_bmp": 300, | |
| } | |
| MOCK_EXTENDED_INTERVENTION_EFFECTS = { | |
| "give_oxygen": { | |
| "baseline_stable": {}, | |
| "respiratory_distress": {"spo2": 0.04, "respiration_rate_bpm": -3.0, "heart_rate_bpm": -2.0}, | |
| "hemorrhagic_shock": {"spo2": 0.02, "respiration_rate_bpm": -1.0, "heart_rate_bpm": -1.0}, | |
| }, | |
| "give_fluids": { | |
| "baseline_stable": {"heart_rate_bpm": 2.0}, | |
| "respiratory_distress": {"systolic_bp_mmhg": 0.5, "heart_rate_bpm": 1.0}, | |
| "hemorrhagic_shock": {"systolic_bp_mmhg": 8.0, "diastolic_bp_mmhg": 4.0, "heart_rate_bpm": -4.0, "blood_volume_ml": 300.0}, | |
| }, | |
| "control_bleeding": { | |
| "baseline_stable": {}, | |
| "respiratory_distress": {}, | |
| "hemorrhagic_shock": {"systolic_bp_mmhg": 5.0, "diastolic_bp_mmhg": 2.5, "heart_rate_bpm": -3.0, "blood_volume_ml": 150.0}, | |
| }, | |
| "position_patient": { | |
| "baseline_stable": {}, | |
| "respiratory_distress": {"spo2": 0.01, "respiration_rate_bpm": -1.0}, | |
| "hemorrhagic_shock": {"systolic_bp_mmhg": 2.0, "diastolic_bp_mmhg": 1.0, "respiration_rate_bpm": -1.0}, | |
| }, | |
| "airway_support": { | |
| "baseline_stable": {"heart_rate_bpm": 1.0}, | |
| "respiratory_distress": {"spo2": 0.04, "respiration_rate_bpm": -4.0, "heart_rate_bpm": -2.0}, | |
| "hemorrhagic_shock": {"spo2": 0.01, "respiration_rate_bpm": -1.0, "heart_rate_bpm": -0.5}, | |
| }, | |
| "give_pressor": { | |
| "baseline_stable": {"systolic_bp_mmhg": 0.5, "diastolic_bp_mmhg": 0.2, "heart_rate_bpm": 3.0}, | |
| "respiratory_distress": {"systolic_bp_mmhg": 1.0, "diastolic_bp_mmhg": 0.5, "heart_rate_bpm": 2.0}, | |
| "hemorrhagic_shock": {"systolic_bp_mmhg": 4.0, "diastolic_bp_mmhg": 2.0, "heart_rate_bpm": -1.0}, | |
| }, | |
| "needle_decompression": { | |
| "baseline_stable": {"systolic_bp_mmhg": -1.0, "heart_rate_bpm": 1.0}, | |
| "respiratory_distress": {"spo2": 0.06, "respiration_rate_bpm": -6.0, "heart_rate_bpm": -4.0}, | |
| "hemorrhagic_shock": {"systolic_bp_mmhg": -1.0, "heart_rate_bpm": 0.5}, | |
| }, | |
| "pericardiocentesis": { | |
| "baseline_stable": {"systolic_bp_mmhg": -2.0, "diastolic_bp_mmhg": -1.0, "heart_rate_bpm": 2.0}, | |
| "respiratory_distress": {"systolic_bp_mmhg": -1.0, "diastolic_bp_mmhg": -0.5, "heart_rate_bpm": 1.0}, | |
| "hemorrhagic_shock": {"systolic_bp_mmhg": -1.0, "diastolic_bp_mmhg": -0.5, "heart_rate_bpm": 1.0}, | |
| }, | |
| } | |
| class PatientBackend(ABC): | |
| """Stable interface between Person 2's stack and the backend runtime.""" | |
| def reset(self, scenario_id: str | None = None, **kwargs: object) -> EnvironmentResponse: | |
| """Reset the environment and return the initial response.""" | |
| def step(self, action: ToolAction) -> EnvironmentResponse: | |
| """Apply one action and return the next response.""" | |
| def get_state(self) -> PatientState: | |
| """Return the latest patient state.""" | |
| class MockPulseAdapter(PatientBackend): | |
| """Deterministic backend used by Person 2 before real Pulse integration exists.""" | |
| def __init__( | |
| self, | |
| default_scenario_id: str = DEFAULT_MOCK_SCENARIO_ID, | |
| *, | |
| observation_noise_level: float = 0.0, | |
| time_pressure_enabled: bool = False, | |
| time_pressure_onset_s: float = 180.0, | |
| time_pressure_escalation_per_minute: float = 0.15, | |
| seed: int | None = None, | |
| ): | |
| self._default_scenario_id = default_scenario_id | |
| self._scenario: MockScenarioDefinition | None = None | |
| self._state: PatientState | None = None | |
| self._step_count = 0 | |
| self._active_supports: set[str] = set() | |
| self._tool_counts: dict[str, int] = {} | |
| self._last_tool_name: str | None = None | |
| self._same_tool_called_consecutively = 0 | |
| self._rng = random.Random(seed) | |
| self._observation_noise = NoisyObservation(observation_noise_level) | |
| self._time_pressure = TimePressureMechanic( | |
| enabled=time_pressure_enabled, | |
| onset_s=time_pressure_onset_s, | |
| escalation_per_minute=time_pressure_escalation_per_minute, | |
| ) | |
| self._episode_observation_rng = random.Random(self._rng.random()) | |
| def reset(self, scenario_id: str | None = None, **kwargs: object) -> EnvironmentResponse: | |
| selected_scenario_id = scenario_id or self._default_scenario_id | |
| if selected_scenario_id not in MOCK_SCENARIOS: | |
| valid = ", ".join(sorted(MOCK_SCENARIOS)) | |
| raise ValueError( | |
| f"Unknown mock scenario_id '{selected_scenario_id}'. Expected one of: {valid}" | |
| ) | |
| scenario = MOCK_SCENARIOS[selected_scenario_id] | |
| self._scenario = scenario | |
| self._configure_runtime_effects(kwargs) | |
| self._state = self._refresh_state(scenario.initial_state.model_copy(deep=True)) | |
| self._step_count = 0 | |
| self._active_supports = set() | |
| self._tool_counts = {} | |
| self._last_tool_name = None | |
| self._same_tool_called_consecutively = 0 | |
| self._episode_observation_rng = random.Random(self._rng.random()) | |
| return self._build_response( | |
| reward=0.0, | |
| tool_result=ToolResult( | |
| tool_name="load_scenario", | |
| success=True, | |
| message=f"Scenario '{scenario.scenario_id}' loaded.", | |
| state_changed=True, | |
| changed_fields=list(self._state.model_dump().keys()), | |
| ), | |
| ) | |
| def step(self, action: ToolAction) -> EnvironmentResponse: | |
| if self._state is None or self._scenario is None: | |
| return self._error_response( | |
| code="NOT_INITIALIZED", | |
| message="Call reset() before step().", | |
| retryable=True, | |
| tool_name=action.tool_name, | |
| ) | |
| canonical_tool_name = canonicalize_tool_name(action.tool_name, allowed_tools=list(KNOWN_TOOL_NAMES)) | |
| action = action.model_copy(update={"tool_name": canonical_tool_name}) | |
| if action.tool_name not in KNOWN_TOOL_NAMES: | |
| return self._error_response( | |
| code="UNKNOWN_TOOL", | |
| message=f"Unsupported tool '{action.tool_name}'.", | |
| retryable=False, | |
| tool_name=action.tool_name, | |
| ) | |
| try: | |
| normalized_arguments = validate_tool_arguments( | |
| action.tool_name, | |
| action.arguments, | |
| allowed_tools=KNOWN_TOOL_NAMES, | |
| ) | |
| except ToolValidationError as exc: | |
| return self._error_response( | |
| code="INVALID_ARGUMENT", | |
| message=str(exc), | |
| retryable=False, | |
| tool_name=action.tool_name, | |
| ) | |
| action = action.model_copy(update={"arguments": normalized_arguments}) | |
| previous_state = self._state.model_copy(deep=True) | |
| self._step_count += 1 | |
| if action.tool_name == "advance_time": | |
| result = self._advance_time(action) | |
| elif action.tool_name in MOCK_READ_ONLY_TOOLS: | |
| result = self._read_only_tool(action.tool_name) | |
| else: | |
| result = self._apply_intervention(action) | |
| if result.error is not None: | |
| return result | |
| self._state = self._refresh_state(self._state) | |
| changed_fields = self._changed_fields(previous_state, self._state) | |
| tool_usage_count = self._tool_counts.get(action.tool_name, 0) + 1 | |
| if self._last_tool_name == action.tool_name: | |
| self._same_tool_called_consecutively += 1 | |
| else: | |
| self._last_tool_name = action.tool_name | |
| self._same_tool_called_consecutively = 1 | |
| self._tool_counts[action.tool_name] = tool_usage_count | |
| reward = compute_reward( | |
| previous_state, | |
| self._state, | |
| action.tool_name, | |
| self._scenario.recommended_actions, | |
| tool_usage_count=tool_usage_count, | |
| same_tool_called_consecutively=self._same_tool_called_consecutively, | |
| state_changed=bool(changed_fields), | |
| time_pressure_multiplier=self._current_time_pressure_multiplier(self._state), | |
| ).total | |
| tool_result = result.tool_result or ToolResult( | |
| tool_name=action.tool_name, | |
| success=True, | |
| message=f"{action.tool_name} executed.", | |
| state_changed=bool(changed_fields), | |
| changed_fields=changed_fields, | |
| ) | |
| tool_result.changed_fields = changed_fields | |
| tool_result.state_changed = bool(changed_fields) | |
| return self._build_response(reward=reward, tool_result=tool_result) | |
| def get_state(self) -> PatientState: | |
| if self._state is None: | |
| raise RuntimeError("MockPulseAdapter has not been reset yet.") | |
| return self._state.model_copy(deep=True) | |
| def _advance_time(self, action: ToolAction) -> EnvironmentResponse: | |
| assert self._state is not None | |
| assert self._scenario is not None | |
| seconds = float(action.arguments.get("seconds", 30)) | |
| if seconds <= 0: | |
| return self._error_response( | |
| code="INVALID_ARGUMENT", | |
| message="seconds must be greater than 0", | |
| retryable=False, | |
| tool_name=action.tool_name, | |
| ) | |
| scale = seconds / 30.0 | |
| updates = self._state.model_dump() | |
| deterioration_multiplier = self._current_time_pressure_multiplier(self._state) | |
| for field_name, delta in self._scenario.deterioration_per_30s.items(): | |
| adjusted_delta = self._deterioration_delta(field_name, delta) * deterioration_multiplier | |
| current_value = updates.get(field_name) | |
| if current_value is None: | |
| continue | |
| updates[field_name] = current_value + adjusted_delta * scale | |
| updates["sim_time_s"] = self._state.sim_time_s + seconds | |
| pending_diagnostics = dict(self._state.pending_diagnostics) | |
| ready_diagnostics = list(self._state.ready_diagnostics) | |
| for tool_name, remaining_seconds in list(pending_diagnostics.items()): | |
| remaining_after_step = max(0, int(remaining_seconds - seconds)) | |
| if remaining_after_step <= 0: | |
| pending_diagnostics.pop(tool_name, None) | |
| if tool_name not in ready_diagnostics: | |
| ready_diagnostics.append(tool_name) | |
| else: | |
| pending_diagnostics[tool_name] = remaining_after_step | |
| updates["pending_diagnostics"] = pending_diagnostics | |
| updates["ready_diagnostics"] = ready_diagnostics | |
| self._state = PatientState(**updates) | |
| return self._build_response( | |
| reward=0.0, | |
| tool_result=ToolResult( | |
| tool_name=action.tool_name, | |
| success=True, | |
| message=f"Advanced simulation by {seconds:.0f} seconds.", | |
| state_changed=True, | |
| changed_fields=[], | |
| ), | |
| ) | |
| def _apply_intervention(self, action: ToolAction) -> EnvironmentResponse: | |
| assert self._state is not None | |
| assert self._scenario is not None | |
| if action.tool_name == "give_pressor" and action.arguments.get("stop") is True: | |
| updates = self._state.model_dump() | |
| self._apply_tool_side_effects(action, updates, effect_scale=0.0) | |
| self._active_supports.discard("give_pressor") | |
| self._state = PatientState(**updates) | |
| return self._build_response( | |
| reward=0.0, | |
| tool_result=ToolResult( | |
| tool_name=action.tool_name, | |
| success=True, | |
| message="Vasopressor support stopped.", | |
| state_changed=True, | |
| changed_fields=[], | |
| ), | |
| ) | |
| effects = self._scenario.tool_effects.get(action.tool_name) | |
| if effects is None: | |
| effects = MOCK_EXTENDED_INTERVENTION_EFFECTS.get(action.tool_name, {}).get(self._scenario.scenario_id) | |
| if effects is None: | |
| return self._error_response( | |
| code="UNSUPPORTED_IN_SCENARIO", | |
| message=f"{action.tool_name} is not modeled for scenario '{self._scenario.scenario_id}'.", | |
| retryable=False, | |
| tool_name=action.tool_name, | |
| ) | |
| updates = self._state.model_dump() | |
| effect_scale = self._intervention_scale(action.tool_name) | |
| for field_name, delta in effects.items(): | |
| current_value = updates.get(field_name) | |
| if current_value is None: | |
| continue | |
| updates[field_name] = current_value + (delta * effect_scale) | |
| self._apply_tool_side_effects(action, updates, effect_scale) | |
| self._active_supports.add(action.tool_name) | |
| self._state = PatientState(**updates) | |
| return self._build_response( | |
| reward=0.0, | |
| tool_result=ToolResult( | |
| tool_name=action.tool_name, | |
| success=True, | |
| message=self._tool_message(action.tool_name), | |
| state_changed=True, | |
| changed_fields=[], | |
| ), | |
| ) | |
| def _intervention_scale(self, tool_name: str) -> float: | |
| assert self._state is not None | |
| alerts = set(self._state.active_alerts) | |
| scale = 1.0 | |
| intervention_multiplier = self._time_pressure.intervention_effectiveness_multiplier( | |
| sim_time_s=self._state.sim_time_s, | |
| injury_severity=self._scenario.injury_severity, | |
| unstable=self._is_state_unstable(self._state), | |
| ) | |
| if tool_name == "control_bleeding": | |
| if "blood_loss" in alerts: | |
| scale = 1.0 | |
| elif "hypotension" in alerts: | |
| scale = 0.5 | |
| else: | |
| scale = 0.15 | |
| elif tool_name == "give_fluids": | |
| if {"hypotension", "blood_loss"} & alerts: | |
| scale = 1.0 | |
| elif "tachycardia" in alerts: | |
| scale = 0.6 | |
| else: | |
| scale = 0.2 | |
| elif tool_name == "give_oxygen": | |
| if {"hypoxemia", "tachypnea"} & alerts: | |
| scale = 1.0 | |
| elif "tachycardia" in alerts: | |
| scale = 0.4 | |
| else: | |
| scale = 0.15 | |
| elif tool_name == "position_patient": | |
| if {"tachypnea", "hypotension"} & alerts: | |
| scale = 1.0 | |
| else: | |
| scale = 0.25 | |
| elif tool_name == "airway_support": | |
| if {"hypoxemia", "tachypnea"} & alerts: | |
| scale = 1.0 | |
| else: | |
| scale = 0.2 | |
| if tool_name in self._active_supports: | |
| scale *= 0.7 | |
| return scale * intervention_multiplier | |
| def _read_only_tool(self, tool_name: str) -> EnvironmentResponse: | |
| assert self._state is not None | |
| assert self._scenario is not None | |
| if tool_name == "summarize_state": | |
| message = ( | |
| f"{self._scenario.description} HR {self._state.heart_rate_bpm:.0f}, " | |
| f"BP {self._state.systolic_bp_mmhg:.0f}/{self._state.diastolic_bp_mmhg:.0f}, " | |
| f"SpO2 {self._state.spo2:.2f}." | |
| ) | |
| elif tool_name == "check_deterioration": | |
| message = "Deterioration ongoing." if self._state.active_alerts else "Patient currently stable." | |
| elif tool_name == "recommend_next_step": | |
| message = f"Recommended next step: {self._scenario.recommended_actions[0]}." | |
| elif tool_name == "get_respiratory_status": | |
| message = ( | |
| f"Breath sounds {self._state.breath_sounds}, SpO2 {self._state.spo2:.2f}, " | |
| f"RR {self._state.respiration_rate_bpm:.0f}, airway support {self._state.airway_support or 'none'}." | |
| ) | |
| elif tool_name in MOCK_DIAGNOSTIC_DELAYS: | |
| message = self._handle_diagnostic_read(tool_name) | |
| else: | |
| message = "Current vitals retrieved." | |
| return self._build_response( | |
| reward=0.0, | |
| tool_result=ToolResult( | |
| tool_name=tool_name, | |
| success=True, | |
| message=message, | |
| state_changed=False, | |
| changed_fields=[], | |
| ), | |
| ) | |
| def _build_response( | |
| self, | |
| reward: float, | |
| tool_result: ToolResult | None = None, | |
| error: ToolError | None = None, | |
| ) -> EnvironmentResponse: | |
| assert self._state is not None | |
| available_tools = self._available_tools() | |
| observed_state, runtime_metadata = self._build_observed_state() | |
| return EnvironmentResponse( | |
| observation=PulsePhysiologyObservation.from_patient_state( | |
| observed_state, | |
| reward=reward, | |
| available_tools=available_tools, | |
| tool_result=tool_result, | |
| error=error, | |
| metadata={ | |
| "step_count": self._step_count, | |
| **runtime_metadata, | |
| }, | |
| ), | |
| reward=reward, | |
| done=observed_state.done, | |
| metadata=ObservationMetadata( | |
| step_count=self._step_count, | |
| available_tools=available_tools, | |
| ), | |
| tool_result=tool_result, | |
| error=error, | |
| ) | |
| def _error_response( | |
| self, | |
| code: str, | |
| message: str, | |
| retryable: bool, | |
| tool_name: str, | |
| ) -> EnvironmentResponse: | |
| state = self._state or MOCK_SCENARIOS[self._default_scenario_id].initial_state | |
| available_tools = self._available_tools() | |
| if self._state is not None: | |
| observed_state, runtime_metadata = self._build_observed_state() | |
| else: | |
| observed_state = state | |
| runtime_metadata = { | |
| "observation_noise": { | |
| "enabled": self._observation_noise.config.enabled, | |
| "noise_level": round(self._observation_noise.config.normalized_level, 3), | |
| "masked_fields": [], | |
| "perturbed_fields": [], | |
| }, | |
| "time_pressure": self._time_pressure.as_metadata( | |
| sim_time_s=float(state.sim_time_s or 0.0), | |
| injury_severity=self._scenario.injury_severity if self._scenario is not None else 0.0, | |
| unstable=self._is_state_unstable(state), | |
| ), | |
| } | |
| return EnvironmentResponse( | |
| observation=PulsePhysiologyObservation.from_patient_state( | |
| observed_state, | |
| reward=-1.0, | |
| available_tools=available_tools, | |
| error=ToolError(code=code, message=message, retryable=retryable), | |
| metadata={ | |
| "step_count": self._step_count, | |
| **runtime_metadata, | |
| }, | |
| ), | |
| reward=-1.0, | |
| done=observed_state.done, | |
| metadata=ObservationMetadata( | |
| step_count=self._step_count, | |
| available_tools=available_tools, | |
| ), | |
| tool_result=ToolResult( | |
| tool_name=tool_name, | |
| success=False, | |
| message=message, | |
| state_changed=False, | |
| changed_fields=[], | |
| ), | |
| error=ToolError(code=code, message=message, retryable=retryable), | |
| ) | |
| def _deterioration_delta(self, field_name: str, delta: float) -> float: | |
| if self._scenario is None: | |
| return delta | |
| if self._scenario.scenario_id == "respiratory_distress": | |
| if field_name == "spo2" and "give_oxygen" in self._active_supports: | |
| return delta * 0.25 | |
| if field_name == "respiration_rate_bpm" and "airway_support" in self._active_supports: | |
| return delta * 0.4 | |
| if self._scenario.scenario_id == "hemorrhagic_shock": | |
| if field_name == "blood_volume_ml" and "control_bleeding" in self._active_supports: | |
| return delta * 0.1 | |
| if field_name in {"systolic_bp_mmhg", "diastolic_bp_mmhg"} and "give_fluids" in self._active_supports: | |
| if "control_bleeding" in self._active_supports or "position_patient" in self._active_supports: | |
| return delta * 0.15 | |
| return delta * 0.4 | |
| if field_name in {"systolic_bp_mmhg", "diastolic_bp_mmhg"} and "control_bleeding" in self._active_supports: | |
| return delta * 0.6 | |
| if field_name == "heart_rate_bpm": | |
| if {"control_bleeding", "give_fluids"} <= self._active_supports: | |
| return delta * 0.25 | |
| if "control_bleeding" in self._active_supports or "give_fluids" in self._active_supports: | |
| return delta * 0.5 | |
| if field_name == "respiration_rate_bpm": | |
| if "give_oxygen" in self._active_supports and "position_patient" in self._active_supports: | |
| return delta * 0.35 | |
| if "give_oxygen" in self._active_supports or "position_patient" in self._active_supports: | |
| return delta * 0.6 | |
| if field_name == "spo2" and "give_oxygen" in self._active_supports: | |
| return delta * 0.3 | |
| return delta | |
| def _apply_tool_side_effects( | |
| self, | |
| action: ToolAction, | |
| updates: dict, | |
| effect_scale: float, | |
| ) -> None: | |
| tool_name = action.tool_name | |
| arguments = action.arguments | |
| if tool_name == "give_oxygen": | |
| updates["oxygen_device"] = str(arguments.get("device") or "nasal_cannula") | |
| updates["oxygen_flow_lpm"] = float(arguments.get("flow_lpm", 15)) | |
| elif tool_name == "position_patient": | |
| updates["position"] = str(arguments.get("position") or updates.get("position") or "upright") | |
| elif tool_name == "airway_support": | |
| requested_mode = str(arguments.get("mode") or arguments.get("support_type") or "auto") | |
| normalized_mode = requested_mode.strip().lower().replace("-", "_").replace(" ", "_") | |
| if normalized_mode in {"auto", "basic", "default", "standard", "support", "airway_support"}: | |
| updates["airway_support"] = self._suggest_airway_support_mode() | |
| else: | |
| updates["airway_support"] = normalized_mode | |
| elif tool_name == "give_fluids": | |
| active_infusions = dict(updates.get("active_infusions") or {}) | |
| fluid_name = str(arguments.get("fluid_type") or arguments.get("fluid") or "saline") | |
| active_infusions[fluid_name] = float(arguments.get("rate_ml_per_min", 100)) | |
| updates["active_infusions"] = active_infusions | |
| elif tool_name == "give_pressor": | |
| active_infusions = dict(updates.get("active_infusions") or {}) | |
| pressor_name = str(arguments.get("pressor") or arguments.get("agent") or "norepinephrine") | |
| if arguments.get("stop") is True: | |
| active_infusions.pop(pressor_name, None) | |
| else: | |
| active_infusions[pressor_name] = float(arguments.get("rate_ml_per_min", 5)) | |
| updates["active_infusions"] = active_infusions | |
| elif tool_name == "needle_decompression": | |
| updates["breath_sounds"] = "present bilateral" | |
| elif tool_name == "pericardiocentesis": | |
| updates["active_alerts"] = [ | |
| alert | |
| for alert in updates.get("active_alerts", []) | |
| if alert != "tamponade" | |
| ] | |
| if tool_name == "needle_decompression" and self._scenario is not None: | |
| if self._scenario.scenario_id == "respiratory_distress": | |
| updates["spo2"] = min(1.0, float(updates.get("spo2") or 0.0) + (0.02 * effect_scale)) | |
| def _handle_diagnostic_read(self, tool_name: str) -> str: | |
| assert self._state is not None | |
| if tool_name in self._state.ready_diagnostics: | |
| return self._diagnostic_result_message(tool_name) | |
| if tool_name in self._state.pending_diagnostics: | |
| remaining = self._state.pending_diagnostics[tool_name] | |
| return f"{tool_name} is pending. {remaining} simulated seconds remaining before results are ready." | |
| updates = self._state.model_dump() | |
| pending_diagnostics = dict(self._state.pending_diagnostics) | |
| pending_diagnostics[tool_name] = MOCK_DIAGNOSTIC_DELAYS[tool_name] | |
| updates["pending_diagnostics"] = pending_diagnostics | |
| self._state = PatientState(**updates) | |
| return ( | |
| f"Ordered {tool_name}. Results will be ready after about " | |
| f"{MOCK_DIAGNOSTIC_DELAYS[tool_name]} simulated seconds." | |
| ) | |
| def _diagnostic_result_message(self, tool_name: str) -> str: | |
| assert self._state is not None | |
| updates = self._state.model_dump() | |
| if tool_name == "get_blood_gas": | |
| abg_result = self._build_abg_result() | |
| updates["abg_result"] = abg_result.model_dump() | |
| self._state = PatientState(**updates) | |
| return ( | |
| f"ABG pH {abg_result.ph:.3f}, PaO2 {abg_result.partial_pressure_of_oxygen_mmhg:.1f} mmHg, " | |
| f"PaCO2 {abg_result.partial_pressure_of_carbon_dioxide_mmhg:.1f} mmHg, " | |
| f"lactate {abg_result.lactate_mg_per_dl:.1f} mg/dL." | |
| ) | |
| if tool_name == "get_cbc": | |
| cbc_result = self._build_cbc_result() | |
| updates["cbc_result"] = cbc_result.model_dump() | |
| self._state = PatientState(**updates) | |
| return ( | |
| f"CBC hemoglobin {cbc_result.hemoglobin_g_per_dl:.1f} g/dL, " | |
| f"hematocrit {cbc_result.hematocrit_fraction:.3f}, " | |
| f"WBC {cbc_result.white_blood_cell_count_per_u_l:.0f} /uL." | |
| ) | |
| bmp_result = self._build_bmp_result() | |
| updates["bmp_result"] = bmp_result.model_dump() | |
| self._state = PatientState(**updates) | |
| return ( | |
| f"BMP sodium {bmp_result.sodium_mmol_per_l:.1f} mmol/L, " | |
| f"potassium {bmp_result.potassium_mmol_per_l:.1f} mmol/L, " | |
| f"creatinine {bmp_result.creatinine_mg_per_dl:.1f} mg/dL, " | |
| f"glucose {bmp_result.glucose_mg_per_dl:.1f} mg/dL." | |
| ) | |
| def _build_abg_result(self): | |
| assert self._state is not None | |
| from ..patient_state import ArterialBloodGasResult | |
| spo2 = float(self._state.spo2 or 0.95) | |
| systolic = float(self._state.systolic_bp_mmhg or 110.0) | |
| ph = max(7.10, min(7.45, 7.40 - max(0.0, (95.0 - systolic) / 200.0))) | |
| pao2 = max(45.0, min(110.0, 40.0 + spo2 * 60.0)) | |
| paco2 = max(28.0, min(60.0, 40.0 + max(0.0, (24.0 - float(self._state.respiration_rate_bpm or 16.0)) * 0.8))) | |
| lactate = max(8.0, min(40.0, 10.0 + max(0.0, (100.0 - systolic) * 0.15))) | |
| return ArterialBloodGasResult( | |
| ph=round(ph, 3), | |
| partial_pressure_of_oxygen_mmhg=round(pao2, 1), | |
| partial_pressure_of_carbon_dioxide_mmhg=round(paco2, 1), | |
| oxygen_saturation=round(spo2, 3), | |
| bicarbonate_meq_per_l=24.0, | |
| lactate_mg_per_dl=round(lactate, 1), | |
| base_excess_meq_per_l=-2.0 if systolic < 95.0 else 0.0, | |
| base_deficit_meq_per_l=2.0 if systolic < 95.0 else 0.0, | |
| ) | |
| def _build_cbc_result(self): | |
| assert self._state is not None | |
| from ..patient_state import CompleteBloodCountResult | |
| blood_volume = float(self._state.blood_volume_ml or 5400.0) | |
| hemoglobin = max(7.5, min(15.0, 15.0 - max(0.0, (5400.0 - blood_volume) / 350.0))) | |
| hematocrit = max(0.24, min(0.45, hemoglobin / 33.0)) | |
| return CompleteBloodCountResult( | |
| hemoglobin_g_per_dl=round(hemoglobin, 1), | |
| hematocrit_fraction=round(hematocrit, 3), | |
| white_blood_cell_count_per_u_l=9000.0 if self._state.active_alerts else 7000.0, | |
| platelet_count_per_u_l=250000.0, | |
| red_blood_cell_count_per_u_l=4800000.0, | |
| ) | |
| def _build_bmp_result(self): | |
| assert self._state is not None | |
| from ..patient_state import BasicMetabolicPanelResult | |
| systolic = float(self._state.systolic_bp_mmhg or 110.0) | |
| return BasicMetabolicPanelResult( | |
| sodium_mmol_per_l=142.0 if systolic >= 95.0 else 145.0, | |
| potassium_mmol_per_l=3.8 if systolic >= 95.0 else 4.2, | |
| calcium_mmol_per_l=2.2, | |
| creatinine_mg_per_dl=1.0 if systolic >= 95.0 else 1.4, | |
| glucose_mg_per_dl=105.0 if not self._state.active_alerts else 128.0, | |
| ) | |
| def _changed_fields(self, previous_state: PatientState, new_state: PatientState) -> list[str]: | |
| changed_fields: list[str] = [] | |
| for field_name in new_state.model_fields: | |
| if getattr(previous_state, field_name) != getattr(new_state, field_name): | |
| changed_fields.append(field_name) | |
| return changed_fields | |
| def _refresh_state(self, state: PatientState) -> PatientState: | |
| updates = state.model_dump() | |
| updates["spo2"] = max(0.5, min(1.0, updates["spo2"])) | |
| updates["heart_rate_bpm"] = max(20.0, updates["heart_rate_bpm"]) | |
| updates["systolic_bp_mmhg"] = max(40.0, updates["systolic_bp_mmhg"]) | |
| updates["diastolic_bp_mmhg"] = max(20.0, updates["diastolic_bp_mmhg"]) | |
| updates["respiration_rate_bpm"] = max(4.0, updates["respiration_rate_bpm"]) | |
| if updates["blood_volume_ml"] is not None: | |
| updates["blood_volume_ml"] = max(2500.0, updates["blood_volume_ml"]) | |
| if updates["systolic_bp_mmhg"] is not None and updates["diastolic_bp_mmhg"] is not None: | |
| updates["mean_arterial_pressure_mmhg"] = ( | |
| updates["systolic_bp_mmhg"] + 2 * updates["diastolic_bp_mmhg"] | |
| ) / 3.0 | |
| if updates["heart_rate_bpm"] is not None and updates["systolic_bp_mmhg"] not in (None, 0): | |
| updates["shock_index"] = updates["heart_rate_bpm"] / updates["systolic_bp_mmhg"] | |
| if self._scenario is not None and self._scenario.scenario_id == "respiratory_distress": | |
| updates["breath_sounds"] = ( | |
| "present bilateral" if "needle_decompression" in self._active_supports else "diminished bilateral" | |
| ) | |
| elif updates.get("breath_sounds") in (None, ""): | |
| updates["breath_sounds"] = "present bilateral" | |
| if self._scenario is not None and self._scenario.scenario_id == "hemorrhagic_shock": | |
| if "control_bleeding" in self._active_supports: | |
| flow_rate = 25.0 | |
| else: | |
| blood_volume = float(updates["blood_volume_ml"] or 4700.0) | |
| flow_rate = max(0.0, min(180.0, (5200.0 - blood_volume) * 0.4 + 60.0)) | |
| updates["active_hemorrhages"] = {"right_leg": round(flow_rate, 1)} if flow_rate > 5.0 else {} | |
| else: | |
| updates["active_hemorrhages"] = {} | |
| if updates.get("systolic_bp_mmhg", 110.0) < 95.0 or updates.get("blood_volume_ml", 5500.0) < 5000.0: | |
| updates["lactate_trend"] = "worsening" | |
| elif self._active_supports & {"give_fluids", "control_bleeding", "give_oxygen", "needle_decompression"}: | |
| updates["lactate_trend"] = "improving" | |
| else: | |
| updates["lactate_trend"] = "stable" | |
| alerts: list[str] = [] | |
| if updates["spo2"] < 0.92: | |
| alerts.append("hypoxemia") | |
| if updates["heart_rate_bpm"] > 110: | |
| alerts.append("tachycardia") | |
| if updates["systolic_bp_mmhg"] < 95: | |
| alerts.append("hypotension") | |
| if updates["blood_volume_ml"] is not None and updates["blood_volume_ml"] < 5000: | |
| alerts.append("blood_loss") | |
| if updates["respiration_rate_bpm"] >= 24: | |
| alerts.append("tachypnea") | |
| if updates.get("shock_index") is not None and updates["shock_index"] >= 0.9: | |
| alerts.append("shock_index_elevated") | |
| if updates["systolic_bp_mmhg"] < 70 or updates["spo2"] < 0.75: | |
| alerts.append("cardiovascular_collapse") | |
| updates["active_alerts"] = alerts | |
| updates["mental_status"] = self._derive_mental_status(updates["spo2"], updates["systolic_bp_mmhg"]) | |
| updates["done"] = "cardiovascular_collapse" in alerts | |
| return PatientState(**updates) | |
| def _available_tools(self) -> list[str]: | |
| return list(KNOWN_TOOL_NAMES) | |
| def _suggest_airway_support_mode(self) -> str: | |
| assert self._state is not None | |
| if self._state.spo2 is not None and self._state.spo2 < 0.85: | |
| return "bag_valve_mask" | |
| if self._state.spo2 is not None and self._state.spo2 < 0.9: | |
| return "cpap" | |
| if self._state.mental_status in {"pain", "unresponsive"}: | |
| return "tracheal" | |
| if self._state.mental_status == "verbal": | |
| return "oropharyngeal" | |
| return "nasopharyngeal" | |
| def _derive_mental_status(self, spo2: float, systolic_bp_mmhg: float) -> str: | |
| if spo2 < 0.75 or systolic_bp_mmhg < 60: | |
| return "unresponsive" | |
| if spo2 < 0.82 or systolic_bp_mmhg < 70: | |
| return "pain" | |
| if spo2 < 0.88 or systolic_bp_mmhg < 85: | |
| return "verbal" | |
| return "alert" | |
| def _tool_message(self, tool_name: str) -> str: | |
| if tool_name == "give_oxygen": | |
| return "Supplemental oxygen started." | |
| if tool_name == "give_fluids": | |
| return "Fluid resuscitation initiated." | |
| if tool_name == "control_bleeding": | |
| return "Bleeding control measures applied." | |
| if tool_name == "position_patient": | |
| return "Patient repositioned for support." | |
| if tool_name == "airway_support": | |
| return "Airway support applied." | |
| if tool_name == "give_pressor": | |
| return "Vasopressor support initiated." | |
| if tool_name == "needle_decompression": | |
| return "Needle decompression performed." | |
| if tool_name == "pericardiocentesis": | |
| return "Pericardiocentesis performed." | |
| return f"{tool_name} executed." | |
| def _configure_runtime_effects(self, kwargs: dict[str, object]) -> None: | |
| observation_noise_level = float(kwargs.get("observation_noise_level", self._observation_noise.config.noise_level)) | |
| raw_time_pressure_enabled = kwargs.get("time_pressure_enabled", self._time_pressure.config.enabled) | |
| if isinstance(raw_time_pressure_enabled, str): | |
| time_pressure_enabled = coerce_boolean_argument(raw_time_pressure_enabled) | |
| else: | |
| time_pressure_enabled = bool(raw_time_pressure_enabled) | |
| time_pressure_onset_s = float(kwargs.get("time_pressure_onset_s", self._time_pressure.config.onset_s)) | |
| time_pressure_escalation_per_minute = float( | |
| kwargs.get( | |
| "time_pressure_escalation_per_minute", | |
| self._time_pressure.config.escalation_per_minute, | |
| ) | |
| ) | |
| self._observation_noise = NoisyObservation(observation_noise_level) | |
| self._time_pressure = TimePressureMechanic( | |
| enabled=time_pressure_enabled, | |
| onset_s=time_pressure_onset_s, | |
| escalation_per_minute=time_pressure_escalation_per_minute, | |
| ) | |
| def _build_observed_state(self) -> tuple[PatientState, dict[str, object]]: | |
| assert self._state is not None | |
| observed_state, noise_metadata = self._observation_noise.apply( | |
| self._state, | |
| rng=self._episode_observation_rng, | |
| ) | |
| time_pressure_metadata = self._time_pressure.as_metadata( | |
| sim_time_s=self._state.sim_time_s, | |
| injury_severity=self._scenario.injury_severity if self._scenario is not None else 0.0, | |
| unstable=self._is_state_unstable(self._state), | |
| ) | |
| return observed_state, { | |
| "observation_noise": noise_metadata, | |
| "time_pressure": time_pressure_metadata, | |
| } | |
| def _current_time_pressure_multiplier(self, state: PatientState) -> float: | |
| assert self._scenario is not None | |
| return self._time_pressure.deterioration_multiplier( | |
| sim_time_s=state.sim_time_s, | |
| injury_severity=self._scenario.injury_severity, | |
| unstable=self._is_state_unstable(state), | |
| ) | |
| def _is_state_unstable(state: PatientState) -> bool: | |
| systolic = state.systolic_bp_mmhg if state.systolic_bp_mmhg is not None else 120.0 | |
| spo2 = state.spo2 if state.spo2 is not None else 1.0 | |
| return bool(state.active_alerts) or systolic < 95.0 or spo2 < 0.92 or state.mental_status != "alert" | |