Pulse_ER_env / generate_before_after_mock_plot.py
KChad's picture
Add all docs_assets image assets to Hugging Face Space snapshot
9b1756a
"""Generate a submission-friendly before/after mock policy comparison plot.
This script is intentionally honest about what "before" means:
- before: a weak random tool-calling baseline on the selected mock scenario
- after: the trained/expert policy on the same scenario
It also includes the verified mock benchmark scores for context so the figure
is useful both as a quick demo asset and as a reproducible benchmark artifact.
"""
from __future__ import annotations
import argparse
import textwrap
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pulse_physiology_env.demo_llm_policy import heuristic_infer_fn
from pulse_physiology_env.episode_runner import EpisodeRunner
from pulse_physiology_env.eval_mock import score_policy, score_random_policy
from pulse_physiology_env.policies import LLMPolicy, RandomPolicy, build_expert_policy, build_no_action_policy
from pulse_physiology_env.server.adapters import MockPulseAdapter
from pulse_physiology_env.server.mock_scenarios import MOCK_SCENARIOS
def _run_trace(policy, *, scenario_id: str):
backend = MockPulseAdapter(default_scenario_id=scenario_id, seed=0)
runner = EpisodeRunner(backend=backend, max_steps=8)
try:
return runner.run(policy=policy, scenario_id=scenario_id)
finally:
close_method = getattr(backend, "close", None)
if callable(close_method):
close_method()
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--scenario", default="respiratory_distress", choices=sorted(MOCK_SCENARIOS))
parser.add_argument(
"--output",
default="artifacts/mock_before_after_comparison.png",
help="Output PNG path.",
)
args = parser.parse_args()
expert_scores = score_policy(lambda scenario_id: build_expert_policy(), "expert")
llm_scores = score_policy(
lambda scenario_id: LLMPolicy(infer_fn=heuristic_infer_fn, name="llm_demo"),
"llm_demo",
)
random_scores = score_random_policy()
no_action_scores = score_policy(lambda scenario_id: build_no_action_policy(), "no_action")
scenario_id = args.scenario
expert_trace = _run_trace(build_expert_policy(), scenario_id=scenario_id)
random_trace = _run_trace(RandomPolicy(seed=0), scenario_id=scenario_id)
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
fig = plt.figure(figsize=(14, 9), facecolor="white")
gs = fig.add_gridspec(2, 2, height_ratios=[2.2, 1.2], hspace=0.35, wspace=0.25)
ax_bar = fig.add_subplot(gs[0, 0])
labels = ["Expert", "LLM Demo", "Random", "No Action"]
values = [
expert_scores.per_scenario[scenario_id],
llm_scores.per_scenario[scenario_id],
random_scores.per_scenario[scenario_id],
no_action_scores.per_scenario[scenario_id],
]
colors = ["#1f9d73", "#4c84c4", "#d68c2f", "#cc4b5a"]
bars = ax_bar.bar(labels, values, color=colors, edgecolor="#1d2736", linewidth=1.2)
ax_bar.axhline(0.0, color="#364152", linewidth=1.1)
ax_bar.set_title(f"Scenario Reward Comparison: {scenario_id.replace('_', ' ').title()}", fontsize=15, pad=12)
ax_bar.set_ylabel("Episode reward")
ax_bar.grid(axis="y", alpha=0.18)
for bar, value in zip(bars, values, strict=True):
va = "bottom" if value >= 0 else "top"
offset = 0.18 if value >= 0 else -0.18
ax_bar.text(
bar.get_x() + bar.get_width() / 2,
value + offset,
f"{value:+.3f}",
ha="center",
va=va,
fontsize=11,
fontweight="bold",
)
ax_trace = fig.add_subplot(gs[0, 1])
expert_steps = [0] + [step.step_index + 1 for step in expert_trace.steps]
expert_spo2 = [expert_trace.initial_observation.spo2 * 100] + [step.observation.spo2 * 100 for step in expert_trace.steps]
random_steps = [0] + [step.step_index + 1 for step in random_trace.steps]
random_spo2 = [random_trace.initial_observation.spo2 * 100] + [step.observation.spo2 * 100 for step in random_trace.steps]
ax_trace.plot(expert_steps, expert_spo2, marker="o", linewidth=2.6, color="#1f9d73", label="Expert / after")
ax_trace.plot(random_steps, random_spo2, marker="o", linewidth=2.2, color="#d68c2f", label="Random seed=0 / before")
ax_trace.set_title("SpO2 Trajectory on the Same Scenario", fontsize=15, pad=12)
ax_trace.set_xlabel("Episode step")
ax_trace.set_ylabel("SpO2 (%)")
ax_trace.grid(alpha=0.18)
ax_trace.legend(frameon=False, loc="lower right")
ax_text_left = fig.add_subplot(gs[1, 0])
ax_text_left.axis("off")
ax_text_left.text(
0,
1,
"Before / after framing",
fontsize=14,
fontweight="bold",
va="top",
)
ax_text_left.text(
0,
0.82,
textwrap.fill(
f"Before: random tool-calling baseline on {scenario_id} "
f"finishes at {random_trace.total_reward:+.3f} reward.\n"
f"After: expert policy on the same scenario finishes at {expert_trace.total_reward:+.3f} reward.\n\n"
"This is a better submission artifact than reusing the same training curve twice, "
"because it shows policy quality separation on an actual episode.",
width=55,
),
fontsize=11.5,
color="#334155",
va="top",
linespacing=1.5,
wrap=True,
)
ax_text_right = fig.add_subplot(gs[1, 1])
ax_text_right.axis("off")
ax_text_right.text(
0,
1,
"Verified mock benchmark ordering",
fontsize=14,
fontweight="bold",
va="top",
)
ax_text_right.text(
0,
0.82,
textwrap.fill(
f"Expert average: {expert_scores.average_reward:+.3f}\n"
f"LLM demo average: {llm_scores.average_reward:+.3f}\n"
f"Random average: {random_scores.average_reward:+.3f}\n"
f"No-action average: {no_action_scores.average_reward:+.3f}",
width=32,
),
fontsize=11.5,
color="#334155",
va="top",
linespacing=1.6,
wrap=True,
)
fig.suptitle(
"Pulse-ER Mock Before vs After\nRandom Tool Calls vs Trained Policy",
fontsize=20,
fontweight="bold",
y=0.98,
)
fig.text(
0.5,
0.03,
"Before = random tool-calling baseline (seed 0). After = expert/trained policy. All values are generated from the current repo.",
ha="center",
fontsize=10.5,
color="#475569",
)
fig.savefig(output_path, dpi=180, bbox_inches="tight")
print(output_path.resolve())
if __name__ == "__main__":
main()