File size: 5,998 Bytes
253d988 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """
Download SAE training metrics from W&B and save as results/figures/sae_training_metrics.png.
Falls back to parsing models/sae_main/log.txt if W&B is unavailable.
Run as part of scripts/run_all.ps1 / run_all.sh.
"""
import json
import os
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
FIGURES_DIR = ROOT / "results" / "figures"
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
try:
from dotenv import load_dotenv
load_dotenv(ROOT / ".env")
except ImportError:
pass
def plot_from_wandb() -> bool:
"""Fetch training history from W&B and plot. Returns True on success."""
try:
import wandb
api = wandb.Api()
project = os.environ.get("WANDB_PROJECT", "sae-gemma-induction")
entity = os.environ.get("WANDB_ENTITY", None)
path = f"{entity}/{project}" if entity else project
# Find the most recent main training run (16k width)
runs = api.runs(path, filters={"config.sae.d_sae": {"$gte": 8192}}, order="-created_at")
if not runs:
print("[plot_metrics] No W&B runs found — falling back to log file", flush=True)
return False
run = runs[0]
print(f"[plot_metrics] W&B run: {run.name} ({run.id})", flush=True)
history = run.history(
keys=["metrics/l0", "losses/overall_loss", "metrics/explained_variance",
"sparsity/dead_features"],
samples=500,
)
if history.empty:
print("[plot_metrics] W&B history empty — trying log file", flush=True)
return False
# Get n_features for converting dead_features count → fraction
sae_width = run.config.get("sae", {}).get("d_sae", None) if run.config else None
fig, axes = plt.subplots(2, 2, figsize=(10, 6))
fig.suptitle(f"SAE Training — {run.name}", fontsize=12)
if sae_width and "sparsity/dead_features" in history.columns:
history["dead_feature_fraction"] = history["sparsity/dead_features"] / sae_width
else:
history["dead_feature_fraction"] = float("nan")
metrics = [
("metrics/l0", "L0 sparsity", axes[0, 0]),
("losses/overall_loss", "Overall loss", axes[0, 1]),
("metrics/explained_variance", "Explained variance", axes[1, 0]),
("dead_feature_fraction", "Dead feature fraction", axes[1, 1]),
]
steps_col = "_step" if "_step" in history.columns else history.columns[0]
for col, label, ax in metrics:
if col in history.columns:
ax.plot(history[steps_col], history[col], linewidth=1)
ax.set_title(label)
ax.set_xlabel("Training step")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
# Draw acceptance-criterion reference lines
axes[0, 0].axhline(20, color="green", linestyle="--", alpha=0.5, label="L0=20 (target min)")
axes[0, 0].axhline(100, color="red", linestyle="--", alpha=0.5, label="L0=100 (target max)")
axes[0, 0].legend(fontsize=7)
axes[1, 0].axhline(0.6, color="green", linestyle="--", alpha=0.5, label="EV=0.6 (target)")
axes[1, 0].legend(fontsize=7)
axes[1, 1].axhline(0.25, color="red", linestyle="--", alpha=0.5, label="25% dead (limit)")
axes[1, 1].legend(fontsize=7)
plt.tight_layout()
out = FIGURES_DIR / "sae_training_metrics.png"
fig.savefig(out, dpi=150)
plt.close(fig)
print(f"[plot_metrics] Saved to {out}", flush=True)
return True
except Exception as exc:
print(f"[plot_metrics] W&B fetch failed: {exc}", flush=True)
return False
def plot_from_logfile() -> bool:
"""Parse log.txt from the main training run and plot whatever metrics are available."""
log_path = ROOT / "models" / "sae_main" / "log.txt"
if not log_path.exists():
print(f"[plot_metrics] Log file not found: {log_path}", flush=True)
return False
import re
steps, l0s, losses = [], [], []
step_re = re.compile(r"step[=:\s]+(\d+)", re.I)
l0_re = re.compile(r"l0[=:\s]+([\d.]+)", re.I)
loss_re = re.compile(r"(?:loss|l2)[=:\s]+([\d.]+)", re.I)
with log_path.open() as f:
for line in f:
sm = step_re.search(line)
lm = l0_re.search(line)
rm = loss_re.search(line)
if sm and lm:
steps.append(int(sm.group(1)))
l0s.append(float(lm.group(1)))
losses.append(float(rm.group(1)) if rm else float("nan"))
if not steps:
print("[plot_metrics] Could not parse any metrics from log file", flush=True)
return False
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.plot(steps, l0s, linewidth=1, color="steelblue")
ax1.axhline(20, color="green", linestyle="--", alpha=0.5, label="target min")
ax1.axhline(100, color="red", linestyle="--", alpha=0.5, label="target max")
ax1.set_title("L0 sparsity (from log)")
ax1.set_xlabel("Step")
ax1.legend(fontsize=8)
ax2.plot(steps, losses, linewidth=1, color="salmon")
ax2.set_title("Reconstruction loss (from log)")
ax2.set_xlabel("Step")
for ax in (ax1, ax2):
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.tight_layout()
out = FIGURES_DIR / "sae_training_metrics.png"
fig.savefig(out, dpi=150)
plt.close(fig)
print(f"[plot_metrics] Saved to {out} (from log file)", flush=True)
return True
if __name__ == "__main__":
ok = plot_from_wandb() or plot_from_logfile()
if not ok:
print("[plot_metrics] Could not generate training metrics plot — skipping.", flush=True)
sys.exit(0) # non-fatal: run_all should not abort if W&B is unavailable
|