--- license: apache-2.0 pipeline_tag: reinforcement-learning tags: - reinforcement-learning - meta-learning - pytorch --- Pretrained weights for the Disco103 meta-network from [Discovering State-of-the-art Reinforcement Learning Algorithms](https://doi.org/10.1038/s41586-025-09761-x) (Nature, 2025). ## What is this? A small LSTM neural network (754,778 parameters) that generates loss targets for RL agents. Instead of hand-crafted loss functions like PPO or GRPO, Disco103 observes an agent's rollout — policy logits, rewards, advantages, auxiliary predictions — and outputs target distributions the agent should match. Meta-trained by DeepMind across 103 complex environments (Atari, ProcGen, DMLab-30). Originally in JAX, this is a PyTorch port. ## Quick Start ```python from disco_torch import DiscoTrainer, collect_rollout agent = YourAgent(obs_dim=64, num_actions=3).to(device) trainer = DiscoTrainer(agent, device=device) # auto-downloads weights env = YourEnv(num_envs=2) obs = env.obs() lstm_state = agent.init_lstm_state(env.num_envs, device) def step_fn(actions): rewards, dones = env.step(actions) return env.obs(), rewards, dones for step in range(1000): rollout, obs, lstm_state = collect_rollout( agent, step_fn, obs, lstm_state, rollout_len=29, device=device, ) logs = trainer.step(rollout) # replay buffer, gradient loop, target updates — all handled DiscoTrainer encapsulates the full training loop: replay buffer, 32x inner gradient steps, per-element gradient clipping, Polyak target network updates, and meta-state management. See https://github.com/asystemoffields/disco-torch/blob/main/examples/catch_disco.py for a complete working example that reaches 99% catch rate in 1000 steps. Advanced: Low-level API from disco_torch import DiscoUpdateRule, load_disco103_weights rule = DiscoUpdateRule() load_disco103_weights(rule) # auto-downloads from this repo # Generate loss targets from a rollout meta_out, new_state = rule.unroll_meta_net( rollout, agent_params, meta_state, unroll_fn, hyper_params ) loss, logs = rule.agent_loss(rollout, meta_out, hyper_params) File disco_103.npz — NumPy archive with 42 parameters (754,778 values total), converted from the original JAX checkpoint. PyTorch Port See https://github.com/asystemoffields/disco-torch for the full PyTorch implementation, examples, and experiment results. Citation @article{oh2025disco, title={Discovering State-of-the-art Reinforcement Learning Algorithms}, author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa and Singh, Satinder and van Hasselt, Hado and Silver, David}, journal={Nature}, volume={648}, pages={312--319}, year={2025}, doi={10.1038/s41586-025-09761-x} } License Apache 2.0 — same as the original https://github.com/google-deepmind/disco_rl.