Pulse_ER_env / server /pulse_physiology_env_environment.py
KChad's picture
Add all docs_assets image assets to Hugging Face Space snapshot
9b1756a
"""Pulse-backed OpenEnv environment implementation."""
from __future__ import annotations
import random
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import EnvironmentMetadata, State
try:
from ..models import PulsePhysiologyAction, PulsePhysiologyObservation
from ..patient_state import PatientState
from ..runtime_effects import NoisyObservation, TimePressureMechanic
from ..tool_catalog import canonicalize_tool_name, coerce_boolean_argument
from .atls_judge import ATLSJudge
from .pathology_architect import PathologyArchitect, PathologyBlueprint
from .patient_monitor import PatientMonitorVisualization
from .pulse_engine_adapter import PulseEngineAdapter
from .reward_engine import RewardBreakdown, RewardEngine, RewardTracker
from .scenarios import DEFAULT_SCENARIO_ID, PatientProfile, ScenarioDefinition, get_scenario_definition
from .tools import PulseToolExecutor
except ImportError:
from models import PulsePhysiologyAction, PulsePhysiologyObservation
from patient_state import PatientState
from runtime_effects import NoisyObservation, TimePressureMechanic
from tool_catalog import canonicalize_tool_name, coerce_boolean_argument
from server.atls_judge import ATLSJudge
from server.pathology_architect import PathologyArchitect, PathologyBlueprint
from server.patient_monitor import PatientMonitorVisualization
from server.pulse_engine_adapter import PulseEngineAdapter
from server.reward_engine import RewardBreakdown, RewardEngine, RewardTracker
from server.scenarios import DEFAULT_SCENARIO_ID, PatientProfile, ScenarioDefinition, get_scenario_definition
from server.tools import PulseToolExecutor
class PulsePhysiologyEnvironment(Environment):
"""A Pulse-backed tool environment for trauma and resuscitation workflows."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(
self,
*,
observation_noise_level: float = 0.0,
time_pressure_enabled: bool = False,
time_pressure_onset_s: float = 180.0,
time_pressure_escalation_per_minute: float = 0.15,
) -> None:
self._adapter = PulseEngineAdapter()
self._tool_executor = PulseToolExecutor(self._adapter)
self._reward_engine = RewardEngine()
self._monitor = PatientMonitorVisualization()
self._atls_judge = ATLSJudge()
self._pathology_architect = PathologyArchitect()
self._state = State(episode_id=str(uuid4()), step_count=0)
self._scenario: ScenarioDefinition = get_scenario_definition(DEFAULT_SCENARIO_ID)
self._selected_patient: PatientProfile | None = None
self._latest_patient_state: PatientState | None = None
self._reward_tracker: RewardTracker | None = None
self._last_reward_breakdown = RewardBreakdown()
self._active_blueprint: PathologyBlueprint | None = None
self._state_history: list[PatientState] = []
self._observed_state_history: list[PatientState] = []
self._observation_noise = NoisyObservation(observation_noise_level)
self._time_pressure = TimePressureMechanic(
enabled=time_pressure_enabled,
onset_s=time_pressure_onset_s,
escalation_per_minute=time_pressure_escalation_per_minute,
)
self._observation_rng = random.Random()
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
**kwargs: object,
) -> PulsePhysiologyObservation:
"""Reset the environment and initialize the requested Pulse scenario."""
self._configure_runtime_effects(kwargs)
try:
blueprint = self._resolve_generated_blueprint(kwargs)
except (KeyError, TypeError, ValueError) as exc:
raise ValueError(
"Invalid generated pathology inputs. Provide patient_id, injury_type, and severity."
) from exc
if blueprint is not None:
self._active_blueprint = blueprint
self._scenario = self._pathology_architect.to_scenario_definition(blueprint)
else:
self._active_blueprint = None
scenario_id = kwargs.get("scenario_id")
try:
self._scenario = get_scenario_definition(
str(scenario_id) if scenario_id is not None else DEFAULT_SCENARIO_ID
)
except KeyError as exc:
raise ValueError(str(exc)) from exc
self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
rng = random.Random(seed)
self._observation_rng = random.Random(rng.random())
self._selected_patient = self._scenario.choose_patient(rng)
patient_state = self._adapter.load_patient(
state_file=self._selected_patient.state_file,
scenario_id=self._scenario.scenario_id,
scenario_difficulty=self._scenario.difficulty,
patient_id=self._selected_patient.patient_id,
)
if self._scenario.setup is not None:
self._scenario.setup(self._adapter)
patient_state = self._adapter.get_full_state()
patient_state = self._apply_episode_rules(patient_state)
self._latest_patient_state = patient_state
self._reward_tracker = self._reward_engine.start_episode(self._scenario, patient_state)
self._last_reward_breakdown = RewardBreakdown(
reward_profile=self._scenario.reward_profile,
difficulty_multiplier=self._reward_engine.DIFFICULTY_MULTIPLIER[self._scenario.difficulty],
action_budget_remaining=self._reward_tracker.action_budget_remaining,
)
self._state_history = [patient_state]
self._observed_state_history = []
return self._build_observation(
patient_state,
reward=0.0,
tool_result=None,
error=None,
)
def step(
self,
action: PulsePhysiologyAction,
timeout_s: float | None = None,
**kwargs: object,
) -> PulsePhysiologyObservation:
"""Execute a named tool against the current Pulse scenario."""
del timeout_s, kwargs
self._state.step_count += 1
action = action.model_copy(
update={
"tool_name": canonicalize_tool_name(
action.tool_name,
allowed_tools=self._tool_executor.available_tools,
)
}
)
previous_state = self._latest_patient_state or self._apply_episode_rules(self._adapter.get_full_state())
execution = self._tool_executor.execute(action)
current_state = self._apply_episode_rules(execution.state)
if self._reward_tracker is None:
self._reward_tracker = self._reward_engine.start_episode(self._scenario, previous_state)
breakdown = self._reward_engine.score_step(
self._reward_tracker,
scenario=self._scenario,
before=previous_state,
after=current_state,
action=action,
success=execution.tool_result.success,
had_error=execution.error is not None,
time_pressure_multiplier=self._time_pressure.deterioration_multiplier(
sim_time_s=current_state.sim_time_s,
injury_severity=self._current_injury_severity(),
unstable=self._is_state_unstable(current_state),
),
)
reward = breakdown.total
self._latest_patient_state = current_state
self._last_reward_breakdown = breakdown
self._state_history.append(current_state)
return self._build_observation(
current_state,
reward=reward,
tool_result=execution.tool_result,
error=execution.error,
)
@property
def state(self) -> State:
"""Get the current environment state."""
return self._state
def get_metadata(self) -> EnvironmentMetadata:
description = (
"Pulse-backed trauma environment with engine-correct tools for airway, breathing, "
"circulation, diagnostics, and decision support."
)
return EnvironmentMetadata(
name="PulsePhysiologyEnvironment",
description=description,
version="0.4.0",
author="OpenAI Codex",
)
def close(self) -> None:
self._adapter.close()
def _apply_episode_rules(self, state: PatientState) -> PatientState:
done = state.done or state.sim_time_s >= self._scenario.max_time_s
alerts = list(state.active_alerts)
if state.sim_time_s >= self._scenario.max_time_s and "time_limit_reached" not in alerts:
alerts.append("time_limit_reached")
return state.model_copy(update={"done": done, "active_alerts": alerts})
def _build_observation(
self,
state: PatientState,
*,
reward: float,
tool_result,
error,
) -> PulsePhysiologyObservation:
observed_state, runtime_effect_metadata = self._observe_state(state)
observed_history = [*self._observed_state_history, observed_state]
self._observed_state_history = observed_history
metadata = {
"step_count": self._state.step_count,
"scenario_description": self._scenario.description,
"scenario_difficulty": self._scenario.difficulty,
"reward_profile": self._scenario.reward_profile,
"patient_pool_size": len(self._scenario.patient_pool),
"selected_state_file": self._selected_patient.state_file if self._selected_patient is not None else None,
"action_budget_remaining": self._reward_tracker.action_budget_remaining if self._reward_tracker is not None else None,
"reward_breakdown": self._last_reward_breakdown.as_metadata(),
"available_tools": self._tool_executor.available_tools,
"patient_monitor": self._monitor.build(
history=observed_history or [observed_state],
action_history=self._reward_tracker.action_history if self._reward_tracker is not None else [],
current_state=observed_state,
).as_dict(),
"atls_judge": self._atls_judge.evaluate(
state=state,
action_history=self._reward_tracker.action_history if self._reward_tracker is not None else [],
reward_profile=self._scenario.reward_profile,
state_history=self._state_history,
).as_dict(),
"pathology_blueprint": self._active_blueprint.as_dict() if self._active_blueprint is not None else None,
**runtime_effect_metadata,
}
return PulsePhysiologyObservation.from_patient_state(
observed_state,
reward=reward,
available_tools=self._tool_executor.available_tools,
tool_result=tool_result,
error=error,
metadata=metadata,
)
def _resolve_generated_blueprint(self, kwargs: dict[str, object]) -> PathologyBlueprint | None:
has_scenario_id = kwargs.get("scenario_id") is not None
raw_blueprint = kwargs.get("pathology_blueprint")
base_keys = {"patient_id", "severity"}
selector_keys = {"injury_type", "injury_types"}
provided_base_keys = {key for key in base_keys if kwargs.get(key) is not None}
provided_selector_keys = {key for key in selector_keys if kwargs.get(key) is not None}
provided_authoring_keys = provided_base_keys | provided_selector_keys
if has_scenario_id and (raw_blueprint is not None or provided_authoring_keys):
raise ValueError(
"Pass either scenario_id or generated-case authoring inputs, not both."
)
if raw_blueprint is not None and provided_authoring_keys:
raise ValueError(
"Pass either pathology_blueprint or patient_id/injury_type(s)/severity, not both."
)
if isinstance(raw_blueprint, dict):
raw_base_keys = {key for key in base_keys if raw_blueprint.get(key) is not None}
raw_selector_keys = {key for key in selector_keys if raw_blueprint.get(key) is not None}
missing_base_keys = base_keys - raw_base_keys
if missing_base_keys:
missing = ", ".join(sorted(missing_base_keys))
raise ValueError(
f"pathology_blueprint is missing required keys: {missing}"
)
if not raw_selector_keys:
raise ValueError(
"pathology_blueprint must include injury_type or injury_types."
)
return self._pathology_architect.build_blueprint(
patient_id=str(raw_blueprint["patient_id"]),
injury_type=(
str(raw_blueprint["injury_type"])
if raw_blueprint.get("injury_type") is not None and raw_blueprint.get("injury_types") is None
else None
),
injury_types=raw_blueprint.get("injury_types"),
severity=float(raw_blueprint["severity"]),
)
if provided_authoring_keys and not provided_base_keys:
raise ValueError(
"Generated-case reset requires patient_id and severity together with injury_type or injury_types."
)
if provided_selector_keys and len(provided_selector_keys) != 1:
raise ValueError(
"Generated-case reset accepts exactly one of injury_type or injury_types."
)
if provided_base_keys and provided_base_keys != base_keys:
missing = ", ".join(sorted(base_keys - provided_base_keys))
raise ValueError(
f"Generated-case reset requires patient_id and severity together. Missing: {missing}"
)
if provided_base_keys == base_keys and not provided_selector_keys:
raise ValueError(
"Generated-case reset requires exactly one of injury_type or injury_types."
)
if provided_base_keys == base_keys and len(provided_selector_keys) == 1:
return self._pathology_architect.build_blueprint(
patient_id=str(kwargs["patient_id"]),
injury_type=str(kwargs["injury_type"]) if kwargs.get("injury_type") is not None else None,
injury_types=kwargs.get("injury_types"),
severity=float(kwargs["severity"]),
)
return None
def _configure_runtime_effects(self, kwargs: dict[str, object]) -> None:
observation_noise_level = float(kwargs.get("observation_noise_level", self._observation_noise.config.noise_level))
raw_time_pressure_enabled = kwargs.get("time_pressure_enabled", self._time_pressure.config.enabled)
if isinstance(raw_time_pressure_enabled, str):
time_pressure_enabled = coerce_boolean_argument(raw_time_pressure_enabled)
else:
time_pressure_enabled = bool(raw_time_pressure_enabled)
time_pressure_onset_s = float(kwargs.get("time_pressure_onset_s", self._time_pressure.config.onset_s))
time_pressure_escalation_per_minute = float(
kwargs.get(
"time_pressure_escalation_per_minute",
self._time_pressure.config.escalation_per_minute,
)
)
self._observation_noise = NoisyObservation(observation_noise_level)
self._time_pressure = TimePressureMechanic(
enabled=time_pressure_enabled,
onset_s=time_pressure_onset_s,
escalation_per_minute=time_pressure_escalation_per_minute,
)
def _observe_state(self, state: PatientState) -> tuple[PatientState, dict[str, object]]:
observed_state, noise_metadata = self._observation_noise.apply(
state,
rng=self._observation_rng,
)
time_pressure_metadata = self._time_pressure.as_metadata(
sim_time_s=state.sim_time_s,
injury_severity=self._current_injury_severity(),
unstable=self._is_state_unstable(state),
)
return observed_state, {
"observation_noise": noise_metadata,
"time_pressure": time_pressure_metadata,
}
def _current_injury_severity(self) -> float:
if self._active_blueprint is not None:
return float(self._active_blueprint.severity)
return {"easy": 0.35, "medium": 0.6, "hard": 0.85}[self._scenario.difficulty]
@staticmethod
def _is_state_unstable(state: PatientState) -> bool:
systolic = state.systolic_bp_mmhg if state.systolic_bp_mmhg is not None else 120.0
spo2 = state.spo2 if state.spo2 is not None else 1.0
return bool(state.active_alerts) or systolic < 95.0 or spo2 < 0.92 or state.mental_status != "alert"