| """ |
| plot_training.py — Training Visualization Dashboard |
| |
| Reads train_log.jsonl and renders a clean, dark-mode training dashboard. |
| |
| Usage: |
| # Static plot of completed/current run |
| python plot_training.py --run_dir runs/run_001 |
| |
| # Live mode: refresh every 5 seconds while training runs |
| python plot_training.py --run_dir runs/run_001 --live |
| |
| # Compare multiple runs |
| python plot_training.py --run_dir runs/run_001 runs/run_002 |
| |
| Dashboard panels: |
| 1. Training Loss (raw + EMA smoothed) |
| 2. Validation Loss (if available) |
| 3. Learning Rate schedule |
| 4. Tokens / second (throughput) |
| 5. VRAM usage (if logged) |
| 6. Gradient norm (if logged) |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import argparse |
| from pathlib import Path |
|
|
| import matplotlib |
| import matplotlib.pyplot as plt |
| import matplotlib.gridspec as gridspec |
| import matplotlib.ticker as ticker |
| import numpy as np |
|
|
|
|
| |
| |
| |
|
|
| DARK_BG = "#0d1117" |
| PANEL_BG = "#161b22" |
| GRID_COLOR = "#21262d" |
| TEXT_COLOR = "#c9d1d9" |
| MUTED_COLOR = "#6e7681" |
| ACCENT_BLUE = "#58a6ff" |
| ACCENT_GREEN = "#3fb950" |
| ACCENT_ORANGE= "#d29922" |
| ACCENT_RED = "#f85149" |
| ACCENT_PURPLE= "#bc8cff" |
| ACCENT_TEAL = "#39d353" |
|
|
| matplotlib.rcParams.update({ |
| "figure.facecolor": DARK_BG, |
| "axes.facecolor": PANEL_BG, |
| "axes.edgecolor": GRID_COLOR, |
| "axes.labelcolor": TEXT_COLOR, |
| "axes.titlecolor": TEXT_COLOR, |
| "xtick.color": MUTED_COLOR, |
| "ytick.color": MUTED_COLOR, |
| "grid.color": GRID_COLOR, |
| "grid.linestyle": "--", |
| "grid.linewidth": 0.5, |
| "grid.alpha": 0.7, |
| "legend.facecolor": PANEL_BG, |
| "legend.edgecolor": GRID_COLOR, |
| "legend.labelcolor": TEXT_COLOR, |
| "text.color": TEXT_COLOR, |
| "font.family": "DejaVu Sans", |
| "font.size": 10, |
| "axes.titlesize": 11, |
| "axes.labelsize": 10, |
| }) |
|
|
|
|
| |
| |
| |
|
|
| def load_log(log_path: str) -> dict: |
| """ |
| Loads train_log.jsonl and returns separate arrays for each metric. |
| Returns dict of metric_name -> list of values, aligned by step. |
| """ |
| train_steps = [] |
| train_loss = [] |
| val_steps = [] |
| val_loss = [] |
| lr_steps = [] |
| lr_vals = [] |
| tok_steps = [] |
| tok_vals = [] |
| vram_steps = [] |
| vram_vals = [] |
| grad_steps = [] |
| grad_vals = [] |
|
|
| if not os.path.exists(log_path): |
| return None |
|
|
| with open(log_path, "r") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| entry = json.loads(line) |
| except json.JSONDecodeError: |
| continue |
|
|
| step = entry.get("step") |
| if step is None: |
| continue |
|
|
| if "loss" in entry: |
| train_steps.append(step) |
| train_loss.append(entry["loss"]) |
|
|
| if "val_loss" in entry: |
| val_steps.append(step) |
| val_loss.append(entry["val_loss"]) |
|
|
| if "lr" in entry: |
| lr_steps.append(step) |
| lr_vals.append(entry["lr"]) |
|
|
| if "tok_per_sec" in entry: |
| tok_steps.append(step) |
| tok_vals.append(entry["tok_per_sec"]) |
|
|
| if "vram_gb" in entry: |
| vram_steps.append(step) |
| vram_vals.append(entry["vram_gb"]) |
|
|
| if "grad_norm" in entry and entry["grad_norm"] is not None: |
| grad_steps.append(step) |
| grad_vals.append(entry["grad_norm"]) |
|
|
| return { |
| "train": (train_steps, train_loss), |
| "val": (val_steps, val_loss), |
| "lr": (lr_steps, lr_vals), |
| "tok": (tok_steps, tok_vals), |
| "vram": (vram_steps, vram_vals), |
| "grad": (grad_steps, grad_vals), |
| } |
|
|
|
|
| def ema_smooth(values: list, alpha: float = 0.9) -> list: |
| """Exponential moving average smoothing.""" |
| if not values: |
| return values |
| smoothed = [values[0]] |
| for v in values[1:]: |
| smoothed.append(alpha * smoothed[-1] + (1 - alpha) * v) |
| return smoothed |
|
|
|
|
| |
| |
| |
|
|
| def make_dashboard(data_dict: dict, run_names: list, save_path: str = None): |
| """ |
| Renders a multi-panel training dashboard. |
| |
| Args: |
| data_dict : dict of run_name -> metrics dict |
| run_names : list of run display names |
| save_path : if set, saves figure to this path instead of showing |
| """ |
| fig = plt.figure(figsize=(16, 10), facecolor=DARK_BG) |
| fig.suptitle( |
| "SLLM Training Dashboard", |
| fontsize=16, |
| fontweight="bold", |
| color=TEXT_COLOR, |
| y=0.98, |
| ) |
|
|
| |
| gs = gridspec.GridSpec(3, 2, figure=fig, hspace=0.45, wspace=0.3, |
| left=0.06, right=0.97, top=0.93, bottom=0.06) |
|
|
| ax_loss = fig.add_subplot(gs[0, 0]) |
| ax_val = fig.add_subplot(gs[0, 1]) |
| ax_lr = fig.add_subplot(gs[1, 0]) |
| ax_tok = fig.add_subplot(gs[1, 1]) |
| ax_vram = fig.add_subplot(gs[2, 0]) |
| ax_grad = fig.add_subplot(gs[2, 1]) |
|
|
| colors = [ACCENT_BLUE, ACCENT_GREEN, ACCENT_ORANGE, ACCENT_PURPLE] |
|
|
| has_val = False |
| has_vram = False |
| has_grad = False |
|
|
| for idx, (run_name, data) in enumerate(data_dict.items()): |
| if data is None: |
| continue |
| color = colors[idx % len(colors)] |
|
|
| |
| steps, loss = data["train"] |
| if steps: |
| smoothed = ema_smooth(loss, alpha=0.92) |
| ax_loss.plot(steps, loss, color=color, alpha=0.25, linewidth=0.8) |
| ax_loss.plot(steps, smoothed, color=color, alpha=1.0, linewidth=1.8, |
| label=run_name) |
| |
| ax_loss.annotate( |
| f"{smoothed[-1]:.4f}", |
| xy=(steps[-1], smoothed[-1]), |
| xytext=(5, 0), textcoords="offset points", |
| color=color, fontsize=8, va="center", |
| ) |
|
|
| |
| vsteps, vloss = data["val"] |
| if vsteps: |
| has_val = True |
| ax_val.plot(vsteps, vloss, color=color, linewidth=2, marker="o", |
| markersize=4, label=run_name) |
| ax_val.annotate( |
| f"{vloss[-1]:.4f}", |
| xy=(vsteps[-1], vloss[-1]), |
| xytext=(5, 0), textcoords="offset points", |
| color=color, fontsize=8, va="center", |
| ) |
|
|
| |
| lsteps, lvals = data["lr"] |
| if lsteps: |
| ax_lr.plot(lsteps, lvals, color=color, linewidth=1.5, label=run_name) |
|
|
| |
| tsteps, tvals = data["tok"] |
| if tsteps: |
| avg_tok = np.mean(tvals) |
| ax_tok.plot(tsteps, tvals, color=color, alpha=0.6, linewidth=1.0) |
| ax_tok.axhline(avg_tok, color=color, linewidth=1.5, linestyle="--", |
| label=f"{run_name} (avg {avg_tok:.0f})") |
|
|
| |
| vsteps2, vvals = data["vram"] |
| if vsteps2: |
| has_vram = True |
| ax_vram.plot(vsteps2, vvals, color=color, linewidth=1.5, label=run_name) |
|
|
| |
| gsteps, gvals = data["grad"] |
| if gsteps: |
| has_grad = True |
| smoothed_g = ema_smooth(gvals, alpha=0.85) |
| ax_grad.plot(gsteps, gvals, color=color, alpha=0.2, linewidth=0.8) |
| ax_grad.plot(gsteps, smoothed_g, color=color, linewidth=1.5, label=run_name) |
|
|
| |
| def _style(ax, title, xlabel, ylabel, legend=True): |
| ax.set_title(title, fontweight="bold", pad=8) |
| ax.set_xlabel(xlabel) |
| ax.set_ylabel(ylabel) |
| ax.grid(True) |
| ax.tick_params(which="both", length=3) |
| if legend and ax.get_legend_handles_labels()[0]: |
| ax.legend(fontsize=8, loc="upper right") |
|
|
| _style(ax_loss, "Training Loss (EMA smoothed)", "Step", "Loss") |
| _style(ax_lr, "Learning Rate Schedule", "Step", "LR") |
| _style(ax_tok, "Throughput", "Step", "Tokens / sec") |
|
|
| if has_val: |
| _style(ax_val, "Validation Loss", "Step", "Val Loss") |
| else: |
| ax_val.text(0.5, 0.5, "No validation data yet", |
| ha="center", va="center", transform=ax_val.transAxes, |
| color=MUTED_COLOR, fontsize=11) |
| ax_val.set_title("Validation Loss", fontweight="bold", pad=8) |
|
|
| if has_vram: |
| _style(ax_vram, "VRAM Usage", "Step", "GB") |
| ax_vram.axhline(4.0, color=ACCENT_RED, linewidth=1, linestyle=":", alpha=0.6, label="4 GB limit") |
| ax_vram.legend(fontsize=8) |
| else: |
| ax_vram.text(0.5, 0.5, "No VRAM data\n(requires CUDA)", ha="center", va="center", |
| transform=ax_vram.transAxes, color=MUTED_COLOR, fontsize=11) |
| ax_vram.set_title("VRAM Usage", fontweight="bold", pad=8) |
|
|
| if has_grad: |
| _style(ax_grad, "Gradient Norm (EMA smoothed)", "Step", "Norm") |
| else: |
| ax_grad.text(0.5, 0.5, "No gradient norm data", ha="center", va="center", |
| transform=ax_grad.transAxes, color=MUTED_COLOR, fontsize=11) |
| ax_grad.set_title("Gradient Norm", fontweight="bold", pad=8) |
|
|
| |
| ax_lr.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True)) |
| ax_lr.ticklabel_format(style="sci", axis="y", scilimits=(0, 0)) |
|
|
| if save_path: |
| plt.savefig(save_path, dpi=150, bbox_inches="tight", facecolor=DARK_BG) |
| print(f"[PLOT] Saved to {save_path}") |
| else: |
| plt.show() |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="SLLM Training Dashboard") |
| p.add_argument("--run_dir", nargs="+", default=["runs/run_001"], |
| help="One or more run directories to plot") |
| p.add_argument("--live", action="store_true", |
| help="Refresh plot every --interval seconds (live mode)") |
| p.add_argument("--interval", type=int, default=10, |
| help="Refresh interval in seconds for --live mode") |
| p.add_argument("--save", type=str, default=None, |
| help="Save plot to this path instead of showing interactively") |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| run_dirs = args.run_dir |
| run_names = [Path(d).name for d in run_dirs] |
|
|
| def _reload_and_plot(): |
| data_dict = {} |
| for name, run_dir in zip(run_names, run_dirs): |
| log_path = os.path.join(run_dir, "train_log.jsonl") |
| data = load_log(log_path) |
| if data is None: |
| print(f"[WARN] No log found at: {log_path}") |
| data_dict[name] = data |
|
|
| |
| total_steps = sum( |
| len(d["train"][0]) for d in data_dict.values() if d |
| ) |
|
|
| if total_steps == 0: |
| print("[PLOT] No data logged yet. Waiting...") |
| return |
|
|
| steps_info = {n: len(d["train"][0]) for n, d in data_dict.items() if d} |
| print(f"[PLOT] Plotting {steps_info} train steps") |
|
|
| plt.close("all") |
| make_dashboard(data_dict, run_names, save_path=args.save) |
|
|
| if args.live: |
| print(f"[LIVE] Refreshing every {args.interval}s (Ctrl+C to stop)") |
| matplotlib.use("TkAgg") if sys.platform == "win32" else None |
| try: |
| while True: |
| _reload_and_plot() |
| plt.pause(args.interval) |
| except KeyboardInterrupt: |
| print("\n[LIVE] Stopped.") |
| else: |
| _reload_and_plot() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|