File size: 3,020 Bytes
6cb406b 25caf7d d7c490d 25caf7d 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a 9fe50a1 170c57a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | ---
license: apache-2.0
pipeline_tag: reinforcement-learning
tags:
- reinforcement-learning
- meta-learning
- pytorch
---
Pretrained weights for the Disco103 meta-network from [Discovering State-of-the-art Reinforcement Learning
Algorithms](https://doi.org/10.1038/s41586-025-09761-x) (Nature, 2025).
## What is this?
A small LSTM neural network (754,778 parameters) that generates loss targets for RL agents. Instead of hand-crafted
loss functions like PPO or GRPO, Disco103 observes an agent's rollout — policy logits, rewards, advantages, auxiliary
predictions — and outputs target distributions the agent should match.
Meta-trained by DeepMind across 103 complex environments (Atari, ProcGen, DMLab-30). Originally in JAX, this is a
PyTorch port.
## Quick Start
```python
from disco_torch import DiscoTrainer, collect_rollout
agent = YourAgent(obs_dim=64, num_actions=3).to(device)
trainer = DiscoTrainer(agent, device=device) # auto-downloads weights
env = YourEnv(num_envs=2)
obs = env.obs()
lstm_state = agent.init_lstm_state(env.num_envs, device)
def step_fn(actions):
rewards, dones = env.step(actions)
return env.obs(), rewards, dones
for step in range(1000):
rollout, obs, lstm_state = collect_rollout(
agent, step_fn, obs, lstm_state, rollout_len=29, device=device,
)
logs = trainer.step(rollout) # replay buffer, gradient loop, target updates — all handled
DiscoTrainer encapsulates the full training loop: replay buffer, 32x inner gradient steps, per-element gradient
clipping, Polyak target network updates, and meta-state management. See
https://github.com/asystemoffields/disco-torch/blob/main/examples/catch_disco.py for a complete working example that
reaches 99% catch rate in 1000 steps.
Advanced: Low-level API
from disco_torch import DiscoUpdateRule, load_disco103_weights
rule = DiscoUpdateRule()
load_disco103_weights(rule) # auto-downloads from this repo
# Generate loss targets from a rollout
meta_out, new_state = rule.unroll_meta_net(
rollout, agent_params, meta_state, unroll_fn, hyper_params
)
loss, logs = rule.agent_loss(rollout, meta_out, hyper_params)
File
disco_103.npz — NumPy archive with 42 parameters (754,778 values total), converted from the original JAX checkpoint.
PyTorch Port
See https://github.com/asystemoffields/disco-torch for the full PyTorch implementation, examples, and experiment
results.
Citation
@article{oh2025disco,
title={Discovering State-of-the-art Reinforcement Learning Algorithms},
author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa
and Singh, Satinder and van Hasselt, Hado and Silver, David},
journal={Nature},
volume={648},
pages={312--319},
year={2025},
doi={10.1038/s41586-025-09761-x}
}
License
Apache 2.0 — same as the original https://github.com/google-deepmind/disco_rl. |