| --- |
| license: mit |
| tags: |
| - test-time-learning |
| - memory-augmented |
| - atlas |
| - nested-learning |
| - polynomial-memory |
| - omega-rule |
| datasets: |
| - HuggingFaceTB/smollm-corpus |
| language: |
| - en |
| pipeline_tag: text-generation |
| --- |
| |
| # Atlas-MAG with Omega Rule |
|
|
| A 43M parameter implementation of the Atlas paper's Memory-As-Gate (MAG) architecture with polynomial memory and test-time learning (TTL). |
|
|
| **Paper**: [Atlas: Learning to Optimally Memorize the Context at Test Time](https://arxiv.org/abs/2505.23735) (Behrouz et al., Google Research) |
|
|
| **Code**: [toddwbucy/Atlas-MAG_OmegaRule](https://github.com/toddwbucy/Atlas-MAG_OmegaRule) |
|
|
| ## What This Model Demonstrates |
|
|
| This checkpoint exists to demonstrate a concrete infrastructure problem: **test-time learning models cannot be served by existing deployment stacks**. |
|
|
| Atlas-MAG uses gradient descent *during the forward pass* to update its memory. PyTorch gates this behind `if self.training`. Every serving framework calls the equivalent of inference mode before serving. The model's memory architecture is silenced. |
|
|
| Two scripts in the GitHub repo let you see this firsthand. |
|
|
| ## Quick Start |
|
|
| ```bash |
| git clone https://github.com/toddwbucy/Atlas-MAG_OmegaRule.git |
| cd Atlas-MAG_OmegaRule |
| pip install torch huggingface_hub tokenizers |
| |
| # Demo: same model, same weights, different outputs depending on training flag |
| python scripts/demo_ttl_inference.py |
| |
| # Benchmark: NIAH memory probe — TTL ON vs TTL OFF side by side |
| python scripts/benchmark_niah.py |
| ``` |
|
|
| Both scripts auto-download this checkpoint. No manual download needed. |
|
|
| ## Model Details |
|
|
| | | | |
| |---|---| |
| | **Architecture** | Atlas-MAG (Memory-As-Gate) | |
| | **Parameters** | 43M | |
| | **Dimensions** | dim=512, 6 layers, 8 heads | |
| | **Memory** | Polynomial degree-2, rank-512 | |
| | **Attention** | Sliding window, size=512 | |
| | **TTL** | Muon optimizer (NS-5), theta=0.9, alpha=0.999, eta=0.01 | |
| | **Vocab** | 49,152 (SmolLM tokenizer) | |
| | **Training Steps** | 8,800 | |
| | **Training Hardware** | 2x NVIDIA A6000 48GB | |
| | **Training Data** | SmolLM-Corpus (cosmopedia 40%, fineweb-edu 50%, python-edu 10%) | |
| | **NIAH Accuracy** | 85.9% | |
| | **Checkpoint Size** | 473MB | |
| | **Format** | PyTorch (.pt) | |
|
|
| ## Architecture |
|
|
| ``` |
| Input -> Embedding -> [MAGBlock x 6] -> RMSNorm -> LM Head -> Output |
| |
| MAGBlock: |
| x --+--> [Sliding Window Attention] --> attn_out |
| | | |
| +--> [Deep Polynomial Memory] --> mem_out |
| | |
| output = x + attn_out * sigmoid(mem_out) |
| ``` |
|
|
| The polynomial feature map increases memory capacity from O(d_k) to O(d_k^2) per layer — roughly 64x more associations. |
|
|
| ## Loading |
|
|
| ```python |
| import torch |
| from huggingface_hub import hf_hub_download |
| |
| # Download checkpoint |
| ckpt_path = hf_hub_download("r3d91ll/Atlas-MAG_OmegaRule", "checkpoint_step008800.pt") |
| checkpoint = torch.load(ckpt_path, map_location="cuda:0", weights_only=False) |
| |
| # The checkpoint contains: |
| # - "model_state_dict": model weights |
| # - "config": full training configuration dict |
| print(checkpoint["config"]) |
| ``` |
|
|
| For full model loading, see the [GitHub repository](https://github.com/toddwbucy/Atlas-MAG_OmegaRule) which includes the model class and demo scripts. |
|
|
| ## Files |
|
|
| | File | Size | Description | |
| |------|------|-------------| |
| | `checkpoint_step008800.pt` | 473MB | Model weights + config + optimizer state | |
| | `tokenizer_smollm.json` | 2.2MB | BPE tokenizer (SmolLM) | |
|
|
| ## Citation |
|
|
| ```bibtex |
| @article{behrouz2025atlas, |
| title={Atlas: Learning to Optimally Memorize the Context at Test Time}, |
| author={Behrouz, Ali and Li, Yingcong and Kacham, Praneeth and Daliri, Poria and Deng, Zhihao and Zhong, Peilin and Razaviyayn, Meisam and Mirrokni, Vahab}, |
| journal={arXiv preprint arXiv:2505.23735}, |
| year={2025} |
| } |
| ``` |
|
|
| ## License |
|
|
| MIT |
|
|