File size: 25,893 Bytes
ddbc1ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 | """
plot_training.py β Parse a train_run_vN.log and emit publication-quality plots.
Usage (on the HF Jupyter server or Colab):
python scripts/plot_training.py --log train_run_v4.log --out plots/
Outputs (in --out dir):
reward_curve.png β rolling-mean total reward + raw scatter
reward_components.png β per reward-function breakdown
loss_curve.png β HF Trainer loss
training_summary.png β 4-panel combined figure (perfect for slides / README)
The parser handles three log formats emitted by train_trl.py:
1. [step N] r0=X.XXX | r_lt=Y.YYY | comp1=A | comp2=B ... (custom reward log)
2. {'loss': '...', 'reward': '...', 'rewards/fn/mean': '...', 'epoch': '...', ...}
(HF/TRL Trainer dicts β values may be quoted strings OR bare floats)
3. Raw JSON-lines in training_logs/generations.jsonl
"""
from __future__ import annotations
import argparse
import json
import os
import re
import sys
from pathlib import Path
from typing import Any
import matplotlib
matplotlib.use("Agg") # headless β safe on Jupyter too
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
# ββ colour palette (dark, print-friendly) βββββββββββββββββββββββββββββββββββββ
C = {
"reward": "#4C8BF5", # Google-blue
"longterm": "#34A853", # Google-green
"task": "#FBBC05", # Google-yellow
"milestone": "#EA4335", # Google-red
"replan": "#9B59B6", # purple
"loss": "#E67E22", # orange
"bg": "#0F0F0F",
"grid": "#2A2A2A",
"text": "#E0E0E0",
"band": "#4C8BF5",
}
plt.rcParams.update({
"figure.facecolor": C["bg"],
"axes.facecolor": C["bg"],
"axes.edgecolor": C["grid"],
"axes.labelcolor": C["text"],
"xtick.color": C["text"],
"ytick.color": C["text"],
"text.color": C["text"],
"grid.color": C["grid"],
"legend.facecolor": "#1A1A1A",
"legend.edgecolor": C["grid"],
"figure.dpi": 150,
"font.family": "DejaVu Sans",
})
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Parsing helpers
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_STEP_RE = re.compile(
r"\[step\s+(\d+)\]\s+"
r"r0=([+-]?\d+\.\d+).*?"
r"r_lt=([+-]?\d+\.\d+)"
r"((?:\s*\|\s*\w+=[-+]?\d+\.\d+)*)"
)
_COMP_RE = re.compile(r"(\w+)=([-+]?\d+\.\d+)")
_DICT_RE = re.compile(r"\{[^{}]+\}")
def _safe_float(v) -> float | None:
"""Convert a value that may be a bare float or a quoted string like '-8.941e-09'."""
try:
return float(v)
except (TypeError, ValueError):
return None
def _parse_dict_line(raw: str) -> dict[str, Any] | None:
"""Parse a Python-dict-style log line where values may be quoted strings or bare floats."""
# Strategy 1: replace single-quotes β double-quotes for JSON
try:
js = raw.replace("'", '"').replace("True", "true").replace("False", "false").replace("None", "null")
return json.loads(js)
except Exception:
pass
# Strategy 2: eval (safe fallback β no builtins)
try:
return eval(raw, {"__builtins__": {}}) # noqa: S307
except Exception:
return None
# Known per-reward-function keys emitted by TRL GRPO trainer
_REWARD_FN_KEYS = [
"rewards/reward_format_fn/mean",
"rewards/reward_clean_eos_fn/mean",
"rewards/reward_route_target_fn/mean",
"rewards/reward_task_success_fn/mean",
"rewards/reward_milestone_fn/mean",
"rewards/reward_replan_fn/mean",
"rewards/reward_longterm_fn/mean",
"rewards/reward_compact_fn/mean",
"rewards/reward_human_feedback_fn/mean",
]
# Human-friendly short labels for the above keys
_REWARD_FN_LABELS = {
"rewards/reward_format_fn/mean": "Format",
"rewards/reward_clean_eos_fn/mean": "Clean EOS",
"rewards/reward_route_target_fn/mean": "Route/Target",
"rewards/reward_task_success_fn/mean": "Task success",
"rewards/reward_milestone_fn/mean": "Milestone",
"rewards/reward_replan_fn/mean": "Replan",
"rewards/reward_longterm_fn/mean": "Long-term",
"rewards/reward_compact_fn/mean": "Compact",
"rewards/reward_human_feedback_fn/mean": "Human feedback",
}
def parse_log(log_path: str | Path) -> dict[str, list]:
"""Return parsed series from a train_run_vN.log file."""
log_path = Path(log_path)
if not log_path.exists():
sys.exit(f"[plot_training] Log file not found: {log_path}")
series: dict[str, list] = {
# Format-1 fields (custom [step N] r0=... lines)
"step": [], "reward": [], "longterm": [],
"task_success": [], "milestone": [], "replan": [],
# Format-2 fields (HF Trainer dict lines)
"train_step": [], "loss": [], "lr": [], "epoch": [],
"trainer_reward": [], "reward_std": [],
}
# Per-reward-function series (populated from Format-2 dicts)
fn_series: dict[str, list] = {k: [] for k in _REWARD_FN_KEYS}
text = log_path.read_text(errors="replace")
# ββ Format 1: [step N] r0=... r_lt=... ββββββββββββββββββββββββββββββββββββ
for m in _STEP_RE.finditer(text):
step = int(m.group(1))
r0 = float(m.group(2))
r_lt = float(m.group(3))
comp_block = m.group(4) or ""
comps = dict(_COMP_RE.findall(comp_block))
series["step"].append(step)
series["reward"].append(r0)
series["longterm"].append(r_lt)
series["task_success"].append(float(comps.get("task_success", comps.get("completion", 0.0))))
series["milestone"].append(float(comps.get("milestone", 0.0)))
series["replan"].append(float(comps.get("replan", 0.0)))
# ββ Format 2: {'loss': '...', 'reward': '...', 'rewards/fn/mean': '...', ...}
# Values are quoted strings in the v4 log (e.g. 'reward': '0.5788')
global_step_counter = 0
for m in _DICT_RE.finditer(text):
d = _parse_dict_line(m.group(0))
if not isinstance(d, dict):
continue
# Skip the final train_runtime summary dict (no per-step metrics)
if "train_runtime" in d:
continue
if "loss" not in d and "reward" not in d:
continue
global_step_counter += 1
explicit_step = d.get("step") or d.get("global_step")
step_val = int(float(explicit_step)) if explicit_step is not None else global_step_counter
series["train_step"].append(step_val)
loss_v = _safe_float(d.get("loss"))
if loss_v is not None:
series["loss"].append(loss_v)
lr_v = _safe_float(d.get("learning_rate"))
if lr_v is not None:
series["lr"].append(lr_v)
ep_v = _safe_float(d.get("epoch"))
if ep_v is not None:
series["epoch"].append(ep_v)
rw_v = _safe_float(d.get("reward"))
if rw_v is not None:
series["trainer_reward"].append(rw_v)
rs_v = _safe_float(d.get("reward_std"))
if rs_v is not None:
series["reward_std"].append(rs_v)
# Per-function reward means
for key in _REWARD_FN_KEYS:
fv = _safe_float(d.get(key))
if fv is not None:
fn_series[key].append(fv)
# ββ Format 3: JSONL generations log βββββββββββββββββββββββββββββββββββββββ
jsonl_path = log_path.parent / "training_logs" / "generations.jsonl"
if jsonl_path.exists() and not series["step"]:
print(f"[plot_training] Loading supplemental JSONL: {jsonl_path}")
for line in jsonl_path.read_text().splitlines():
try:
d = json.loads(line)
series["step"].append(int(d.get("step", 0)))
series["reward"].append(float(d.get("reward", 0)))
series["longterm"].append(float(d.get("longterm_reward", 0)))
comps = d.get("breakdown", {}).get("components", {})
series["task_success"].append(float(comps.get("completion", comps.get("task_success", 0))))
series["milestone"].append(float(comps.get("milestone", 0)))
series["replan"].append(float(comps.get("replan", 0)))
except Exception:
pass
# Attach fn_series so callers can access it
series["_fn_series"] = fn_series # type: ignore[assignment]
total_reward = len(series["step"]) or len(series["trainer_reward"])
total_trainer = len(series["loss"])
print(f"[plot_training] Parsed {len(series['step'])} custom-format steps, "
f"{len(series['trainer_reward'])} trainer reward entries, "
f"{total_trainer} loss entries.")
active_fns = [k for k, v in fn_series.items() if v]
if active_fns:
print(f"[plot_training] Per-function reward series found: "
f"{[_REWARD_FN_LABELS[k] for k in active_fns]}")
if total_reward == 0 and total_trainer == 0:
sys.exit("[plot_training] Nothing parsed. Check the log path and format.")
return series
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Smoothing
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def rolling_mean(arr: list[float], w: int = 20) -> np.ndarray:
if len(arr) < 2:
return np.array(arr, dtype=float)
a = np.array(arr, dtype=float)
kernel = np.ones(w) / w
padded = np.pad(a, (w - 1, 0), mode="edge")
return np.convolve(padded, kernel, mode="valid")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Individual plots
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _annotate_trend(ax, x, y_smooth):
"""Add a start/end annotation showing net improvement."""
if len(y_smooth) < 4:
return
start_val = float(np.mean(y_smooth[:max(1, len(y_smooth)//10)]))
end_val = float(np.mean(y_smooth[-max(1, len(y_smooth)//10):]))
delta = end_val - start_val
sign = "+" if delta >= 0 else ""
ax.annotate(
f"Ξ = {sign}{delta:.3f}",
xy=(x[-1], y_smooth[-1]),
xytext=(-60, 12),
textcoords="offset points",
fontsize=9,
color=C["text"],
arrowprops=dict(arrowstyle="->", color=C["text"], lw=0.8),
)
def plot_reward_curve(series: dict, out_dir: Path) -> Path:
# Prefer custom-format steps; fall back to trainer_reward from HF Trainer dicts
if series["reward"]:
steps = np.array(series["step"])
rewards = np.array(series["reward"])
xlabel = "Reward Call Step"
ylabel = "Immediate Reward rβ"
std_arr = None
elif series["trainer_reward"]:
n = len(series["trainer_reward"])
steps = np.array(series["train_step"] if series["train_step"] else range(1, n + 1))
rewards = np.array(series["trainer_reward"])
xlabel = "Training Step"
ylabel = "Mean Reward (per GRPO step)"
std_arr = np.array(series["reward_std"]) if len(series["reward_std"]) == n else None
else:
print("[plot_training] No reward data β skipping reward_curve.png")
return None
w = max(2, len(steps) // 10)
smooth = rolling_mean(rewards.tolist(), w)
fig, ax = plt.subplots(figsize=(10, 4.5))
ax.scatter(steps, rewards, s=30, alpha=0.6, color=C["reward"], label="Per-step reward", zorder=2)
ax.plot(steps, smooth, lw=2.5, color=C["reward"], label=f"Rolling mean (w={w})", zorder=3)
# shade Β±1 std band (from logged reward_std if available, else computed)
if std_arr is not None:
ax.fill_between(steps, rewards - std_arr, rewards + std_arr,
alpha=0.2, color=C["band"], label="Β±1 Ο (logged)")
elif len(rewards) >= w:
sq_sm = rolling_mean((rewards ** 2).tolist(), w)
std_sm = np.sqrt(np.maximum(0, sq_sm - smooth ** 2))
ax.fill_between(steps, smooth - std_sm, smooth + std_sm,
alpha=0.18, color=C["band"], label="Β±1 Ο band")
_annotate_trend(ax, steps, smooth)
ax.axhline(0, color=C["grid"], lw=0.8, ls="--")
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title("LifeStack GRPO β Reward Curve (v4 run)", pad=12)
ax.legend(framealpha=0.7, fontsize=9)
ax.yaxis.set_major_formatter(mticker.FormatStrFormatter("%.3f"))
ax.grid(True, lw=0.4)
fig.tight_layout()
out = out_dir / "reward_curve.png"
fig.savefig(out, bbox_inches="tight")
plt.close(fig)
print(f"[plot_training] Saved: {out}")
return out
def plot_reward_components(series: dict, out_dir: Path) -> Path:
fn_series = series.get("_fn_series", {})
# Build component dict: prefer per-fn rewards from HF Trainer; fall back to custom format
active_fns = {k: v for k, v in fn_series.items() if v}
use_custom = (not active_fns) and series["step"]
if not active_fns and not use_custom:
print("[plot_training] No component reward data β skipping reward_components.png")
return None
if use_custom:
steps_raw = series["step"]
components = {
"Long-term": (series["longterm"], C["longterm"]),
"Task success":(series["task_success"],C["task"]),
"Milestone": (series["milestone"], C["milestone"]),
"Replan": (series["replan"], C["replan"]),
}
else:
n = max(len(v) for v in active_fns.values())
steps_raw = (series["train_step"] if len(series["train_step"]) == n
else list(range(1, n + 1)))
palette = [C["reward"], C["longterm"], C["task"], C["milestone"],
C["replan"], "#00BCD4", "#FF5722", "#9C27B0", "#607D8B"]
components = {
_REWARD_FN_LABELS[k]: (v, palette[i % len(palette)])
for i, (k, v) in enumerate(active_fns.items())
}
steps = np.array(steps_raw)
ncols = 2
nrows = (len(components) + 1) // ncols
w = max(2, len(steps) // 10)
fig, axes = plt.subplots(nrows, ncols, figsize=(12, 3.5 * nrows), sharex=True)
axes_flat = axes.flat if hasattr(axes, "flat") else [axes]
fig.suptitle("LifeStack GRPO β Per-Function Reward Components (v4 run)",
y=1.01, fontsize=13)
for ax, (label, (vals, color)) in zip(axes_flat, components.items()):
vals_arr = np.array(vals[:len(steps)], dtype=float)
if len(vals_arr) == 0 or (vals_arr == 0).all():
ax.text(0.5, 0.5, "No signal", ha="center", va="center",
transform=ax.transAxes, color=C["text"], fontsize=11)
ax.set_title(label, fontsize=10)
continue
xs = steps[:len(vals_arr)]
smooth = rolling_mean(vals_arr.tolist(), w)
ax.scatter(xs, vals_arr, s=25, alpha=0.5, color=color)
ax.plot(xs, smooth, lw=2.0, color=color)
ax.axhline(0, color=C["grid"], lw=0.6, ls="--")
ax.set_title(label, fontsize=10)
ax.set_ylabel("Reward")
ax.grid(True, lw=0.3)
_annotate_trend(ax, xs, smooth)
# Hide any unused subplot panels
for ax in list(axes_flat)[len(components):]:
ax.set_visible(False)
for ax in (axes[-1] if nrows > 1 else [axes]):
if hasattr(ax, '__iter__'):
for a in ax:
a.set_xlabel("Training Step")
else:
ax.set_xlabel("Training Step")
fig.tight_layout()
out = out_dir / "reward_components.png"
fig.savefig(out, bbox_inches="tight")
plt.close(fig)
print(f"[plot_training] Saved: {out}")
return out
def plot_loss_curve(series: dict, out_dir: Path) -> Path:
loss = series["loss"]
steps = series["train_step"]
if len(loss) == 0:
print("[plot_training] No loss data in log β skipping loss_curve.png")
return None
x = np.array(steps if steps else range(len(loss)), dtype=float)
y = np.array(loss, dtype=float)
w = max(5, len(y) // 20)
smooth = rolling_mean(y.tolist(), w)
fig, ax = plt.subplots(figsize=(10, 4))
ax.scatter(x, y, s=8, alpha=0.3, color=C["loss"], label="Raw loss")
ax.plot(x, smooth, lw=2.0, color=C["loss"], label=f"Rolling mean (w={w})")
_annotate_trend(ax, x, smooth)
ax.set_xlabel("Training Step")
ax.set_ylabel("Loss")
ax.set_title("LifeStack GRPO β Training Loss (v4 run)", pad=12)
ax.legend(framealpha=0.7, fontsize=9)
ax.grid(True, lw=0.4)
fig.tight_layout()
out = out_dir / "loss_curve.png"
fig.savefig(out, bbox_inches="tight")
plt.close(fig)
print(f"[plot_training] Saved: {out}")
return out
def plot_summary_4panel(series: dict, out_dir: Path) -> Path:
"""4-panel combined figure: reward, reward_std, components, loss."""
fn_series = series.get("_fn_series", {})
# Determine reward source
has_custom = len(series["step"]) > 0
has_trainer = len(series["trainer_reward"]) > 0
has_loss = len(series["loss"]) > 0
has_fn = any(v for v in fn_series.values())
if not has_custom and not has_trainer and not has_loss:
return None
# Build reward arrays
if has_custom:
steps_r = np.array(series["step"])
rewards = np.array(series["reward"], dtype=float)
longterm = np.array(series["longterm"], dtype=float)
rlabel = "rβ"
elif has_trainer:
n = len(series["trainer_reward"])
steps_r = np.array(series["train_step"] if series["train_step"] else range(1, n + 1))
rewards = np.array(series["trainer_reward"], dtype=float)
longterm = (np.array(series["reward_std"], dtype=float)
if len(series["reward_std"]) == n else None)
rlabel = "Mean reward"
else:
steps_r = rewards = longterm = None
rlabel = ""
n_loss = len(series["loss"])
steps_t = np.array(series["train_step"][:n_loss] if series["train_step"]
else range(1, n_loss + 1), dtype=float)
w_r = max(2, len(steps_r) // 10) if steps_r is not None else 2
w_l = max(2, n_loss // 10) if n_loss else 2
fig = plt.figure(figsize=(14, 8))
fig.suptitle("LifeStack GRPO Training Evidence β Run v4", fontsize=14, y=1.00)
gs = fig.add_gridspec(2, 2, hspace=0.44, wspace=0.36)
ax_r = fig.add_subplot(gs[0, 0])
ax_lt = fig.add_subplot(gs[0, 1])
ax_c = fig.add_subplot(gs[1, 0])
ax_l = fig.add_subplot(gs[1, 1])
# ββ Panel 1: total reward ββββββββββββββββββββββββββββββββββββββββββββββββββ
if rewards is not None:
sm = rolling_mean(rewards.tolist(), w_r)
ax_r.scatter(steps_r, rewards, s=40, alpha=0.6, color=C["reward"])
ax_r.plot(steps_r, sm, lw=2.5, color=C["reward"])
ax_r.axhline(0, color=C["grid"], lw=0.7, ls="--")
ax_r.set_title(f"Total Reward ({rlabel})", fontsize=10)
ax_r.set_xlabel("Step"); ax_r.set_ylabel(rlabel)
_annotate_trend(ax_r, steps_r, sm)
else:
ax_r.text(0.5, 0.5, "No reward data", ha="center", va="center",
transform=ax_r.transAxes, color=C["text"])
# ββ Panel 2: reward std OR longterm βββββββββββββββββββββββββββββββββββββββ
if longterm is not None and len(longterm) == len(steps_r):
sm_lt = rolling_mean(longterm.tolist(), w_r)
ax_lt.scatter(steps_r, longterm, s=40, alpha=0.6, color=C["longterm"])
ax_lt.plot(steps_r, sm_lt, lw=2.5, color=C["longterm"])
ax_lt.axhline(0, color=C["grid"], lw=0.7, ls="--")
p2_title = "Long-term Reward" if has_custom else "Reward Std Dev"
ax_lt.set_title(p2_title, fontsize=10)
ax_lt.set_xlabel("Step"); ax_lt.set_ylabel(p2_title)
_annotate_trend(ax_lt, steps_r, sm_lt)
else:
ax_lt.text(0.5, 0.5, "No secondary reward data",
ha="center", va="center", transform=ax_lt.transAxes, color=C["text"])
# ββ Panel 3: per-function rewards (line chart) βββββββββββββββββββββββββββββ
palette = [C["reward"], C["longterm"], C["task"], C["milestone"],
C["replan"], "#00BCD4", "#FF5722", "#607D8B"]
active_fns = {k: v for k, v in fn_series.items() if v}
if active_fns:
n_fn = max(len(v) for v in active_fns.values())
xs = np.array(series["train_step"][:n_fn] if series["train_step"]
else range(1, n_fn + 1))
for i, (key, vals) in enumerate(active_fns.items()):
col = palette[i % len(palette)]
va = np.array(vals[:len(xs)], dtype=float)
sm = rolling_mean(va.tolist(), max(2, len(va) // 10))
ax_c.plot(xs[:len(va)], sm, lw=2.0, color=col,
label=_REWARD_FN_LABELS[key])
ax_c.axhline(0, color=C["grid"], lw=0.6, ls="--")
ax_c.set_title("Per-Function Rewards", fontsize=10)
ax_c.set_xlabel("Step"); ax_c.set_ylabel("Reward")
ax_c.legend(fontsize=7, framealpha=0.6)
elif has_custom:
comp_data = [
("Task success", np.array(series["task_success"], dtype=float), C["task"]),
("Milestone", np.array(series["milestone"], dtype=float), C["milestone"]),
("Replan", np.array(series["replan"], dtype=float), C["replan"]),
]
bottoms = np.zeros(len(steps_r))
bar_w = max(1.0, (steps_r[-1] - steps_r[0]) / max(1, len(steps_r)) * 0.9)
for label, vals, color in comp_data:
clipped = np.clip(vals, 0, None)
ax_c.bar(steps_r, clipped, bottom=bottoms, color=color,
alpha=0.7, width=bar_w, label=label)
bottoms += clipped
ax_c.set_title("Component Rewards (stacked)", fontsize=10)
ax_c.set_xlabel("Step"); ax_c.set_ylabel("Stacked reward")
ax_c.legend(fontsize=8, framealpha=0.6)
else:
ax_c.text(0.5, 0.5, "No component data", ha="center", va="center",
transform=ax_c.transAxes, color=C["text"])
# ββ Panel 4: loss βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if has_loss:
ly = np.array(series["loss"], dtype=float)
sm_l = rolling_mean(ly.tolist(), w_l)
ax_l.scatter(steps_t, ly, s=40, alpha=0.5, color=C["loss"])
ax_l.plot(steps_t, sm_l, lw=2.5, color=C["loss"])
ax_l.set_title("Training Loss", fontsize=10)
ax_l.set_xlabel("Step"); ax_l.set_ylabel("Loss")
_annotate_trend(ax_l, steps_t, sm_l)
else:
ax_l.text(0.5, 0.5, "No loss entries in log",
ha="center", va="center", transform=ax_l.transAxes,
color=C["text"], fontsize=9)
for ax in [ax_r, ax_lt, ax_c, ax_l]:
ax.grid(True, lw=0.3)
fig.tight_layout()
out = out_dir / "training_summary.png"
fig.savefig(out, bbox_inches="tight")
plt.close(fig)
print(f"[plot_training] Saved: {out}")
return out
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# CLI
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def main():
ap = argparse.ArgumentParser(description="Generate training plots from a train_run_vN.log")
ap.add_argument("--log", default="train_run_v4.log",
help="Path to the training log file (default: train_run_v4.log)")
ap.add_argument("--out", default="plots",
help="Output directory for PNG files (default: ./plots/)")
ap.add_argument("--window", type=int, default=0,
help="Rolling-mean window size (0 = auto)")
args = ap.parse_args()
out_dir = Path(args.out)
out_dir.mkdir(parents=True, exist_ok=True)
series = parse_log(args.log)
plot_reward_curve(series, out_dir)
plot_reward_components(series, out_dir)
plot_loss_curve(series, out_dir)
plot_summary_4panel(series, out_dir)
print(f"\n[plot_training] All plots written to: {out_dir.resolve()}")
if __name__ == "__main__":
main()
|