File size: 6,762 Bytes
9b1756a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""Generate a submission-friendly before/after mock policy comparison plot.

This script is intentionally honest about what "before" means:

- before: a weak random tool-calling baseline on the selected mock scenario
- after: the trained/expert policy on the same scenario

It also includes the verified mock benchmark scores for context so the figure
is useful both as a quick demo asset and as a reproducible benchmark artifact.
"""

from __future__ import annotations

import argparse
import textwrap
from pathlib import Path

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

from pulse_physiology_env.demo_llm_policy import heuristic_infer_fn
from pulse_physiology_env.episode_runner import EpisodeRunner
from pulse_physiology_env.eval_mock import score_policy, score_random_policy
from pulse_physiology_env.policies import LLMPolicy, RandomPolicy, build_expert_policy, build_no_action_policy
from pulse_physiology_env.server.adapters import MockPulseAdapter
from pulse_physiology_env.server.mock_scenarios import MOCK_SCENARIOS


def _run_trace(policy, *, scenario_id: str):
    backend = MockPulseAdapter(default_scenario_id=scenario_id, seed=0)
    runner = EpisodeRunner(backend=backend, max_steps=8)
    try:
        return runner.run(policy=policy, scenario_id=scenario_id)
    finally:
        close_method = getattr(backend, "close", None)
        if callable(close_method):
            close_method()


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--scenario", default="respiratory_distress", choices=sorted(MOCK_SCENARIOS))
    parser.add_argument(
        "--output",
        default="artifacts/mock_before_after_comparison.png",
        help="Output PNG path.",
    )
    args = parser.parse_args()

    expert_scores = score_policy(lambda scenario_id: build_expert_policy(), "expert")
    llm_scores = score_policy(
        lambda scenario_id: LLMPolicy(infer_fn=heuristic_infer_fn, name="llm_demo"),
        "llm_demo",
    )
    random_scores = score_random_policy()
    no_action_scores = score_policy(lambda scenario_id: build_no_action_policy(), "no_action")

    scenario_id = args.scenario
    expert_trace = _run_trace(build_expert_policy(), scenario_id=scenario_id)
    random_trace = _run_trace(RandomPolicy(seed=0), scenario_id=scenario_id)

    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    fig = plt.figure(figsize=(14, 9), facecolor="white")
    gs = fig.add_gridspec(2, 2, height_ratios=[2.2, 1.2], hspace=0.35, wspace=0.25)

    ax_bar = fig.add_subplot(gs[0, 0])
    labels = ["Expert", "LLM Demo", "Random", "No Action"]
    values = [
        expert_scores.per_scenario[scenario_id],
        llm_scores.per_scenario[scenario_id],
        random_scores.per_scenario[scenario_id],
        no_action_scores.per_scenario[scenario_id],
    ]
    colors = ["#1f9d73", "#4c84c4", "#d68c2f", "#cc4b5a"]
    bars = ax_bar.bar(labels, values, color=colors, edgecolor="#1d2736", linewidth=1.2)
    ax_bar.axhline(0.0, color="#364152", linewidth=1.1)
    ax_bar.set_title(f"Scenario Reward Comparison: {scenario_id.replace('_', ' ').title()}", fontsize=15, pad=12)
    ax_bar.set_ylabel("Episode reward")
    ax_bar.grid(axis="y", alpha=0.18)
    for bar, value in zip(bars, values, strict=True):
        va = "bottom" if value >= 0 else "top"
        offset = 0.18 if value >= 0 else -0.18
        ax_bar.text(
            bar.get_x() + bar.get_width() / 2,
            value + offset,
            f"{value:+.3f}",
            ha="center",
            va=va,
            fontsize=11,
            fontweight="bold",
        )

    ax_trace = fig.add_subplot(gs[0, 1])
    expert_steps = [0] + [step.step_index + 1 for step in expert_trace.steps]
    expert_spo2 = [expert_trace.initial_observation.spo2 * 100] + [step.observation.spo2 * 100 for step in expert_trace.steps]
    random_steps = [0] + [step.step_index + 1 for step in random_trace.steps]
    random_spo2 = [random_trace.initial_observation.spo2 * 100] + [step.observation.spo2 * 100 for step in random_trace.steps]
    ax_trace.plot(expert_steps, expert_spo2, marker="o", linewidth=2.6, color="#1f9d73", label="Expert / after")
    ax_trace.plot(random_steps, random_spo2, marker="o", linewidth=2.2, color="#d68c2f", label="Random seed=0 / before")
    ax_trace.set_title("SpO2 Trajectory on the Same Scenario", fontsize=15, pad=12)
    ax_trace.set_xlabel("Episode step")
    ax_trace.set_ylabel("SpO2 (%)")
    ax_trace.grid(alpha=0.18)
    ax_trace.legend(frameon=False, loc="lower right")

    ax_text_left = fig.add_subplot(gs[1, 0])
    ax_text_left.axis("off")
    ax_text_left.text(
        0,
        1,
        "Before / after framing",
        fontsize=14,
        fontweight="bold",
        va="top",
    )
    ax_text_left.text(
        0,
        0.82,
        textwrap.fill(
            f"Before: random tool-calling baseline on {scenario_id} "
            f"finishes at {random_trace.total_reward:+.3f} reward.\n"
            f"After: expert policy on the same scenario finishes at {expert_trace.total_reward:+.3f} reward.\n\n"
            "This is a better submission artifact than reusing the same training curve twice, "
            "because it shows policy quality separation on an actual episode.",
            width=55,
        ),
        fontsize=11.5,
        color="#334155",
        va="top",
        linespacing=1.5,
        wrap=True,
    )

    ax_text_right = fig.add_subplot(gs[1, 1])
    ax_text_right.axis("off")
    ax_text_right.text(
        0,
        1,
        "Verified mock benchmark ordering",
        fontsize=14,
        fontweight="bold",
        va="top",
    )
    ax_text_right.text(
        0,
        0.82,
        textwrap.fill(
            f"Expert average: {expert_scores.average_reward:+.3f}\n"
            f"LLM demo average: {llm_scores.average_reward:+.3f}\n"
            f"Random average: {random_scores.average_reward:+.3f}\n"
            f"No-action average: {no_action_scores.average_reward:+.3f}",
            width=32,
        ),
        fontsize=11.5,
        color="#334155",
        va="top",
        linespacing=1.6,
        wrap=True,
    )

    fig.suptitle(
        "Pulse-ER Mock Before vs After\nRandom Tool Calls vs Trained Policy",
        fontsize=20,
        fontweight="bold",
        y=0.98,
    )
    fig.text(
        0.5,
        0.03,
        "Before = random tool-calling baseline (seed 0). After = expert/trained policy. All values are generated from the current repo.",
        ha="center",
        fontsize=10.5,
        color="#475569",
    )
    fig.savefig(output_path, dpi=180, bbox_inches="tight")
    print(output_path.resolve())


if __name__ == "__main__":
    main()