Matteo He
Fact-check fixes: citation title+authors, operating points, checkpoint schema
958f866 verified
|
Raw
History Blame Contribute Delete
6 kB
---
license: mit
library_name: sparse-readout-prism
tags:
- sparse-autoencoder
- mechanistic-interpretability
- interpretability
- unembedding
- logit-lens
base_model:
- Qwen/Qwen3.5-0.8B
- Qwen/Qwen3.5-2B
- Qwen/Qwen3.5-9B
- deepseek-ai/DeepSeek-R1-Distill-Llama-8B
- deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
- google/gemma-4-E2B-it
- google/gemma-4-E4B-it
- mistralai/Ministral-3-8B-Base-2512
---
# Sparse Readout Prism β€” pretrained readout-feature dictionaries
Pretrained dictionaries for **Sparse Readout Prism**, which factorizes a language
model's unembedding matrix (`W_U`) into reusable *readout features*, then
decomposes a selected vocabulary logit (or a logit contrast) into
```
h Β· W_U[token] β‰ˆ base + Ξ£_i z_i (h Β· d_i) + residual
```
β€” signed per-feature contributions plus an explicit residual β€” and **withholds**
the explanation via a per-query fidelity gate when the sparse approximation fails
to preserve the held-out logit/margin. These are *final-readout* dictionaries
(trained on `W_U` rows), not residual-stream / per-layer SAEs.
- **Code:** https://github.com/hematteo/sparse-readout-prism
- **Paper:** *Sparse Readout Prism: A Sparse LM-Head Basis for Logit-Lens Readouts* (preprint forthcoming)
## Operating points
Most base models ship two dictionaries β€” a **fidelity** point (`k256`, 32Γ— width,
k = 256) and a **strict-budget** point (`k128`, 16Γ— width β€” 8Γ— for Qwen-3.5-9B β€”
k = 128). Qwen-3.5-9B additionally ships a 16Γ—/`k256` capacity point, so it has
three. The exact width of each is in the `width` column below.
## Checkpoints
Layout: `<model>/<operating_point>/checkpoint.pt`.
| Path | base model | width | d_features | k | rowEV | top1 | KL (bits) |
|---|---|---|---:|---:|---:|---:|---:|
| `qwen3.5-0.8b/k128_16x` | Qwen/Qwen3.5-0.8B | 16Γ— | 16384 | 128 | 0.760 | 0.844 | β€” |
| `qwen3.5-0.8b/k256_32x` | Qwen/Qwen3.5-0.8B | 32Γ— | 32768 | 256 | 0.877 | 0.891 | β€” |
| `qwen3.5-2b/k128_16x` | Qwen/Qwen3.5-2B | 16Γ— | 32768 | 128 | 0.712 | 0.858 | β€” |
| `qwen3.5-2b/k256_32x` | Qwen/Qwen3.5-2B | 32Γ— | 65536 | 256 | 0.847 | 0.887 | β€” |
| `qwen3.5-9b/k128_8x` | Qwen/Qwen3.5-9B | 8Γ— | 32768 | 128 | 0.621 | 0.846 | 0.296 |
| `qwen3.5-9b/k256_16x` | Qwen/Qwen3.5-9B | 16Γ— | 65536 | 256 | 0.761 | 0.874 | 0.167 |
| `qwen3.5-9b/k256_32x` | Qwen/Qwen3.5-9B | 32Γ— | 131072 | 256 | 0.857 | 0.900 | 0.105 |
| `gemma-4-e2b/k128_16x` | google/gemma-4-E2B-it | 16Γ— | 24576 | 128 | β€” | β€” | β€” |
| `gemma-4-e2b/k256_32x` | google/gemma-4-E2B-it | 32Γ— | 49152 | 256 | β€” | β€” | β€” |
| `gemma-4-e4b/k128_16x` | google/gemma-4-E4B-it | 16Γ— | 40960 | 128 | β€” | β€” | β€” |
| `gemma-4-e4b/k256_32x` | google/gemma-4-E4B-it | 32Γ— | 81920 | 256 | β€” | β€” | β€” |
| `ministral-3-8b/k128_16x` | mistralai/Ministral-3-8B-Base-2512 | 16Γ— | 65536 | 128 | 0.806 | 0.885 | 0.130 |
| `ministral-3-8b/k256_32x` | mistralai/Ministral-3-8B-Base-2512 | 32Γ— | 131072 | 256 | 0.888 | 0.904 | 0.087 |
| `r1-distill-qwen-7b/k128_16x` | deepseek-ai/DeepSeek-R1-Distill-Qwen-7B | 16Γ— | 57344 | 128 | 0.709 | 0.695 | 0.777 |
| `r1-distill-qwen-7b/k256_32x` | deepseek-ai/DeepSeek-R1-Distill-Qwen-7B | 32Γ— | 114688 | 256 | 0.844 | 0.760 | 0.489 |
| `r1-distill-llama-8b/k128_16x` | deepseek-ai/DeepSeek-R1-Distill-Llama-8B | 16Γ— | 65536 | 128 | 0.796 | 0.725 | 0.536 |
| `r1-distill-llama-8b/k256_32x` | deepseek-ai/DeepSeek-R1-Distill-Llama-8B | 32Γ— | 131072 | 256 | 0.888 | 0.754 | 0.434 |
`rowEV` = row-reconstruction explained variance; `top1` = top-token agreement
after replacing `W_U` with its reconstruction on held-out hidden states; `KL` =
readout KL (bits). Qwen numbers are the Appendix-K figures; Ministral / R1-Distill
are the checkpoints' held-out eval. **Gemma dictionaries are provisional** β€” the
final-logit softcap eval layer is not yet applied, so top1/KL are withheld (see
the paper).
## How they were trained
TopK factorizer on the **centered + row-normalized** `W_U` rows. Shared "converged
finalist" recipe: 20k steps, batch 4096, AdamW lr 1e-3 (warmup β†’ cosine), prism
penalty `lambda_prism = 1e-3` with a delayed linear ramp, hybrid (50% frequency /
50% uniform) row sampling, row-seeded init. The operating-point `k` is the
audit-`k` used for decomposition.
## Usage
```python
from huggingface_hub import hf_hub_download
from sparse_readout_prism import load_factorizer # pip install -e . from the GitHub repo
path = hf_hub_download("matteohe/sparse-readout-prism", "qwen3.5-2b/k256_32x/checkpoint.pt")
sae = load_factorizer(path, freeze=True) # rebuild + load_state_dict + eval, one call
```
Each `checkpoint.pt` is a `weights_only=True`-loadable dict with `model_state_dict`
(encoder / decoder / biases) and the factorizer config (`architecture`, `k`,
`d_features`); `load_factorizer` resolves it whether the config sits under
`factorizer` or `config.factorizer`. To decompose you also need the centered /
row-normalized preprocessing β€” recompute it from the model's `W_U` with
`preprocess_rows(W_U)` (the Ministral / R1-Distill checkpoints additionally embed
`row_mean` / `row_norms`). Decomposing against a different preprocessing breaks the
identity. See the GitHub README quickstart for the full decomposition snippet.
## Intended use & limitations
Research artifact for mechanistic interpretability of the **final readout**.
A decomposition is interpretable only when its local query passes the
residual/sign fidelity gate; high `rowEV` alone does not license feature-level
claims. These dictionaries say nothing about *why* a hidden state arose (no
residual-stream / circuit attribution). Gemma results are provisional as noted
above.
## Citation
```bibtex
@misc{he2026sparsereadoutprism,
title = {Sparse Readout Prism: A Sparse LM-Head Basis for Logit-Lens Readouts},
author = {He, Matteo and Shen, William F. and Qiu, Xinchi and Lane, Nicholas D.},
year = {2026},
note = {Preprint forthcoming; see the repository for the up-to-date reference},
}
```
License: MIT.