Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -1,3 +1,120 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- test-time-learning
|
| 5 |
+
- memory-augmented
|
| 6 |
+
- atlas
|
| 7 |
+
- nested-learning
|
| 8 |
+
- polynomial-memory
|
| 9 |
+
- omega-rule
|
| 10 |
+
datasets:
|
| 11 |
+
- HuggingFaceTB/smollm-corpus
|
| 12 |
+
language:
|
| 13 |
+
- en
|
| 14 |
+
pipeline_tag: text-generation
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# Atlas-MAG with Omega Rule
|
| 18 |
+
|
| 19 |
+
A 43M parameter implementation of the Atlas paper's Memory-As-Gate (MAG) architecture with polynomial memory and test-time learning (TTL).
|
| 20 |
+
|
| 21 |
+
**Paper**: [Atlas: Learning to Optimally Memorize the Context at Test Time](https://arxiv.org/abs/2505.23735) (Behrouz et al., Google Research)
|
| 22 |
+
|
| 23 |
+
**Code**: [toddwbucy/Atlas-MAG_OmegaRule](https://github.com/toddwbucy/Atlas-MAG_OmegaRule)
|
| 24 |
+
|
| 25 |
+
## What This Model Demonstrates
|
| 26 |
+
|
| 27 |
+
This checkpoint exists to demonstrate a concrete infrastructure problem: **test-time learning models cannot be served by existing deployment stacks**.
|
| 28 |
+
|
| 29 |
+
Atlas-MAG uses gradient descent *during the forward pass* to update its memory. PyTorch gates this behind `if self.training`. Every serving framework calls the equivalent of inference mode before serving. The model's memory architecture is silenced.
|
| 30 |
+
|
| 31 |
+
Two scripts in the GitHub repo let you see this firsthand.
|
| 32 |
+
|
| 33 |
+
## Quick Start
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
git clone https://github.com/toddwbucy/Atlas-MAG_OmegaRule.git
|
| 37 |
+
cd Atlas-MAG_OmegaRule
|
| 38 |
+
pip install torch huggingface_hub tokenizers
|
| 39 |
+
|
| 40 |
+
# Demo: same model, same weights, different outputs depending on training flag
|
| 41 |
+
python scripts/demo_ttl_inference.py
|
| 42 |
+
|
| 43 |
+
# Benchmark: NIAH memory probe — TTL ON vs TTL OFF side by side
|
| 44 |
+
python scripts/benchmark_niah.py
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
Both scripts auto-download this checkpoint. No manual download needed.
|
| 48 |
+
|
| 49 |
+
## Model Details
|
| 50 |
+
|
| 51 |
+
| | |
|
| 52 |
+
|---|---|
|
| 53 |
+
| **Architecture** | Atlas-MAG (Memory-As-Gate) |
|
| 54 |
+
| **Parameters** | 43M |
|
| 55 |
+
| **Dimensions** | dim=512, 6 layers, 8 heads |
|
| 56 |
+
| **Memory** | Polynomial degree-2, rank-512 |
|
| 57 |
+
| **Attention** | Sliding window, size=512 |
|
| 58 |
+
| **TTL** | Muon optimizer (NS-5), theta=0.9, alpha=0.999, eta=0.01 |
|
| 59 |
+
| **Vocab** | 49,152 (SmolLM tokenizer) |
|
| 60 |
+
| **Training Steps** | 8,800 |
|
| 61 |
+
| **Training Hardware** | 2x NVIDIA A6000 48GB |
|
| 62 |
+
| **Training Data** | SmolLM-Corpus (cosmopedia 40%, fineweb-edu 50%, python-edu 10%) |
|
| 63 |
+
| **NIAH Accuracy** | 85.9% |
|
| 64 |
+
| **Checkpoint Size** | 473MB |
|
| 65 |
+
| **Format** | PyTorch (.pt) |
|
| 66 |
+
|
| 67 |
+
## Architecture
|
| 68 |
+
|
| 69 |
+
```
|
| 70 |
+
Input -> Embedding -> [MAGBlock x 6] -> RMSNorm -> LM Head -> Output
|
| 71 |
+
|
| 72 |
+
MAGBlock:
|
| 73 |
+
x --+--> [Sliding Window Attention] --> attn_out
|
| 74 |
+
| |
|
| 75 |
+
+--> [Deep Polynomial Memory] --> mem_out
|
| 76 |
+
|
|
| 77 |
+
output = x + attn_out * sigmoid(mem_out)
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
The polynomial feature map increases memory capacity from O(d_k) to O(d_k^2) per layer — roughly 64x more associations.
|
| 81 |
+
|
| 82 |
+
## Loading
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
import torch
|
| 86 |
+
from huggingface_hub import hf_hub_download
|
| 87 |
+
|
| 88 |
+
# Download checkpoint
|
| 89 |
+
ckpt_path = hf_hub_download("r3d91ll/Atlas-MAG_OmegaRule", "checkpoint_step008800.pt")
|
| 90 |
+
checkpoint = torch.load(ckpt_path, map_location="cuda:0", weights_only=False)
|
| 91 |
+
|
| 92 |
+
# The checkpoint contains:
|
| 93 |
+
# - "model_state_dict": model weights
|
| 94 |
+
# - "config": full training configuration dict
|
| 95 |
+
print(checkpoint["config"])
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
For full model loading, see the [GitHub repository](https://github.com/toddwbucy/Atlas-MAG_OmegaRule) which includes the model class and demo scripts.
|
| 99 |
+
|
| 100 |
+
## Files
|
| 101 |
+
|
| 102 |
+
| File | Size | Description |
|
| 103 |
+
|------|------|-------------|
|
| 104 |
+
| `checkpoint_step008800.pt` | 473MB | Model weights + config + optimizer state |
|
| 105 |
+
| `tokenizer_smollm.json` | 2.2MB | BPE tokenizer (SmolLM) |
|
| 106 |
+
|
| 107 |
+
## Citation
|
| 108 |
+
|
| 109 |
+
```bibtex
|
| 110 |
+
@article{behrouz2025atlas,
|
| 111 |
+
title={Atlas: Learning to Optimally Memorize the Context at Test Time},
|
| 112 |
+
author={Behrouz, Ali and Li, Yingcong and Kacham, Praneeth and Daliri, Poria and Deng, Zhihao and Zhong, Peilin and Razaviyayn, Meisam and Mirrokni, Vahab},
|
| 113 |
+
journal={arXiv preprint arXiv:2505.23735},
|
| 114 |
+
year={2025}
|
| 115 |
+
}
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
## License
|
| 119 |
+
|
| 120 |
+
MIT
|