sae-gemma / src /sae_gemma /activation_patching.py
senator1's picture
Sparse-feature audit of induction in Gemma-2-2B (full project)
253d988
"""
Phase 6 — Activation patching: causal head-feature link.
For each induction probe, runs Gemma-2-2B twice:
1. Baseline — capture target features' activations at the final position.
2. Per-head ablation (h=0..7) — zero head h's hook_z output at all positions; capture again.
Outputs (results/activation_patching.json + figure) report the mean reduction in
each target feature's activation when each head is ablated.
Hypothesis: head-6 ablation should cause the largest reduction in F15289 (and the
other top induction features), upgrading the head-correspondence claim from
correlational (Pearson r=0.12) to causal (% of feature activation contributed by head 6).
Usage:
python src/sae_gemma/activation_patching.py --n-probes 500
"""
import argparse
import json
import time
from functools import partial
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from transformer_lens import HookedTransformer
from sae_gemma.model_utils import load_model
from sae_gemma.paths import (
HOOK_NAME,
INDUCTION_PROBES_PATH,
REPO_ROOT,
RESULTS_DIR,
SAE_MAIN_DIR,
)
ATTN_Z_HOOK = "blocks.12.attn.hook_z"
LAYER = 12
N_HEADS = 8
TARGET_FEATURES = [15289, 11606, 14740, 7467] # top induction-clean features by score
def load_sae_local(sae_path: Path, device: str):
from sae_lens.saes.sae import SAE
sae = SAE.load_from_disk(str(sae_path), device=device)
sae.eval()
return sae
def zero_head_hook(value: torch.Tensor, hook, head_idx: int) -> torch.Tensor:
"""Zero out head_idx's z output at all positions."""
value = value.clone()
value[:, :, head_idx, :] = 0
return value
def mean_head_hook(value: torch.Tensor, hook, head_idx: int, mean_z: torch.Tensor) -> torch.Tensor:
"""Replace head_idx's z output at all positions with the precomputed mean vector.
`mean_z` is the per-head-dim mean over a representative distribution of probe positions,
shape [d_head]. Broadcast-replace.
"""
value = value.clone()
value[:, :, head_idx, :] = mean_z.to(value.dtype).to(value.device)
return value
@torch.no_grad()
def compute_head_z_mean(model, token_seqs, device, batch_size=8):
"""Compute per-head mean of hook_z across all positions of all probes.
Returns tensor of shape [n_heads, d_head]."""
accum = None
n = 0
for i in range(0, len(token_seqs), batch_size):
batch = token_seqs[i: i + batch_size]
max_len = max(len(seq) for seq in batch)
padded = torch.zeros(len(batch), max_len, dtype=torch.long, device=device)
seq_lens = []
for j, seq in enumerate(batch):
padded[j, : len(seq)] = seq
seq_lens.append(len(seq))
captured = {}
def cap_z(value, hook):
captured["z"] = value # [B, S, H, D_head]
return value
model.run_with_hooks(padded, fwd_hooks=[(ATTN_Z_HOOK, cap_z)])
z = captured["z"] # [B, S, H, D]
# Sum over batch and seq dims (only over real positions, not pad)
for j in range(len(batch)):
zj = z[j, : seq_lens[j], :, :] # [S_j, H, D]
if accum is None:
accum = zj.sum(dim=0).clone() # [H, D]
else:
accum += zj.sum(dim=0)
n += seq_lens[j]
return (accum / n) # [H, D]
@torch.no_grad()
def get_feature_activations(
model: HookedTransformer,
sae,
token_seqs: list[torch.Tensor],
target_features: list[int],
device: str,
head_ablate: int | None = None,
ablation_mode: str = "zero",
mean_z: torch.Tensor | None = None,
batch_size: int = 8,
) -> np.ndarray:
"""Run probes; return activations of target_features at each probe's final position. Shape [n_probes, n_features].
ablation_mode='zero': zero head_ablate's z output (default, OOD).
ablation_mode='mean': replace head_ablate's z with mean_z[head_ablate] at every position.
"""
all_acts = []
for i in range(0, len(token_seqs), batch_size):
batch = token_seqs[i: i + batch_size]
max_len = max(len(seq) for seq in batch)
padded = torch.zeros(len(batch), max_len, dtype=torch.long, device=device)
seq_lens = []
for j, seq in enumerate(batch):
padded[j, : len(seq)] = seq
seq_lens.append(len(seq))
captured: dict = {}
def capture_resid(value, hook):
captured["resid"] = value
return value
fwd_hooks = [(HOOK_NAME, capture_resid)]
if head_ablate is not None:
if ablation_mode == "zero":
fwd_hooks.append((ATTN_Z_HOOK, partial(zero_head_hook, head_idx=head_ablate)))
elif ablation_mode == "mean":
assert mean_z is not None, "mean ablation requires mean_z"
fwd_hooks.append((ATTN_Z_HOOK, partial(mean_head_hook, head_idx=head_ablate, mean_z=mean_z[head_ablate])))
else:
raise ValueError(f"Unknown ablation_mode {ablation_mode}")
model.run_with_hooks(padded, fwd_hooks=fwd_hooks)
resid = captured["resid"] # [batch, seq, d_model]
# Activation at each probe's actual final (non-pad) position.
final_resid = torch.stack(
[resid[j, seq_lens[j] - 1, :] for j in range(len(batch))]
) # [batch, d_model]
feature_acts = sae.encode(final_resid.float()) # [batch, d_sae]
all_acts.append(feature_acts[:, target_features].cpu().numpy())
return np.concatenate(all_acts, axis=0)
def main():
parser = argparse.ArgumentParser(description="Activation-patch each layer-12 head and measure effect on top induction features")
parser.add_argument("--n-probes", type=int, default=500, help="Subsample of induction probes to use (max 2000)")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--device", default="cuda")
parser.add_argument("--sae-path", type=Path, default=SAE_MAIN_DIR)
parser.add_argument("--output", type=Path, default=RESULTS_DIR / "activation_patching.json")
parser.add_argument("--ablation-mode", choices=["zero", "mean"], default="zero",
help="zero: replace head z with zeros (OOD). mean: replace with per-head mean over the probe distribution (more principled).")
args = parser.parse_args()
try:
from dotenv import load_dotenv
load_dotenv(REPO_ROOT / ".env")
except ImportError:
pass
print(f"[ap] Loading model ...", flush=True)
model = load_model(device=args.device)
print(f"[ap] Loading SAE from {args.sae_path} ...", flush=True)
sae = load_sae_local(args.sae_path, args.device)
print(f"[ap] Loading {args.n_probes} induction probes ...", flush=True)
probes_df = pd.read_parquet(INDUCTION_PROBES_PATH).head(args.n_probes)
probe_tokens = [torch.tensor(t, dtype=torch.long, device=args.device) for t in probes_df["tokens"]]
print(f"[ap] {len(probe_tokens)} probes loaded; first len={len(probe_tokens[0])}", flush=True)
t0 = time.monotonic()
print("[ap] Baseline (no ablation) ...", flush=True)
baseline = get_feature_activations(model, sae, probe_tokens, TARGET_FEATURES, args.device, head_ablate=None, batch_size=args.batch_size)
print(f"[ap] baseline mean across probes: {baseline.mean(0)}", flush=True)
mean_z = None
if args.ablation_mode == "mean":
print(f"[ap] Computing per-head z means across probes for mean-ablation ...", flush=True)
mean_z = compute_head_z_mean(model, probe_tokens, args.device, args.batch_size)
print(f"[ap] mean_z shape: {tuple(mean_z.shape)}, head-norms: {mean_z.norm(dim=-1).tolist()}", flush=True)
per_head: dict[int, np.ndarray] = {}
for h in range(N_HEADS):
print(f"[ap] {args.ablation_mode}-ablating head {h} ...", flush=True)
per_head[h] = get_feature_activations(
model, sae, probe_tokens, TARGET_FEATURES, args.device,
head_ablate=h, ablation_mode=args.ablation_mode, mean_z=mean_z, batch_size=args.batch_size,
)
print(f"[ap] head-{h}-ablated mean: {per_head[h].mean(0)}", flush=True)
elapsed = time.monotonic() - t0
print(f"[ap] All 9 passes done in {elapsed/60:.1f} min", flush=True)
# Assemble result
result = {
"target_features": TARGET_FEATURES,
"n_probes": len(probe_tokens),
"layer": LAYER,
"n_heads": N_HEADS,
"baseline_mean": baseline.mean(0).tolist(),
"baseline_std": baseline.std(0).tolist(),
"head_results": {},
}
for h in range(N_HEADS):
ablated_mean = per_head[h].mean(0)
reduction = baseline.mean(0) - ablated_mean
reduction_pct = reduction / np.maximum(baseline.mean(0), 1e-6) * 100
result["head_results"][str(h)] = {
"ablated_mean": ablated_mean.tolist(),
"ablated_std": per_head[h].std(0).tolist(),
"reduction": reduction.tolist(),
"reduction_pct": reduction_pct.tolist(),
}
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as f:
json.dump(result, f, indent=2)
print(f"[ap] Saved results -> {args.output}", flush=True)
# Figure
fig, axes = plt.subplots(1, len(TARGET_FEATURES), figsize=(4 * len(TARGET_FEATURES), 3.5), sharey=True)
if len(TARGET_FEATURES) == 1:
axes = [axes]
for i, fid in enumerate(TARGET_FEATURES):
ax = axes[i]
reductions_pct = [result["head_results"][str(h)]["reduction_pct"][i] for h in range(N_HEADS)]
colors = ["crimson" if h == 6 else "steelblue" for h in range(N_HEADS)]
ax.bar(range(N_HEADS), reductions_pct, color=colors, edgecolor="white")
ax.set_title(f"F{fid}")
ax.set_xlabel("Head index (layer 12)")
if i == 0:
ax.set_ylabel("% reduction in feature activation\nwhen head is ablated")
ax.set_xticks(range(N_HEADS))
ax.axhline(0, color="black", linewidth=0.5)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.suptitle(f"Activation patching: how much each layer-12 head contributes to each induction feature\n(red = head 6, the head-level ablation winner)", y=1.02)
plt.tight_layout()
fig_path = RESULTS_DIR / "figures" / "activation_patching.png"
fig_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(fig_path, dpi=120, bbox_inches="tight")
plt.close(fig)
print(f"[ap] Figure saved -> {fig_path}", flush=True)
if __name__ == "__main__":
main()