| """ |
| For a given seed (43 or 44), run capture_activations + autointerp on its top-5 |
| induction features. Lets the writeup's multi-seed section claim qualitative |
| (not just quantitative) replication. |
| |
| python scripts/autointerp_seed.py --seed 43 |
| |
| Steps: |
| 1. Convert models/sae_main_dl_seed{N}/trainer_0/ae.pt -> models/sae_seed{N}/ (SAELens format) |
| 2. Run src/sae_gemma/capture_activations.py with --sae-path models/sae_seed{N}/ |
| and --output results/seed{N}_top_snippets.parquet (250k tokens, faster than 1M) |
| 3. Build results/seed{N}_top5.json from results/seed{N}_replication.json |
| 4. Run src/sae_gemma/autointerp.py with seed-specific snippets/features/cache |
| """ |
| import argparse |
| import json |
| import os |
| import subprocess |
| import sys |
| from pathlib import Path |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
|
|
|
|
| def convert_dl_to_saelens(seed: int): |
| """Inline conversion of seed's ae.pt -> SAELens format in models/sae_seed{N}/.""" |
| import torch |
| from safetensors.torch import save_file |
|
|
| dl_dir = REPO_ROOT / "models" / f"sae_main_dl_seed{seed}" / "trainer_0" |
| out_dir = REPO_ROOT / "models" / f"sae_seed{seed}" |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| sd = torch.load(dl_dir / "ae.pt", map_location="cpu", weights_only=False) |
| dl_cfg = json.loads((dl_dir / "config.json").read_text(encoding="utf-8")) |
| k = dl_cfg.get("trainer", {}).get("k", 100) |
| d_sae, d_in = sd["encoder.weight"].shape |
|
|
| out = { |
| "W_enc": sd["encoder.weight"].T.contiguous().float(), |
| "W_dec": sd["decoder.weight"].T.contiguous().float(), |
| "b_enc": sd["encoder.bias"].float(), |
| "b_dec": sd["b_dec"].float(), |
| } |
| save_file(out, str(out_dir / "sae_weights.safetensors")) |
|
|
| cfg = { |
| "apply_b_dec_to_input": True, |
| "metadata": { |
| "sae_lens_version": "6.43.0", |
| "sae_lens_training_version": "6.43.0+dictionary_learning", |
| "dataset_path": "local-pile-cache", |
| "hook_name": "blocks.12.hook_resid_post", |
| "model_name": "google/gemma-2-2b", |
| "model_class_name": "HookedTransformer", |
| "hook_head_index": None, |
| "context_size": 1024, |
| "seqpos_slice": [None], |
| "model_from_pretrained_kwargs": {"dtype": "bfloat16"}, |
| "prepend_bos": True, |
| "exclude_special_tokens": True, |
| "sequence_separator_token": "bos", |
| "disable_concat_sequences": False, |
| }, |
| "dtype": "float32", |
| "device": "cuda", |
| "d_in": int(d_in), |
| "normalize_activations": "expected_average_only_in", |
| "k": int(k), |
| "rescale_acts_by_decoder_norm": False, |
| "d_sae": int(d_sae), |
| "reshape_activations": "none", |
| "architecture": "topk", |
| } |
| (out_dir / "cfg.json").write_text(json.dumps(cfg, indent=2), encoding="utf-8") |
| print(f"[seed{seed}] converted -> {out_dir}", flush=True) |
| return out_dir |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--seed", type=int, required=True) |
| parser.add_argument("--top-n", type=int, default=5, help="How many top induction features to label") |
| parser.add_argument("--n-tokens", type=int, default=250_000) |
| args = parser.parse_args() |
|
|
| |
| sae_dir = convert_dl_to_saelens(args.seed) |
|
|
| |
| snippets_path = REPO_ROOT / "results" / f"seed{args.seed}_top_snippets.parquet" |
| print(f"[seed{args.seed}] capture_activations on {args.n_tokens:,} tokens ...", flush=True) |
| r = subprocess.run( |
| [sys.executable, "src/sae_gemma/capture_activations.py", |
| "--sae-path", str(sae_dir), |
| "--output", str(snippets_path), |
| "--n-tokens", str(args.n_tokens), |
| "--top-k", "20"], |
| cwd=str(REPO_ROOT), |
| ) |
| if r.returncode != 0: |
| sys.exit(f"capture_activations failed (exit {r.returncode})") |
|
|
| |
| rep = json.loads((REPO_ROOT / "results" / f"seed{args.seed}_replication.json").read_text(encoding="utf-8")) |
| top_ids = rep["top20_ids"][: args.top_n] |
| top_path = REPO_ROOT / "results" / f"seed{args.seed}_top{args.top_n}.json" |
| top_path.write_text(json.dumps(top_ids), encoding="utf-8") |
| print(f"[seed{args.seed}] top-{args.top_n} ids: {top_ids}", flush=True) |
|
|
| |
| env = os.environ.copy() |
| env["PATH"] = r"C:\Users\sohum\.local\bin;" + env.get("PATH", "") |
| labels_path = REPO_ROOT / "results" / f"seed{args.seed}_labels.json" |
| labels_path.write_text("{}", encoding="utf-8") |
| print(f"[seed{args.seed}] autointerp Sonnet on top-{args.top_n} ...", flush=True) |
| r = subprocess.run( |
| [sys.executable, "src/sae_gemma/autointerp.py", |
| "--snippets", str(snippets_path), |
| "--cache", str(labels_path), |
| "--features", str(top_path), |
| "--model", "claude-sonnet-4-5", |
| "--workers", "2", |
| "--timeout", "90"], |
| cwd=str(REPO_ROOT), env=env, |
| ) |
| if r.returncode != 0: |
| sys.exit(f"autointerp failed (exit {r.returncode})") |
|
|
| |
| labels = json.loads(labels_path.read_text(encoding="utf-8")) |
| print(f"\n=== seed {args.seed} top-{args.top_n} labels ===", flush=True) |
| for fid in top_ids: |
| label = labels.get(str(fid), "(missing)").split("\n")[0][:160] |
| print(f" F{fid}: {label}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|