Asystemoffields commited on
Commit
170c57a
·
verified ·
1 Parent(s): ca7c6c8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -46
README.md CHANGED
@@ -1,63 +1,77 @@
1
- ---
2
- license: apache-2.0
3
- library_name: disco-torch
4
- tags:
5
- - reinforcement-learning
6
- - meta-learning
7
- - pytorch
8
- - disco-rl
9
- language:
10
- - en
11
- pipeline_tag: reinforcement-learning
12
- ---
13
 
14
- # Disco103 — Meta-Learned RL Update Rule
15
 
16
- Pretrained weights for the **Disco103** meta-network from [*Discovering State-of-the-art Reinforcement Learning Algorithms*](https://doi.org/10.1038/s41586-025-09761-x) (Nature, 2025).
 
 
17
 
18
- ## What is this?
 
19
 
20
- A small LSTM neural network (754,778 parameters) that **generates loss targets** for RL agents. Instead of hand-crafted loss functions like PPO or GRPO, Disco103 observes an agent's rollout — policy logits, rewards, advantages, auxiliary predictions — and outputs target distributions the agent should match.
21
 
22
- Meta-trained by DeepMind across 103 complex environments (Atari, ProcGen, DMLab-30). Originally in JAX, this is a PyTorch port.
 
23
 
24
- ## Usage
 
25
 
26
- ```python
27
- from disco_torch import DiscoUpdateRule, load_disco103_weights
 
28
 
29
- rule = DiscoUpdateRule()
30
- load_disco103_weights(rule) # auto-downloads from this repo
 
31
 
32
- # Generate loss targets from a rollout
33
- meta_out, new_state = rule.unroll_meta_net(
34
- rollout, agent_params, meta_state, unroll_fn, hyper_params
35
- )
36
- loss, logs = rule.agent_loss(rollout, meta_out, hyper_params)
37
- ```
38
 
39
- ## File
 
 
 
40
 
41
- - `disco_103.npz` — NumPy archive with 42 parameters (754,778 values total), converted from the [original JAX checkpoint](https://github.com/google-deepmind/disco_rl).
42
 
43
- ## PyTorch Port
44
 
45
- See [asystemoffields/disco-torch](https://github.com/asystemoffields/disco-torch) for the full PyTorch implementation, examples, and experiment results.
 
46
 
47
- ## Citation
 
 
 
 
48
 
49
- ```bibtex
50
- @article{oh2025disco,
51
- title={Discovering State-of-the-art Reinforcement Learning Algorithms},
52
- author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa and Singh, Satinder and van Hasselt, Hado and Silver, David},
53
- journal={Nature},
54
- volume={648},
55
- pages={312--319},
56
- year={2025},
57
- doi={10.1038/s41586-025-09761-x}
58
- }
59
- ```
60
 
61
- ## License
62
 
63
- Apache 2.0 — same as the original [google-deepmind/disco_rl](https://github.com/google-deepmind/disco_rl).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pretrained weights for the Disco103 meta-network from [Discovering State-of-the-art Reinforcement Learning
2
+ Algorithms](https://doi.org/10.1038/s41586-025-09761-x) (Nature, 2025).
 
 
 
 
 
 
 
 
 
 
3
 
4
+ ## What is this?
5
 
6
+ A small LSTM neural network (754,778 parameters) that generates loss targets for RL agents. Instead of hand-crafted
7
+ loss functions like PPO or GRPO, Disco103 observes an agent's rollout — policy logits, rewards, advantages, auxiliary
8
+ predictions — and outputs target distributions the agent should match.
9
 
10
+ Meta-trained by DeepMind across 103 complex environments (Atari, ProcGen, DMLab-30). Originally in JAX, this is a
11
+ PyTorch port.
12
 
13
+ ## Quick Start
14
 
15
+ ```python
16
+ from disco_torch import DiscoTrainer, collect_rollout
17
 
18
+ agent = YourAgent(obs_dim=64, num_actions=3).to(device)
19
+ trainer = DiscoTrainer(agent, device=device) # auto-downloads weights
20
 
21
+ env = YourEnv(num_envs=2)
22
+ obs = env.obs()
23
+ lstm_state = agent.init_lstm_state(env.num_envs, device)
24
 
25
+ def step_fn(actions):
26
+ rewards, dones = env.step(actions)
27
+ return env.obs(), rewards, dones
28
 
29
+ for step in range(1000):
30
+ rollout, obs, lstm_state = collect_rollout(
31
+ agent, step_fn, obs, lstm_state, rollout_len=29, device=device,
32
+ )
33
+ logs = trainer.step(rollout) # replay buffer, gradient loop, target updates — all handled
 
34
 
35
+ DiscoTrainer encapsulates the full training loop: replay buffer, 32x inner gradient steps, per-element gradient
36
+ clipping, Polyak target network updates, and meta-state management. See
37
+ https://github.com/asystemoffields/disco-torch/blob/main/examples/catch_disco.py for a complete working example that
38
+ reaches 99% catch rate in 1000 steps.
39
 
40
+ Advanced: Low-level API
41
 
42
+ from disco_torch import DiscoUpdateRule, load_disco103_weights
43
 
44
+ rule = DiscoUpdateRule()
45
+ load_disco103_weights(rule) # auto-downloads from this repo
46
 
47
+ # Generate loss targets from a rollout
48
+ meta_out, new_state = rule.unroll_meta_net(
49
+ rollout, agent_params, meta_state, unroll_fn, hyper_params
50
+ )
51
+ loss, logs = rule.agent_loss(rollout, meta_out, hyper_params)
52
 
53
+ File
 
 
 
 
 
 
 
 
 
 
54
 
55
+ disco_103.npz — NumPy archive with 42 parameters (754,778 values total), converted from the original JAX checkpoint.
56
 
57
+ PyTorch Port
58
+
59
+ See https://github.com/asystemoffields/disco-torch for the full PyTorch implementation, examples, and experiment
60
+ results.
61
+
62
+ Citation
63
+
64
+ @article{oh2025disco,
65
+ title={Discovering State-of-the-art Reinforcement Learning Algorithms},
66
+ author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa
67
+ and Singh, Satinder and van Hasselt, Hado and Silver, David},
68
+ journal={Nature},
69
+ volume={648},
70
+ pages={312--319},
71
+ year={2025},
72
+ doi={10.1038/s41586-025-09761-x}
73
+ }
74
+
75
+ License
76
+
77
+ Apache 2.0 — same as the original https://github.com/google-deepmind/disco_rl.