Pulse_ER_env / real_backend.py
KChad's picture
Add all docs_assets image assets to Hugging Face Space snapshot
9b1756a
"""Consumer-side wrapper that adapts the real OpenEnv runtime to PatientBackend.
The real backend returns ``PulsePhysiologyObservation`` objects directly, while
Person 2's runner and validation stack expect ``EnvironmentResponse`` envelopes.
This thin wrapper preserves Person 1's runtime behavior and reshapes it for the
consumer-side pipeline without touching engine internals.
"""
from __future__ import annotations
from typing import Callable
from .models import EnvironmentResponse, ObservationMetadata, PulsePhysiologyObservation, ToolAction
from .patient_state import PatientState
from .server.adapters import PatientBackend
from .tool_catalog import KNOWN_TOOL_NAMES
_REAL_SCENARIO_ALIASES = {
"baseline_stable": "polytrauma_demo",
"respiratory_distress": "trauma_easy_soldier",
"hemorrhagic_shock": "trauma_hard_underweight",
}
class RealPulseBackend(PatientBackend):
"""Adapt ``PulsePhysiologyEnvironment`` to the ``PatientBackend`` interface."""
def __init__(
self,
default_scenario_id: str | None = None,
*,
environment_factory: Callable[[], object] | None = None,
observation_noise_level: float = 0.0,
time_pressure_enabled: bool = False,
time_pressure_onset_s: float = 180.0,
time_pressure_escalation_per_minute: float = 0.15,
) -> None:
self._default_scenario_id = default_scenario_id
self._environment = self._build_environment(environment_factory)
self._latest_observation: PulsePhysiologyObservation | None = None
self._reset_defaults = {
"observation_noise_level": observation_noise_level,
"time_pressure_enabled": time_pressure_enabled,
"time_pressure_onset_s": time_pressure_onset_s,
"time_pressure_escalation_per_minute": time_pressure_escalation_per_minute,
}
def reset(self, scenario_id: str | None = None, **kwargs: object) -> EnvironmentResponse:
"""Reset the real runtime and wrap its observation in the consumer envelope."""
selected_scenario_id = scenario_id or self._default_scenario_id
reset_kwargs = {
**self._reset_defaults,
**kwargs,
}
if selected_scenario_id is None:
observation = self._environment.reset(**reset_kwargs)
else:
observation = self._environment.reset(
scenario_id=self._resolve_real_scenario_id(selected_scenario_id),
**reset_kwargs,
)
return self._wrap_observation(observation)
def step(self, action: ToolAction) -> EnvironmentResponse:
"""Execute one action against the real runtime and wrap the response."""
observation = self._environment.step(action)
return self._wrap_observation(observation)
def get_state(self) -> PatientState:
"""Reconstruct a ``PatientState`` view from the latest wrapped observation."""
if self._latest_observation is None:
raise RuntimeError("RealPulseBackend has not been reset yet.")
payload = self._latest_observation.model_dump()
state_payload = {
field_name: payload[field_name]
for field_name in PatientState.model_fields
if field_name in payload
}
return PatientState(**state_payload)
def close(self) -> None:
"""Close the underlying environment when it exposes a close hook."""
close_method = getattr(self._environment, "close", None)
if callable(close_method):
close_method()
@staticmethod
def _resolve_real_scenario_id(scenario_id: str) -> str:
"""Map consumer-side mock aliases onto the nearest real runtime scenarios.
The Person 2 pipeline historically used mock scenario names such as
``baseline_stable``. The real runtime exposes a different scenario ID
set, so the wrapper translates only these known aliases and otherwise
passes the provided value through untouched.
"""
return _REAL_SCENARIO_ALIASES.get(scenario_id, scenario_id)
@staticmethod
def _build_environment(environment_factory: Callable[[], object] | None) -> object:
"""Instantiate the real environment lazily to avoid import-time runtime coupling."""
if environment_factory is not None:
return environment_factory()
try:
from .server.pulse_physiology_env_environment import PulsePhysiologyEnvironment
except Exception as exc: # pragma: no cover - depends on local Pulse/OpenEnv runtime
raise RuntimeError(
"Could not import PulsePhysiologyEnvironment. The real backend currently requires "
"the Pulse/OpenEnv runtime stack and a working Python 3.12 installation."
) from exc
return PulsePhysiologyEnvironment()
def _wrap_observation(self, observation: PulsePhysiologyObservation) -> EnvironmentResponse:
"""Normalize a real observation into the standard ``EnvironmentResponse`` envelope."""
if not isinstance(observation, PulsePhysiologyObservation):
observation = PulsePhysiologyObservation.model_validate(observation)
metadata_dict = dict(observation.metadata or {})
raw_available_tools = list(
observation.available_tools
or metadata_dict.get("available_tools")
or []
)
# The real runtime may expose a richer clinical tool catalog than the
# consumer-side runner currently understands. Filter the public
# available-tools list down to the frozen contract so Person 2's stack
# can fail closed on truly unknown tools without rejecting the entire
# episode. Preserve the full runtime tool list in metadata for
# debugging and future upgrades.
available_tools = [
tool_name
for tool_name in raw_available_tools
if tool_name in KNOWN_TOOL_NAMES
]
metadata_dict["raw_available_tools"] = raw_available_tools
metadata_dict["available_tools"] = available_tools
metadata = ObservationMetadata(
step_count=int(metadata_dict.get("step_count", 0)),
available_tools=available_tools,
)
wrapped_observation = observation.model_copy(
update={
"available_tools": available_tools,
"metadata": metadata_dict,
}
)
self._latest_observation = wrapped_observation
return EnvironmentResponse(
observation=wrapped_observation,
reward=float(wrapped_observation.reward or 0.0),
done=wrapped_observation.done,
metadata=metadata,
tool_result=wrapped_observation.tool_result,
error=wrapped_observation.error,
)