File size: 5,472 Bytes
253d988 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | """
For a given seed (43 or 44), run capture_activations + autointerp on its top-5
induction features. Lets the writeup's multi-seed section claim qualitative
(not just quantitative) replication.
python scripts/autointerp_seed.py --seed 43
Steps:
1. Convert models/sae_main_dl_seed{N}/trainer_0/ae.pt -> models/sae_seed{N}/ (SAELens format)
2. Run src/sae_gemma/capture_activations.py with --sae-path models/sae_seed{N}/
and --output results/seed{N}_top_snippets.parquet (250k tokens, faster than 1M)
3. Build results/seed{N}_top5.json from results/seed{N}_replication.json
4. Run src/sae_gemma/autointerp.py with seed-specific snippets/features/cache
"""
import argparse
import json
import os
import subprocess
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
def convert_dl_to_saelens(seed: int):
"""Inline conversion of seed's ae.pt -> SAELens format in models/sae_seed{N}/."""
import torch
from safetensors.torch import save_file
dl_dir = REPO_ROOT / "models" / f"sae_main_dl_seed{seed}" / "trainer_0"
out_dir = REPO_ROOT / "models" / f"sae_seed{seed}"
out_dir.mkdir(parents=True, exist_ok=True)
sd = torch.load(dl_dir / "ae.pt", map_location="cpu", weights_only=False)
dl_cfg = json.loads((dl_dir / "config.json").read_text(encoding="utf-8"))
k = dl_cfg.get("trainer", {}).get("k", 100)
d_sae, d_in = sd["encoder.weight"].shape
out = {
"W_enc": sd["encoder.weight"].T.contiguous().float(),
"W_dec": sd["decoder.weight"].T.contiguous().float(),
"b_enc": sd["encoder.bias"].float(),
"b_dec": sd["b_dec"].float(),
}
save_file(out, str(out_dir / "sae_weights.safetensors"))
cfg = {
"apply_b_dec_to_input": True,
"metadata": {
"sae_lens_version": "6.43.0",
"sae_lens_training_version": "6.43.0+dictionary_learning",
"dataset_path": "local-pile-cache",
"hook_name": "blocks.12.hook_resid_post",
"model_name": "google/gemma-2-2b",
"model_class_name": "HookedTransformer",
"hook_head_index": None,
"context_size": 1024,
"seqpos_slice": [None],
"model_from_pretrained_kwargs": {"dtype": "bfloat16"},
"prepend_bos": True,
"exclude_special_tokens": True,
"sequence_separator_token": "bos",
"disable_concat_sequences": False,
},
"dtype": "float32",
"device": "cuda",
"d_in": int(d_in),
"normalize_activations": "expected_average_only_in",
"k": int(k),
"rescale_acts_by_decoder_norm": False,
"d_sae": int(d_sae),
"reshape_activations": "none",
"architecture": "topk",
}
(out_dir / "cfg.json").write_text(json.dumps(cfg, indent=2), encoding="utf-8")
print(f"[seed{seed}] converted -> {out_dir}", flush=True)
return out_dir
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, required=True)
parser.add_argument("--top-n", type=int, default=5, help="How many top induction features to label")
parser.add_argument("--n-tokens", type=int, default=250_000)
args = parser.parse_args()
# 1. Convert
sae_dir = convert_dl_to_saelens(args.seed)
# 2. Capture activations
snippets_path = REPO_ROOT / "results" / f"seed{args.seed}_top_snippets.parquet"
print(f"[seed{args.seed}] capture_activations on {args.n_tokens:,} tokens ...", flush=True)
r = subprocess.run(
[sys.executable, "src/sae_gemma/capture_activations.py",
"--sae-path", str(sae_dir),
"--output", str(snippets_path),
"--n-tokens", str(args.n_tokens),
"--top-k", "20"],
cwd=str(REPO_ROOT),
)
if r.returncode != 0:
sys.exit(f"capture_activations failed (exit {r.returncode})")
# 3. Build top-N ids file from replication.json
rep = json.loads((REPO_ROOT / "results" / f"seed{args.seed}_replication.json").read_text(encoding="utf-8"))
top_ids = rep["top20_ids"][: args.top_n]
top_path = REPO_ROOT / "results" / f"seed{args.seed}_top{args.top_n}.json"
top_path.write_text(json.dumps(top_ids), encoding="utf-8")
print(f"[seed{args.seed}] top-{args.top_n} ids: {top_ids}", flush=True)
# 4. Autointerp with seed-specific paths; ensure claude CLI is on PATH
env = os.environ.copy()
env["PATH"] = r"C:\Users\sohum\.local\bin;" + env.get("PATH", "")
labels_path = REPO_ROOT / "results" / f"seed{args.seed}_labels.json"
labels_path.write_text("{}", encoding="utf-8")
print(f"[seed{args.seed}] autointerp Sonnet on top-{args.top_n} ...", flush=True)
r = subprocess.run(
[sys.executable, "src/sae_gemma/autointerp.py",
"--snippets", str(snippets_path),
"--cache", str(labels_path),
"--features", str(top_path),
"--model", "claude-sonnet-4-5",
"--workers", "2",
"--timeout", "90"],
cwd=str(REPO_ROOT), env=env,
)
if r.returncode != 0:
sys.exit(f"autointerp failed (exit {r.returncode})")
# Print summary
labels = json.loads(labels_path.read_text(encoding="utf-8"))
print(f"\n=== seed {args.seed} top-{args.top_n} labels ===", flush=True)
for fid in top_ids:
label = labels.get(str(fid), "(missing)").split("\n")[0][:160]
print(f" F{fid}: {label}", flush=True)
if __name__ == "__main__":
main()
|