sae-gemma / scripts /test_convert_dl.py
senator1's picture
Sparse-feature audit of induction in Gemma-2-2B (full project)
253d988
"""
Pre-flight test for convert_dl_to_saelens.py.
Fabricates a tiny dictionary_learning-style ae.pt and runs the conversion,
then loads the result with sae_lens.saes.sae.SAE to confirm it works.
This catches conversion bugs before v8 finishes training.
"""
import json
import shutil
import sys
import tempfile
from pathlib import Path
import torch
from safetensors.torch import save_file
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT / "scripts"))
def make_fake_ae_pt(out_dir: Path, d_in: int = 2304, d_sae: int = 16384, k: int = 80):
"""Mimic AutoEncoderTopK's state_dict layout."""
sd = {
"encoder.weight": torch.randn(d_sae, d_in) * 0.01,
"encoder.bias": torch.zeros(d_sae),
"decoder.weight": torch.randn(d_in, d_sae) * 0.01,
"b_dec": torch.zeros(d_in),
"threshold": torch.tensor(0.0),
"k": torch.tensor(k),
}
out_dir.mkdir(parents=True, exist_ok=True)
torch.save(sd, out_dir / "ae.pt")
cfg = {"trainer": {"k": k, "dict_size": d_sae, "activation_dim": d_in}}
(out_dir / "config.json").write_text(json.dumps(cfg, indent=2), encoding="utf-8")
def main():
print("[test_convert] Setting up fake training output ...")
real_dl_dir = REPO_ROOT / "models" / "sae_main_dl"
real_saelens_dir = REPO_ROOT / "models" / "sae_main"
real_weights = real_saelens_dir / "sae_weights.safetensors"
real_cfg = real_saelens_dir / "cfg.json"
# Snapshot the real files so we can restore them
snap_weights = None
snap_cfg = None
if real_weights.exists():
snap_weights = real_weights.read_bytes()
if real_cfg.exists():
snap_cfg = real_cfg.read_text(encoding="utf-8")
# If dl dir doesn't exist yet (v8 not done), make a fake one
fake_made = False
fake_trainer_dir = real_dl_dir / "trainer_0"
if not (fake_trainer_dir / "ae.pt").exists():
print("[test_convert] Real ae.pt not found; making fake one")
make_fake_ae_pt(fake_trainer_dir, d_in=2304, d_sae=16384, k=80)
fake_made = True
try:
# Run conversion
import subprocess
result = subprocess.run(
[sys.executable, str(REPO_ROOT / "scripts" / "convert_dl_to_saelens.py")],
capture_output=True, text=True, cwd=str(REPO_ROOT)
)
print(result.stdout)
if result.returncode != 0:
print("STDERR:", result.stderr)
raise RuntimeError(f"Conversion failed (exit {result.returncode})")
# Check output files
assert real_weights.exists(), f"{real_weights} missing"
assert real_cfg.exists(), f"{real_cfg} missing"
print("[test_convert] OK: weights + cfg written")
# Try loading with SAELens
from sae_lens.saes.sae import SAE
sae = SAE.load_from_disk(str(real_saelens_dir), device="cpu")
print(f"[test_convert] OK: SAELens.load_from_disk succeeded; cfg={sae.cfg}")
print(f"[test_convert] SAE shapes: W_enc={tuple(sae.W_enc.shape)}, W_dec={tuple(sae.W_dec.shape)}")
# Try encoding
x = torch.randn(4, 2304)
z = sae.encode(x)
x_hat = sae.decode(z)
print(f"[test_convert] OK: encode->decode works. z.shape={tuple(z.shape)}, x_hat.shape={tuple(x_hat.shape)}")
# Check sparsity (TopK should give exactly k nonzero per row)
nnz = (z != 0).sum(dim=-1).float().mean().item()
print(f"[test_convert] sparsity check: mean nonzero per token = {nnz}")
print("\n=== ALL TESTS PASSED ===")
finally:
# Restore real files — use shutil.copy2 instead of write_bytes (which can fail on large files on Windows)
if fake_made:
shutil.rmtree(real_dl_dir, ignore_errors=True)
print(f"[test_convert] cleaned up fake {real_dl_dir}")
if snap_weights is not None:
# Restore from v1 backup file rather than the in-memory bytes (more reliable on Windows)
v1_backup = real_saelens_dir / "sae_weights_v1_backup.safetensors"
if v1_backup.exists():
shutil.copy2(str(v1_backup), str(real_weights))
print(f"[test_convert] restored {real_weights} (from v1 backup file)")
else:
try:
real_weights.write_bytes(snap_weights)
except OSError as e:
print(f"[test_convert] WARNING: failed to restore weights: {e}")
if snap_cfg is not None:
real_cfg.write_text(snap_cfg, encoding="utf-8")
print(f"[test_convert] restored {real_cfg}")
if __name__ == "__main__":
main()