Spaces:
Sleeping
Sleeping
| """Evaluation harness for comparing mock policies. | |
| Run from the repo root with: | |
| python -m pulse_physiology_env.eval_mock | |
| """ | |
| from __future__ import annotations | |
| import warnings | |
| from dataclasses import dataclass | |
| from pulse_physiology_env.demo_llm_policy import heuristic_infer_fn | |
| from pulse_physiology_env.episode_runner import EpisodeRunner, EpisodeTerminationReason | |
| from pulse_physiology_env.gym_env import PulseGymEnv | |
| from pulse_physiology_env.models import EnvironmentResponse, ObservationMetadata, PulsePhysiologyObservation, ToolError, ToolResult | |
| from pulse_physiology_env.policies import ( | |
| LLMPolicy, | |
| RandomPolicy, | |
| action, | |
| build_expert_policy, | |
| build_no_action_policy, | |
| ) | |
| from pulse_physiology_env.prompt_builder import build_policy_prompt | |
| from pulse_physiology_env.tool_availability import ToolAvailabilityError | |
| from pulse_physiology_env.tool_parser import ParseError, ParseWarning, parse_tool_action, parse_with_fallback | |
| from pulse_physiology_env.server.adapters import MockPulseAdapter | |
| from pulse_physiology_env.server.mock_scenarios import MOCK_SCENARIOS | |
| from pulse_physiology_env.patient_state import PatientState | |
| class PolicyScore: | |
| """Summary score for one policy over all scenarios.""" | |
| policy_name: str | |
| per_scenario: dict[str, float] | |
| def average_reward(self) -> float: | |
| return round(sum(self.per_scenario.values()) / len(self.per_scenario), 3) | |
| def _make_observation(*, available_tools: list[str], **overrides) -> PulsePhysiologyObservation: | |
| """Create a compact observation fixture for consumer-side regression checks.""" | |
| payload = dict( | |
| scenario_id="baseline_stable", | |
| patient_id="regression_patient", | |
| sim_time_s=0.0, | |
| heart_rate_bpm=72.0, | |
| systolic_bp_mmhg=118.0, | |
| diastolic_bp_mmhg=76.0, | |
| spo2=0.98, | |
| respiration_rate_bpm=14.0, | |
| blood_volume_ml=5500.0, | |
| available_tools=available_tools, | |
| ) | |
| payload.update(overrides) | |
| return PulsePhysiologyObservation(**payload) | |
| def _make_response( | |
| observation: PulsePhysiologyObservation, | |
| *, | |
| reward: float, | |
| error: ToolError | None = None, | |
| tool_result: ToolResult | None = None, | |
| ) -> EnvironmentResponse: | |
| """Wrap a fixture observation in the standard environment response envelope.""" | |
| return EnvironmentResponse( | |
| observation=observation, | |
| reward=reward, | |
| done=observation.done, | |
| metadata=ObservationMetadata(step_count=0, available_tools=list(observation.available_tools)), | |
| tool_result=tool_result, | |
| error=error, | |
| ) | |
| class _FixedActionPolicy: | |
| """Deterministic policy used to probe episode-runner edge cases.""" | |
| name = "fixed_action" | |
| def __init__(self, tool_name: str = "get_vitals") -> None: | |
| self._action = action(tool_name) | |
| def reset(self, scenario_id: str) -> None: | |
| return None | |
| def select_action(self, observation: PulsePhysiologyObservation): | |
| return self._action | |
| def observe_outcome(self, action, result) -> None: | |
| return None | |
| class _TransientFailureBackend: | |
| """Backend stub that succeeds only after transient retryable failures.""" | |
| def __init__(self) -> None: | |
| self.attempts = 0 | |
| self._state = PatientState( | |
| scenario_id="baseline_stable", | |
| patient_id="retry_backend", | |
| heart_rate_bpm=72.0, | |
| systolic_bp_mmhg=118.0, | |
| diastolic_bp_mmhg=76.0, | |
| spo2=0.98, | |
| respiration_rate_bpm=14.0, | |
| blood_volume_ml=5500.0, | |
| ) | |
| def reset(self, scenario_id: str | None = None) -> EnvironmentResponse: | |
| return _make_response(_make_observation(available_tools=["get_vitals", "advance_time"]), reward=0.0) | |
| def step(self, tool_action) -> EnvironmentResponse: | |
| self.attempts += 1 | |
| observation = _make_observation(available_tools=["get_vitals", "advance_time"]) | |
| if self.attempts < 3: | |
| return _make_response( | |
| observation, | |
| reward=-1.0, | |
| error=ToolError( | |
| code="TEMPORARY_BUSY", | |
| message="Transient backend failure.", | |
| retryable=True, | |
| ), | |
| tool_result=ToolResult( | |
| tool_name=tool_action.tool_name, | |
| success=False, | |
| message="Transient backend failure.", | |
| state_changed=False, | |
| changed_fields=[], | |
| ), | |
| ) | |
| return _make_response( | |
| observation, | |
| reward=0.25, | |
| tool_result=ToolResult( | |
| tool_name=tool_action.tool_name, | |
| success=True, | |
| message="Recovered after retries.", | |
| state_changed=False, | |
| changed_fields=[], | |
| ), | |
| ) | |
| def get_state(self) -> PatientState: | |
| return self._state | |
| class _MissingToolsBackend: | |
| """Backend stub that violates the available-tools contract on reset.""" | |
| def __init__(self) -> None: | |
| self._state = PatientState( | |
| scenario_id="baseline_stable", | |
| patient_id="missing_tools_backend", | |
| heart_rate_bpm=72.0, | |
| systolic_bp_mmhg=118.0, | |
| diastolic_bp_mmhg=76.0, | |
| spo2=0.98, | |
| respiration_rate_bpm=14.0, | |
| blood_volume_ml=5500.0, | |
| ) | |
| def reset(self, scenario_id: str | None = None) -> EnvironmentResponse: | |
| return _make_response(_make_observation(available_tools=[]), reward=0.0) | |
| def step(self, tool_action) -> EnvironmentResponse: | |
| raise RuntimeError("step() should not be called when available_tools is missing.") | |
| def get_state(self) -> PatientState: | |
| return self._state | |
| class _TerminalRetryBackend: | |
| """Backend stub that surfaces cardiac arrest with a retryable error once.""" | |
| def __init__(self) -> None: | |
| self.attempts = 0 | |
| self._state = PatientState( | |
| scenario_id="baseline_stable", | |
| patient_id="terminal_retry_backend", | |
| heart_rate_bpm=20.0, | |
| systolic_bp_mmhg=40.0, | |
| diastolic_bp_mmhg=20.0, | |
| spo2=0.7, | |
| respiration_rate_bpm=4.0, | |
| blood_volume_ml=3000.0, | |
| mental_status="unresponsive", | |
| active_alerts=["cardiac_arrest"], | |
| ) | |
| def reset(self, scenario_id: str | None = None) -> EnvironmentResponse: | |
| observation = PulsePhysiologyObservation.from_patient_state( | |
| self._state.model_copy(update={"active_alerts": [], "done": False}), | |
| reward=0.0, | |
| available_tools=["advance_time"], | |
| metadata={"step_count": 0, "available_tools": ["advance_time"]}, | |
| ) | |
| return _make_response(observation, reward=0.0) | |
| def step(self, tool_action) -> EnvironmentResponse: | |
| self.attempts += 1 | |
| observation = PulsePhysiologyObservation.from_patient_state( | |
| self._state.model_copy(update={"done": False}), | |
| reward=-9.0, | |
| available_tools=["advance_time"], | |
| error=ToolError( | |
| code="ENGINE_ERROR", | |
| message="Cardiac arrest prevents further simulation advance.", | |
| retryable=True, | |
| ), | |
| tool_result=ToolResult( | |
| tool_name=tool_action.tool_name, | |
| success=False, | |
| message="Terminal arrest reached.", | |
| state_changed=True, | |
| changed_fields=["active_alerts"], | |
| ), | |
| metadata={"step_count": 1, "available_tools": ["advance_time"]}, | |
| ) | |
| return _make_response( | |
| observation, | |
| reward=-9.0, | |
| error=observation.error, | |
| tool_result=observation.tool_result, | |
| ) | |
| def get_state(self) -> PatientState: | |
| return self._state | |
| def _regression_check_retryable_runner_behavior() -> None: | |
| """Ensure retryable backend failures do not terminate the episode immediately.""" | |
| runner = EpisodeRunner( | |
| backend=_TransientFailureBackend(), | |
| max_steps=1, | |
| max_retry_attempts=3, | |
| retry_backoff_s=0.0, | |
| ) | |
| trace = runner.run(policy=_FixedActionPolicy(), scenario_id="baseline_stable") | |
| if trace.termination_reason != EpisodeTerminationReason.MAX_TIMESTEPS: | |
| raise SystemExit( | |
| "Episode runner regression: retryable failures should recover without fatal termination." | |
| ) | |
| if trace.num_steps != 1 or trace.steps[0].error is not None: | |
| raise SystemExit( | |
| "Episode runner regression: the recovered step should be recorded as a successful action." | |
| ) | |
| if "Retryable error on get_vitals attempt 1/3" not in "\n".join(trace.events): | |
| raise SystemExit("Episode runner regression: retryable attempts were not logged.") | |
| def _regression_check_available_tools_fail_closed() -> None: | |
| """Ensure missing available tools stop policy execution instead of exposing the full catalog.""" | |
| observation = _make_observation(available_tools=[]) | |
| try: | |
| build_policy_prompt(observation) | |
| except ToolAvailabilityError: | |
| pass | |
| else: # pragma: no cover - exercised by CLI validation | |
| raise SystemExit("Prompt builder regression: missing available_tools should fail closed.") | |
| runner = EpisodeRunner(backend=_MissingToolsBackend(), max_steps=1, retry_backoff_s=0.0) | |
| trace = runner.run(policy=_FixedActionPolicy(), scenario_id="baseline_stable") | |
| if trace.termination_reason != EpisodeTerminationReason.FATAL_BACKEND_ERROR: | |
| raise SystemExit( | |
| "Episode runner regression: missing available_tools must terminate as a fatal backend error." | |
| ) | |
| if trace.num_steps != 0: | |
| raise SystemExit("Episode runner regression: missing available_tools should terminate before any step runs.") | |
| def _regression_check_terminal_retry_short_circuit() -> None: | |
| """Ensure cardiac-arrest observations stop retry loops immediately.""" | |
| backend = _TerminalRetryBackend() | |
| runner = EpisodeRunner( | |
| backend=backend, | |
| max_steps=4, | |
| max_retry_attempts=3, | |
| retry_backoff_s=0.0, | |
| ) | |
| trace = runner.run(policy=_FixedActionPolicy("advance_time"), scenario_id="baseline_stable") | |
| if backend.attempts != 1: | |
| raise SystemExit("Episode runner regression: terminal observations should short-circuit retry loops.") | |
| if trace.termination_reason != EpisodeTerminationReason.PATIENT_DEATH: | |
| raise SystemExit("Episode runner regression: terminal observation should map to patient_death.") | |
| if not any("Terminal physiology detected" in event for event in trace.events): | |
| raise SystemExit("Episode runner regression: terminal short-circuit should be logged.") | |
| def _regression_check_tool_parser_hierarchy() -> None: | |
| """Ensure parser extraction prefers schema-aware JSON candidates over greedy braces.""" | |
| fenced_payload = parse_with_fallback( | |
| 'Narration first\n```json\n{"tool_name":"get_vitals","arguments":{},"reasoning":"check now"}\n```' | |
| ) | |
| if fenced_payload["tool_name"] != "get_vitals": | |
| raise SystemExit("Tool parser regression: fenced JSON extraction failed.") | |
| action_payload = parse_tool_action( | |
| 'Model says: {"tool_name":"give_oxygen","arguments":{"flow_lpm":15},"reasoning":"low saturation"} ' | |
| 'extra trailing text {"noise":true}' | |
| ) | |
| if action_payload.tool_name != "give_oxygen" or action_payload.arguments.get("flow_lpm") != 15.0: | |
| raise SystemExit("Tool parser regression: schema-aware extraction failed before greedy fallback.") | |
| with warnings.catch_warnings(record=True) as caught: | |
| warnings.simplefilter("always", ParseWarning) | |
| try: | |
| parse_with_fallback('Prefix only {"note": 1}', log_warnings=True) | |
| except ParseError as exc: | |
| if "Raw output:" not in str(exc): | |
| raise SystemExit("Tool parser regression: parse errors must include the raw model output preview.") | |
| else: # pragma: no cover - exercised by CLI validation | |
| raise SystemExit("Tool parser regression: invalid fallback payload should raise ParseError.") | |
| if not any(isinstance(warning.message, ParseWarning) for warning in caught): | |
| raise SystemExit("Tool parser regression: fallback parsing must emit ParseWarning.") | |
| def _regression_check_real_expert_heuristics() -> None: | |
| """Ensure the adaptive real expert prioritizes decompression and targeted hemorrhage control.""" | |
| available_tools = [ | |
| "get_vitals", | |
| "check_deterioration", | |
| "needle_decompression", | |
| "control_bleeding", | |
| "give_fluids", | |
| "give_pressor", | |
| "give_oxygen", | |
| "advance_time", | |
| ] | |
| observation = _make_observation( | |
| available_tools=available_tools, | |
| scenario_id="polytrauma_demo", | |
| breath_sounds="absent left", | |
| spo2=0.89, | |
| systolic_bp_mmhg=82.0, | |
| diastolic_bp_mmhg=54.0, | |
| mean_arterial_pressure_mmhg=63.0, | |
| active_alerts=["active_hemorrhage", "shock_index_elevated"], | |
| active_hemorrhages={"right_leg": 140.0, "spleen": 45.0}, | |
| ) | |
| expert = build_expert_policy() | |
| expert.reset("polytrauma_demo") | |
| first_action = expert.select_action(observation) | |
| if first_action.tool_name != "get_vitals": | |
| raise SystemExit("Expert regression: real trauma policy should start with get_vitals.") | |
| expert.observe_outcome(first_action, _make_response(observation, reward=0.0)) | |
| second_action = expert.select_action(observation) | |
| if second_action.tool_name != "needle_decompression" or second_action.arguments.get("side") != "left": | |
| raise SystemExit("Expert regression: real trauma policy should prioritize left needle decompression.") | |
| observation_after_needle = observation.model_copy(update={"breath_sounds": "present bilateral"}) | |
| expert.observe_outcome(second_action, _make_response(observation_after_needle, reward=0.0)) | |
| third_action = expert.select_action(observation_after_needle) | |
| if third_action.tool_name != "control_bleeding" or third_action.arguments.get("site") != "right_leg": | |
| raise SystemExit("Expert regression: real trauma policy should target the highest-flow hemorrhage site.") | |
| def _regression_check_gym_wrapper() -> None: | |
| """Validate the training-facing Gym wrapper against the mock backend.""" | |
| env = PulseGymEnv( | |
| backend_name="mock", | |
| scenario_id="respiratory_distress", | |
| max_episode_steps=4, | |
| seed=0, | |
| ) | |
| try: | |
| observation, info = env.reset(seed=0) | |
| if len(observation) != env.observation_space.shape[0]: | |
| raise SystemExit("Gym wrapper regression failed: reset feature length does not match observation_space.") | |
| if len(info["action_mask"]) != env.action_space.n: | |
| raise SystemExit("Gym wrapper regression failed: action mask length does not match action space.") | |
| get_vitals_index = env.tool_names.index("get_vitals") | |
| _, reward, terminated, truncated, step_info = env.step(get_vitals_index) | |
| if not isinstance(reward, float): | |
| raise SystemExit("Gym wrapper regression failed: reward should be a float.") | |
| if terminated or truncated: | |
| raise SystemExit("Gym wrapper regression failed: get_vitals should not end the episode.") | |
| if step_info["tool_name"] != "get_vitals": | |
| raise SystemExit("Gym wrapper regression failed: expected get_vitals in step info.") | |
| # Build a deterministic masked-action fixture so invalid-action behavior | |
| # is tested even when a backend exposes a broad tool set. | |
| env._current_observation = env._current_observation.model_copy( # type: ignore[attr-defined] | |
| update={"available_tools": ["get_vitals"]} | |
| ) | |
| invalid_index = env.tool_names.index("give_pressor") | |
| _, invalid_reward, invalid_terminated, invalid_truncated, invalid_info = env.step(invalid_index) | |
| if not invalid_info["invalid_action"]: | |
| raise SystemExit("Gym wrapper regression failed: unavailable action should be marked invalid.") | |
| if invalid_reward >= 0.0: | |
| raise SystemExit("Gym wrapper regression failed: invalid action should incur a penalty.") | |
| if invalid_terminated or invalid_truncated: | |
| raise SystemExit("Gym wrapper regression failed: invalid masked action should not immediately end the episode.") | |
| finally: | |
| env.close() | |
| def score_policy(policy_factory, policy_name: str) -> PolicyScore: | |
| """Evaluate one policy factory across all mock scenarios.""" | |
| per_scenario: dict[str, float] = {} | |
| for scenario_id in MOCK_SCENARIOS: | |
| backend = MockPulseAdapter(default_scenario_id=scenario_id) | |
| runner = EpisodeRunner(backend=backend, max_steps=8) | |
| policy = policy_factory(scenario_id) | |
| trace = runner.run(policy=policy, scenario_id=scenario_id) | |
| per_scenario[scenario_id] = trace.total_reward | |
| return PolicyScore(policy_name=policy_name, per_scenario=per_scenario) | |
| def score_random_policy(num_seeds: int = 12) -> PolicyScore: | |
| """Evaluate the mean reward of seeded random policies.""" | |
| per_scenario: dict[str, float] = {} | |
| for scenario_id in MOCK_SCENARIOS: | |
| rewards = [] | |
| for seed in range(num_seeds): | |
| backend = MockPulseAdapter(default_scenario_id=scenario_id) | |
| runner = EpisodeRunner(backend=backend, max_steps=8) | |
| policy = RandomPolicy(seed=seed) | |
| trace = runner.run(policy=policy, scenario_id=scenario_id) | |
| rewards.append(trace.total_reward) | |
| per_scenario[scenario_id] = round(sum(rewards) / len(rewards), 3) | |
| return PolicyScore(policy_name="random", per_scenario=per_scenario) | |
| def print_policy_score(score: PolicyScore) -> None: | |
| """Pretty-print one policy summary.""" | |
| print(f"{score.policy_name} policy") | |
| for scenario_id, reward in score.per_scenario.items(): | |
| print(f" {scenario_id}: {reward:.3f}") | |
| print(f" average: {score.average_reward:.3f}") | |
| def main() -> None: | |
| """Compare expert, random, and no-action baselines.""" | |
| print("Consumer regression checks\n") | |
| _regression_check_retryable_runner_behavior() | |
| print("PASS retryable runner handling") | |
| _regression_check_available_tools_fail_closed() | |
| print("PASS available_tools fail-closed handling\n") | |
| _regression_check_terminal_retry_short_circuit() | |
| print("PASS terminal retry short-circuit\n") | |
| _regression_check_tool_parser_hierarchy() | |
| print("PASS tool parser hierarchy\n") | |
| _regression_check_real_expert_heuristics() | |
| print("PASS real expert heuristics\n") | |
| _regression_check_gym_wrapper() | |
| print("PASS gym wrapper reset/step surfaces\n") | |
| expert = score_policy(lambda scenario_id: build_expert_policy(), "expert") | |
| llm_demo = score_policy( | |
| lambda scenario_id: LLMPolicy(infer_fn=heuristic_infer_fn, name="llm_demo"), | |
| "llm_demo", | |
| ) | |
| random_policy = score_random_policy() | |
| no_action = score_policy(lambda scenario_id: build_no_action_policy(), "no_action") | |
| print("Mock policy evaluation\n") | |
| print_policy_score(expert) | |
| print() | |
| print_policy_score(llm_demo) | |
| print() | |
| print_policy_score(random_policy) | |
| print() | |
| print_policy_score(no_action) | |
| print() | |
| if not ( | |
| expert.average_reward > llm_demo.average_reward > random_policy.average_reward > no_action.average_reward | |
| ): | |
| raise SystemExit( | |
| "Policy ranking check failed: expected expert > llm_demo > random > no_action on average." | |
| ) | |
| print("PASS expert > llm_demo > random > no_action") | |
| if __name__ == "__main__": | |
| main() | |