""" Compare our v9c SAE's induction features against a public Gemma-Scope SAE. Loads a pre-trained Gemma-Scope SAE (Google DeepMind release) for google/gemma-2-2b layer 12 residual stream, scores its features by the same induction_score we use for v9c (mean activation on induction probes - mean activation on matched controls, at the final probe position), and reports overlap between the top-20 induction features of the two SAEs. Outputs: results/saebench_induction_scores.parquet columns: feature_id, induction_mean, control_mean, induction_score, rank results/saebench_candidate_ids.json Top-100 feature IDs Usage: python scripts/sae_bench_comparison.py """ import argparse import json import random import time from pathlib import Path import numpy as np import pandas as pd import torch from sae_gemma.model_utils import load_model from sae_gemma.paths import ( HOOK_NAME, INDUCTION_PROBES_PATH, REPO_ROOT, RESULTS_DIR, ) from sae_gemma.induction_probes import _safe_vocab_range from sae_gemma.find_induction_features import _get_final_pos_features # v9c reference outputs (read-only — never overwritten by this script) V9C_SCORES_PATH = RESULTS_DIR / "induction_feature_scores.parquet" V9C_CANDIDATE_IDS_PATH = RESULTS_DIR / "induction_candidate_ids.json" # New outputs for the public SAE SAEBENCH_SCORES_PATH = RESULTS_DIR / "saebench_induction_scores.parquet" SAEBENCH_CANDIDATE_IDS_PATH = RESULTS_DIR / "saebench_candidate_ids.json" # Gemma-Scope release on HuggingFace (Google DeepMind) DEFAULT_RELEASE = "gemma-scope-2b-pt-res-canonical" DEFAULT_SAE_ID = "layer_12/width_16k/canonical" # Fallback if canonical isn't registered: non-canonical release uses an L0 suffix FALLBACK_RELEASE = "gemma-scope-2b-pt-res" FALLBACK_SAE_ID_PREFIX = "layer_12/width_16k/average_l0_" def load_public_sae(device: str): """ Load the public Gemma-Scope SAE for layer 12, width 16k. Tries the canonical release first; falls back to picking any available average_l0_* variant from the non-canonical release if needed. """ from sae_lens.saes.sae import SAE from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory directory = get_pretrained_saes_directory() # 1) Try the canonical release first if DEFAULT_RELEASE in directory and DEFAULT_SAE_ID in directory[DEFAULT_RELEASE].saes_map: release, sae_id = DEFAULT_RELEASE, DEFAULT_SAE_ID print(f"[saebench] Using canonical release: {release} / {sae_id}", flush=True) else: # 2) Fallback: pick the smallest-L0 width_16k variant from the non-canonical release if FALLBACK_RELEASE not in directory: raise RuntimeError( f"Neither {DEFAULT_RELEASE} nor {FALLBACK_RELEASE} found in sae_lens " f"pretrained_saes_directory. Update sae_lens or check release names." ) candidates = [ sid for sid in directory[FALLBACK_RELEASE].saes_map if sid.startswith(FALLBACK_SAE_ID_PREFIX) ] if not candidates: raise RuntimeError( f"No '{FALLBACK_SAE_ID_PREFIX}*' SAE found in release {FALLBACK_RELEASE}. " f"Available IDs starting with 'layer_12/width_16k': " f"{[s for s in directory[FALLBACK_RELEASE].saes_map if s.startswith('layer_12/width_16k')]}" ) # Sort by the L0 number embedded in the id and take the median-ish one def _l0(s: str) -> int: try: return int(s.rsplit("_", 1)[-1]) except ValueError: return 10**9 candidates.sort(key=_l0) sae_id = candidates[len(candidates) // 2] release = FALLBACK_RELEASE print(f"[saebench] Canonical not available; using {release} / {sae_id}", flush=True) # SAE.from_pretrained returns (sae, cfg_dict, sparsity) in current sae_lens out = SAE.from_pretrained(release=release, sae_id=sae_id, device=device) if isinstance(out, tuple): sae = out[0] else: sae = out sae.eval() print( f"[saebench] Loaded SAE: d_in={sae.cfg.d_in}, d_sae={sae.cfg.d_sae}, " f"hook_name={sae.cfg.metadata.hook_name}", flush=True, ) if sae.cfg.metadata.hook_name != HOOK_NAME: print( f"[saebench] WARNING: SAE hook_name {sae.cfg.metadata.hook_name} != project HOOK_NAME {HOOK_NAME}. " f"Continuing with project HOOK_NAME for activation extraction.", flush=True, ) return sae, release, sae_id def score_sae(model, sae, device: str, batch_size: int, n_controls: int, seed: int): """Run probes + controls through Gemma+SAE and return induction scores per feature.""" df_probes = pd.read_parquet(INDUCTION_PROBES_PATH) print(f"[saebench] Loaded {len(df_probes)} induction probe sequences", flush=True) induction_seqs = [list(row) for row in df_probes["tokens"].tolist()] vocab_lo, vocab_hi = _safe_vocab_range(model.cfg.d_vocab) print(f"[saebench] Generating {n_controls} control sequences ...", flush=True) rng_ctrl = random.Random(seed) control_seqs = [] for i in range(n_controls): total_len = len(induction_seqs[i % len(induction_seqs)]) toks = [rng_ctrl.randint(vocab_lo, vocab_hi) for _ in range(total_len)] control_seqs.append(toks) print("[saebench] Computing feature activations for induction probes ...", flush=True) t0 = time.monotonic() induction_acts = _get_final_pos_features( model, sae, HOOK_NAME, induction_seqs, device, batch_size ) print("[saebench] Computing feature activations for control sequences ...", flush=True) control_acts = _get_final_pos_features( model, sae, HOOK_NAME, control_seqs, device, batch_size ) print(f"[saebench] Activations computed in {time.monotonic() - t0:.0f}s", flush=True) induction_mean = induction_acts.mean(axis=0) control_mean = control_acts.mean(axis=0) induction_score = induction_mean - control_mean return induction_mean, control_mean, induction_score def try_fetch_hf_labels(release: str, sae_id: str): """Best-effort fetch of any feature labels / neuronpedia metadata for the SAE.""" try: from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory directory = get_pretrained_saes_directory() info = directory.get(release) if info is None: return None # neuronpedia_id often present on Gemma-Scope releases npid = None if hasattr(info, "neuronpedia_id") and isinstance(info.neuronpedia_id, dict): npid = info.neuronpedia_id.get(sae_id) return { "release": release, "sae_id": sae_id, "repo_id": getattr(info, "repo_id", None), "model": getattr(info, "model", None), "neuronpedia_id": npid, } except Exception as e: print(f"[saebench] Could not fetch HF metadata: {e}", flush=True) return None def main(): parser = argparse.ArgumentParser(description="Compare v9c SAE vs public Gemma-Scope SAE") parser.add_argument("--top-n", type=int, default=100) parser.add_argument("--n-controls", type=int, default=2000) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--device", default="cuda") parser.add_argument("--seed", type=int, default=123) args = parser.parse_args() try: from dotenv import load_dotenv load_dotenv(REPO_ROOT / ".env") except ImportError: pass print("[saebench] Loading Gemma-2-2B ...", flush=True) model = load_model(device=args.device) print("[saebench] Loading public Gemma-Scope SAE ...", flush=True) sae, release, sae_id = load_public_sae(args.device) n_features = sae.cfg.d_sae induction_mean, control_mean, induction_score = score_sae( model, sae, args.device, args.batch_size, args.n_controls, args.seed ) ranked_ids = np.argsort(-induction_score) scores_df = pd.DataFrame({ "feature_id": np.arange(n_features, dtype=np.int32), "induction_mean": induction_mean.astype(np.float32), "control_mean": control_mean.astype(np.float32), "induction_score": induction_score.astype(np.float32), "rank": np.argsort(np.argsort(-induction_score)).astype(np.int32), }) SAEBENCH_SCORES_PATH.parent.mkdir(parents=True, exist_ok=True) scores_df.to_parquet(SAEBENCH_SCORES_PATH, index=False) print(f"[saebench] Scores saved to {SAEBENCH_SCORES_PATH}", flush=True) top_ids = [int(x) for x in ranked_ids[:args.top_n].tolist()] with SAEBENCH_CANDIDATE_IDS_PATH.open("w", encoding="utf-8") as f: json.dump(top_ids, f, indent=2) print(f"[saebench] Top-{args.top_n} candidates saved to {SAEBENCH_CANDIDATE_IDS_PATH}", flush=True) # ── Comparison vs v9c ────────────────────────────────────────────────────── print("\n[saebench] === Comparison: v9c vs public Gemma-Scope SAE ===", flush=True) if V9C_CANDIDATE_IDS_PATH.exists(): with V9C_CANDIDATE_IDS_PATH.open("r", encoding="utf-8") as f: v9c_top100 = json.load(f) else: v9c_top100 = [] print(f"[saebench] WARNING: {V9C_CANDIDATE_IDS_PATH} not found.", flush=True) v9c_top20 = v9c_top100[:20] saebench_top20 = top_ids[:20] # Per-feature score view for the SAEBench top-20 print("\n[saebench] Public SAE top-20 induction features:", flush=True) print(f"{'Rank':>5} {'FeatID':>8} {'Induction':>10} {'Control':>10} {'Score':>10}") for rank, fid in enumerate(saebench_top20): print( f"{rank:>5} {fid:>8} {induction_mean[fid]:>10.4f} " f"{control_mean[fid]:>10.4f} {induction_score[fid]:>10.4f}" ) print("\n[saebench] v9c top-20 feature IDs: ", v9c_top20, flush=True) print("[saebench] SAEBench top-20 feature IDs:", saebench_top20, flush=True) # Overlap is informational only — feature IDs are NOT comparable across SAEs # (different SAEs learn different feature bases). Reported for completeness. overlap = sorted(set(v9c_top20) & set(saebench_top20)) print( f"\n[saebench] Top-20 ID overlap (note: feature IDs are not aligned across SAEs): " f"{len(overlap)} -> {overlap}", flush=True, ) overlap100 = sorted(set(v9c_top100) & set(top_ids)) print(f"[saebench] Top-100 ID overlap: {len(overlap100)}", flush=True) # Compare strength of the top induction signal across SAEs if V9C_SCORES_PATH.exists(): v9c_scores = pd.read_parquet(V9C_SCORES_PATH) v9c_top_score = v9c_scores.sort_values("induction_score", ascending=False)["induction_score"].iloc[0] sae_top_score = float(induction_score[ranked_ids[0]]) print( f"\n[saebench] Top-feature induction_score: v9c={v9c_top_score:.4f} " f"SAEBench={sae_top_score:.4f}", flush=True, ) v9c_top20_mean = v9c_scores.sort_values("induction_score", ascending=False)["induction_score"].iloc[:20].mean() sae_top20_mean = float(induction_score[ranked_ids[:20]].mean()) print( f"[saebench] Top-20 mean induction_score: v9c={v9c_top20_mean:.4f} " f"SAEBench={sae_top20_mean:.4f}", flush=True, ) # HF metadata (neuronpedia link is the closest thing to "prior labels") meta = try_fetch_hf_labels(release, sae_id) if meta is not None: print("\n[saebench] Public SAE metadata (from sae_lens directory):", flush=True) for k, v in meta.items(): print(f" {k}: {v}") if meta.get("neuronpedia_id"): print( f" -> Neuronpedia base URL: " f"https://neuronpedia.org/{meta['neuronpedia_id']}/", flush=True, ) print(" Top-20 SAEBench feature Neuronpedia URLs:", flush=True) for fid in saebench_top20: print(f" f{fid}: https://neuronpedia.org/{meta['neuronpedia_id']}/{fid}") if __name__ == "__main__": main()