amoudgl commited on
Commit
b1bfe10
·
verified ·
1 Parent(s): d040338

Update README

Browse files
Files changed (1) hide show
  1. README.md +101 -3
README.md CHANGED
@@ -1,7 +1,105 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
3
  ---
4
- Official weights for Celo2 learned update rule proposed in paper:
5
- [Celo2: Towards Learned Optimization Free Lunch](https://huggingface.co/papers/2602.19142)
6
 
7
- Code repo: https://github.com/amoudgl/celo2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ library_name: optax
4
+ tags:
5
+ - optimizer
6
+ - learned-optimizer
7
+ - meta-learning
8
+ - jax
9
  ---
 
 
10
 
11
+ # Celo2: Towards Learned Optimization Free Lunch
12
+
13
+ <p>
14
+ <a href="https://arxiv.org/abs/2602.19142"><img alt="Paper" src="https://img.shields.io/badge/arXiv-2602.19142-b31b1b.svg"></a>
15
+ <a href="https://github.com/amoudgl/celo2"><img alt="Code" src="https://img.shields.io/badge/GitHub-black?logo=github&logoColor=white&labelColor=grey"></a>
16
+ <a href="https://opensource.org/licenses/MIT"><img alt="License: MIT" src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
17
+ </p>
18
+
19
+
20
+ Official pretrained weights for the **Celo2** learned update rule: This variant applies (and is meta-trained with) a harness that includes Newton-Schulz orthogonalization on top of the learned update for matrix parameters and uses AdamW for biases/embeddings. For a fully-learned variant without any harness, see [celo2-base](https://huggingface.co/amoudgl/celo2-base).
21
+
22
+
23
+ ## Quickstart
24
+
25
+ Download checkpoint and install:
26
+ ```bash
27
+ pip install git+https://github.com/amoudgl/celo2.git
28
+ hf download amoudgl/celo2 --local-dir ./celo2
29
+ ```
30
+
31
+ Use `load_checkpoint` method to fetch pretrained params from checkpoint path:
32
+ ```python
33
+ from celo2_optax import load_checkpoint
34
+ pretrained_params = load_checkpoint('./celo2/theta.state')
35
+ ```
36
+
37
+ Standard optax usage with `scale_by_celo2` method that takes pretrained params as input:
38
+ ```python
39
+ import optax
40
+ from celo2_optax import scale_by_celo2
41
+
42
+ optimizer = optax.multi_transform(
43
+ transforms={
44
+ 'celo2': optax.chain(
45
+ scale_by_celo2(pretrained_params, orthogonalize=True),
46
+ optax.add_decayed_weights(weight_decay),
47
+ optax.scale_by_learning_rate(lr_schedule),
48
+ ),
49
+ 'adam': optax.adamw(lr_schedule, 0.9, 0.95, weight_decay=weight_decay),
50
+ },
51
+ param_labels=lambda params: jax.tree.map_with_path(
52
+ lambda path, val: 'adam' if val.ndim <= 1 or 'embed' in jax.tree_util.keystr(path) else 'celo2',
53
+ params,
54
+ ),
55
+ )
56
+ ```
57
+
58
+ ## Loading and inspecting MLP update rule weights
59
+
60
+ ```python
61
+ from celo2_optax import load_checkpoint
62
+ import jax
63
+
64
+ pretrained_params = load_checkpoint('./celo2/theta.state') # dictionary containing weights
65
+ print(jax.tree.map(lambda x: x.shape, pretrained_params))
66
+ ```
67
+
68
+ The checkpoint contains a small MLP stored under the `ff_mod_stack` key with weight matrices (`w0__*`, `w1`, `w2`) and biases (`b0`, `b1`, `b2`). Each `w0__*` key contains weights corresponding to particular input feature such as momentum, gradient, parameter, etc.
69
+
70
+ ## Meta-training config
71
+
72
+ | Key | Value |
73
+ | ----------------------- | ------------------------------------------------------------ |
74
+ | **Optimizer architecture** | MLP, 2 hidden layers, 8 units each |
75
+ | **Meta-training tasks** | 4 image classification tasks (MNIST, FMNIST, CIFAR-10, SVHN) |
76
+ | **Task architecture** | MLP (64-32-10) |
77
+ | **Meta-trainer** | Persistent Evolution Strategies (PES) |
78
+ | **Outer iterations** | 100K |
79
+ | **Truncation length** | 50 |
80
+ | **Min unroll length** | 100 |
81
+ | **Max unroll length** | 2000 |
82
+
83
+ For more details, see config JSON included in the repo [here](./config.json).
84
+
85
+ ## Files
86
+
87
+ | File | Description |
88
+ | ------------- | -------------------------------- |
89
+ | `theta.state` | Pretrained MLP optimizer weights |
90
+ | `config.json` | Meta-training configuration |
91
+
92
+
93
+ ## Citation
94
+
95
+ ```bibtex
96
+ @misc{moudgil2026celo2,
97
+ title={Celo2: Towards Learned Optimization Free Lunch},
98
+ author={Abhinav Moudgil and Boris Knyazev and Eugene Belilovsky},
99
+ year={2026},
100
+ eprint={2602.19142},
101
+ archivePrefix={arXiv},
102
+ primaryClass={cs.LG},
103
+ url={https://arxiv.org/abs/2602.19142},
104
+ }
105
+ ```