--- license: mit tags: - test-time-learning - memory-augmented - atlas - nested-learning - polynomial-memory - omega-rule datasets: - HuggingFaceTB/smollm-corpus language: - en pipeline_tag: text-generation --- # Atlas-MAG with Omega Rule A 43M parameter implementation of the Atlas paper's Memory-As-Gate (MAG) architecture with polynomial memory and test-time learning (TTL). **Paper**: [Atlas: Learning to Optimally Memorize the Context at Test Time](https://arxiv.org/abs/2505.23735) (Behrouz et al., Google Research) **Code**: [toddwbucy/Atlas-MAG_OmegaRule](https://github.com/toddwbucy/Atlas-MAG_OmegaRule) ## What This Model Demonstrates This checkpoint exists to demonstrate a concrete infrastructure problem: **test-time learning models cannot be served by existing deployment stacks**. Atlas-MAG uses gradient descent *during the forward pass* to update its memory. PyTorch gates this behind `if self.training`. Every serving framework calls the equivalent of inference mode before serving. The model's memory architecture is silenced. Two scripts in the GitHub repo let you see this firsthand. ## Quick Start ```bash git clone https://github.com/toddwbucy/Atlas-MAG_OmegaRule.git cd Atlas-MAG_OmegaRule pip install torch huggingface_hub tokenizers # Demo: same model, same weights, different outputs depending on training flag python scripts/demo_ttl_inference.py # Benchmark: NIAH memory probe — TTL ON vs TTL OFF side by side python scripts/benchmark_niah.py ``` Both scripts auto-download this checkpoint. No manual download needed. ## Model Details | | | |---|---| | **Architecture** | Atlas-MAG (Memory-As-Gate) | | **Parameters** | 43M | | **Dimensions** | dim=512, 6 layers, 8 heads | | **Memory** | Polynomial degree-2, rank-512 | | **Attention** | Sliding window, size=512 | | **TTL** | Muon optimizer (NS-5), theta=0.9, alpha=0.999, eta=0.01 | | **Vocab** | 49,152 (SmolLM tokenizer) | | **Training Steps** | 8,800 | | **Training Hardware** | 2x NVIDIA A6000 48GB | | **Training Data** | SmolLM-Corpus (cosmopedia 40%, fineweb-edu 50%, python-edu 10%) | | **NIAH Accuracy** | 85.9% | | **Checkpoint Size** | 473MB | | **Format** | PyTorch (.pt) | ## Architecture ``` Input -> Embedding -> [MAGBlock x 6] -> RMSNorm -> LM Head -> Output MAGBlock: x --+--> [Sliding Window Attention] --> attn_out | | +--> [Deep Polynomial Memory] --> mem_out | output = x + attn_out * sigmoid(mem_out) ``` The polynomial feature map increases memory capacity from O(d_k) to O(d_k^2) per layer — roughly 64x more associations. ## Loading ```python import torch from huggingface_hub import hf_hub_download # Download checkpoint ckpt_path = hf_hub_download("r3d91ll/Atlas-MAG_OmegaRule", "checkpoint_step008800.pt") checkpoint = torch.load(ckpt_path, map_location="cuda:0", weights_only=False) # The checkpoint contains: # - "model_state_dict": model weights # - "config": full training configuration dict print(checkpoint["config"]) ``` For full model loading, see the [GitHub repository](https://github.com/toddwbucy/Atlas-MAG_OmegaRule) which includes the model class and demo scripts. ## Files | File | Size | Description | |------|------|-------------| | `checkpoint_step008800.pt` | 473MB | Model weights + config + optimizer state | | `tokenizer_smollm.json` | 2.2MB | BPE tokenizer (SmolLM) | ## Citation ```bibtex @article{behrouz2025atlas, title={Atlas: Learning to Optimally Memorize the Context at Test Time}, author={Behrouz, Ali and Li, Yingcong and Kacham, Praneeth and Daliri, Poria and Deng, Zhihao and Zhong, Peilin and Razaviyayn, Meisam and Mirrokni, Vahab}, journal={arXiv preprint arXiv:2505.23735}, year={2025} } ``` ## License MIT