| """Generate all figures referenced from WRITEUP.md. |
| |
| Reads everything from results/*.json and results/*.parquet; writes PNGs to |
| results/figures/. Idempotent. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| RESULTS = ROOT / "results" |
| FIGS = RESULTS / "figures" |
| FIGS.mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def load_json(name: str) -> dict: |
| with (RESULTS / name).open("r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| |
| |
| |
| def fig_ablation_curve() -> None: |
| abl = load_json("ablation_results.json") |
| rnd = load_json("random_feature_ablation.json") |
|
|
| baseline = abl["baseline_accuracy"] * 100 |
| Ns = abl["feature_ablation"]["N"] |
| acc = [a * 100 for a in abl["feature_ablation"]["accuracy"]] |
| ci_lo = [a * 100 for a in abl["feature_ablation"]["ci_low"]] |
| ci_hi = [a * 100 for a in abl["feature_ablation"]["ci_high"]] |
| heads = abl["head_ablation"]["heads"] |
| head_acc = [a * 100 for a in abl["head_ablation"]["accuracy"]] |
|
|
| rnd_mean_acc = rnd["random_mean_acc"] * 100 |
| rnd_std_drop = rnd["random_std_drop"] * 100 |
|
|
| fig, (ax_l, ax_r) = plt.subplots(1, 2, figsize=(13, 4.8)) |
|
|
| |
| ax_l.axhline(baseline, color="grey", linestyle="--", label=f"Baseline ({baseline:.1f}%)") |
| ax_l.axhspan( |
| rnd_mean_acc - rnd_std_drop, |
| rnd_mean_acc + rnd_std_drop, |
| color="tab:orange", |
| alpha=0.25, |
| label=f"Random 50 features (mean ±1σ across 5 seeds)", |
| ) |
| ax_l.axhline(rnd_mean_acc, color="tab:orange", linestyle=":", linewidth=1) |
| ax_l.fill_between(Ns, ci_lo, ci_hi, color="tab:blue", alpha=0.2) |
| ax_l.plot(Ns, acc, "o-", color="tab:blue", label="Top-N induction features ablated") |
| ax_l.set_xlabel("Number of top induction features ablated") |
| ax_l.set_ylabel("ICL top-1 accuracy (%)") |
| ax_l.set_title("SAE feature ablation vs random-feature control") |
| ax_l.set_ylim(40, 65) |
| ax_l.set_xticks(Ns) |
| ax_l.legend(loc="lower left", fontsize=9) |
| ax_l.grid(axis="y", alpha=0.3) |
|
|
| |
| colors = ["tab:red" if h == 6 else "lightcoral" for h in heads] |
| ax_r.bar(heads, head_acc, color=colors, edgecolor="black", linewidth=0.5) |
| ax_r.axhline(baseline, color="grey", linestyle="--", label=f"Baseline ({baseline:.1f}%)") |
| ax_r.set_xlabel("Layer-12 attention head") |
| ax_r.set_ylabel("ICL top-1 accuracy (%)") |
| ax_r.set_title("Head ablation (Olsson et al. baseline)") |
| ax_r.set_xticks(heads) |
| ax_r.set_ylim(40, 65) |
| ax_r.legend(loc="lower left", fontsize=9) |
| ax_r.grid(axis="y", alpha=0.3) |
|
|
| fig.tight_layout() |
| fig.savefig(FIGS / "ablation_curve.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| print("wrote ablation_curve.png") |
|
|
|
|
| |
| |
| |
| def fig_activation_patching() -> None: |
| zero = load_json("activation_patching.json") |
| mean = load_json("activation_patching_mean.json") |
| feat_ids_all = zero["target_features"] |
| |
| |
| keep = [15289, 14740] |
| keep_idx = [feat_ids_all.index(f) for f in keep] |
| feat_ids = keep |
| subtitles = { |
| 15289: "F15289 — rank-1 induction feature", |
| 14740: "F14740 — 'tokens in repeated/parallel structures'", |
| } |
|
|
| n_heads = 8 |
| heads = list(range(n_heads)) |
|
|
| def gather(payload: dict) -> np.ndarray: |
| arr = np.zeros((len(feat_ids), n_heads)) |
| for h in heads: |
| full = payload["head_results"][str(h)]["reduction_pct"] |
| arr[:, h] = [full[i] for i in keep_idx] |
| return arr |
|
|
| z = gather(zero) |
| m = gather(mean) |
|
|
| fig, axes = plt.subplots(1, len(feat_ids), figsize=(6.5 * len(feat_ids), 4.6), sharey=False) |
| width = 0.4 |
| x = np.arange(n_heads) |
|
|
| for i, ax in enumerate(axes): |
| ax.bar(x - width / 2, z[i], width, label="Zero-ablation (OOD)", color="tab:blue", edgecolor="black", linewidth=0.4) |
| ax.bar(x + width / 2, m[i], width, label="Mean-ablation (in-distribution)", color="tab:orange", edgecolor="black", linewidth=0.4) |
| ax.axhline(0, color="black", linewidth=0.6) |
| ax.set_title(subtitles[feat_ids[i]], fontsize=11) |
| ax.set_xticks(x) |
| ax.set_xticklabels([f"H{h}" for h in heads]) |
| ax.set_xlabel("Layer-12 attention head") |
| ax.set_ylabel("% reduction in feature activation when head is ablated") |
| ax.grid(axis="y", alpha=0.3) |
| ax.legend(loc="lower left", fontsize=9) |
|
|
| |
| |
| max_h = int(np.argmax(np.abs(z[i]))) |
| zv = z[i][max_h] |
| mv = m[i][max_h] |
| ax.text( |
| x[max_h] - width / 2, |
| zv + (3 if zv > 0 else -3), |
| f"{zv:+.0f}%", |
| ha="center", |
| va="bottom" if zv > 0 else "top", |
| fontsize=9, |
| fontweight="bold", |
| color="tab:blue", |
| ) |
| ax.text( |
| x[max_h] + width / 2, |
| mv + (3 if mv > 0 else -3), |
| f"{mv:+.0f}%", |
| ha="center", |
| va="bottom" if mv > 0 else "top", |
| fontsize=9, |
| fontweight="bold", |
| color="tab:orange", |
| ) |
|
|
| fig.suptitle( |
| "Zero vs mean ablation flips the 'which head matters' answer", |
| fontsize=12, |
| ) |
| fig.tight_layout() |
| fig.savefig(FIGS / "activation_patching.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| print("wrote activation_patching.png") |
|
|
|
|
| |
| |
| |
| def fig_induction_score_distribution() -> None: |
| df = pd.read_parquet(RESULTS / "induction_feature_scores.parquet") |
| scores = df["induction_score"].to_numpy() |
|
|
| target_ids = [15289, 11606, 14740, 7467] |
| target_scores = {fid: df.loc[df.feature_id == fid, "induction_score"].iloc[0] for fid in target_ids} |
|
|
| fig, ax = plt.subplots(1, 1, figsize=(9, 4.6)) |
| bins = np.linspace(scores.min(), scores.max(), 120) |
| ax.hist(scores, bins=bins, color="lightgrey", edgecolor="black", linewidth=0.3) |
| ax.set_yscale("log") |
| ax.set_xlabel("Induction score (mean act on induction probes − mean act on control)") |
| ax.set_ylabel("Number of SAE features (log scale)") |
| ax.set_title("Distribution of induction scores across all 16,384 v9c SAE features") |
|
|
| colors = ["tab:red", "tab:orange", "tab:green", "tab:purple"] |
| ymax = ax.get_ylim()[1] |
| for (fid, sc), color in zip(target_scores.items(), colors): |
| ax.axvline(sc, color=color, linewidth=1.5, linestyle="--") |
| ax.text( |
| sc, |
| ymax * 0.4, |
| f"F{fid}\n({sc:.2f})", |
| color=color, |
| fontsize=9, |
| rotation=90, |
| va="top", |
| ha="right", |
| ) |
|
|
| ax.grid(axis="y", alpha=0.3) |
| fig.tight_layout() |
| fig.savefig(FIGS / "induction_score_distribution.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| print("wrote induction_score_distribution.png") |
|
|
|
|
| |
| |
| |
| def fig_multi_seed() -> None: |
| seed43 = load_json("seed43_replication.json") |
| seed44 = load_json("seed44_replication.json") |
| |
| rows = [ |
| ("v9c (seed 42)", "F15289", 2.31, 0.79, 10.1), |
| ("seed 43", f"F{seed43['top_feature_id']}", seed43["top_induction_score"], seed43["top20_mean_score"], seed43["drop_pp"]), |
| ("seed 44", f"F{seed44['top_feature_id']}", seed44["top_induction_score"], seed44["top20_mean_score"], seed44["drop_pp"]), |
| ] |
|
|
| labels = [r[0] for r in rows] |
| top_scores = [r[2] for r in rows] |
| top20_means = [r[3] for r in rows] |
| drops = [r[4] for r in rows] |
| top_feat_labels = [r[1] for r in rows] |
|
|
| fig, (ax_l, ax_m, ax_r) = plt.subplots(1, 3, figsize=(13, 4.2)) |
|
|
| x = np.arange(len(labels)) |
|
|
| |
| ax_l.bar(x, top_scores, color="tab:blue", edgecolor="black") |
| ax_l.set_xticks(x) |
| ax_l.set_xticklabels(labels) |
| ax_l.set_ylabel("Top induction score") |
| ax_l.set_title("Rank-1 induction-score feature\n(IDs differ across seeds)") |
| for xi, sc, lab in zip(x, top_scores, top_feat_labels): |
| ax_l.text(xi, sc + 0.05, f"{lab}\n{sc:.2f}", ha="center", va="bottom", fontsize=9) |
| ax_l.set_ylim(0, max(top_scores) * 1.3) |
| ax_l.grid(axis="y", alpha=0.3) |
|
|
| |
| ax_m.bar(x, top20_means, color="tab:green", edgecolor="black") |
| mean20 = float(np.mean(top20_means)) |
| ax_m.axhline(mean20, color="black", linestyle="--", linewidth=1, label=f"Mean = {mean20:.2f}") |
| ax_m.set_xticks(x) |
| ax_m.set_xticklabels(labels) |
| ax_m.set_ylabel("Mean induction score of top-20 features") |
| ax_m.set_title("Top-20 mean induction score\n(replicates within ±0.05)") |
| ax_m.set_ylim(0, max(top20_means) * 1.3) |
| for xi, sc in zip(x, top20_means): |
| ax_m.text(xi, sc + 0.02, f"{sc:.2f}", ha="center", va="bottom", fontsize=9) |
| ax_m.legend(loc="lower right", fontsize=9) |
| ax_m.grid(axis="y", alpha=0.3) |
|
|
| |
| ax_r.bar(x, drops, color="tab:red", edgecolor="black") |
| mean_drop = float(np.mean(drops)) |
| std_drop = float(np.std(drops, ddof=1)) |
| ax_r.axhline(mean_drop, color="black", linestyle="--", linewidth=1, label=f"Mean = {mean_drop:.1f} ± {std_drop:.1f}pp") |
| ax_r.set_xticks(x) |
| ax_r.set_xticklabels(labels) |
| ax_r.set_ylabel("Top-50 ablation ICL drop (pp)") |
| ax_r.set_title("Top-50 ablation effect on ICL\n(replicates across seeds)") |
| ax_r.set_ylim(0, max(drops) * 1.3) |
| for xi, sc in zip(x, drops): |
| ax_r.text(xi, sc + 0.3, f"{sc:.1f}pp", ha="center", va="bottom", fontsize=9) |
| ax_r.legend(loc="lower right", fontsize=9) |
| ax_r.grid(axis="y", alpha=0.3) |
|
|
| fig.suptitle("Multi-seed replication of v9c SAE (seeds 42 / 43 / 44, identical training config)", fontsize=11) |
| fig.tight_layout() |
| fig.savefig(FIGS / "multi_seed_replication.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| print("wrote multi_seed_replication.png") |
|
|
|
|
| |
| |
| |
| def fig_cross_sae() -> None: |
| |
| |
| metrics = [ |
| ("Top induction score", 2.31, 1.72), |
| ("Top-20 mean induction score", 0.79, 0.78), |
| ] |
| labels = [m[0] for m in metrics] |
| v9c = [m[1] for m in metrics] |
| scope = [m[2] for m in metrics] |
| x = np.arange(len(labels)) |
| width = 0.36 |
|
|
| fig, ax = plt.subplots(1, 1, figsize=(8, 4.4)) |
| b1 = ax.bar(x - width / 2, v9c, width, label="v9c (mine, dictionary_learning)", color="tab:blue", edgecolor="black") |
| b2 = ax.bar(x + width / 2, scope, width, label="Gemma Scope (DeepMind, SAEBench)", color="tab:gray", edgecolor="black") |
| ax.set_xticks(x) |
| ax.set_xticklabels(labels) |
| ax.set_ylabel("Induction score") |
| ax.set_title("Cross-SAE: same selectivity, different SAEs") |
| for b, val in zip(b1, v9c): |
| ax.text(b.get_x() + b.get_width() / 2, val + 0.04, f"{val:.2f}", ha="center", va="bottom", fontsize=10) |
| for b, val in zip(b2, scope): |
| ax.text(b.get_x() + b.get_width() / 2, val + 0.04, f"{val:.2f}", ha="center", va="bottom", fontsize=10) |
| ax.legend(loc="upper right", fontsize=9) |
| ax.grid(axis="y", alpha=0.3) |
| ax.set_ylim(0, max(v9c) * 1.25) |
| fig.tight_layout() |
| fig.savefig(FIGS / "cross_sae_gemma_scope.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| print("wrote cross_sae_gemma_scope.png") |
|
|
|
|
| |
| |
| |
| def fig_mmlu_negative() -> None: |
| mmlu = load_json("mmlu_feature_activations.json") |
| targets = mmlu["target_features"] |
| idx_15289 = targets.index(15289) |
| few_shot = mmlu["few_shot_mean"][idx_15289] |
| shuffled = mmlu["shuffled_mean"][idx_15289] |
|
|
| |
| df = pd.read_parquet(RESULTS / "induction_feature_scores.parquet") |
| synth = df.loc[df.feature_id == 15289, "induction_mean"].iloc[0] |
| synth_ctrl = df.loc[df.feature_id == 15289, "control_mean"].iloc[0] |
|
|
| labels = [ |
| "Synthetic A-B-A\ninduction probe\n(final pos)", |
| "Synthetic\ncontrol\n(final pos)", |
| "MMLU 4-shot\nreal answers\n(final pos)", |
| "MMLU 4-shot\nshuffled answers\n(final pos)", |
| ] |
| values = [synth, synth_ctrl, few_shot, shuffled] |
| colors = ["tab:blue", "lightblue", "tab:red", "lightcoral"] |
|
|
| fig, ax = plt.subplots(1, 1, figsize=(8.5, 4.6)) |
| bars = ax.bar(labels, values, color=colors, edgecolor="black") |
| for b, v in zip(bars, values): |
| ax.text(b.get_x() + b.get_width() / 2, v + 0.1, f"{v:.2f}", ha="center", va="bottom", fontsize=10) |
| ax.set_ylabel("F15289 mean activation at final position") |
| ax.set_title( |
| "F15289 fires on synthetic token-copying induction, not on natural few-shot ICL\n" |
| f"(MMLU n={mmlu['n_questions']} questions, {mmlu['n_shots']}-shot)" |
| ) |
| ax.set_ylim(0, max(values) * 1.3 + 0.2) |
| ax.grid(axis="y", alpha=0.3) |
| fig.tight_layout() |
| fig.savefig(FIGS / "mmlu_negative_finding.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| print("wrote mmlu_negative_finding.png") |
|
|
|
|
| |
| |
| |
| def fig_top_feature_snippets() -> None: |
| """Two-panel: F15289 (second occurrence of repeated token) and F14740 |
| (tokens in parallel/repeated structures). Text is laid out via HPacker so |
| segment widths come from the actual rendered glyphs — no overlap from |
| bold-vs-normal width mismatches.""" |
| import re |
| from matplotlib.offsetbox import TextArea, HPacker, AnnotationBbox |
|
|
| df = pd.read_parquet(RESULTS / "top_snippets.parquet") |
|
|
| def extract_rows(feature_id: int, n_rows: int) -> list: |
| sub = df[df.feature_id == feature_id].nsmallest(20, "rank").reset_index(drop=True) |
| rows = [] |
| for _, row in sub.iterrows(): |
| token_clean = str(row["token"]).strip() |
| if not token_clean: |
| continue |
| context = str(row["context"]).replace("\n", " · ") |
| act = float(row["activation"]) |
| pattern = re.compile(r"\b" + re.escape(token_clean) + r"\b", re.IGNORECASE) |
| matches = list(pattern.finditer(context)) |
| if len(matches) < 2: |
| continue |
| s0, e0 = matches[0].span() |
| s1, e1 = matches[1].span() |
| lo = max(0, s0 - 24) |
| hi = min(len(context), e1 + 32) |
| prefix = ("… " if lo > 0 else "") + context[lo:s0] |
| first = context[s0:e0] |
| middle = context[e0:s1] |
| second = context[s1:e1] |
| suffix = context[e1:hi] + (" …" if hi < len(context) else "") |
| |
| max_chars = 100 |
| total = len(prefix) + len(first) + len(middle) + len(second) + len(suffix) |
| if total > max_chars: |
| overshoot = total - max_chars |
| trim = overshoot // 2 + 1 |
| if len(prefix) > trim + 3: |
| prefix = "… " + prefix.lstrip("… ")[trim:] |
| if len(suffix) > trim + 3: |
| suffix = suffix.rstrip(" …")[:-trim] + " …" |
| rows.append((act, prefix, first, middle, second, suffix)) |
| if len(rows) >= n_rows: |
| break |
| return rows |
|
|
| panels = [ |
| ( |
| 15289, |
| "F15289 — fires on the SECOND occurrence of a repeated token", |
| extract_rows(15289, 6), |
| ), |
| ( |
| 14740, |
| "F14740 — fires on tokens in repeated / parallel structures", |
| extract_rows(14740, 6), |
| ), |
| ] |
|
|
| font_kwargs = {"family": "DejaVu Sans Mono", "size": 10} |
|
|
| NBSP = " " |
|
|
| def make_line(prefix, first, middle, second, suffix): |
| boxes = [] |
| |
| |
| for text, color, weight in [ |
| (prefix.replace(" ", NBSP), "#555555", "normal"), |
| (first, "#222222", "bold"), |
| (middle.replace(" ", NBSP), "#555555", "normal"), |
| (second, "#c00000", "bold"), |
| (suffix.replace(" ", NBSP), "#555555", "normal"), |
| ]: |
| if not text: |
| continue |
| boxes.append( |
| TextArea( |
| text, |
| textprops={"color": color, "fontweight": weight, **font_kwargs}, |
| ) |
| ) |
| return HPacker(children=boxes, align="baseline", pad=0, sep=0) |
|
|
| n_panels = len(panels) |
| max_rows = max(len(p[2]) for p in panels) |
| fig, axes = plt.subplots( |
| n_panels, 1, figsize=(13, 0.55 * max_rows * n_panels + 1.6), squeeze=False |
| ) |
| axes = axes.ravel() |
|
|
| for ax, (fid, title, rows) in zip(axes, panels): |
| n = len(rows) |
| ax.set_xlim(0, 1) |
| ax.set_ylim(0, n + 1) |
| ax.invert_yaxis() |
| ax.axis("off") |
| ax.text(0.04, 0.35, "Act.", fontsize=10, fontweight="bold", ha="center") |
| ax.text(0.09, 0.35, title, fontsize=11, fontweight="bold", ha="left") |
| for i, (act, prefix, first, middle, second, suffix) in enumerate(rows): |
| y = i + 1.0 |
| ax.text(0.04, y, f"{act:.1f}", fontsize=11, ha="center", va="center", fontweight="bold") |
| packer = make_line(prefix, first, middle, second, suffix) |
| ab = AnnotationBbox( |
| packer, |
| xy=(0.09, y), |
| xycoords=("axes fraction", "data"), |
| box_alignment=(0.0, 0.5), |
| frameon=False, |
| pad=0, |
| ) |
| ax.add_artist(ab) |
|
|
| fig.suptitle( |
| "Top-activating C4 snippets — first occurrence in dark-bold, activating token in red-bold", |
| fontsize=12, |
| y=0.995, |
| ) |
| fig.tight_layout(rect=[0, 0, 1, 0.97]) |
| fig.savefig(FIGS / "top_feature_snippets.png", dpi=150, bbox_inches="tight") |
| plt.close(fig) |
| print("wrote top_feature_snippets.png") |
|
|
|
|
| if __name__ == "__main__": |
| fig_ablation_curve() |
| fig_activation_patching() |
| fig_induction_score_distribution() |
| fig_multi_seed() |
| fig_cross_sae() |
| fig_mmlu_negative() |
| fig_top_feature_snippets() |
|
|