Multi-Agentic / ER_MAP /plotting.py
Uddiii's picture
kaggle: 8B Doctor + train-until-optimal early-stop
9c68ba6
"""
ER_MAP/plotting.py
==================
Per-phase + cross-phase visualization for ER-MAP GRPO training.
Reads the ``training_metrics.json`` file produced by
``ER_MAP/training/train_grpo.py`` and produces one comprehensive
multi-panel PNG per curriculum phase plus a cross-phase comparison
chart.
Output layout under ``output_dir``::
plots/
├── phase1_dashboard.png # 6-panel dashboard for Phase 1
├── phase2_dashboard.png # 6-panel dashboard for Phase 2
├── phase3_dashboard.png # 6-panel dashboard for Phase 3
├── all_phases_overview.png # cross-phase overview (single plot)
└── all_phases_comparison.png # phase summary bar charts
Each phase dashboard packs:
1. Reward growth — raw + rolling mean
2. Win-rate evolution
3. Outcome distribution (stacked bars over episode bins)
4. Reward components (mean per component within the phase)
5. GRPO loss + KL divergence (per group update)
6. Episode length distribution
Usage (from anywhere):
from ER_MAP.plotting import plot_per_phase_dashboards
plot_per_phase_dashboards(
"er_map_grpo_checkpoints/training_metrics.json",
"er_map_grpo_checkpoints/plots",
)
Or from the CLI::
python -m ER_MAP.plotting \\
--metrics er_map_grpo_checkpoints/training_metrics.json \\
--out er_map_grpo_checkpoints/plots
"""
from __future__ import annotations
import json
import os
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence
# Matplotlib import is deferred to plotting time so the module can be
# imported in environments without matplotlib (e.g. for type checking).
# ---------------------------------------------------------------------------
# Color scheme — consistent across every chart for visual continuity
# ---------------------------------------------------------------------------
OUTCOME_COLORS = {
# Trainer-style outcome names (used by ER_MAP/training/train_grpo.py)
"WIN": "#10b981", # emerald
"PARTIAL": "#3b82f6", # blue
"INCORRECT": "#f59e0b", # amber
"AMA_LOSS": "#a855f7", # violet
"FATAL_LOSS": "#ef4444", # red
# Baseline-eval outcome names (used by ER_MAP/evaluate_baseline.py)
"AMA": "#a855f7", # violet — same as AMA_LOSS
"WRONG": "#f59e0b", # amber — same as INCORRECT
"FATAL": "#ef4444", # red — same as FATAL_LOSS
# Raw env-event names (used when the script falls through to the
# default branch — e.g. `terminal_partial` for partial-credit wins).
"terminal_win": "#10b981", # emerald
"terminal_partial": "#3b82f6", # blue (partial credit)
"terminal_incorrect": "#f59e0b", # amber
"terminal_ama": "#a855f7", # violet
"terminal_fatal": "#ef4444", # red
"TRUNCATED": "#9ca3af", # gray
"MAX_STEPS": "#6b7280", # darker gray
"ERROR": "#1f2937", # near-black
"done": "#9ca3af", # gray
"unknown": "#6b7280", # gray
}
OUTCOME_ORDER = [
"WIN", "terminal_win",
"PARTIAL", "terminal_partial",
"INCORRECT", "WRONG", "terminal_incorrect",
"AMA_LOSS", "AMA", "terminal_ama",
"FATAL_LOSS", "FATAL", "terminal_fatal",
"TRUNCATED", "MAX_STEPS", "done", "ERROR", "unknown",
]
# Each reward component gets a distinctive color so the per-phase chart
# matches the colors used in the dashboard's "Live Rewards" panel.
COMPONENT_COLORS = {
"process": "#22d3ee", # cyan
"diagnosis": "#3b82f6", # blue
"plan": "#8b5cf6", # violet
"labs": "#06b6d4", # teal
"treatment": "#10b981", # emerald
"empathy": "#ec4899", # pink
"milestones": "#f59e0b", # amber
"consent": "#84cc16", # lime
"documentation": "#a855f7", # purple
"emergency_id": "#ef4444", # red
"penalties": "#64748b", # slate
}
PHASE_COLOR = {
1: "#06b6d4", # teal
2: "#3b82f6", # blue
3: "#a855f7", # purple
}
PHASE_NAME = {
1: "Phase 1: Tool Mastery",
2: "Phase 2: Clinical Reasoning",
3: "Phase 3: Empathetic Negotiation",
}
# ---------------------------------------------------------------------------
# Loading + utilities
# ---------------------------------------------------------------------------
def load_metrics(path: str) -> List[Dict[str, Any]]:
"""Load the per-episode metrics list dumped by ``train_grpo.train``."""
p = Path(path)
if not p.exists():
raise FileNotFoundError(f"Metrics file not found: {p}")
with p.open("r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError(f"Expected a JSON list at {p}; got {type(data).__name__}")
return data
def split_by_phase(metrics: List[Dict[str, Any]]) -> Dict[int, List[Dict[str, Any]]]:
"""Bucket episode records by their ``phase`` field (1, 2, or 3)."""
buckets: Dict[int, List[Dict[str, Any]]] = defaultdict(list)
for m in metrics:
buckets[int(m.get("phase", 1))].append(m)
return dict(sorted(buckets.items()))
def rolling_mean(xs: Sequence[float], window: int = 10) -> List[float]:
"""Simple right-aligned rolling mean (matches what the env scheduler uses)."""
out: List[float] = []
buf: List[float] = []
for x in xs:
buf.append(x)
if len(buf) > window:
buf.pop(0)
out.append(sum(buf) / len(buf))
return out
# ---------------------------------------------------------------------------
# Per-phase dashboard
# ---------------------------------------------------------------------------
def _plot_phase_dashboard(phase_id: int, episodes: List[Dict[str, Any]],
out_path: str) -> bool:
"""
Render a single 6-panel dashboard PNG for one curriculum phase.
Returns True on success, False if the phase has no episodes.
"""
import matplotlib.pyplot as plt
import numpy as np
if not episodes:
return False
ep_idx = [m["episode"] for m in episodes]
raw = [m.get("raw_reward", 0.0) for m in episodes]
verified = [m.get("verified_reward", 0.0) for m in episodes]
win_rates = [m.get("rolling_win_rate", 0.0) for m in episodes]
avg_rewards = [m.get("rolling_avg_reward", 0.0) for m in episodes]
outcomes = [m.get("outcome", "unknown") for m in episodes]
steps = [m.get("steps", 0) for m in episodes]
components = [m.get("reward_components", {}) for m in episodes]
fig = plt.figure(figsize=(16, 10), constrained_layout=True)
gs = fig.add_gridspec(3, 3)
fig.patch.set_facecolor("white")
phase_color = PHASE_COLOR.get(phase_id, "#3b82f6")
fig.suptitle(
f"{PHASE_NAME.get(phase_id, f'Phase {phase_id}')} "
f"\u2014 {len(episodes)} episodes "
f"\u2014 win rate {sum(1 for o in outcomes if o == 'WIN') / len(episodes):.0%}",
fontsize=15, fontweight="bold", color=phase_color, y=1.02,
)
# ----- Panel 1: Reward growth ------------------------------------------------
ax1 = fig.add_subplot(gs[0, :2])
ax1.scatter(ep_idx, raw, alpha=0.35, s=18, color=phase_color, label="raw episode reward")
ax1.plot(ep_idx, rolling_mean(raw, window=10), linewidth=2.5, color="#111827",
label="rolling mean (w=10)")
ax1.plot(ep_idx, rolling_mean(verified, window=10), linewidth=2,
color="#10b981", linestyle="--", label="verified rolling mean")
ax1.axhline(0, color="gray", linestyle=":", alpha=0.5)
ax1.set_title("Reward growth", fontweight="bold")
ax1.set_xlabel("Episode")
ax1.set_ylabel("Reward")
ax1.legend(loc="upper left", fontsize=9)
ax1.grid(alpha=0.25)
# ----- Panel 2: Win-rate evolution ------------------------------------------
ax2 = fig.add_subplot(gs[0, 2])
ax2.plot(ep_idx, win_rates, linewidth=2.5, color="#10b981")
ax2.fill_between(ep_idx, 0, win_rates, color="#10b981", alpha=0.2)
ax2.set_ylim(0, 1.0)
ax2.set_title("Rolling win rate (w=20)", fontweight="bold")
ax2.set_xlabel("Episode")
ax2.set_ylabel("Win rate")
ax2.grid(alpha=0.25)
# ----- Panel 3: Outcome distribution over episode bins ----------------------
ax3 = fig.add_subplot(gs[1, :2])
bin_size = max(1, len(episodes) // 12) # ~12 bins per phase
bins = []
bin_labels = []
for start in range(0, len(episodes), bin_size):
chunk = outcomes[start:start + bin_size]
bins.append(Counter(chunk))
ep_start = episodes[start]["episode"]
ep_end = episodes[min(start + bin_size, len(episodes)) - 1]["episode"]
bin_labels.append(f"{ep_start}-{ep_end}")
bottom = np.zeros(len(bins))
x = np.arange(len(bins))
for outcome in OUTCOME_ORDER:
heights = np.array([b.get(outcome, 0) for b in bins])
if heights.sum() == 0:
continue
ax3.bar(x, heights, bottom=bottom, color=OUTCOME_COLORS[outcome],
label=outcome, edgecolor="white", linewidth=0.5)
bottom += heights
ax3.set_xticks(x)
ax3.set_xticklabels(bin_labels, rotation=40, ha="right", fontsize=8)
ax3.set_title("Outcome distribution over time", fontweight="bold")
ax3.set_ylabel("Episodes per bin")
ax3.legend(loc="upper right", fontsize=8, ncol=2)
ax3.grid(alpha=0.2, axis="y")
# ----- Panel 4: Reward components (mean within phase) -----------------------
ax4 = fig.add_subplot(gs[1, 2])
component_means: Dict[str, float] = defaultdict(float)
counts: Dict[str, int] = defaultdict(int)
for c in components:
for k, v in c.items():
component_means[k] += float(v)
counts[k] += 1
if component_means:
ordered = sorted(
component_means.keys(),
key=lambda k: component_means[k] / max(counts[k], 1),
reverse=True,
)
means = [component_means[k] / max(counts[k], 1) for k in ordered]
colors = [COMPONENT_COLORS.get(k, "#94a3b8") for k in ordered]
bars = ax4.barh(range(len(ordered)), means, color=colors,
edgecolor="white", linewidth=0.5)
ax4.set_yticks(range(len(ordered)))
ax4.set_yticklabels(ordered, fontsize=8)
ax4.axvline(0, color="gray", linestyle=":", alpha=0.5)
ax4.set_title("Reward components (mean / episode)", fontweight="bold")
ax4.set_xlabel("Reward")
ax4.grid(alpha=0.2, axis="x")
for bar, value in zip(bars, means):
ax4.text(value, bar.get_y() + bar.get_height() / 2,
f" {value:+.2f}", va="center",
ha="left" if value >= 0 else "right",
fontsize=7, color="#374151")
else:
ax4.text(0.5, 0.5, "No reward-component data",
transform=ax4.transAxes, ha="center", va="center", color="gray")
ax4.set_title("Reward components (mean / episode)", fontweight="bold")
# ----- Panel 5: GRPO loss + KL divergence per update ------------------------
ax5 = fig.add_subplot(gs[2, :2])
update_eps = [m["episode"] for m in episodes if m.get("grpo_update")]
losses = [m["grpo_update"]["loss"] for m in episodes if m.get("grpo_update")]
kls = [m["grpo_update"]["kl"] for m in episodes if m.get("grpo_update")]
if update_eps:
l1 = ax5.plot(update_eps, losses, marker="o", color="#3b82f6",
linewidth=1.8, label="loss")
ax5b = ax5.twinx()
l2 = ax5b.plot(update_eps, kls, marker="s", color="#ef4444",
linewidth=1.5, label="KL", alpha=0.85)
ax5.set_title("GRPO update statistics", fontweight="bold")
ax5.set_xlabel("Episode (last episode of group)")
ax5.set_ylabel("Loss", color="#3b82f6")
ax5b.set_ylabel("KL", color="#ef4444")
ax5.tick_params(axis="y", labelcolor="#3b82f6")
ax5b.tick_params(axis="y", labelcolor="#ef4444")
ax5.axhline(0, color="gray", linestyle=":", alpha=0.5)
ax5.grid(alpha=0.2)
lines = l1 + l2
ax5.legend(lines, [l.get_label() for l in lines], loc="upper right", fontsize=8)
else:
ax5.text(0.5, 0.5, "No GRPO update stats logged\n(set --no-dry-run or upgrade train_grpo.py)",
transform=ax5.transAxes, ha="center", va="center", color="gray")
ax5.set_title("GRPO update statistics", fontweight="bold")
# ----- Panel 6: Episode-length distribution ---------------------------------
ax6 = fig.add_subplot(gs[2, 2])
if steps:
ax6.hist(steps, bins=range(0, max(steps) + 2), color=phase_color,
alpha=0.85, edgecolor="white")
ax6.axvline(sum(steps) / len(steps), color="#111827", linestyle="--",
linewidth=1.5, label=f"mean={sum(steps) / len(steps):.1f}")
ax6.set_title("Episode length distribution", fontweight="bold")
ax6.set_xlabel("Steps")
ax6.set_ylabel("Episodes")
ax6.legend(fontsize=8)
ax6.grid(alpha=0.2, axis="y")
else:
ax6.set_visible(False)
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=140, bbox_inches="tight")
plt.close(fig)
return True
# ---------------------------------------------------------------------------
# Cross-phase overview (single chart, all 3 phases on the same axes)
# ---------------------------------------------------------------------------
def _plot_all_phases_overview(metrics: List[Dict[str, Any]], out_path: str) -> None:
import matplotlib.pyplot as plt
if not metrics:
return
eps = [m["episode"] for m in metrics]
raw = [m.get("raw_reward", 0.0) for m in metrics]
rolling = [m.get("rolling_avg_reward", 0.0) for m in metrics]
wr = [m.get("rolling_win_rate", 0.0) for m in metrics]
phase = [int(m.get("phase", 1)) for m in metrics]
fig, axes = plt.subplots(2, 1, figsize=(13, 7), sharex=True,
constrained_layout=True)
fig.patch.set_facecolor("white")
fig.suptitle("ER-MAP GRPO training — full curriculum overview",
fontsize=14, fontweight="bold", y=1.02)
# Reward axis
axes[0].scatter(eps, raw, alpha=0.30, s=14, color="#3b82f6",
label="raw episode reward")
axes[0].plot(eps, rolling, linewidth=2.5, color="#111827",
label="rolling avg reward (w=20)")
axes[0].axhline(0, color="gray", linestyle=":", alpha=0.5)
axes[0].set_ylabel("Reward")
axes[0].legend(loc="upper left", fontsize=9)
axes[0].grid(alpha=0.25)
# Win-rate axis
axes[1].plot(eps, wr, linewidth=2.5, color="#10b981")
axes[1].fill_between(eps, 0, wr, color="#10b981", alpha=0.18)
axes[1].set_ylim(0, 1.0)
axes[1].set_ylabel("Rolling win rate")
axes[1].set_xlabel("Episode")
axes[1].grid(alpha=0.25)
# Phase boundary lines + colored bands
for i in range(1, len(phase)):
if phase[i] != phase[i - 1]:
for ax in axes:
ax.axvline(eps[i], color="red", linestyle=":", alpha=0.6, linewidth=1.5)
axes[0].text(eps[i], axes[0].get_ylim()[1] * 0.92,
f" \u2192 Phase {phase[i]}", color="red",
fontsize=9, fontweight="bold")
# Faint background colour bands per phase
cur_phase = phase[0]
band_start = eps[0]
for i in range(1, len(phase) + 1):
if i == len(phase) or phase[i] != cur_phase:
band_end = eps[i - 1]
for ax in axes:
ax.axvspan(band_start, band_end,
color=PHASE_COLOR.get(cur_phase, "#3b82f6"),
alpha=0.06, zorder=0)
if i < len(phase):
cur_phase = phase[i]
band_start = eps[i]
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=140, bbox_inches="tight")
plt.close(fig)
# ---------------------------------------------------------------------------
# Cross-phase summary bars
# ---------------------------------------------------------------------------
def _plot_phase_comparison(buckets: Dict[int, List[Dict[str, Any]]],
out_path: str) -> None:
import matplotlib.pyplot as plt
import numpy as np
if not buckets:
return
phases = sorted(buckets.keys())
avg_reward = [sum(m.get("raw_reward", 0.0) for m in buckets[p]) / max(len(buckets[p]), 1) for p in phases]
win_rate = [sum(1 for m in buckets[p] if m.get("outcome") == "WIN") / max(len(buckets[p]), 1) for p in phases]
avg_steps = [sum(m.get("steps", 0) for m in buckets[p]) / max(len(buckets[p]), 1) for p in phases]
fatal_rate = [sum(1 for m in buckets[p] if m.get("outcome") == "FATAL_LOSS") / max(len(buckets[p]), 1) for p in phases]
fig, axes = plt.subplots(1, 4, figsize=(16, 4.2), constrained_layout=True)
fig.patch.set_facecolor("white")
fig.suptitle("Phase-by-phase comparison", fontsize=14, fontweight="bold", y=1.05)
labels = [f"Phase {p}" for p in phases]
colors = [PHASE_COLOR.get(p, "#3b82f6") for p in phases]
metrics = [
("Avg raw reward", avg_reward, "Reward", "{:+.2f}"),
("Win rate", win_rate, "Rate", "{:.0%}"),
("Avg steps", avg_steps, "Steps", "{:.1f}"),
("Fatal-loss rate", fatal_rate, "Rate", "{:.0%}"),
]
for ax, (title, values, ylabel, fmt) in zip(axes, metrics):
bars = ax.bar(labels, values, color=colors,
edgecolor="white", linewidth=0.6)
ax.set_title(title, fontweight="bold")
ax.set_ylabel(ylabel)
ax.grid(alpha=0.2, axis="y")
if title in ("Win rate", "Fatal-loss rate"):
ax.set_ylim(0, max(1.0, max(values) * 1.15))
for bar, v in zip(bars, values):
ax.text(bar.get_x() + bar.get_width() / 2,
bar.get_height(),
" " + fmt.format(v), ha="center", va="bottom",
fontsize=9, color="#111827")
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=140, bbox_inches="tight")
plt.close(fig)
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
def plot_per_phase_dashboards(
metrics_path: str,
output_dir: str,
*,
phases: Optional[Sequence[int]] = None,
) -> Dict[str, str]:
"""
Read a ``training_metrics.json`` file and produce:
- one 6-panel dashboard PNG per phase that has episodes
- one cross-phase overview PNG (all phases on shared axes)
- one phase-comparison bar chart PNG
Returns a dict mapping ``{logical_name: written_path}`` for every
PNG that was actually written.
"""
metrics = load_metrics(metrics_path)
buckets = split_by_phase(metrics)
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
written: Dict[str, str] = {}
# Per-phase dashboards
target_phases = list(phases) if phases else sorted(buckets.keys())
for p in target_phases:
eps = buckets.get(int(p), [])
if not eps:
continue
path = out_dir / f"phase{p}_dashboard.png"
ok = _plot_phase_dashboard(int(p), eps, str(path))
if ok:
written[f"phase{p}_dashboard"] = str(path)
# Cross-phase overview + comparison
overview = out_dir / "all_phases_overview.png"
_plot_all_phases_overview(metrics, str(overview))
written["all_phases_overview"] = str(overview)
comparison = out_dir / "all_phases_comparison.png"
_plot_phase_comparison(buckets, str(comparison))
written["all_phases_comparison"] = str(comparison)
return written
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Baseline-eval plots (clean single-panel episode-vs-reward histograms)
# ---------------------------------------------------------------------------
def plot_baseline_phase_histogram(
results: List[Dict[str, Any]],
phase_id: int,
out_path: str,
*,
title_suffix: str = "Baseline (no RL)",
) -> str:
"""
Render a single clean episode-vs-reward bar chart for one phase of a
baseline evaluation run. One bar per episode, colored by outcome.
Returns the absolute path of the PNG that was written.
"""
import matplotlib.pyplot as plt
if not results:
raise ValueError(f"No results to plot for phase {phase_id}")
episodes = [r.get("episode", i + 1) for i, r in enumerate(results)]
rewards = [float(r.get("total_reward", 0.0)) for r in results]
outcomes = [r.get("outcome", "unknown") for r in results]
colors = [OUTCOME_COLORS.get(o, OUTCOME_COLORS["unknown"]) for o in outcomes]
fig, ax = plt.subplots(figsize=(12, 5.5), constrained_layout=True)
fig.patch.set_facecolor("white")
ax.bar(episodes, rewards, color=colors, edgecolor="white", linewidth=0.6,
width=0.85)
ax.axhline(0, color="#374151", linewidth=1, linestyle="-", alpha=0.6)
avg = sum(rewards) / len(rewards)
win_rate = sum(1 for o in outcomes if o == "WIN") / len(outcomes)
mean_line = ax.axhline(avg, color="#111827", linewidth=1.5, linestyle="--",
alpha=0.85, label=f"mean = {avg:+.2f}")
ax.set_xticks(episodes)
ax.set_xticklabels([str(e) for e in episodes], fontsize=9)
ax.set_xlabel("Episode", fontsize=12, fontweight="bold")
ax.set_ylabel("Total reward", fontsize=12, fontweight="bold")
ax.set_title(
f"{title_suffix}{PHASE_NAME.get(phase_id, f'Phase {phase_id}')} "
f"(n={len(results)}, win rate {win_rate:.0%})",
fontsize=13, fontweight="bold",
color=PHASE_COLOR.get(phase_id, "#3b82f6"),
pad=12,
)
ax.grid(alpha=0.25, axis="y")
# Single, unified legend: outcome swatches + the mean line, all in
# one legend in the upper-right so the chart stays clean.
used_outcomes = [o for o in OUTCOME_ORDER if o in outcomes]
handles = [
plt.Rectangle((0, 0), 1, 1, color=OUTCOME_COLORS[o], label=o)
for o in used_outcomes
]
handles.append(mean_line)
ax.legend(handles=handles, loc="upper right", fontsize=9,
framealpha=0.92, ncol=1, title="Outcome / mean",
title_fontsize=9)
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=140, bbox_inches="tight")
plt.close(fig)
return str(Path(out_path).resolve())
def plot_baseline_phase_comparison(
results_by_phase: Dict[int, List[Dict[str, Any]]],
out_path: str,
) -> str:
"""
Cross-phase summary for a baseline run: 3 mini bar charts of
win-rate / avg-reward / fatal-rate per phase. Useful as a single
quick sanity glance before/after we plug in the trained model.
"""
import matplotlib.pyplot as plt
if not results_by_phase:
raise ValueError("No baseline results to compare")
phases = sorted(results_by_phase.keys())
win_rate = [sum(1 for r in results_by_phase[p] if r.get("outcome") == "WIN") /
max(len(results_by_phase[p]), 1) for p in phases]
avg_reward = [sum(float(r.get("total_reward", 0.0)) for r in results_by_phase[p]) /
max(len(results_by_phase[p]), 1) for p in phases]
fatal_rate = [sum(1 for r in results_by_phase[p]
if r.get("outcome") in ("FATAL", "FATAL_LOSS")) /
max(len(results_by_phase[p]), 1) for p in phases]
fig, axes = plt.subplots(1, 3, figsize=(13, 4.0), constrained_layout=True)
fig.patch.set_facecolor("white")
fig.suptitle("Baseline (no RL) — per-phase summary",
fontsize=14, fontweight="bold", y=1.06)
labels = [f"Phase {p}" for p in phases]
colors = [PHASE_COLOR.get(p, "#3b82f6") for p in phases]
for ax, (title, vals, ylabel, fmt) in zip(axes, [
("Win rate", win_rate, "Rate", "{:.0%}"),
("Avg reward", avg_reward, "Reward", "{:+.2f}"),
("Fatal/Wrong rate", fatal_rate, "Rate", "{:.0%}"),
]):
bars = ax.bar(labels, vals, color=colors, edgecolor="white", linewidth=0.6)
ax.set_title(title, fontweight="bold")
ax.set_ylabel(ylabel)
ax.grid(alpha=0.2, axis="y")
if "rate" in title.lower():
ax.set_ylim(0, max(1.0, (max(vals) if vals else 0) * 1.15))
for bar, v in zip(bars, vals):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
" " + fmt.format(v), ha="center", va="bottom",
fontsize=10, color="#111827")
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_path, dpi=140, bbox_inches="tight")
plt.close(fig)
return str(Path(out_path).resolve())
def _cli() -> None:
import argparse
parser = argparse.ArgumentParser(
description="Render per-phase + cross-phase ER-MAP training plots."
)
parser.add_argument(
"--metrics", default="er_map_grpo_checkpoints/training_metrics.json",
help="Path to training_metrics.json from train_grpo",
)
parser.add_argument(
"--out", default="er_map_grpo_checkpoints/plots",
help="Output directory for the rendered PNG files",
)
parser.add_argument(
"--phases", default="", help="Optional comma-separated subset of phases (e.g. 1,2)"
)
args = parser.parse_args()
phase_list = None
if args.phases.strip():
phase_list = [int(p.strip()) for p in args.phases.split(",") if p.strip()]
written = plot_per_phase_dashboards(args.metrics, args.out, phases=phase_list)
print("=" * 60)
print(f" Plotted {len(written)} chart(s):")
for name, path in written.items():
print(f" {name:<28s} -> {path}")
print("=" * 60)
if __name__ == "__main__":
_cli()