Pulse_ER_env / eval_mock.py
KChad's picture
Add all docs_assets image assets to Hugging Face Space snapshot
9b1756a
"""Evaluation harness for comparing mock policies.
Run from the repo root with:
python -m pulse_physiology_env.eval_mock
"""
from __future__ import annotations
import warnings
from dataclasses import dataclass
from pulse_physiology_env.demo_llm_policy import heuristic_infer_fn
from pulse_physiology_env.episode_runner import EpisodeRunner, EpisodeTerminationReason
from pulse_physiology_env.gym_env import PulseGymEnv
from pulse_physiology_env.models import EnvironmentResponse, ObservationMetadata, PulsePhysiologyObservation, ToolError, ToolResult
from pulse_physiology_env.policies import (
LLMPolicy,
RandomPolicy,
action,
build_expert_policy,
build_no_action_policy,
)
from pulse_physiology_env.prompt_builder import build_policy_prompt
from pulse_physiology_env.tool_availability import ToolAvailabilityError
from pulse_physiology_env.tool_parser import ParseError, ParseWarning, parse_tool_action, parse_with_fallback
from pulse_physiology_env.server.adapters import MockPulseAdapter
from pulse_physiology_env.server.mock_scenarios import MOCK_SCENARIOS
from pulse_physiology_env.patient_state import PatientState
@dataclass(frozen=True)
class PolicyScore:
"""Summary score for one policy over all scenarios."""
policy_name: str
per_scenario: dict[str, float]
@property
def average_reward(self) -> float:
return round(sum(self.per_scenario.values()) / len(self.per_scenario), 3)
def _make_observation(*, available_tools: list[str], **overrides) -> PulsePhysiologyObservation:
"""Create a compact observation fixture for consumer-side regression checks."""
payload = dict(
scenario_id="baseline_stable",
patient_id="regression_patient",
sim_time_s=0.0,
heart_rate_bpm=72.0,
systolic_bp_mmhg=118.0,
diastolic_bp_mmhg=76.0,
spo2=0.98,
respiration_rate_bpm=14.0,
blood_volume_ml=5500.0,
available_tools=available_tools,
)
payload.update(overrides)
return PulsePhysiologyObservation(**payload)
def _make_response(
observation: PulsePhysiologyObservation,
*,
reward: float,
error: ToolError | None = None,
tool_result: ToolResult | None = None,
) -> EnvironmentResponse:
"""Wrap a fixture observation in the standard environment response envelope."""
return EnvironmentResponse(
observation=observation,
reward=reward,
done=observation.done,
metadata=ObservationMetadata(step_count=0, available_tools=list(observation.available_tools)),
tool_result=tool_result,
error=error,
)
class _FixedActionPolicy:
"""Deterministic policy used to probe episode-runner edge cases."""
name = "fixed_action"
def __init__(self, tool_name: str = "get_vitals") -> None:
self._action = action(tool_name)
def reset(self, scenario_id: str) -> None:
return None
def select_action(self, observation: PulsePhysiologyObservation):
return self._action
def observe_outcome(self, action, result) -> None:
return None
class _TransientFailureBackend:
"""Backend stub that succeeds only after transient retryable failures."""
def __init__(self) -> None:
self.attempts = 0
self._state = PatientState(
scenario_id="baseline_stable",
patient_id="retry_backend",
heart_rate_bpm=72.0,
systolic_bp_mmhg=118.0,
diastolic_bp_mmhg=76.0,
spo2=0.98,
respiration_rate_bpm=14.0,
blood_volume_ml=5500.0,
)
def reset(self, scenario_id: str | None = None) -> EnvironmentResponse:
return _make_response(_make_observation(available_tools=["get_vitals", "advance_time"]), reward=0.0)
def step(self, tool_action) -> EnvironmentResponse:
self.attempts += 1
observation = _make_observation(available_tools=["get_vitals", "advance_time"])
if self.attempts < 3:
return _make_response(
observation,
reward=-1.0,
error=ToolError(
code="TEMPORARY_BUSY",
message="Transient backend failure.",
retryable=True,
),
tool_result=ToolResult(
tool_name=tool_action.tool_name,
success=False,
message="Transient backend failure.",
state_changed=False,
changed_fields=[],
),
)
return _make_response(
observation,
reward=0.25,
tool_result=ToolResult(
tool_name=tool_action.tool_name,
success=True,
message="Recovered after retries.",
state_changed=False,
changed_fields=[],
),
)
def get_state(self) -> PatientState:
return self._state
class _MissingToolsBackend:
"""Backend stub that violates the available-tools contract on reset."""
def __init__(self) -> None:
self._state = PatientState(
scenario_id="baseline_stable",
patient_id="missing_tools_backend",
heart_rate_bpm=72.0,
systolic_bp_mmhg=118.0,
diastolic_bp_mmhg=76.0,
spo2=0.98,
respiration_rate_bpm=14.0,
blood_volume_ml=5500.0,
)
def reset(self, scenario_id: str | None = None) -> EnvironmentResponse:
return _make_response(_make_observation(available_tools=[]), reward=0.0)
def step(self, tool_action) -> EnvironmentResponse:
raise RuntimeError("step() should not be called when available_tools is missing.")
def get_state(self) -> PatientState:
return self._state
class _TerminalRetryBackend:
"""Backend stub that surfaces cardiac arrest with a retryable error once."""
def __init__(self) -> None:
self.attempts = 0
self._state = PatientState(
scenario_id="baseline_stable",
patient_id="terminal_retry_backend",
heart_rate_bpm=20.0,
systolic_bp_mmhg=40.0,
diastolic_bp_mmhg=20.0,
spo2=0.7,
respiration_rate_bpm=4.0,
blood_volume_ml=3000.0,
mental_status="unresponsive",
active_alerts=["cardiac_arrest"],
)
def reset(self, scenario_id: str | None = None) -> EnvironmentResponse:
observation = PulsePhysiologyObservation.from_patient_state(
self._state.model_copy(update={"active_alerts": [], "done": False}),
reward=0.0,
available_tools=["advance_time"],
metadata={"step_count": 0, "available_tools": ["advance_time"]},
)
return _make_response(observation, reward=0.0)
def step(self, tool_action) -> EnvironmentResponse:
self.attempts += 1
observation = PulsePhysiologyObservation.from_patient_state(
self._state.model_copy(update={"done": False}),
reward=-9.0,
available_tools=["advance_time"],
error=ToolError(
code="ENGINE_ERROR",
message="Cardiac arrest prevents further simulation advance.",
retryable=True,
),
tool_result=ToolResult(
tool_name=tool_action.tool_name,
success=False,
message="Terminal arrest reached.",
state_changed=True,
changed_fields=["active_alerts"],
),
metadata={"step_count": 1, "available_tools": ["advance_time"]},
)
return _make_response(
observation,
reward=-9.0,
error=observation.error,
tool_result=observation.tool_result,
)
def get_state(self) -> PatientState:
return self._state
def _regression_check_retryable_runner_behavior() -> None:
"""Ensure retryable backend failures do not terminate the episode immediately."""
runner = EpisodeRunner(
backend=_TransientFailureBackend(),
max_steps=1,
max_retry_attempts=3,
retry_backoff_s=0.0,
)
trace = runner.run(policy=_FixedActionPolicy(), scenario_id="baseline_stable")
if trace.termination_reason != EpisodeTerminationReason.MAX_TIMESTEPS:
raise SystemExit(
"Episode runner regression: retryable failures should recover without fatal termination."
)
if trace.num_steps != 1 or trace.steps[0].error is not None:
raise SystemExit(
"Episode runner regression: the recovered step should be recorded as a successful action."
)
if "Retryable error on get_vitals attempt 1/3" not in "\n".join(trace.events):
raise SystemExit("Episode runner regression: retryable attempts were not logged.")
def _regression_check_available_tools_fail_closed() -> None:
"""Ensure missing available tools stop policy execution instead of exposing the full catalog."""
observation = _make_observation(available_tools=[])
try:
build_policy_prompt(observation)
except ToolAvailabilityError:
pass
else: # pragma: no cover - exercised by CLI validation
raise SystemExit("Prompt builder regression: missing available_tools should fail closed.")
runner = EpisodeRunner(backend=_MissingToolsBackend(), max_steps=1, retry_backoff_s=0.0)
trace = runner.run(policy=_FixedActionPolicy(), scenario_id="baseline_stable")
if trace.termination_reason != EpisodeTerminationReason.FATAL_BACKEND_ERROR:
raise SystemExit(
"Episode runner regression: missing available_tools must terminate as a fatal backend error."
)
if trace.num_steps != 0:
raise SystemExit("Episode runner regression: missing available_tools should terminate before any step runs.")
def _regression_check_terminal_retry_short_circuit() -> None:
"""Ensure cardiac-arrest observations stop retry loops immediately."""
backend = _TerminalRetryBackend()
runner = EpisodeRunner(
backend=backend,
max_steps=4,
max_retry_attempts=3,
retry_backoff_s=0.0,
)
trace = runner.run(policy=_FixedActionPolicy("advance_time"), scenario_id="baseline_stable")
if backend.attempts != 1:
raise SystemExit("Episode runner regression: terminal observations should short-circuit retry loops.")
if trace.termination_reason != EpisodeTerminationReason.PATIENT_DEATH:
raise SystemExit("Episode runner regression: terminal observation should map to patient_death.")
if not any("Terminal physiology detected" in event for event in trace.events):
raise SystemExit("Episode runner regression: terminal short-circuit should be logged.")
def _regression_check_tool_parser_hierarchy() -> None:
"""Ensure parser extraction prefers schema-aware JSON candidates over greedy braces."""
fenced_payload = parse_with_fallback(
'Narration first\n```json\n{"tool_name":"get_vitals","arguments":{},"reasoning":"check now"}\n```'
)
if fenced_payload["tool_name"] != "get_vitals":
raise SystemExit("Tool parser regression: fenced JSON extraction failed.")
action_payload = parse_tool_action(
'Model says: {"tool_name":"give_oxygen","arguments":{"flow_lpm":15},"reasoning":"low saturation"} '
'extra trailing text {"noise":true}'
)
if action_payload.tool_name != "give_oxygen" or action_payload.arguments.get("flow_lpm") != 15.0:
raise SystemExit("Tool parser regression: schema-aware extraction failed before greedy fallback.")
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always", ParseWarning)
try:
parse_with_fallback('Prefix only {"note": 1}', log_warnings=True)
except ParseError as exc:
if "Raw output:" not in str(exc):
raise SystemExit("Tool parser regression: parse errors must include the raw model output preview.")
else: # pragma: no cover - exercised by CLI validation
raise SystemExit("Tool parser regression: invalid fallback payload should raise ParseError.")
if not any(isinstance(warning.message, ParseWarning) for warning in caught):
raise SystemExit("Tool parser regression: fallback parsing must emit ParseWarning.")
def _regression_check_real_expert_heuristics() -> None:
"""Ensure the adaptive real expert prioritizes decompression and targeted hemorrhage control."""
available_tools = [
"get_vitals",
"check_deterioration",
"needle_decompression",
"control_bleeding",
"give_fluids",
"give_pressor",
"give_oxygen",
"advance_time",
]
observation = _make_observation(
available_tools=available_tools,
scenario_id="polytrauma_demo",
breath_sounds="absent left",
spo2=0.89,
systolic_bp_mmhg=82.0,
diastolic_bp_mmhg=54.0,
mean_arterial_pressure_mmhg=63.0,
active_alerts=["active_hemorrhage", "shock_index_elevated"],
active_hemorrhages={"right_leg": 140.0, "spleen": 45.0},
)
expert = build_expert_policy()
expert.reset("polytrauma_demo")
first_action = expert.select_action(observation)
if first_action.tool_name != "get_vitals":
raise SystemExit("Expert regression: real trauma policy should start with get_vitals.")
expert.observe_outcome(first_action, _make_response(observation, reward=0.0))
second_action = expert.select_action(observation)
if second_action.tool_name != "needle_decompression" or second_action.arguments.get("side") != "left":
raise SystemExit("Expert regression: real trauma policy should prioritize left needle decompression.")
observation_after_needle = observation.model_copy(update={"breath_sounds": "present bilateral"})
expert.observe_outcome(second_action, _make_response(observation_after_needle, reward=0.0))
third_action = expert.select_action(observation_after_needle)
if third_action.tool_name != "control_bleeding" or third_action.arguments.get("site") != "right_leg":
raise SystemExit("Expert regression: real trauma policy should target the highest-flow hemorrhage site.")
def _regression_check_gym_wrapper() -> None:
"""Validate the training-facing Gym wrapper against the mock backend."""
env = PulseGymEnv(
backend_name="mock",
scenario_id="respiratory_distress",
max_episode_steps=4,
seed=0,
)
try:
observation, info = env.reset(seed=0)
if len(observation) != env.observation_space.shape[0]:
raise SystemExit("Gym wrapper regression failed: reset feature length does not match observation_space.")
if len(info["action_mask"]) != env.action_space.n:
raise SystemExit("Gym wrapper regression failed: action mask length does not match action space.")
get_vitals_index = env.tool_names.index("get_vitals")
_, reward, terminated, truncated, step_info = env.step(get_vitals_index)
if not isinstance(reward, float):
raise SystemExit("Gym wrapper regression failed: reward should be a float.")
if terminated or truncated:
raise SystemExit("Gym wrapper regression failed: get_vitals should not end the episode.")
if step_info["tool_name"] != "get_vitals":
raise SystemExit("Gym wrapper regression failed: expected get_vitals in step info.")
# Build a deterministic masked-action fixture so invalid-action behavior
# is tested even when a backend exposes a broad tool set.
env._current_observation = env._current_observation.model_copy( # type: ignore[attr-defined]
update={"available_tools": ["get_vitals"]}
)
invalid_index = env.tool_names.index("give_pressor")
_, invalid_reward, invalid_terminated, invalid_truncated, invalid_info = env.step(invalid_index)
if not invalid_info["invalid_action"]:
raise SystemExit("Gym wrapper regression failed: unavailable action should be marked invalid.")
if invalid_reward >= 0.0:
raise SystemExit("Gym wrapper regression failed: invalid action should incur a penalty.")
if invalid_terminated or invalid_truncated:
raise SystemExit("Gym wrapper regression failed: invalid masked action should not immediately end the episode.")
finally:
env.close()
def score_policy(policy_factory, policy_name: str) -> PolicyScore:
"""Evaluate one policy factory across all mock scenarios."""
per_scenario: dict[str, float] = {}
for scenario_id in MOCK_SCENARIOS:
backend = MockPulseAdapter(default_scenario_id=scenario_id)
runner = EpisodeRunner(backend=backend, max_steps=8)
policy = policy_factory(scenario_id)
trace = runner.run(policy=policy, scenario_id=scenario_id)
per_scenario[scenario_id] = trace.total_reward
return PolicyScore(policy_name=policy_name, per_scenario=per_scenario)
def score_random_policy(num_seeds: int = 12) -> PolicyScore:
"""Evaluate the mean reward of seeded random policies."""
per_scenario: dict[str, float] = {}
for scenario_id in MOCK_SCENARIOS:
rewards = []
for seed in range(num_seeds):
backend = MockPulseAdapter(default_scenario_id=scenario_id)
runner = EpisodeRunner(backend=backend, max_steps=8)
policy = RandomPolicy(seed=seed)
trace = runner.run(policy=policy, scenario_id=scenario_id)
rewards.append(trace.total_reward)
per_scenario[scenario_id] = round(sum(rewards) / len(rewards), 3)
return PolicyScore(policy_name="random", per_scenario=per_scenario)
def print_policy_score(score: PolicyScore) -> None:
"""Pretty-print one policy summary."""
print(f"{score.policy_name} policy")
for scenario_id, reward in score.per_scenario.items():
print(f" {scenario_id}: {reward:.3f}")
print(f" average: {score.average_reward:.3f}")
def main() -> None:
"""Compare expert, random, and no-action baselines."""
print("Consumer regression checks\n")
_regression_check_retryable_runner_behavior()
print("PASS retryable runner handling")
_regression_check_available_tools_fail_closed()
print("PASS available_tools fail-closed handling\n")
_regression_check_terminal_retry_short_circuit()
print("PASS terminal retry short-circuit\n")
_regression_check_tool_parser_hierarchy()
print("PASS tool parser hierarchy\n")
_regression_check_real_expert_heuristics()
print("PASS real expert heuristics\n")
_regression_check_gym_wrapper()
print("PASS gym wrapper reset/step surfaces\n")
expert = score_policy(lambda scenario_id: build_expert_policy(), "expert")
llm_demo = score_policy(
lambda scenario_id: LLMPolicy(infer_fn=heuristic_infer_fn, name="llm_demo"),
"llm_demo",
)
random_policy = score_random_policy()
no_action = score_policy(lambda scenario_id: build_no_action_policy(), "no_action")
print("Mock policy evaluation\n")
print_policy_score(expert)
print()
print_policy_score(llm_demo)
print()
print_policy_score(random_policy)
print()
print_policy_score(no_action)
print()
if not (
expert.average_reward > llm_demo.average_reward > random_policy.average_reward > no_action.average_reward
):
raise SystemExit(
"Policy ranking check failed: expected expert > llm_demo > random > no_action on average."
)
print("PASS expert > llm_demo > random > no_action")
if __name__ == "__main__":
main()