WriteSAE

WriteSAE: Sparse Autoencoders for Recurrent State

Jack Young

Paper | Website | Code

WriteSAE factors each decoder atom as the rank-1 outer product vᵢwᵢᵀ, matching the native kₜvₜᵀ write that Gated DeltaNet, Mamba-2, and RWKV-7 install into a dₖ × dᵥ matrix cache. Residual SAEs cannot reach that write site; WriteSAE can. Atom substitution beats matched-Frobenius-norm ablation on 92.4% of n=4,851 firings at Qwen3.5-0.8B L9 H4, the closed form predicts measured logit shifts at R² = 0.98, and sustained three-position installs lift midrank target-in-continuation from 33.3% to 100% under greedy decoding. Cross-architecture: GDN rank-1 atoms transfer to Mamba-2-370M at 88.1% over 2,500 firings, with sharpness ordering GDN > RWKV-7 > Mamba-2.

Quick start

from huggingface_hub import snapshot_download
import torch

ckpt_dir = snapshot_download(
    "JackYoung27/writesae-ckpts",
    allow_patterns=["writesae/qwen0p8b/L9_H4/*"],
)

ckpt = torch.load(
    f"{ckpt_dir}/writesae/qwen0p8b/L9_H4/best.pt",
    weights_only=False,
    map_location="cpu",
)

# Decoder atom 412 — the paper's ERASE example.
v_412 = ckpt["sae"].decoder.v[412]   # (d_k,)
w_412 = ckpt["sae"].decoder.w[412]   # (d_v,)
atom = torch.outer(v_412, w_412)      # (d_k, d_v)

Standalone runnable in LOAD_EXAMPLE.py.

Variants

variant encoder decoder role
WriteSAE bilinear vᵢᵀ S wᵢ rank-1 vᵢwᵢᵀ All headline numbers
FlatSAE linear on vec(S) flat Architectural-prior comparison
MatrixSAE linear on vec(S) full-rank Ablation
BilinearSAE bilinear bilinear Ablation

Base models covered

Qwen3.5-0.8B (primary), Qwen3.5-4B, Qwen3.5-27B, Mamba-2-370M, RWKV-7-1.5B, DeltaNet-1.3B, GLA-1.3B. See MODEL_CARD.md for full layer / head coverage and training details.

Repository layout

writesae-ckpts/
  README.md
  MODEL_CARD.md
  manifest.json
  LOAD_EXAMPLE.py
  LICENSE

  writesae/<base-model>/<layer>_<head>/best.pt        # primary cells
  flat_baseline/<base-model>_<layer>_<head>/best.pt   # FlatSAE controls
  results/<test-name>/                                # JSON outputs per paper claim

Limitations

The closed-form factorization predicts well only on Gated DeltaNet (R² = 0.98 at L9 H4); applied to Mamba-2 or Qwen3.5-4B, it returns negative R². The substitution test itself transfers to Mamba-2 (88.1%); the analytical coefficient does not. Per-atom identity varies across SAE seeds; the class-level register / bundle partition reproduces at CV 4–12%.

Citation

@article{young2026writesae,
  title  = {WriteSAE: Sparse Autoencoders for Recurrent State},
  author = {Young, Jack},
  year   = {2026},
  journal= {arXiv preprint arXiv:TBA},
  url    = {https://github.com/JackYoung27/writesae}
}

MIT license. Base models retain their upstream licenses; no base-model weights are redistributed.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for JackYoung27/writesae-ckpts

Finetuned
(176)
this model