| # DeltaLens: Selective Reading from Compressed Memory via Cross-Attention |
|
|
| DeltaLens replaces linear attention's read operation with cross-attention over the compressed state matrix. While existing DeltaNet variants (Gated DeltaNet, KDA, DeltaProduct) focus on improving the **write** mechanism, the **read** remains a simple linear projection that averages over all stored associations. DeltaLens fixes this by enabling selective retrieval. |
|
|
| ## Key Results (1.36B scale, 1B tokens C4) |
|
|
| | Model | Architecture | Params | Val PPL | |
| |-------|-------------|--------|---------| |
| | F0 | Transformer (unfactored) | 1,364M | 25.4 | |
| | F1 | Transformer + factored | 515M | 37.6 | |
| | DN-F1 | Gated DeltaNet + factored | 1,236M | 22.2 | |
| | **DeltaLens** | **DeltaLens + factored** | **751M** | **19.01** | |
|
|
| Factored training alone degrades Transformers (25.4 -> 37.6). DeltaLens overcomes this handicap and still outperforms the unfactored full Transformer by 25%. |
|
|
| ## Architecture |
|
|
| Each DeltaLens layer: |
| 1. **Write**: DeltaNet delta rule (unchanged) |
| 2. **Read**: Cross-attention over state matrix rows (our contribution) |
| 3. **Gate**: Learned combination of linear read + cross-attention read |
|
|
| The cross-attention reads from d_k rows of the state matrix (not from n tokens), so it costs O(d_k) per token -- independent of sequence length. O(1) memory is preserved. |
|
|
| ## Checkpoint |
|
|
| `checkpoints/deltalens-751m/model.safetensors` -- unfactored weights, ready to load: |
|
|
| ```python |
| from safetensors.torch import load_file |
| from src.deltalens_layer import DeltaLensModel |
| |
| model = DeltaLensModel( |
| vocab_size=32000, d_model=2048, n_layers=24, |
| d_state=512, n_heads=16, |
| ) |
| state = load_file("checkpoints/deltalens-751m/model.safetensors") |
| model.load_state_dict(state) |
| ``` |
|
|
| Note: This checkpoint was trained with factored decomposition (W=B@A) and then unfactored (W materialized) for easy loading. The unfactored model has 1.93B parameters; the effective parameter count during training was 751M. |
|
|
| ## Requirements |
|
|
| - PyTorch >= 2.1 |
| - flash-linear-attention (`pip install flash-linear-attention`) |
| - safetensors |
|
|
| ## Paper |
|
|
| **"DeltaLens: Selective Reading from Compressed Memory via Cross-Attention"** |
|
|
| Preprint available on Zenodo (DOI to be added). |
|
|
| Training logs: https://wandb.ai/2264k-none/lora-merge-pretraining |
|
|
| ## License |
|
|
| CC-BY-4.0 |
|
|