Asystemoffields commited on
Commit
9fe50a1
Β·
verified Β·
1 Parent(s): 3a1784c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +206 -3
README.md CHANGED
@@ -1,3 +1,206 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # disco-torch
5
+
6
+ A PyTorch port of DeepMind's **Disco103** β€” the meta-learned reinforcement learning update rule from [*Discovering State-of-the-art Reinforcement Learning Algorithms*](https://doi.org/10.1038/s41586-025-09761-x) (Nature, 2025).
7
+
8
+ ## What is DiscoRL?
9
+
10
+ Instead of hand-crafted loss functions like PPO or GRPO, DiscoRL uses a small LSTM neural network (the "meta-network") that **generates loss targets** for RL agents. Given a rollout of agent experience β€” policy logits, rewards, advantages, auxiliary predictions β€” the meta-network outputs target distributions. The agent then minimizes KL divergence between its outputs and these learned targets.
11
+
12
+ The Disco103 checkpoint (754,778 parameters) was meta-trained by DeepMind across thousands of Atari-like environments. It generalizes as a drop-in update rule for new tasks β€” no reward shaping, no hyperparameter-specific loss design.
13
+
14
+ ## Why a PyTorch port?
15
+
16
+ The [original implementation](https://github.com/google-deepmind/disco_rl) uses JAX + Haiku. This port enables using Disco103 in PyTorch training pipelines without any JAX dependency at inference time.
17
+
18
+ ## Installation
19
+
20
+ ```bash
21
+ pip install disco-torch
22
+ ```
23
+
24
+ With optional extras:
25
+
26
+ ```bash
27
+ pip install disco-torch[hub] # HuggingFace Hub weight downloads
28
+ pip install disco-torch[examples] # gymnasium for running examples
29
+ pip install disco-torch[dev] # pytest + all extras for development
30
+ ```
31
+
32
+ ### Weights
33
+
34
+ Option 1 β€” Download from HuggingFace Hub (requires `pip install disco-torch[hub]`):
35
+
36
+ ```python
37
+ from disco_torch import load_disco103_weights
38
+
39
+ rule = DiscoUpdateRule()
40
+ load_disco103_weights(rule) # auto-downloads from HuggingFace Hub
41
+ ```
42
+
43
+ Option 2 β€” Manual download from the [disco_rl repo](https://github.com/google-deepmind/disco_rl):
44
+
45
+ ```bash
46
+ cp path/to/disco_103.npz weights/
47
+ ```
48
+
49
+ ```python
50
+ load_disco103_weights(rule, "weights/disco_103.npz")
51
+ ```
52
+
53
+ ## Quick start
54
+
55
+ ```python
56
+ import torch
57
+ from disco_torch import DiscoUpdateRule, UpdateRuleInputs, load_disco103_weights
58
+
59
+ # Load the meta-network with pretrained weights
60
+ rule = DiscoUpdateRule()
61
+ load_disco103_weights(rule, "weights/disco_103.npz")
62
+
63
+ # Initialize meta-RNN state (persists across training steps)
64
+ state = rule.meta_net.initial_meta_rnn_state()
65
+
66
+ # Run the meta-network on a rollout
67
+ with torch.no_grad():
68
+ meta_out, new_state = rule.meta_net(inputs, state)
69
+ # meta_out["pi"] β€” policy loss targets [T, B, A]
70
+ # meta_out["y"] β€” value loss targets [T, B, 600]
71
+ # meta_out["z"] β€” auxiliary loss targets [T, B, 600]
72
+ ```
73
+
74
+ ### Full training loop
75
+
76
+ ```python
77
+ # At each learner step:
78
+ meta_out, new_meta_state = rule.unroll_meta_net(
79
+ rollout, agent_params, meta_state, unroll_fn, hyper_params
80
+ )
81
+
82
+ # Compute agent loss (KL divergence against meta-network targets)
83
+ loss, logs = rule.agent_loss(rollout, meta_out, hyper_params)
84
+
85
+ # Value function loss (no meta-gradient)
86
+ value_loss, value_logs = rule.agent_loss_no_meta(rollout, meta_out, hyper_params)
87
+ ```
88
+
89
+ ## Architecture
90
+
91
+ ```
92
+ Outer (per-trajectory):
93
+ y_net MLP [600 -> 16 -> 1] Value prediction embedding
94
+ z_net MLP [600 -> 16 -> 1] Auxiliary prediction embedding
95
+ policy_net Conv1dNet [9 -> 16 -> 2] Action-conditional embedding
96
+ trajectory_rnn LSTM(27, 256) Reverse-unrolled over trajectory
97
+ state_gate Linear(128 -> 256) Multiplicative gate from meta-RNN
98
+ y_head / z_head Linear(256 -> 600) Loss targets for y and z
99
+ pi_conv + head Conv1dNet [258 -> 16] -> 1 Policy loss target (per action)
100
+
101
+ Meta-RNN (per-lifetime):
102
+ Separate y/z/policy nets, input MLP(29 -> 16), LSTMCell(16, 128)
103
+ ```
104
+
105
+ The outer network processes each trajectory with a reverse-unrolled LSTM. The meta-RNN operates at a slower timescale β€” it sees batch-time averages and modulates the outer network via a multiplicative gate. This two-level architecture lets the update rule adapt its behavior over an agent's lifetime.
106
+
107
+ ## End-to-end example
108
+
109
+ See [`examples/cartpole_disco.py`](examples/cartpole_disco.py) for a complete training loop that trains a CartPole agent using the Disco103 update rule:
110
+
111
+ ```bash
112
+ # With pretrained weights
113
+ python examples/cartpole_disco.py --weights weights/disco_103.npz
114
+
115
+ # With random meta-network weights (still demonstrates the full pipeline)
116
+ python examples/cartpole_disco.py
117
+ ```
118
+
119
+ ## Package structure
120
+
121
+ ```
122
+ disco_torch/
123
+ __init__.py Public API
124
+ types.py Dataclasses: UpdateRuleInputs, MetaNetInputOption, ValueOuts, etc.
125
+ transforms.py Input transforms and construct_input()
126
+ meta_net.py DiscoMetaNet β€” the full LSTM meta-network
127
+ update_rule.py DiscoUpdateRule β€” meta-net + value computation + loss
128
+ value_utils.py V-trace, TD-error, advantage estimation, Q-values
129
+ utils.py batch_lookup, signed_logp1, 2-hot encoding, EMA
130
+ load_weights.py Maps JAX/Haiku NPZ keys -> PyTorch modules
131
+
132
+ examples/
133
+ cartpole_disco.py End-to-end CartPole training with Disco103
134
+
135
+ scripts/
136
+ inspect_disco103.py Print NPZ weight names and shapes
137
+ validate_against_jax.py Numerical comparison: PyTorch vs JAX reference
138
+
139
+ tests/
140
+ test_utils.py Unit tests for utility functions
141
+ test_building_blocks.py Unit tests for network building blocks
142
+ test_meta_net.py Snapshot tests for meta-network forward pass
143
+ ```
144
+
145
+ ## Numerical validation
146
+
147
+ All outputs match the JAX reference implementation within float32 precision:
148
+
149
+ | Output | Max diff | Status |
150
+ |--------|----------|--------|
151
+ | pi (policy targets) | < 1.3e-06 | PASS |
152
+ | y (value targets) | < 1.3e-06 | PASS |
153
+ | z (auxiliary targets) | < 1.3e-06 | PASS |
154
+ | meta_input_emb | < 1.3e-06 | PASS |
155
+ | meta_rnn_h | < 1.3e-06 | PASS |
156
+
157
+ To run the test suite (no JAX required):
158
+
159
+ ```bash
160
+ pip install disco-torch[dev]
161
+ pytest
162
+ ```
163
+
164
+ To run JAX cross-validation (requires JAX + disco_rl):
165
+
166
+ ```bash
167
+ pip install disco_rl jax dm-haiku rlax distrax
168
+ python scripts/validate_against_jax.py
169
+ ```
170
+
171
+ ## Key implementation details
172
+
173
+ - **HaikuLSTMCell**: Haiku uses gate order `[i, g, f, o]` with a +1 forget gate bias, vs PyTorch's `[i, f, g, o]`. This is handled by a custom LSTM cell.
174
+ - **Weight mapping**: The 42 JAX/Haiku parameters have nested path names (e.g., `lstm/~/meta_lstm/~unroll/mlp_2/~/linear_0/w`). `load_weights.py` maps every one to the correct PyTorch module.
175
+ - **Conv1dBlock**: Each block concatenates per-action features with their mean across actions before the convolution β€” matching the JAX implementation's broadcast pattern.
176
+ - **Value utilities**: V-trace, Retrace-style Q-value estimation, signed hyperbolic transforms, and 2-hot categorical encoding are all ported.
177
+
178
+ ## Requirements
179
+
180
+ - Python >= 3.11
181
+ - PyTorch >= 2.0
182
+ - NumPy >= 1.24
183
+
184
+ ## License
185
+
186
+ Apache 2.0 β€” same as the original disco_rl.
187
+
188
+ ## Citation
189
+
190
+ If you use this port, please cite the original paper:
191
+
192
+ ```bibtex
193
+ @article{oh2025disco,
194
+ title={Discovering State-of-the-art Reinforcement Learning Algorithms},
195
+ author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa and Singh, Satinder and van Hasselt, Hado and Silver, David},
196
+ journal={Nature},
197
+ volume={648},
198
+ pages={312--319},
199
+ year={2025},
200
+ doi={10.1038/s41586-025-09761-x}
201
+ }
202
+ ```
203
+
204
+ ## Acknowledgments
205
+
206
+ This is a community port of [google-deepmind/disco_rl](https://github.com/google-deepmind/disco_rl). All credit for the algorithm, architecture, and pretrained weights goes to the original authors.