"""Pulse-backed OpenEnv environment implementation.""" from __future__ import annotations import random from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import EnvironmentMetadata, State try: from ..models import PulsePhysiologyAction, PulsePhysiologyObservation from ..patient_state import PatientState from ..runtime_effects import NoisyObservation, TimePressureMechanic from ..tool_catalog import canonicalize_tool_name, coerce_boolean_argument from .atls_judge import ATLSJudge from .pathology_architect import PathologyArchitect, PathologyBlueprint from .patient_monitor import PatientMonitorVisualization from .pulse_engine_adapter import PulseEngineAdapter from .reward_engine import RewardBreakdown, RewardEngine, RewardTracker from .scenarios import DEFAULT_SCENARIO_ID, PatientProfile, ScenarioDefinition, get_scenario_definition from .tools import PulseToolExecutor except ImportError: from models import PulsePhysiologyAction, PulsePhysiologyObservation from patient_state import PatientState from runtime_effects import NoisyObservation, TimePressureMechanic from tool_catalog import canonicalize_tool_name, coerce_boolean_argument from server.atls_judge import ATLSJudge from server.pathology_architect import PathologyArchitect, PathologyBlueprint from server.patient_monitor import PatientMonitorVisualization from server.pulse_engine_adapter import PulseEngineAdapter from server.reward_engine import RewardBreakdown, RewardEngine, RewardTracker from server.scenarios import DEFAULT_SCENARIO_ID, PatientProfile, ScenarioDefinition, get_scenario_definition from server.tools import PulseToolExecutor class PulsePhysiologyEnvironment(Environment): """A Pulse-backed tool environment for trauma and resuscitation workflows.""" SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__( self, *, observation_noise_level: float = 0.0, time_pressure_enabled: bool = False, time_pressure_onset_s: float = 180.0, time_pressure_escalation_per_minute: float = 0.15, ) -> None: self._adapter = PulseEngineAdapter() self._tool_executor = PulseToolExecutor(self._adapter) self._reward_engine = RewardEngine() self._monitor = PatientMonitorVisualization() self._atls_judge = ATLSJudge() self._pathology_architect = PathologyArchitect() self._state = State(episode_id=str(uuid4()), step_count=0) self._scenario: ScenarioDefinition = get_scenario_definition(DEFAULT_SCENARIO_ID) self._selected_patient: PatientProfile | None = None self._latest_patient_state: PatientState | None = None self._reward_tracker: RewardTracker | None = None self._last_reward_breakdown = RewardBreakdown() self._active_blueprint: PathologyBlueprint | None = None self._state_history: list[PatientState] = [] self._observed_state_history: list[PatientState] = [] self._observation_noise = NoisyObservation(observation_noise_level) self._time_pressure = TimePressureMechanic( enabled=time_pressure_enabled, onset_s=time_pressure_onset_s, escalation_per_minute=time_pressure_escalation_per_minute, ) self._observation_rng = random.Random() def reset( self, seed: int | None = None, episode_id: str | None = None, **kwargs: object, ) -> PulsePhysiologyObservation: """Reset the environment and initialize the requested Pulse scenario.""" self._configure_runtime_effects(kwargs) try: blueprint = self._resolve_generated_blueprint(kwargs) except (KeyError, TypeError, ValueError) as exc: raise ValueError( "Invalid generated pathology inputs. Provide patient_id, injury_type, and severity." ) from exc if blueprint is not None: self._active_blueprint = blueprint self._scenario = self._pathology_architect.to_scenario_definition(blueprint) else: self._active_blueprint = None scenario_id = kwargs.get("scenario_id") try: self._scenario = get_scenario_definition( str(scenario_id) if scenario_id is not None else DEFAULT_SCENARIO_ID ) except KeyError as exc: raise ValueError(str(exc)) from exc self._state = State(episode_id=episode_id or str(uuid4()), step_count=0) rng = random.Random(seed) self._observation_rng = random.Random(rng.random()) self._selected_patient = self._scenario.choose_patient(rng) patient_state = self._adapter.load_patient( state_file=self._selected_patient.state_file, scenario_id=self._scenario.scenario_id, scenario_difficulty=self._scenario.difficulty, patient_id=self._selected_patient.patient_id, ) if self._scenario.setup is not None: self._scenario.setup(self._adapter) patient_state = self._adapter.get_full_state() patient_state = self._apply_episode_rules(patient_state) self._latest_patient_state = patient_state self._reward_tracker = self._reward_engine.start_episode(self._scenario, patient_state) self._last_reward_breakdown = RewardBreakdown( reward_profile=self._scenario.reward_profile, difficulty_multiplier=self._reward_engine.DIFFICULTY_MULTIPLIER[self._scenario.difficulty], action_budget_remaining=self._reward_tracker.action_budget_remaining, ) self._state_history = [patient_state] self._observed_state_history = [] return self._build_observation( patient_state, reward=0.0, tool_result=None, error=None, ) def step( self, action: PulsePhysiologyAction, timeout_s: float | None = None, **kwargs: object, ) -> PulsePhysiologyObservation: """Execute a named tool against the current Pulse scenario.""" del timeout_s, kwargs self._state.step_count += 1 action = action.model_copy( update={ "tool_name": canonicalize_tool_name( action.tool_name, allowed_tools=self._tool_executor.available_tools, ) } ) previous_state = self._latest_patient_state or self._apply_episode_rules(self._adapter.get_full_state()) execution = self._tool_executor.execute(action) current_state = self._apply_episode_rules(execution.state) if self._reward_tracker is None: self._reward_tracker = self._reward_engine.start_episode(self._scenario, previous_state) breakdown = self._reward_engine.score_step( self._reward_tracker, scenario=self._scenario, before=previous_state, after=current_state, action=action, success=execution.tool_result.success, had_error=execution.error is not None, time_pressure_multiplier=self._time_pressure.deterioration_multiplier( sim_time_s=current_state.sim_time_s, injury_severity=self._current_injury_severity(), unstable=self._is_state_unstable(current_state), ), ) reward = breakdown.total self._latest_patient_state = current_state self._last_reward_breakdown = breakdown self._state_history.append(current_state) return self._build_observation( current_state, reward=reward, tool_result=execution.tool_result, error=execution.error, ) @property def state(self) -> State: """Get the current environment state.""" return self._state def get_metadata(self) -> EnvironmentMetadata: description = ( "Pulse-backed trauma environment with engine-correct tools for airway, breathing, " "circulation, diagnostics, and decision support." ) return EnvironmentMetadata( name="PulsePhysiologyEnvironment", description=description, version="0.4.0", author="OpenAI Codex", ) def close(self) -> None: self._adapter.close() def _apply_episode_rules(self, state: PatientState) -> PatientState: done = state.done or state.sim_time_s >= self._scenario.max_time_s alerts = list(state.active_alerts) if state.sim_time_s >= self._scenario.max_time_s and "time_limit_reached" not in alerts: alerts.append("time_limit_reached") return state.model_copy(update={"done": done, "active_alerts": alerts}) def _build_observation( self, state: PatientState, *, reward: float, tool_result, error, ) -> PulsePhysiologyObservation: observed_state, runtime_effect_metadata = self._observe_state(state) observed_history = [*self._observed_state_history, observed_state] self._observed_state_history = observed_history metadata = { "step_count": self._state.step_count, "scenario_description": self._scenario.description, "scenario_difficulty": self._scenario.difficulty, "reward_profile": self._scenario.reward_profile, "patient_pool_size": len(self._scenario.patient_pool), "selected_state_file": self._selected_patient.state_file if self._selected_patient is not None else None, "action_budget_remaining": self._reward_tracker.action_budget_remaining if self._reward_tracker is not None else None, "reward_breakdown": self._last_reward_breakdown.as_metadata(), "available_tools": self._tool_executor.available_tools, "patient_monitor": self._monitor.build( history=observed_history or [observed_state], action_history=self._reward_tracker.action_history if self._reward_tracker is not None else [], current_state=observed_state, ).as_dict(), "atls_judge": self._atls_judge.evaluate( state=state, action_history=self._reward_tracker.action_history if self._reward_tracker is not None else [], reward_profile=self._scenario.reward_profile, state_history=self._state_history, ).as_dict(), "pathology_blueprint": self._active_blueprint.as_dict() if self._active_blueprint is not None else None, **runtime_effect_metadata, } return PulsePhysiologyObservation.from_patient_state( observed_state, reward=reward, available_tools=self._tool_executor.available_tools, tool_result=tool_result, error=error, metadata=metadata, ) def _resolve_generated_blueprint(self, kwargs: dict[str, object]) -> PathologyBlueprint | None: has_scenario_id = kwargs.get("scenario_id") is not None raw_blueprint = kwargs.get("pathology_blueprint") base_keys = {"patient_id", "severity"} selector_keys = {"injury_type", "injury_types"} provided_base_keys = {key for key in base_keys if kwargs.get(key) is not None} provided_selector_keys = {key for key in selector_keys if kwargs.get(key) is not None} provided_authoring_keys = provided_base_keys | provided_selector_keys if has_scenario_id and (raw_blueprint is not None or provided_authoring_keys): raise ValueError( "Pass either scenario_id or generated-case authoring inputs, not both." ) if raw_blueprint is not None and provided_authoring_keys: raise ValueError( "Pass either pathology_blueprint or patient_id/injury_type(s)/severity, not both." ) if isinstance(raw_blueprint, dict): raw_base_keys = {key for key in base_keys if raw_blueprint.get(key) is not None} raw_selector_keys = {key for key in selector_keys if raw_blueprint.get(key) is not None} missing_base_keys = base_keys - raw_base_keys if missing_base_keys: missing = ", ".join(sorted(missing_base_keys)) raise ValueError( f"pathology_blueprint is missing required keys: {missing}" ) if not raw_selector_keys: raise ValueError( "pathology_blueprint must include injury_type or injury_types." ) return self._pathology_architect.build_blueprint( patient_id=str(raw_blueprint["patient_id"]), injury_type=( str(raw_blueprint["injury_type"]) if raw_blueprint.get("injury_type") is not None and raw_blueprint.get("injury_types") is None else None ), injury_types=raw_blueprint.get("injury_types"), severity=float(raw_blueprint["severity"]), ) if provided_authoring_keys and not provided_base_keys: raise ValueError( "Generated-case reset requires patient_id and severity together with injury_type or injury_types." ) if provided_selector_keys and len(provided_selector_keys) != 1: raise ValueError( "Generated-case reset accepts exactly one of injury_type or injury_types." ) if provided_base_keys and provided_base_keys != base_keys: missing = ", ".join(sorted(base_keys - provided_base_keys)) raise ValueError( f"Generated-case reset requires patient_id and severity together. Missing: {missing}" ) if provided_base_keys == base_keys and not provided_selector_keys: raise ValueError( "Generated-case reset requires exactly one of injury_type or injury_types." ) if provided_base_keys == base_keys and len(provided_selector_keys) == 1: return self._pathology_architect.build_blueprint( patient_id=str(kwargs["patient_id"]), injury_type=str(kwargs["injury_type"]) if kwargs.get("injury_type") is not None else None, injury_types=kwargs.get("injury_types"), severity=float(kwargs["severity"]), ) return None def _configure_runtime_effects(self, kwargs: dict[str, object]) -> None: observation_noise_level = float(kwargs.get("observation_noise_level", self._observation_noise.config.noise_level)) raw_time_pressure_enabled = kwargs.get("time_pressure_enabled", self._time_pressure.config.enabled) if isinstance(raw_time_pressure_enabled, str): time_pressure_enabled = coerce_boolean_argument(raw_time_pressure_enabled) else: time_pressure_enabled = bool(raw_time_pressure_enabled) time_pressure_onset_s = float(kwargs.get("time_pressure_onset_s", self._time_pressure.config.onset_s)) time_pressure_escalation_per_minute = float( kwargs.get( "time_pressure_escalation_per_minute", self._time_pressure.config.escalation_per_minute, ) ) self._observation_noise = NoisyObservation(observation_noise_level) self._time_pressure = TimePressureMechanic( enabled=time_pressure_enabled, onset_s=time_pressure_onset_s, escalation_per_minute=time_pressure_escalation_per_minute, ) def _observe_state(self, state: PatientState) -> tuple[PatientState, dict[str, object]]: observed_state, noise_metadata = self._observation_noise.apply( state, rng=self._observation_rng, ) time_pressure_metadata = self._time_pressure.as_metadata( sim_time_s=state.sim_time_s, injury_severity=self._current_injury_severity(), unstable=self._is_state_unstable(state), ) return observed_state, { "observation_noise": noise_metadata, "time_pressure": time_pressure_metadata, } def _current_injury_severity(self) -> float: if self._active_blueprint is not None: return float(self._active_blueprint.severity) return {"easy": 0.35, "medium": 0.6, "hard": 0.85}[self._scenario.difficulty] @staticmethod def _is_state_unstable(state: PatientState) -> bool: systolic = state.systolic_bp_mmhg if state.systolic_bp_mmhg is not None else 120.0 spo2 = state.spo2 if state.spo2 is not None else 1.0 return bool(state.active_alerts) or systolic < 95.0 or spo2 < 0.92 or state.mental_status != "alert"