dntf-architecture / README.md
2264K's picture
Upload folder using huggingface_hub
24b1807 verified
# DeltaLens: Selective Reading from Compressed Memory via Cross-Attention
DeltaLens replaces linear attention's read operation with cross-attention over the compressed state matrix. While existing DeltaNet variants (Gated DeltaNet, KDA, DeltaProduct) focus on improving the **write** mechanism, the **read** remains a simple linear projection that averages over all stored associations. DeltaLens fixes this by enabling selective retrieval.
## Key Results (1.36B scale, 1B tokens C4)
| Model | Architecture | Params | Val PPL |
|-------|-------------|--------|---------|
| F0 | Transformer (unfactored) | 1,364M | 25.4 |
| F1 | Transformer + factored | 515M | 37.6 |
| DN-F1 | Gated DeltaNet + factored | 1,236M | 22.2 |
| **DeltaLens** | **DeltaLens + factored** | **751M** | **19.01** |
Factored training alone degrades Transformers (25.4 -> 37.6). DeltaLens overcomes this handicap and still outperforms the unfactored full Transformer by 25%.
## Architecture
Each DeltaLens layer:
1. **Write**: DeltaNet delta rule (unchanged)
2. **Read**: Cross-attention over state matrix rows (our contribution)
3. **Gate**: Learned combination of linear read + cross-attention read
The cross-attention reads from d_k rows of the state matrix (not from n tokens), so it costs O(d_k) per token -- independent of sequence length. O(1) memory is preserved.
## Checkpoint
`checkpoints/deltalens-751m/model.safetensors` -- unfactored weights, ready to load:
```python
from safetensors.torch import load_file
from src.deltalens_layer import DeltaLensModel
model = DeltaLensModel(
vocab_size=32000, d_model=2048, n_layers=24,
d_state=512, n_heads=16,
)
state = load_file("checkpoints/deltalens-751m/model.safetensors")
model.load_state_dict(state)
```
Note: This checkpoint was trained with factored decomposition (W=B@A) and then unfactored (W materialized) for easy loading. The unfactored model has 1.93B parameters; the effective parameter count during training was 751M.
## Requirements
- PyTorch >= 2.1
- flash-linear-attention (`pip install flash-linear-attention`)
- safetensors
## Paper
**"DeltaLens: Selective Reading from Compressed Memory via Cross-Attention"**
Preprint available on Zenodo (DOI to be added).
Training logs: https://wandb.ai/2264k-none/lora-merge-pretraining
## License
CC-BY-4.0