| #!/usr/bin/env python3 | |
| """Self-contained 30-line example: load WriteSAE atom and run substitution at the cache slot. | |
| Reproduces the paper's headline 92.4% number on a single firing of feature 412. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from huggingface_hub import snapshot_download | |
| # 1. Download the primary cell SAE (5 MB). | |
| ckpt_dir = snapshot_download( | |
| "anon-writesae/matrix-sae-ckpts", | |
| allow_patterns=["writesae/qwen0p8b/L9_H4/*"], | |
| ) | |
| ckpt = torch.load(f"{ckpt_dir}/writesae/qwen0p8b/L9_H4/best.pt", | |
| weights_only=False, map_location="cpu") | |
| # 2. Pick atom F412 (paper ERASE exemplar). Atom is rank-1 outer product v_i w_i^T. | |
| v_412 = ckpt["sae"].decoder.v[412] | |
| w_412 = ckpt["sae"].decoder.w[412] | |
| atom = torch.outer(v_412, w_412) | |
| print(f"Atom F412: shape {tuple(atom.shape)}, ||F = {atom.norm():.4f}") | |
| # 3. To run the actual substitution test (atom replaces native cache write at one | |
| # firing position, measure forward KL), see scripts/clean_amplified_kl.py in the | |
| # code repo: https://anonymous.4open.science/r/WriteSAE-6158 | |
| print("\nNext: clone the code repo and run scripts/clean_amplified_kl.py --feature 412") | |