sae-gemma / scripts /autointerp_seed.py
senator1's picture
Sparse-feature audit of induction in Gemma-2-2B (full project)
253d988
"""
For a given seed (43 or 44), run capture_activations + autointerp on its top-5
induction features. Lets the writeup's multi-seed section claim qualitative
(not just quantitative) replication.
python scripts/autointerp_seed.py --seed 43
Steps:
1. Convert models/sae_main_dl_seed{N}/trainer_0/ae.pt -> models/sae_seed{N}/ (SAELens format)
2. Run src/sae_gemma/capture_activations.py with --sae-path models/sae_seed{N}/
and --output results/seed{N}_top_snippets.parquet (250k tokens, faster than 1M)
3. Build results/seed{N}_top5.json from results/seed{N}_replication.json
4. Run src/sae_gemma/autointerp.py with seed-specific snippets/features/cache
"""
import argparse
import json
import os
import subprocess
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
def convert_dl_to_saelens(seed: int):
"""Inline conversion of seed's ae.pt -> SAELens format in models/sae_seed{N}/."""
import torch
from safetensors.torch import save_file
dl_dir = REPO_ROOT / "models" / f"sae_main_dl_seed{seed}" / "trainer_0"
out_dir = REPO_ROOT / "models" / f"sae_seed{seed}"
out_dir.mkdir(parents=True, exist_ok=True)
sd = torch.load(dl_dir / "ae.pt", map_location="cpu", weights_only=False)
dl_cfg = json.loads((dl_dir / "config.json").read_text(encoding="utf-8"))
k = dl_cfg.get("trainer", {}).get("k", 100)
d_sae, d_in = sd["encoder.weight"].shape
out = {
"W_enc": sd["encoder.weight"].T.contiguous().float(),
"W_dec": sd["decoder.weight"].T.contiguous().float(),
"b_enc": sd["encoder.bias"].float(),
"b_dec": sd["b_dec"].float(),
}
save_file(out, str(out_dir / "sae_weights.safetensors"))
cfg = {
"apply_b_dec_to_input": True,
"metadata": {
"sae_lens_version": "6.43.0",
"sae_lens_training_version": "6.43.0+dictionary_learning",
"dataset_path": "local-pile-cache",
"hook_name": "blocks.12.hook_resid_post",
"model_name": "google/gemma-2-2b",
"model_class_name": "HookedTransformer",
"hook_head_index": None,
"context_size": 1024,
"seqpos_slice": [None],
"model_from_pretrained_kwargs": {"dtype": "bfloat16"},
"prepend_bos": True,
"exclude_special_tokens": True,
"sequence_separator_token": "bos",
"disable_concat_sequences": False,
},
"dtype": "float32",
"device": "cuda",
"d_in": int(d_in),
"normalize_activations": "expected_average_only_in",
"k": int(k),
"rescale_acts_by_decoder_norm": False,
"d_sae": int(d_sae),
"reshape_activations": "none",
"architecture": "topk",
}
(out_dir / "cfg.json").write_text(json.dumps(cfg, indent=2), encoding="utf-8")
print(f"[seed{seed}] converted -> {out_dir}", flush=True)
return out_dir
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--top-n", type=int, default=5, help="How many top induction features to label")
parser.add_argument("--n-tokens", type=int, default=250_000)
args = parser.parse_args()
# 1. Convert
sae_dir = convert_dl_to_saelens(args.seed)
# 2. Capture activations
snippets_path = REPO_ROOT / "results" / f"seed{args.seed}_top_snippets.parquet"
print(f"[seed{args.seed}] capture_activations on {args.n_tokens:,} tokens ...", flush=True)
r = subprocess.run(
[sys.executable, "src/sae_gemma/capture_activations.py",
"--sae-path", str(sae_dir),
"--output", str(snippets_path),
"--n-tokens", str(args.n_tokens),
"--top-k", "20"],
cwd=str(REPO_ROOT),
)
if r.returncode != 0:
sys.exit(f"capture_activations failed (exit {r.returncode})")
# 3. Build top-N ids file from replication.json
rep = json.loads((REPO_ROOT / "results" / f"seed{args.seed}_replication.json").read_text(encoding="utf-8"))
top_ids = rep["top20_ids"][: args.top_n]
top_path = REPO_ROOT / "results" / f"seed{args.seed}_top{args.top_n}.json"
top_path.write_text(json.dumps(top_ids), encoding="utf-8")
print(f"[seed{args.seed}] top-{args.top_n} ids: {top_ids}", flush=True)
# 4. Autointerp with seed-specific paths; ensure claude CLI is on PATH
env = os.environ.copy()
env["PATH"] = r"C:\Users\sohum\.local\bin;" + env.get("PATH", "")
labels_path = REPO_ROOT / "results" / f"seed{args.seed}_labels.json"
labels_path.write_text("{}", encoding="utf-8")
print(f"[seed{args.seed}] autointerp Sonnet on top-{args.top_n} ...", flush=True)
r = subprocess.run(
[sys.executable, "src/sae_gemma/autointerp.py",
"--snippets", str(snippets_path),
"--cache", str(labels_path),
"--features", str(top_path),
"--model", "claude-sonnet-4-5",
"--workers", "2",
"--timeout", "90"],
cwd=str(REPO_ROOT), env=env,
)
if r.returncode != 0:
sys.exit(f"autointerp failed (exit {r.returncode})")
# Print summary
labels = json.loads(labels_path.read_text(encoding="utf-8"))
print(f"\n=== seed {args.seed} top-{args.top_n} labels ===", flush=True)
for fid in top_ids:
label = labels.get(str(fid), "(missing)").split("\n")[0][:160]
print(f" F{fid}: {label}", flush=True)
if __name__ == "__main__":
main()