File size: 12,339 Bytes
253d988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
"""
Compare our v9c SAE's induction features against a public Gemma-Scope SAE.

Loads a pre-trained Gemma-Scope SAE (Google DeepMind release) for
google/gemma-2-2b layer 12 residual stream, scores its features by the same
induction_score we use for v9c (mean activation on induction probes - mean
activation on matched controls, at the final probe position), and reports
overlap between the top-20 induction features of the two SAEs.

Outputs:
    results/saebench_induction_scores.parquet
        columns: feature_id, induction_mean, control_mean, induction_score, rank
    results/saebench_candidate_ids.json
        Top-100 feature IDs

Usage:
    python scripts/sae_bench_comparison.py
"""

import argparse
import json
import random
import time
from pathlib import Path

import numpy as np
import pandas as pd
import torch

from sae_gemma.model_utils import load_model
from sae_gemma.paths import (
    HOOK_NAME,
    INDUCTION_PROBES_PATH,
    REPO_ROOT,
    RESULTS_DIR,
)
from sae_gemma.induction_probes import _safe_vocab_range
from sae_gemma.find_induction_features import _get_final_pos_features

# v9c reference outputs (read-only β€” never overwritten by this script)
V9C_SCORES_PATH = RESULTS_DIR / "induction_feature_scores.parquet"
V9C_CANDIDATE_IDS_PATH = RESULTS_DIR / "induction_candidate_ids.json"

# New outputs for the public SAE
SAEBENCH_SCORES_PATH = RESULTS_DIR / "saebench_induction_scores.parquet"
SAEBENCH_CANDIDATE_IDS_PATH = RESULTS_DIR / "saebench_candidate_ids.json"

# Gemma-Scope release on HuggingFace (Google DeepMind)
DEFAULT_RELEASE = "gemma-scope-2b-pt-res-canonical"
DEFAULT_SAE_ID = "layer_12/width_16k/canonical"
# Fallback if canonical isn't registered: non-canonical release uses an L0 suffix
FALLBACK_RELEASE = "gemma-scope-2b-pt-res"
FALLBACK_SAE_ID_PREFIX = "layer_12/width_16k/average_l0_"


def load_public_sae(device: str):
    """
    Load the public Gemma-Scope SAE for layer 12, width 16k.

    Tries the canonical release first; falls back to picking any available
    average_l0_* variant from the non-canonical release if needed.
    """
    from sae_lens.saes.sae import SAE
    from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory

    directory = get_pretrained_saes_directory()

    # 1) Try the canonical release first
    if DEFAULT_RELEASE in directory and DEFAULT_SAE_ID in directory[DEFAULT_RELEASE].saes_map:
        release, sae_id = DEFAULT_RELEASE, DEFAULT_SAE_ID
        print(f"[saebench] Using canonical release: {release} / {sae_id}", flush=True)
    else:
        # 2) Fallback: pick the smallest-L0 width_16k variant from the non-canonical release
        if FALLBACK_RELEASE not in directory:
            raise RuntimeError(
                f"Neither {DEFAULT_RELEASE} nor {FALLBACK_RELEASE} found in sae_lens "
                f"pretrained_saes_directory. Update sae_lens or check release names."
            )
        candidates = [
            sid for sid in directory[FALLBACK_RELEASE].saes_map
            if sid.startswith(FALLBACK_SAE_ID_PREFIX)
        ]
        if not candidates:
            raise RuntimeError(
                f"No '{FALLBACK_SAE_ID_PREFIX}*' SAE found in release {FALLBACK_RELEASE}. "
                f"Available IDs starting with 'layer_12/width_16k': "
                f"{[s for s in directory[FALLBACK_RELEASE].saes_map if s.startswith('layer_12/width_16k')]}"
            )
        # Sort by the L0 number embedded in the id and take the median-ish one
        def _l0(s: str) -> int:
            try:
                return int(s.rsplit("_", 1)[-1])
            except ValueError:
                return 10**9
        candidates.sort(key=_l0)
        sae_id = candidates[len(candidates) // 2]
        release = FALLBACK_RELEASE
        print(f"[saebench] Canonical not available; using {release} / {sae_id}", flush=True)

    # SAE.from_pretrained returns (sae, cfg_dict, sparsity) in current sae_lens
    out = SAE.from_pretrained(release=release, sae_id=sae_id, device=device)
    if isinstance(out, tuple):
        sae = out[0]
    else:
        sae = out
    sae.eval()
    print(
        f"[saebench] Loaded SAE: d_in={sae.cfg.d_in}, d_sae={sae.cfg.d_sae}, "
        f"hook_name={sae.cfg.metadata.hook_name}",
        flush=True,
    )
    if sae.cfg.metadata.hook_name != HOOK_NAME:
        print(
            f"[saebench] WARNING: SAE hook_name {sae.cfg.metadata.hook_name} != project HOOK_NAME {HOOK_NAME}. "
            f"Continuing with project HOOK_NAME for activation extraction.",
            flush=True,
        )
    return sae, release, sae_id


def score_sae(model, sae, device: str, batch_size: int, n_controls: int, seed: int):
    """Run probes + controls through Gemma+SAE and return induction scores per feature."""
    df_probes = pd.read_parquet(INDUCTION_PROBES_PATH)
    print(f"[saebench] Loaded {len(df_probes)} induction probe sequences", flush=True)
    induction_seqs = [list(row) for row in df_probes["tokens"].tolist()]

    vocab_lo, vocab_hi = _safe_vocab_range(model.cfg.d_vocab)

    print(f"[saebench] Generating {n_controls} control sequences ...", flush=True)
    rng_ctrl = random.Random(seed)
    control_seqs = []
    for i in range(n_controls):
        total_len = len(induction_seqs[i % len(induction_seqs)])
        toks = [rng_ctrl.randint(vocab_lo, vocab_hi) for _ in range(total_len)]
        control_seqs.append(toks)

    print("[saebench] Computing feature activations for induction probes ...", flush=True)
    t0 = time.monotonic()
    induction_acts = _get_final_pos_features(
        model, sae, HOOK_NAME, induction_seqs, device, batch_size
    )

    print("[saebench] Computing feature activations for control sequences ...", flush=True)
    control_acts = _get_final_pos_features(
        model, sae, HOOK_NAME, control_seqs, device, batch_size
    )
    print(f"[saebench] Activations computed in {time.monotonic() - t0:.0f}s", flush=True)

    induction_mean = induction_acts.mean(axis=0)
    control_mean = control_acts.mean(axis=0)
    induction_score = induction_mean - control_mean
    return induction_mean, control_mean, induction_score


def try_fetch_hf_labels(release: str, sae_id: str):
    """Best-effort fetch of any feature labels / neuronpedia metadata for the SAE."""
    try:
        from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
        directory = get_pretrained_saes_directory()
        info = directory.get(release)
        if info is None:
            return None
        # neuronpedia_id often present on Gemma-Scope releases
        npid = None
        if hasattr(info, "neuronpedia_id") and isinstance(info.neuronpedia_id, dict):
            npid = info.neuronpedia_id.get(sae_id)
        return {
            "release": release,
            "sae_id": sae_id,
            "repo_id": getattr(info, "repo_id", None),
            "model": getattr(info, "model", None),
            "neuronpedia_id": npid,
        }
    except Exception as e:
        print(f"[saebench] Could not fetch HF metadata: {e}", flush=True)
        return None


def main():
    parser = argparse.ArgumentParser(description="Compare v9c SAE vs public Gemma-Scope SAE")
    parser.add_argument("--top-n", type=int, default=100)
    parser.add_argument("--n-controls", type=int, default=2000)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--seed", type=int, default=123)
    args = parser.parse_args()

    try:
        from dotenv import load_dotenv
        load_dotenv(REPO_ROOT / ".env")
    except ImportError:
        pass

    print("[saebench] Loading Gemma-2-2B ...", flush=True)
    model = load_model(device=args.device)

    print("[saebench] Loading public Gemma-Scope SAE ...", flush=True)
    sae, release, sae_id = load_public_sae(args.device)
    n_features = sae.cfg.d_sae

    induction_mean, control_mean, induction_score = score_sae(
        model, sae, args.device, args.batch_size, args.n_controls, args.seed
    )

    ranked_ids = np.argsort(-induction_score)

    scores_df = pd.DataFrame({
        "feature_id": np.arange(n_features, dtype=np.int32),
        "induction_mean": induction_mean.astype(np.float32),
        "control_mean": control_mean.astype(np.float32),
        "induction_score": induction_score.astype(np.float32),
        "rank": np.argsort(np.argsort(-induction_score)).astype(np.int32),
    })
    SAEBENCH_SCORES_PATH.parent.mkdir(parents=True, exist_ok=True)
    scores_df.to_parquet(SAEBENCH_SCORES_PATH, index=False)
    print(f"[saebench] Scores saved to {SAEBENCH_SCORES_PATH}", flush=True)

    top_ids = [int(x) for x in ranked_ids[:args.top_n].tolist()]
    with SAEBENCH_CANDIDATE_IDS_PATH.open("w", encoding="utf-8") as f:
        json.dump(top_ids, f, indent=2)
    print(f"[saebench] Top-{args.top_n} candidates saved to {SAEBENCH_CANDIDATE_IDS_PATH}", flush=True)

    # ── Comparison vs v9c ──────────────────────────────────────────────────────
    print("\n[saebench] === Comparison: v9c vs public Gemma-Scope SAE ===", flush=True)

    if V9C_CANDIDATE_IDS_PATH.exists():
        with V9C_CANDIDATE_IDS_PATH.open("r", encoding="utf-8") as f:
            v9c_top100 = json.load(f)
    else:
        v9c_top100 = []
        print(f"[saebench] WARNING: {V9C_CANDIDATE_IDS_PATH} not found.", flush=True)

    v9c_top20 = v9c_top100[:20]
    saebench_top20 = top_ids[:20]

    # Per-feature score view for the SAEBench top-20
    print("\n[saebench] Public SAE top-20 induction features:", flush=True)
    print(f"{'Rank':>5} {'FeatID':>8} {'Induction':>10} {'Control':>10} {'Score':>10}")
    for rank, fid in enumerate(saebench_top20):
        print(
            f"{rank:>5} {fid:>8} {induction_mean[fid]:>10.4f} "
            f"{control_mean[fid]:>10.4f} {induction_score[fid]:>10.4f}"
        )

    print("\n[saebench] v9c top-20 feature IDs:     ", v9c_top20, flush=True)
    print("[saebench] SAEBench top-20 feature IDs:", saebench_top20, flush=True)

    # Overlap is informational only β€” feature IDs are NOT comparable across SAEs
    # (different SAEs learn different feature bases). Reported for completeness.
    overlap = sorted(set(v9c_top20) & set(saebench_top20))
    print(
        f"\n[saebench] Top-20 ID overlap (note: feature IDs are not aligned across SAEs): "
        f"{len(overlap)} -> {overlap}",
        flush=True,
    )
    overlap100 = sorted(set(v9c_top100) & set(top_ids))
    print(f"[saebench] Top-100 ID overlap: {len(overlap100)}", flush=True)

    # Compare strength of the top induction signal across SAEs
    if V9C_SCORES_PATH.exists():
        v9c_scores = pd.read_parquet(V9C_SCORES_PATH)
        v9c_top_score = v9c_scores.sort_values("induction_score", ascending=False)["induction_score"].iloc[0]
        sae_top_score = float(induction_score[ranked_ids[0]])
        print(
            f"\n[saebench] Top-feature induction_score:  v9c={v9c_top_score:.4f}  "
            f"SAEBench={sae_top_score:.4f}",
            flush=True,
        )
        v9c_top20_mean = v9c_scores.sort_values("induction_score", ascending=False)["induction_score"].iloc[:20].mean()
        sae_top20_mean = float(induction_score[ranked_ids[:20]].mean())
        print(
            f"[saebench] Top-20 mean induction_score:   v9c={v9c_top20_mean:.4f}  "
            f"SAEBench={sae_top20_mean:.4f}",
            flush=True,
        )

    # HF metadata (neuronpedia link is the closest thing to "prior labels")
    meta = try_fetch_hf_labels(release, sae_id)
    if meta is not None:
        print("\n[saebench] Public SAE metadata (from sae_lens directory):", flush=True)
        for k, v in meta.items():
            print(f"  {k}: {v}")
        if meta.get("neuronpedia_id"):
            print(
                f"  -> Neuronpedia base URL: "
                f"https://neuronpedia.org/{meta['neuronpedia_id']}/<feature_id>",
                flush=True,
            )
            print("  Top-20 SAEBench feature Neuronpedia URLs:", flush=True)
            for fid in saebench_top20:
                print(f"    f{fid}: https://neuronpedia.org/{meta['neuronpedia_id']}/{fid}")


if __name__ == "__main__":
    main()