Spaces:
Sleeping
Sleeping
| """Consumer-side wrapper that adapts the real OpenEnv runtime to PatientBackend. | |
| The real backend returns ``PulsePhysiologyObservation`` objects directly, while | |
| Person 2's runner and validation stack expect ``EnvironmentResponse`` envelopes. | |
| This thin wrapper preserves Person 1's runtime behavior and reshapes it for the | |
| consumer-side pipeline without touching engine internals. | |
| """ | |
| from __future__ import annotations | |
| from typing import Callable | |
| from .models import EnvironmentResponse, ObservationMetadata, PulsePhysiologyObservation, ToolAction | |
| from .patient_state import PatientState | |
| from .server.adapters import PatientBackend | |
| from .tool_catalog import KNOWN_TOOL_NAMES | |
| _REAL_SCENARIO_ALIASES = { | |
| "baseline_stable": "polytrauma_demo", | |
| "respiratory_distress": "trauma_easy_soldier", | |
| "hemorrhagic_shock": "trauma_hard_underweight", | |
| } | |
| class RealPulseBackend(PatientBackend): | |
| """Adapt ``PulsePhysiologyEnvironment`` to the ``PatientBackend`` interface.""" | |
| def __init__( | |
| self, | |
| default_scenario_id: str | None = None, | |
| *, | |
| environment_factory: Callable[[], object] | None = None, | |
| observation_noise_level: float = 0.0, | |
| time_pressure_enabled: bool = False, | |
| time_pressure_onset_s: float = 180.0, | |
| time_pressure_escalation_per_minute: float = 0.15, | |
| ) -> None: | |
| self._default_scenario_id = default_scenario_id | |
| self._environment = self._build_environment(environment_factory) | |
| self._latest_observation: PulsePhysiologyObservation | None = None | |
| self._reset_defaults = { | |
| "observation_noise_level": observation_noise_level, | |
| "time_pressure_enabled": time_pressure_enabled, | |
| "time_pressure_onset_s": time_pressure_onset_s, | |
| "time_pressure_escalation_per_minute": time_pressure_escalation_per_minute, | |
| } | |
| def reset(self, scenario_id: str | None = None, **kwargs: object) -> EnvironmentResponse: | |
| """Reset the real runtime and wrap its observation in the consumer envelope.""" | |
| selected_scenario_id = scenario_id or self._default_scenario_id | |
| reset_kwargs = { | |
| **self._reset_defaults, | |
| **kwargs, | |
| } | |
| if selected_scenario_id is None: | |
| observation = self._environment.reset(**reset_kwargs) | |
| else: | |
| observation = self._environment.reset( | |
| scenario_id=self._resolve_real_scenario_id(selected_scenario_id), | |
| **reset_kwargs, | |
| ) | |
| return self._wrap_observation(observation) | |
| def step(self, action: ToolAction) -> EnvironmentResponse: | |
| """Execute one action against the real runtime and wrap the response.""" | |
| observation = self._environment.step(action) | |
| return self._wrap_observation(observation) | |
| def get_state(self) -> PatientState: | |
| """Reconstruct a ``PatientState`` view from the latest wrapped observation.""" | |
| if self._latest_observation is None: | |
| raise RuntimeError("RealPulseBackend has not been reset yet.") | |
| payload = self._latest_observation.model_dump() | |
| state_payload = { | |
| field_name: payload[field_name] | |
| for field_name in PatientState.model_fields | |
| if field_name in payload | |
| } | |
| return PatientState(**state_payload) | |
| def close(self) -> None: | |
| """Close the underlying environment when it exposes a close hook.""" | |
| close_method = getattr(self._environment, "close", None) | |
| if callable(close_method): | |
| close_method() | |
| def _resolve_real_scenario_id(scenario_id: str) -> str: | |
| """Map consumer-side mock aliases onto the nearest real runtime scenarios. | |
| The Person 2 pipeline historically used mock scenario names such as | |
| ``baseline_stable``. The real runtime exposes a different scenario ID | |
| set, so the wrapper translates only these known aliases and otherwise | |
| passes the provided value through untouched. | |
| """ | |
| return _REAL_SCENARIO_ALIASES.get(scenario_id, scenario_id) | |
| def _build_environment(environment_factory: Callable[[], object] | None) -> object: | |
| """Instantiate the real environment lazily to avoid import-time runtime coupling.""" | |
| if environment_factory is not None: | |
| return environment_factory() | |
| try: | |
| from .server.pulse_physiology_env_environment import PulsePhysiologyEnvironment | |
| except Exception as exc: # pragma: no cover - depends on local Pulse/OpenEnv runtime | |
| raise RuntimeError( | |
| "Could not import PulsePhysiologyEnvironment. The real backend currently requires " | |
| "the Pulse/OpenEnv runtime stack and a working Python 3.12 installation." | |
| ) from exc | |
| return PulsePhysiologyEnvironment() | |
| def _wrap_observation(self, observation: PulsePhysiologyObservation) -> EnvironmentResponse: | |
| """Normalize a real observation into the standard ``EnvironmentResponse`` envelope.""" | |
| if not isinstance(observation, PulsePhysiologyObservation): | |
| observation = PulsePhysiologyObservation.model_validate(observation) | |
| metadata_dict = dict(observation.metadata or {}) | |
| raw_available_tools = list( | |
| observation.available_tools | |
| or metadata_dict.get("available_tools") | |
| or [] | |
| ) | |
| # The real runtime may expose a richer clinical tool catalog than the | |
| # consumer-side runner currently understands. Filter the public | |
| # available-tools list down to the frozen contract so Person 2's stack | |
| # can fail closed on truly unknown tools without rejecting the entire | |
| # episode. Preserve the full runtime tool list in metadata for | |
| # debugging and future upgrades. | |
| available_tools = [ | |
| tool_name | |
| for tool_name in raw_available_tools | |
| if tool_name in KNOWN_TOOL_NAMES | |
| ] | |
| metadata_dict["raw_available_tools"] = raw_available_tools | |
| metadata_dict["available_tools"] = available_tools | |
| metadata = ObservationMetadata( | |
| step_count=int(metadata_dict.get("step_count", 0)), | |
| available_tools=available_tools, | |
| ) | |
| wrapped_observation = observation.model_copy( | |
| update={ | |
| "available_tools": available_tools, | |
| "metadata": metadata_dict, | |
| } | |
| ) | |
| self._latest_observation = wrapped_observation | |
| return EnvironmentResponse( | |
| observation=wrapped_observation, | |
| reward=float(wrapped_observation.reward or 0.0), | |
| done=wrapped_observation.done, | |
| metadata=metadata, | |
| tool_result=wrapped_observation.tool_result, | |
| error=wrapped_observation.error, | |
| ) | |