| --- | |
| license: apache-2.0 | |
| pipeline_tag: reinforcement-learning | |
| tags: | |
| - reinforcement-learning | |
| - meta-learning | |
| - pytorch | |
| --- | |
| Pretrained weights for the Disco103 meta-network from [Discovering State-of-the-art Reinforcement Learning | |
| Algorithms](https://doi.org/10.1038/s41586-025-09761-x) (Nature, 2025). | |
| ## What is this? | |
| A small LSTM neural network (754,778 parameters) that generates loss targets for RL agents. Instead of hand-crafted | |
| loss functions like PPO or GRPO, Disco103 observes an agent's rollout β policy logits, rewards, advantages, auxiliary | |
| predictions β and outputs target distributions the agent should match. | |
| Meta-trained by DeepMind across 103 complex environments (Atari, ProcGen, DMLab-30). Originally in JAX, this is a | |
| PyTorch port. | |
| ## Quick Start | |
| ```python | |
| from disco_torch import DiscoTrainer, collect_rollout | |
| agent = YourAgent(obs_dim=64, num_actions=3).to(device) | |
| trainer = DiscoTrainer(agent, device=device) # auto-downloads weights | |
| env = YourEnv(num_envs=2) | |
| obs = env.obs() | |
| lstm_state = agent.init_lstm_state(env.num_envs, device) | |
| def step_fn(actions): | |
| rewards, dones = env.step(actions) | |
| return env.obs(), rewards, dones | |
| for step in range(1000): | |
| rollout, obs, lstm_state = collect_rollout( | |
| agent, step_fn, obs, lstm_state, rollout_len=29, device=device, | |
| ) | |
| logs = trainer.step(rollout) # replay buffer, gradient loop, target updates β all handled | |
| DiscoTrainer encapsulates the full training loop: replay buffer, 32x inner gradient steps, per-element gradient | |
| clipping, Polyak target network updates, and meta-state management. See | |
| https://github.com/asystemoffields/disco-torch/blob/main/examples/catch_disco.py for a complete working example that | |
| reaches 99% catch rate in 1000 steps. | |
| Advanced: Low-level API | |
| from disco_torch import DiscoUpdateRule, load_disco103_weights | |
| rule = DiscoUpdateRule() | |
| load_disco103_weights(rule) # auto-downloads from this repo | |
| # Generate loss targets from a rollout | |
| meta_out, new_state = rule.unroll_meta_net( | |
| rollout, agent_params, meta_state, unroll_fn, hyper_params | |
| ) | |
| loss, logs = rule.agent_loss(rollout, meta_out, hyper_params) | |
| File | |
| disco_103.npz β NumPy archive with 42 parameters (754,778 values total), converted from the original JAX checkpoint. | |
| PyTorch Port | |
| See https://github.com/asystemoffields/disco-torch for the full PyTorch implementation, examples, and experiment | |
| results. | |
| Citation | |
| @article{oh2025disco, | |
| title={Discovering State-of-the-art Reinforcement Learning Algorithms}, | |
| author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa | |
| and Singh, Satinder and van Hasselt, Hado and Silver, David}, | |
| journal={Nature}, | |
| volume={648}, | |
| pages={312--319}, | |
| year={2025}, | |
| doi={10.1038/s41586-025-09761-x} | |
| } | |
| License | |
| Apache 2.0 β same as the original https://github.com/google-deepmind/disco_rl. |