File size: 5,472 Bytes
253d988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
For a given seed (43 or 44), run capture_activations + autointerp on its top-5
induction features. Lets the writeup's multi-seed section claim qualitative
(not just quantitative) replication.

    python scripts/autointerp_seed.py --seed 43

Steps:
1. Convert models/sae_main_dl_seed{N}/trainer_0/ae.pt -> models/sae_seed{N}/ (SAELens format)
2. Run src/sae_gemma/capture_activations.py with --sae-path models/sae_seed{N}/
   and --output results/seed{N}_top_snippets.parquet (250k tokens, faster than 1M)
3. Build results/seed{N}_top5.json from results/seed{N}_replication.json
4. Run src/sae_gemma/autointerp.py with seed-specific snippets/features/cache
"""
import argparse
import json
import os
import subprocess
import sys
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parents[1]


def convert_dl_to_saelens(seed: int):
    """Inline conversion of seed's ae.pt -> SAELens format in models/sae_seed{N}/."""
    import torch
    from safetensors.torch import save_file

    dl_dir = REPO_ROOT / "models" / f"sae_main_dl_seed{seed}" / "trainer_0"
    out_dir = REPO_ROOT / "models" / f"sae_seed{seed}"
    out_dir.mkdir(parents=True, exist_ok=True)

    sd = torch.load(dl_dir / "ae.pt", map_location="cpu", weights_only=False)
    dl_cfg = json.loads((dl_dir / "config.json").read_text(encoding="utf-8"))
    k = dl_cfg.get("trainer", {}).get("k", 100)
    d_sae, d_in = sd["encoder.weight"].shape

    out = {
        "W_enc": sd["encoder.weight"].T.contiguous().float(),
        "W_dec": sd["decoder.weight"].T.contiguous().float(),
        "b_enc": sd["encoder.bias"].float(),
        "b_dec": sd["b_dec"].float(),
    }
    save_file(out, str(out_dir / "sae_weights.safetensors"))

    cfg = {
        "apply_b_dec_to_input": True,
        "metadata": {
            "sae_lens_version": "6.43.0",
            "sae_lens_training_version": "6.43.0+dictionary_learning",
            "dataset_path": "local-pile-cache",
            "hook_name": "blocks.12.hook_resid_post",
            "model_name": "google/gemma-2-2b",
            "model_class_name": "HookedTransformer",
            "hook_head_index": None,
            "context_size": 1024,
            "seqpos_slice": [None],
            "model_from_pretrained_kwargs": {"dtype": "bfloat16"},
            "prepend_bos": True,
            "exclude_special_tokens": True,
            "sequence_separator_token": "bos",
            "disable_concat_sequences": False,
        },
        "dtype": "float32",
        "device": "cuda",
        "d_in": int(d_in),
        "normalize_activations": "expected_average_only_in",
        "k": int(k),
        "rescale_acts_by_decoder_norm": False,
        "d_sae": int(d_sae),
        "reshape_activations": "none",
        "architecture": "topk",
    }
    (out_dir / "cfg.json").write_text(json.dumps(cfg, indent=2), encoding="utf-8")
    print(f"[seed{seed}] converted -> {out_dir}", flush=True)
    return out_dir


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--top-n", type=int, default=5, help="How many top induction features to label")
    parser.add_argument("--n-tokens", type=int, default=250_000)
    args = parser.parse_args()

    # 1. Convert
    sae_dir = convert_dl_to_saelens(args.seed)

    # 2. Capture activations
    snippets_path = REPO_ROOT / "results" / f"seed{args.seed}_top_snippets.parquet"
    print(f"[seed{args.seed}] capture_activations on {args.n_tokens:,} tokens ...", flush=True)
    r = subprocess.run(
        [sys.executable, "src/sae_gemma/capture_activations.py",
         "--sae-path", str(sae_dir),
         "--output", str(snippets_path),
         "--n-tokens", str(args.n_tokens),
         "--top-k", "20"],
        cwd=str(REPO_ROOT),
    )
    if r.returncode != 0:
        sys.exit(f"capture_activations failed (exit {r.returncode})")

    # 3. Build top-N ids file from replication.json
    rep = json.loads((REPO_ROOT / "results" / f"seed{args.seed}_replication.json").read_text(encoding="utf-8"))
    top_ids = rep["top20_ids"][: args.top_n]
    top_path = REPO_ROOT / "results" / f"seed{args.seed}_top{args.top_n}.json"
    top_path.write_text(json.dumps(top_ids), encoding="utf-8")
    print(f"[seed{args.seed}] top-{args.top_n} ids: {top_ids}", flush=True)

    # 4. Autointerp with seed-specific paths; ensure claude CLI is on PATH
    env = os.environ.copy()
    env["PATH"] = r"C:\Users\sohum\.local\bin;" + env.get("PATH", "")
    labels_path = REPO_ROOT / "results" / f"seed{args.seed}_labels.json"
    labels_path.write_text("{}", encoding="utf-8")
    print(f"[seed{args.seed}] autointerp Sonnet on top-{args.top_n} ...", flush=True)
    r = subprocess.run(
        [sys.executable, "src/sae_gemma/autointerp.py",
         "--snippets", str(snippets_path),
         "--cache", str(labels_path),
         "--features", str(top_path),
         "--model", "claude-sonnet-4-5",
         "--workers", "2",
         "--timeout", "90"],
        cwd=str(REPO_ROOT), env=env,
    )
    if r.returncode != 0:
        sys.exit(f"autointerp failed (exit {r.returncode})")

    # Print summary
    labels = json.loads(labels_path.read_text(encoding="utf-8"))
    print(f"\n=== seed {args.seed} top-{args.top_n} labels ===", flush=True)
    for fid in top_ids:
        label = labels.get(str(fid), "(missing)").split("\n")[0][:160]
        print(f"  F{fid}: {label}", flush=True)


if __name__ == "__main__":
    main()