Spaces:
Sleeping
Sleeping
| """Pulse-backed OpenEnv environment implementation.""" | |
| from __future__ import annotations | |
| import random | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import EnvironmentMetadata, State | |
| try: | |
| from ..models import PulsePhysiologyAction, PulsePhysiologyObservation | |
| from ..patient_state import PatientState | |
| from ..runtime_effects import NoisyObservation, TimePressureMechanic | |
| from ..tool_catalog import canonicalize_tool_name, coerce_boolean_argument | |
| from .atls_judge import ATLSJudge | |
| from .pathology_architect import PathologyArchitect, PathologyBlueprint | |
| from .patient_monitor import PatientMonitorVisualization | |
| from .pulse_engine_adapter import PulseEngineAdapter | |
| from .reward_engine import RewardBreakdown, RewardEngine, RewardTracker | |
| from .scenarios import DEFAULT_SCENARIO_ID, PatientProfile, ScenarioDefinition, get_scenario_definition | |
| from .tools import PulseToolExecutor | |
| except ImportError: | |
| from models import PulsePhysiologyAction, PulsePhysiologyObservation | |
| from patient_state import PatientState | |
| from runtime_effects import NoisyObservation, TimePressureMechanic | |
| from tool_catalog import canonicalize_tool_name, coerce_boolean_argument | |
| from server.atls_judge import ATLSJudge | |
| from server.pathology_architect import PathologyArchitect, PathologyBlueprint | |
| from server.patient_monitor import PatientMonitorVisualization | |
| from server.pulse_engine_adapter import PulseEngineAdapter | |
| from server.reward_engine import RewardBreakdown, RewardEngine, RewardTracker | |
| from server.scenarios import DEFAULT_SCENARIO_ID, PatientProfile, ScenarioDefinition, get_scenario_definition | |
| from server.tools import PulseToolExecutor | |
| class PulsePhysiologyEnvironment(Environment): | |
| """A Pulse-backed tool environment for trauma and resuscitation workflows.""" | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__( | |
| self, | |
| *, | |
| observation_noise_level: float = 0.0, | |
| time_pressure_enabled: bool = False, | |
| time_pressure_onset_s: float = 180.0, | |
| time_pressure_escalation_per_minute: float = 0.15, | |
| ) -> None: | |
| self._adapter = PulseEngineAdapter() | |
| self._tool_executor = PulseToolExecutor(self._adapter) | |
| self._reward_engine = RewardEngine() | |
| self._monitor = PatientMonitorVisualization() | |
| self._atls_judge = ATLSJudge() | |
| self._pathology_architect = PathologyArchitect() | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._scenario: ScenarioDefinition = get_scenario_definition(DEFAULT_SCENARIO_ID) | |
| self._selected_patient: PatientProfile | None = None | |
| self._latest_patient_state: PatientState | None = None | |
| self._reward_tracker: RewardTracker | None = None | |
| self._last_reward_breakdown = RewardBreakdown() | |
| self._active_blueprint: PathologyBlueprint | None = None | |
| self._state_history: list[PatientState] = [] | |
| self._observed_state_history: list[PatientState] = [] | |
| self._observation_noise = NoisyObservation(observation_noise_level) | |
| self._time_pressure = TimePressureMechanic( | |
| enabled=time_pressure_enabled, | |
| onset_s=time_pressure_onset_s, | |
| escalation_per_minute=time_pressure_escalation_per_minute, | |
| ) | |
| self._observation_rng = random.Random() | |
| def reset( | |
| self, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| **kwargs: object, | |
| ) -> PulsePhysiologyObservation: | |
| """Reset the environment and initialize the requested Pulse scenario.""" | |
| self._configure_runtime_effects(kwargs) | |
| try: | |
| blueprint = self._resolve_generated_blueprint(kwargs) | |
| except (KeyError, TypeError, ValueError) as exc: | |
| raise ValueError( | |
| "Invalid generated pathology inputs. Provide patient_id, injury_type, and severity." | |
| ) from exc | |
| if blueprint is not None: | |
| self._active_blueprint = blueprint | |
| self._scenario = self._pathology_architect.to_scenario_definition(blueprint) | |
| else: | |
| self._active_blueprint = None | |
| scenario_id = kwargs.get("scenario_id") | |
| try: | |
| self._scenario = get_scenario_definition( | |
| str(scenario_id) if scenario_id is not None else DEFAULT_SCENARIO_ID | |
| ) | |
| except KeyError as exc: | |
| raise ValueError(str(exc)) from exc | |
| self._state = State(episode_id=episode_id or str(uuid4()), step_count=0) | |
| rng = random.Random(seed) | |
| self._observation_rng = random.Random(rng.random()) | |
| self._selected_patient = self._scenario.choose_patient(rng) | |
| patient_state = self._adapter.load_patient( | |
| state_file=self._selected_patient.state_file, | |
| scenario_id=self._scenario.scenario_id, | |
| scenario_difficulty=self._scenario.difficulty, | |
| patient_id=self._selected_patient.patient_id, | |
| ) | |
| if self._scenario.setup is not None: | |
| self._scenario.setup(self._adapter) | |
| patient_state = self._adapter.get_full_state() | |
| patient_state = self._apply_episode_rules(patient_state) | |
| self._latest_patient_state = patient_state | |
| self._reward_tracker = self._reward_engine.start_episode(self._scenario, patient_state) | |
| self._last_reward_breakdown = RewardBreakdown( | |
| reward_profile=self._scenario.reward_profile, | |
| difficulty_multiplier=self._reward_engine.DIFFICULTY_MULTIPLIER[self._scenario.difficulty], | |
| action_budget_remaining=self._reward_tracker.action_budget_remaining, | |
| ) | |
| self._state_history = [patient_state] | |
| self._observed_state_history = [] | |
| return self._build_observation( | |
| patient_state, | |
| reward=0.0, | |
| tool_result=None, | |
| error=None, | |
| ) | |
| def step( | |
| self, | |
| action: PulsePhysiologyAction, | |
| timeout_s: float | None = None, | |
| **kwargs: object, | |
| ) -> PulsePhysiologyObservation: | |
| """Execute a named tool against the current Pulse scenario.""" | |
| del timeout_s, kwargs | |
| self._state.step_count += 1 | |
| action = action.model_copy( | |
| update={ | |
| "tool_name": canonicalize_tool_name( | |
| action.tool_name, | |
| allowed_tools=self._tool_executor.available_tools, | |
| ) | |
| } | |
| ) | |
| previous_state = self._latest_patient_state or self._apply_episode_rules(self._adapter.get_full_state()) | |
| execution = self._tool_executor.execute(action) | |
| current_state = self._apply_episode_rules(execution.state) | |
| if self._reward_tracker is None: | |
| self._reward_tracker = self._reward_engine.start_episode(self._scenario, previous_state) | |
| breakdown = self._reward_engine.score_step( | |
| self._reward_tracker, | |
| scenario=self._scenario, | |
| before=previous_state, | |
| after=current_state, | |
| action=action, | |
| success=execution.tool_result.success, | |
| had_error=execution.error is not None, | |
| time_pressure_multiplier=self._time_pressure.deterioration_multiplier( | |
| sim_time_s=current_state.sim_time_s, | |
| injury_severity=self._current_injury_severity(), | |
| unstable=self._is_state_unstable(current_state), | |
| ), | |
| ) | |
| reward = breakdown.total | |
| self._latest_patient_state = current_state | |
| self._last_reward_breakdown = breakdown | |
| self._state_history.append(current_state) | |
| return self._build_observation( | |
| current_state, | |
| reward=reward, | |
| tool_result=execution.tool_result, | |
| error=execution.error, | |
| ) | |
| def state(self) -> State: | |
| """Get the current environment state.""" | |
| return self._state | |
| def get_metadata(self) -> EnvironmentMetadata: | |
| description = ( | |
| "Pulse-backed trauma environment with engine-correct tools for airway, breathing, " | |
| "circulation, diagnostics, and decision support." | |
| ) | |
| return EnvironmentMetadata( | |
| name="PulsePhysiologyEnvironment", | |
| description=description, | |
| version="0.4.0", | |
| author="OpenAI Codex", | |
| ) | |
| def close(self) -> None: | |
| self._adapter.close() | |
| def _apply_episode_rules(self, state: PatientState) -> PatientState: | |
| done = state.done or state.sim_time_s >= self._scenario.max_time_s | |
| alerts = list(state.active_alerts) | |
| if state.sim_time_s >= self._scenario.max_time_s and "time_limit_reached" not in alerts: | |
| alerts.append("time_limit_reached") | |
| return state.model_copy(update={"done": done, "active_alerts": alerts}) | |
| def _build_observation( | |
| self, | |
| state: PatientState, | |
| *, | |
| reward: float, | |
| tool_result, | |
| error, | |
| ) -> PulsePhysiologyObservation: | |
| observed_state, runtime_effect_metadata = self._observe_state(state) | |
| observed_history = [*self._observed_state_history, observed_state] | |
| self._observed_state_history = observed_history | |
| metadata = { | |
| "step_count": self._state.step_count, | |
| "scenario_description": self._scenario.description, | |
| "scenario_difficulty": self._scenario.difficulty, | |
| "reward_profile": self._scenario.reward_profile, | |
| "patient_pool_size": len(self._scenario.patient_pool), | |
| "selected_state_file": self._selected_patient.state_file if self._selected_patient is not None else None, | |
| "action_budget_remaining": self._reward_tracker.action_budget_remaining if self._reward_tracker is not None else None, | |
| "reward_breakdown": self._last_reward_breakdown.as_metadata(), | |
| "available_tools": self._tool_executor.available_tools, | |
| "patient_monitor": self._monitor.build( | |
| history=observed_history or [observed_state], | |
| action_history=self._reward_tracker.action_history if self._reward_tracker is not None else [], | |
| current_state=observed_state, | |
| ).as_dict(), | |
| "atls_judge": self._atls_judge.evaluate( | |
| state=state, | |
| action_history=self._reward_tracker.action_history if self._reward_tracker is not None else [], | |
| reward_profile=self._scenario.reward_profile, | |
| state_history=self._state_history, | |
| ).as_dict(), | |
| "pathology_blueprint": self._active_blueprint.as_dict() if self._active_blueprint is not None else None, | |
| **runtime_effect_metadata, | |
| } | |
| return PulsePhysiologyObservation.from_patient_state( | |
| observed_state, | |
| reward=reward, | |
| available_tools=self._tool_executor.available_tools, | |
| tool_result=tool_result, | |
| error=error, | |
| metadata=metadata, | |
| ) | |
| def _resolve_generated_blueprint(self, kwargs: dict[str, object]) -> PathologyBlueprint | None: | |
| has_scenario_id = kwargs.get("scenario_id") is not None | |
| raw_blueprint = kwargs.get("pathology_blueprint") | |
| base_keys = {"patient_id", "severity"} | |
| selector_keys = {"injury_type", "injury_types"} | |
| provided_base_keys = {key for key in base_keys if kwargs.get(key) is not None} | |
| provided_selector_keys = {key for key in selector_keys if kwargs.get(key) is not None} | |
| provided_authoring_keys = provided_base_keys | provided_selector_keys | |
| if has_scenario_id and (raw_blueprint is not None or provided_authoring_keys): | |
| raise ValueError( | |
| "Pass either scenario_id or generated-case authoring inputs, not both." | |
| ) | |
| if raw_blueprint is not None and provided_authoring_keys: | |
| raise ValueError( | |
| "Pass either pathology_blueprint or patient_id/injury_type(s)/severity, not both." | |
| ) | |
| if isinstance(raw_blueprint, dict): | |
| raw_base_keys = {key for key in base_keys if raw_blueprint.get(key) is not None} | |
| raw_selector_keys = {key for key in selector_keys if raw_blueprint.get(key) is not None} | |
| missing_base_keys = base_keys - raw_base_keys | |
| if missing_base_keys: | |
| missing = ", ".join(sorted(missing_base_keys)) | |
| raise ValueError( | |
| f"pathology_blueprint is missing required keys: {missing}" | |
| ) | |
| if not raw_selector_keys: | |
| raise ValueError( | |
| "pathology_blueprint must include injury_type or injury_types." | |
| ) | |
| return self._pathology_architect.build_blueprint( | |
| patient_id=str(raw_blueprint["patient_id"]), | |
| injury_type=( | |
| str(raw_blueprint["injury_type"]) | |
| if raw_blueprint.get("injury_type") is not None and raw_blueprint.get("injury_types") is None | |
| else None | |
| ), | |
| injury_types=raw_blueprint.get("injury_types"), | |
| severity=float(raw_blueprint["severity"]), | |
| ) | |
| if provided_authoring_keys and not provided_base_keys: | |
| raise ValueError( | |
| "Generated-case reset requires patient_id and severity together with injury_type or injury_types." | |
| ) | |
| if provided_selector_keys and len(provided_selector_keys) != 1: | |
| raise ValueError( | |
| "Generated-case reset accepts exactly one of injury_type or injury_types." | |
| ) | |
| if provided_base_keys and provided_base_keys != base_keys: | |
| missing = ", ".join(sorted(base_keys - provided_base_keys)) | |
| raise ValueError( | |
| f"Generated-case reset requires patient_id and severity together. Missing: {missing}" | |
| ) | |
| if provided_base_keys == base_keys and not provided_selector_keys: | |
| raise ValueError( | |
| "Generated-case reset requires exactly one of injury_type or injury_types." | |
| ) | |
| if provided_base_keys == base_keys and len(provided_selector_keys) == 1: | |
| return self._pathology_architect.build_blueprint( | |
| patient_id=str(kwargs["patient_id"]), | |
| injury_type=str(kwargs["injury_type"]) if kwargs.get("injury_type") is not None else None, | |
| injury_types=kwargs.get("injury_types"), | |
| severity=float(kwargs["severity"]), | |
| ) | |
| return None | |
| def _configure_runtime_effects(self, kwargs: dict[str, object]) -> None: | |
| observation_noise_level = float(kwargs.get("observation_noise_level", self._observation_noise.config.noise_level)) | |
| raw_time_pressure_enabled = kwargs.get("time_pressure_enabled", self._time_pressure.config.enabled) | |
| if isinstance(raw_time_pressure_enabled, str): | |
| time_pressure_enabled = coerce_boolean_argument(raw_time_pressure_enabled) | |
| else: | |
| time_pressure_enabled = bool(raw_time_pressure_enabled) | |
| time_pressure_onset_s = float(kwargs.get("time_pressure_onset_s", self._time_pressure.config.onset_s)) | |
| time_pressure_escalation_per_minute = float( | |
| kwargs.get( | |
| "time_pressure_escalation_per_minute", | |
| self._time_pressure.config.escalation_per_minute, | |
| ) | |
| ) | |
| self._observation_noise = NoisyObservation(observation_noise_level) | |
| self._time_pressure = TimePressureMechanic( | |
| enabled=time_pressure_enabled, | |
| onset_s=time_pressure_onset_s, | |
| escalation_per_minute=time_pressure_escalation_per_minute, | |
| ) | |
| def _observe_state(self, state: PatientState) -> tuple[PatientState, dict[str, object]]: | |
| observed_state, noise_metadata = self._observation_noise.apply( | |
| state, | |
| rng=self._observation_rng, | |
| ) | |
| time_pressure_metadata = self._time_pressure.as_metadata( | |
| sim_time_s=state.sim_time_s, | |
| injury_severity=self._current_injury_severity(), | |
| unstable=self._is_state_unstable(state), | |
| ) | |
| return observed_state, { | |
| "observation_noise": noise_metadata, | |
| "time_pressure": time_pressure_metadata, | |
| } | |
| def _current_injury_severity(self) -> float: | |
| if self._active_blueprint is not None: | |
| return float(self._active_blueprint.severity) | |
| return {"easy": 0.35, "medium": 0.6, "hard": 0.85}[self._scenario.difficulty] | |
| def _is_state_unstable(state: PatientState) -> bool: | |
| systolic = state.systolic_bp_mmhg if state.systolic_bp_mmhg is not None else 120.0 | |
| spo2 = state.spo2 if state.spo2 is not None else 1.0 | |
| return bool(state.active_alerts) or systolic < 95.0 or spo2 < 0.92 or state.mental_status != "alert" | |