Spaces:
Sleeping
Sleeping
File size: 6,403 Bytes
a2144da | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
SERIES_ORDER = [
("center_sft", "center SFT", "#2563eb"),
("center_grpo", "center GRPO", "#7c3aed"),
("warehouse_sft", "warehouse SFT", "#0f766e"),
("warehouse_grpo", "warehouse GRPO", "#047857"),
("center", "center", "#2563eb"),
("warehouse", "warehouse", "#0f766e"),
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Create static SupplyMind training/eval plots.")
parser.add_argument("--input", type=Path, default=Path("results/training_dashboard.json"))
parser.add_argument("--output-dir", type=Path, default=Path("results/plots"))
return parser.parse_args()
def load_json(path: Path) -> dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8"))
def value(row: dict[str, Any], *keys: str) -> float | None:
for key in keys:
raw = row.get(key)
if raw is None:
continue
try:
return float(raw)
except (TypeError, ValueError):
continue
return None
def active_series(data: dict[str, Any]) -> list[tuple[str, str, str]]:
series = data.get("training_series", {})
active = [item for item in SERIES_ORDER if item[0] in series]
if "center_sft" in series or "center_grpo" in series or "warehouse_sft" in series or "warehouse_grpo" in series:
return [item for item in active if item[0].endswith(("_sft", "_grpo"))]
return active
def line_plot(data: dict[str, Any], y_keys: tuple[str, ...], title: str, ylabel: str, output: Path) -> None:
series = data.get("training_series", {})
plt.figure(figsize=(10, 5.2))
plotted = False
for key, label, color in active_series(data):
rows = series.get(key, {}).get("steps", [])
xs: list[float] = []
ys: list[float] = []
for idx, row in enumerate(rows, start=1):
y = value(row, *y_keys)
if y is None:
continue
xs.append(value(row, "step", "global_step") or idx)
ys.append(y)
if xs:
plt.plot(xs, ys, label=label, color=color, linewidth=2)
plotted = True
plt.title(title)
plt.xlabel("training step")
plt.ylabel(ylabel)
plt.grid(alpha=0.25)
if plotted:
plt.legend()
else:
plt.text(0.5, 0.5, "No series available", ha="center", va="center", transform=plt.gca().transAxes)
plt.tight_layout()
plt.savefig(output, dpi=160)
plt.close()
def invalid_plot(data: dict[str, Any], output: Path) -> None:
series = data.get("training_series", {})
labels: list[str] = []
payloads: list[float] = []
actions: list[float] = []
for key, label, _color in active_series(data):
batches = series.get(key, {}).get("reward_batches", [])
if not batches:
continue
labels.append(label)
payloads.append(sum(value(row, "invalid_payloads") or 0 for row in batches))
actions.append(sum(value(row, "invalid_actions") or 0 for row in batches))
plt.figure(figsize=(10, 5.2))
if labels:
xs = range(len(labels))
plt.bar([x - 0.18 for x in xs], payloads, width=0.36, label="invalid payloads", color="#c2410c")
plt.bar([x + 0.18 for x in xs], actions, width=0.36, label="invalid env actions", color="#b7791f")
plt.xticks(list(xs), labels, rotation=20, ha="right")
plt.legend()
else:
plt.text(0.5, 0.5, "No invalid-action diagnostics available", ha="center", va="center", transform=plt.gca().transAxes)
plt.title("Invalid Payloads / Actions")
plt.ylabel("count across logged reward batches")
plt.grid(axis="y", alpha=0.25)
plt.tight_layout()
plt.savefig(output, dpi=160)
plt.close()
def heldout_plot(data: dict[str, Any], output: Path) -> None:
comparisons = data.get("comparisons", [])
roles = ["center", "warehouse"]
variants = [("base", "#64748b"), ("sft", "#2563eb"), ("grpo", "#7c3aed")]
groups: list[tuple[str, str, str]] = []
for role in roles:
groups.append((role, "global", "mean_global_score"))
groups.append((role, "role", "mean_center_role_score" if role == "center" else "mean_warehouse_role_score"))
plt.figure(figsize=(11, 5.5))
group_xs = list(range(len(groups)))
width = 0.22
any_rows = False
for offset, (variant, color) in zip([-width, 0, width], variants, strict=True):
ys = []
for role, _metric, key in groups:
row = next((item for item in comparisons if item.get("role") == role and item.get("label") == variant), None)
ys.append(value(row or {}, key) if row else None)
xs = [x + offset for x, y in zip(group_xs, ys, strict=True) if y is not None]
vals = [y for y in ys if y is not None]
if vals:
any_rows = True
plt.bar(xs, vals, width=width, label=variant.upper(), color=color)
if any_rows:
plt.xticks(group_xs, [f"{role}\n{metric}" for role, metric, _key in groups])
plt.ylim(0, 1)
plt.legend()
else:
plt.text(0.5, 0.5, "No held-out comparisons available", ha="center", va="center", transform=plt.gca().transAxes)
plt.title("Held-out Scores: Base vs SFT vs GRPO")
plt.ylabel("normalized score")
plt.grid(axis="y", alpha=0.25)
plt.tight_layout()
plt.savefig(output, dpi=160)
plt.close()
def main() -> None:
args = parse_args()
data = load_json(args.input)
args.output_dir.mkdir(parents=True, exist_ok=True)
line_plot(data, ("loss",), "Loss Over Step", "loss", args.output_dir / "loss.png")
line_plot(data, ("reward", "rewards/reward_completions/mean"), "Reward Over Step", "reward", args.output_dir / "reward.png")
line_plot(data, ("completions/clipped_ratio", "clipped_ratio"), "Clipped Ratio Over Step", "clipped ratio", args.output_dir / "clipped_ratio.png")
line_plot(data, ("completions/mean_length", "completion_length", "mean_completion_length"), "Completion Length Over Step", "tokens", args.output_dir / "completion_length.png")
invalid_plot(data, args.output_dir / "invalids.png")
heldout_plot(data, args.output_dir / "heldout_comparison.png")
print(f"Wrote plots to {args.output_dir}")
if __name__ == "__main__":
main()
|