r3d91ll commited on
Commit
a6cabc5
·
verified ·
1 Parent(s): 0a95bda

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +120 -3
README.md CHANGED
@@ -1,3 +1,120 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - test-time-learning
5
+ - memory-augmented
6
+ - atlas
7
+ - nested-learning
8
+ - polynomial-memory
9
+ - omega-rule
10
+ datasets:
11
+ - HuggingFaceTB/smollm-corpus
12
+ language:
13
+ - en
14
+ pipeline_tag: text-generation
15
+ ---
16
+
17
+ # Atlas-MAG with Omega Rule
18
+
19
+ A 43M parameter implementation of the Atlas paper's Memory-As-Gate (MAG) architecture with polynomial memory and test-time learning (TTL).
20
+
21
+ **Paper**: [Atlas: Learning to Optimally Memorize the Context at Test Time](https://arxiv.org/abs/2505.23735) (Behrouz et al., Google Research)
22
+
23
+ **Code**: [toddwbucy/Atlas-MAG_OmegaRule](https://github.com/toddwbucy/Atlas-MAG_OmegaRule)
24
+
25
+ ## What This Model Demonstrates
26
+
27
+ This checkpoint exists to demonstrate a concrete infrastructure problem: **test-time learning models cannot be served by existing deployment stacks**.
28
+
29
+ Atlas-MAG uses gradient descent *during the forward pass* to update its memory. PyTorch gates this behind `if self.training`. Every serving framework calls the equivalent of inference mode before serving. The model's memory architecture is silenced.
30
+
31
+ Two scripts in the GitHub repo let you see this firsthand.
32
+
33
+ ## Quick Start
34
+
35
+ ```bash
36
+ git clone https://github.com/toddwbucy/Atlas-MAG_OmegaRule.git
37
+ cd Atlas-MAG_OmegaRule
38
+ pip install torch huggingface_hub tokenizers
39
+
40
+ # Demo: same model, same weights, different outputs depending on training flag
41
+ python scripts/demo_ttl_inference.py
42
+
43
+ # Benchmark: NIAH memory probe — TTL ON vs TTL OFF side by side
44
+ python scripts/benchmark_niah.py
45
+ ```
46
+
47
+ Both scripts auto-download this checkpoint. No manual download needed.
48
+
49
+ ## Model Details
50
+
51
+ | | |
52
+ |---|---|
53
+ | **Architecture** | Atlas-MAG (Memory-As-Gate) |
54
+ | **Parameters** | 43M |
55
+ | **Dimensions** | dim=512, 6 layers, 8 heads |
56
+ | **Memory** | Polynomial degree-2, rank-512 |
57
+ | **Attention** | Sliding window, size=512 |
58
+ | **TTL** | Muon optimizer (NS-5), theta=0.9, alpha=0.999, eta=0.01 |
59
+ | **Vocab** | 49,152 (SmolLM tokenizer) |
60
+ | **Training Steps** | 8,800 |
61
+ | **Training Hardware** | 2x NVIDIA A6000 48GB |
62
+ | **Training Data** | SmolLM-Corpus (cosmopedia 40%, fineweb-edu 50%, python-edu 10%) |
63
+ | **NIAH Accuracy** | 85.9% |
64
+ | **Checkpoint Size** | 473MB |
65
+ | **Format** | PyTorch (.pt) |
66
+
67
+ ## Architecture
68
+
69
+ ```
70
+ Input -> Embedding -> [MAGBlock x 6] -> RMSNorm -> LM Head -> Output
71
+
72
+ MAGBlock:
73
+ x --+--> [Sliding Window Attention] --> attn_out
74
+ | |
75
+ +--> [Deep Polynomial Memory] --> mem_out
76
+ |
77
+ output = x + attn_out * sigmoid(mem_out)
78
+ ```
79
+
80
+ The polynomial feature map increases memory capacity from O(d_k) to O(d_k^2) per layer — roughly 64x more associations.
81
+
82
+ ## Loading
83
+
84
+ ```python
85
+ import torch
86
+ from huggingface_hub import hf_hub_download
87
+
88
+ # Download checkpoint
89
+ ckpt_path = hf_hub_download("r3d91ll/Atlas-MAG_OmegaRule", "checkpoint_step008800.pt")
90
+ checkpoint = torch.load(ckpt_path, map_location="cuda:0", weights_only=False)
91
+
92
+ # The checkpoint contains:
93
+ # - "model_state_dict": model weights
94
+ # - "config": full training configuration dict
95
+ print(checkpoint["config"])
96
+ ```
97
+
98
+ For full model loading, see the [GitHub repository](https://github.com/toddwbucy/Atlas-MAG_OmegaRule) which includes the model class and demo scripts.
99
+
100
+ ## Files
101
+
102
+ | File | Size | Description |
103
+ |------|------|-------------|
104
+ | `checkpoint_step008800.pt` | 473MB | Model weights + config + optimizer state |
105
+ | `tokenizer_smollm.json` | 2.2MB | BPE tokenizer (SmolLM) |
106
+
107
+ ## Citation
108
+
109
+ ```bibtex
110
+ @article{behrouz2025atlas,
111
+ title={Atlas: Learning to Optimally Memorize the Context at Test Time},
112
+ author={Behrouz, Ali and Li, Yingcong and Kacham, Praneeth and Daliri, Poria and Deng, Zhihao and Zhong, Peilin and Razaviyayn, Meisam and Mirrokni, Vahab},
113
+ journal={arXiv preprint arXiv:2505.23735},
114
+ year={2025}
115
+ }
116
+ ```
117
+
118
+ ## License
119
+
120
+ MIT