--- license: mit library_name: sparse-readout-prism tags: - sparse-autoencoder - mechanistic-interpretability - interpretability - unembedding - logit-lens base_model: - Qwen/Qwen3.5-0.8B - Qwen/Qwen3.5-2B - Qwen/Qwen3.5-9B - deepseek-ai/DeepSeek-R1-Distill-Llama-8B - deepseek-ai/DeepSeek-R1-Distill-Qwen-7B - google/gemma-4-E2B-it - google/gemma-4-E4B-it - mistralai/Ministral-3-8B-Base-2512 --- # Sparse Readout Prism — pretrained readout-feature dictionaries Pretrained dictionaries for **Sparse Readout Prism**, which factorizes a language model's unembedding matrix (`W_U`) into reusable *readout features*, then decomposes a selected vocabulary logit (or a logit contrast) into ``` h · W_U[token] ≈ base + Σ_i z_i (h · d_i) + residual ``` — signed per-feature contributions plus an explicit residual — and **withholds** the explanation via a per-query fidelity gate when the sparse approximation fails to preserve the held-out logit/margin. These are *final-readout* dictionaries (trained on `W_U` rows), not residual-stream / per-layer SAEs. - **Code:** https://github.com/hematteo/sparse-readout-prism - **Paper:** *Sparse Readout Prism: A Sparse LM-Head Basis for Logit-Lens Readouts* (preprint forthcoming) ## Operating points Most base models ship two dictionaries — a **fidelity** point (`k256`, 32× width, k = 256) and a **strict-budget** point (`k128`, 16× width — 8× for Qwen-3.5-9B — k = 128). Qwen-3.5-9B additionally ships a 16×/`k256` capacity point, so it has three. The exact width of each is in the `width` column below. ## Checkpoints Layout: `//checkpoint.pt`. | Path | base model | width | d_features | k | rowEV | top1 | KL (bits) | |---|---|---|---:|---:|---:|---:|---:| | `qwen3.5-0.8b/k128_16x` | Qwen/Qwen3.5-0.8B | 16× | 16384 | 128 | 0.760 | 0.844 | — | | `qwen3.5-0.8b/k256_32x` | Qwen/Qwen3.5-0.8B | 32× | 32768 | 256 | 0.877 | 0.891 | — | | `qwen3.5-2b/k128_16x` | Qwen/Qwen3.5-2B | 16× | 32768 | 128 | 0.712 | 0.858 | — | | `qwen3.5-2b/k256_32x` | Qwen/Qwen3.5-2B | 32× | 65536 | 256 | 0.847 | 0.887 | — | | `qwen3.5-9b/k128_8x` | Qwen/Qwen3.5-9B | 8× | 32768 | 128 | 0.621 | 0.846 | 0.296 | | `qwen3.5-9b/k256_16x` | Qwen/Qwen3.5-9B | 16× | 65536 | 256 | 0.761 | 0.874 | 0.167 | | `qwen3.5-9b/k256_32x` | Qwen/Qwen3.5-9B | 32× | 131072 | 256 | 0.857 | 0.900 | 0.105 | | `gemma-4-e2b/k128_16x` | google/gemma-4-E2B-it | 16× | 24576 | 128 | — | — | — | | `gemma-4-e2b/k256_32x` | google/gemma-4-E2B-it | 32× | 49152 | 256 | — | — | — | | `gemma-4-e4b/k128_16x` | google/gemma-4-E4B-it | 16× | 40960 | 128 | — | — | — | | `gemma-4-e4b/k256_32x` | google/gemma-4-E4B-it | 32× | 81920 | 256 | — | — | — | | `ministral-3-8b/k128_16x` | mistralai/Ministral-3-8B-Base-2512 | 16× | 65536 | 128 | 0.806 | 0.885 | 0.130 | | `ministral-3-8b/k256_32x` | mistralai/Ministral-3-8B-Base-2512 | 32× | 131072 | 256 | 0.888 | 0.904 | 0.087 | | `r1-distill-qwen-7b/k128_16x` | deepseek-ai/DeepSeek-R1-Distill-Qwen-7B | 16× | 57344 | 128 | 0.709 | 0.695 | 0.777 | | `r1-distill-qwen-7b/k256_32x` | deepseek-ai/DeepSeek-R1-Distill-Qwen-7B | 32× | 114688 | 256 | 0.844 | 0.760 | 0.489 | | `r1-distill-llama-8b/k128_16x` | deepseek-ai/DeepSeek-R1-Distill-Llama-8B | 16× | 65536 | 128 | 0.796 | 0.725 | 0.536 | | `r1-distill-llama-8b/k256_32x` | deepseek-ai/DeepSeek-R1-Distill-Llama-8B | 32× | 131072 | 256 | 0.888 | 0.754 | 0.434 | `rowEV` = row-reconstruction explained variance; `top1` = top-token agreement after replacing `W_U` with its reconstruction on held-out hidden states; `KL` = readout KL (bits). Qwen numbers are the Appendix-K figures; Ministral / R1-Distill are the checkpoints' held-out eval. **Gemma dictionaries are provisional** — the final-logit softcap eval layer is not yet applied, so top1/KL are withheld (see the paper). ## How they were trained TopK factorizer on the **centered + row-normalized** `W_U` rows. Shared "converged finalist" recipe: 20k steps, batch 4096, AdamW lr 1e-3 (warmup → cosine), prism penalty `lambda_prism = 1e-3` with a delayed linear ramp, hybrid (50% frequency / 50% uniform) row sampling, row-seeded init. The operating-point `k` is the audit-`k` used for decomposition. ## Usage ```python from huggingface_hub import hf_hub_download from sparse_readout_prism import load_factorizer # pip install -e . from the GitHub repo path = hf_hub_download("matteohe/sparse-readout-prism", "qwen3.5-2b/k256_32x/checkpoint.pt") sae = load_factorizer(path, freeze=True) # rebuild + load_state_dict + eval, one call ``` Each `checkpoint.pt` is a `weights_only=True`-loadable dict with `model_state_dict` (encoder / decoder / biases) and the factorizer config (`architecture`, `k`, `d_features`); `load_factorizer` resolves it whether the config sits under `factorizer` or `config.factorizer`. To decompose you also need the centered / row-normalized preprocessing — recompute it from the model's `W_U` with `preprocess_rows(W_U)` (the Ministral / R1-Distill checkpoints additionally embed `row_mean` / `row_norms`). Decomposing against a different preprocessing breaks the identity. See the GitHub README quickstart for the full decomposition snippet. ## Intended use & limitations Research artifact for mechanistic interpretability of the **final readout**. A decomposition is interpretable only when its local query passes the residual/sign fidelity gate; high `rowEV` alone does not license feature-level claims. These dictionaries say nothing about *why* a hidden state arose (no residual-stream / circuit attribution). Gemma results are provisional as noted above. ## Citation ```bibtex @misc{he2026sparsereadoutprism, title = {Sparse Readout Prism: A Sparse LM-Head Basis for Logit-Lens Readouts}, author = {He, Matteo and Shen, William F. and Qiu, Xinchi and Lane, Nicholas D.}, year = {2026}, note = {Preprint forthcoming; see the repository for the up-to-date reference}, } ``` License: MIT.