Atlas-MAG_OmegaRule / README.md
r3d91ll's picture
Upload README.md with huggingface_hub
a6cabc5 verified
---
license: mit
tags:
- test-time-learning
- memory-augmented
- atlas
- nested-learning
- polynomial-memory
- omega-rule
datasets:
- HuggingFaceTB/smollm-corpus
language:
- en
pipeline_tag: text-generation
---
# Atlas-MAG with Omega Rule
A 43M parameter implementation of the Atlas paper's Memory-As-Gate (MAG) architecture with polynomial memory and test-time learning (TTL).
**Paper**: [Atlas: Learning to Optimally Memorize the Context at Test Time](https://arxiv.org/abs/2505.23735) (Behrouz et al., Google Research)
**Code**: [toddwbucy/Atlas-MAG_OmegaRule](https://github.com/toddwbucy/Atlas-MAG_OmegaRule)
## What This Model Demonstrates
This checkpoint exists to demonstrate a concrete infrastructure problem: **test-time learning models cannot be served by existing deployment stacks**.
Atlas-MAG uses gradient descent *during the forward pass* to update its memory. PyTorch gates this behind `if self.training`. Every serving framework calls the equivalent of inference mode before serving. The model's memory architecture is silenced.
Two scripts in the GitHub repo let you see this firsthand.
## Quick Start
```bash
git clone https://github.com/toddwbucy/Atlas-MAG_OmegaRule.git
cd Atlas-MAG_OmegaRule
pip install torch huggingface_hub tokenizers
# Demo: same model, same weights, different outputs depending on training flag
python scripts/demo_ttl_inference.py
# Benchmark: NIAH memory probe — TTL ON vs TTL OFF side by side
python scripts/benchmark_niah.py
```
Both scripts auto-download this checkpoint. No manual download needed.
## Model Details
| | |
|---|---|
| **Architecture** | Atlas-MAG (Memory-As-Gate) |
| **Parameters** | 43M |
| **Dimensions** | dim=512, 6 layers, 8 heads |
| **Memory** | Polynomial degree-2, rank-512 |
| **Attention** | Sliding window, size=512 |
| **TTL** | Muon optimizer (NS-5), theta=0.9, alpha=0.999, eta=0.01 |
| **Vocab** | 49,152 (SmolLM tokenizer) |
| **Training Steps** | 8,800 |
| **Training Hardware** | 2x NVIDIA A6000 48GB |
| **Training Data** | SmolLM-Corpus (cosmopedia 40%, fineweb-edu 50%, python-edu 10%) |
| **NIAH Accuracy** | 85.9% |
| **Checkpoint Size** | 473MB |
| **Format** | PyTorch (.pt) |
## Architecture
```
Input -> Embedding -> [MAGBlock x 6] -> RMSNorm -> LM Head -> Output
MAGBlock:
x --+--> [Sliding Window Attention] --> attn_out
| |
+--> [Deep Polynomial Memory] --> mem_out
|
output = x + attn_out * sigmoid(mem_out)
```
The polynomial feature map increases memory capacity from O(d_k) to O(d_k^2) per layer — roughly 64x more associations.
## Loading
```python
import torch
from huggingface_hub import hf_hub_download
# Download checkpoint
ckpt_path = hf_hub_download("r3d91ll/Atlas-MAG_OmegaRule", "checkpoint_step008800.pt")
checkpoint = torch.load(ckpt_path, map_location="cuda:0", weights_only=False)
# The checkpoint contains:
# - "model_state_dict": model weights
# - "config": full training configuration dict
print(checkpoint["config"])
```
For full model loading, see the [GitHub repository](https://github.com/toddwbucy/Atlas-MAG_OmegaRule) which includes the model class and demo scripts.
## Files
| File | Size | Description |
|------|------|-------------|
| `checkpoint_step008800.pt` | 473MB | Model weights + config + optimizer state |
| `tokenizer_smollm.json` | 2.2MB | BPE tokenizer (SmolLM) |
## Citation
```bibtex
@article{behrouz2025atlas,
title={Atlas: Learning to Optimally Memorize the Context at Test Time},
author={Behrouz, Ali and Li, Yingcong and Kacham, Praneeth and Daliri, Poria and Deng, Zhihao and Zhong, Peilin and Razaviyayn, Meisam and Mirrokni, Vahab},
journal={arXiv preprint arXiv:2505.23735},
year={2025}
}
```
## License
MIT