File size: 3,020 Bytes
6cb406b
25caf7d
d7c490d
25caf7d
 
 
 
 
 
170c57a
 
9fe50a1
170c57a
9fe50a1
170c57a
 
 
9fe50a1
170c57a
 
9fe50a1
170c57a
9fe50a1
170c57a
 
9fe50a1
170c57a
 
9fe50a1
170c57a
 
 
9fe50a1
170c57a
 
 
9fe50a1
170c57a
 
 
 
 
9fe50a1
170c57a
 
 
 
9fe50a1
170c57a
9fe50a1
170c57a
9fe50a1
170c57a
 
9fe50a1
170c57a
 
 
 
 
9fe50a1
170c57a
9fe50a1
170c57a
9fe50a1
170c57a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
---
  license: apache-2.0
  pipeline_tag: reinforcement-learning
  tags:
    - reinforcement-learning
    - meta-learning
    - pytorch
  ---

  Pretrained weights for the Disco103 meta-network from [Discovering State-of-the-art Reinforcement Learning
  Algorithms](https://doi.org/10.1038/s41586-025-09761-x) (Nature, 2025).

  ## What is this?

  A small LSTM neural network (754,778 parameters) that generates loss targets for RL agents. Instead of hand-crafted
  loss functions like PPO or GRPO, Disco103 observes an agent's rollout — policy logits, rewards, advantages, auxiliary
  predictions — and outputs target distributions the agent should match.

  Meta-trained by DeepMind across 103 complex environments (Atari, ProcGen, DMLab-30). Originally in JAX, this is a
  PyTorch port.

  ## Quick Start

  ```python
  from disco_torch import DiscoTrainer, collect_rollout

  agent = YourAgent(obs_dim=64, num_actions=3).to(device)
  trainer = DiscoTrainer(agent, device=device)  # auto-downloads weights

  env = YourEnv(num_envs=2)
  obs = env.obs()
  lstm_state = agent.init_lstm_state(env.num_envs, device)

  def step_fn(actions):
      rewards, dones = env.step(actions)
      return env.obs(), rewards, dones

  for step in range(1000):
      rollout, obs, lstm_state = collect_rollout(
          agent, step_fn, obs, lstm_state, rollout_len=29, device=device,
      )
      logs = trainer.step(rollout)  # replay buffer, gradient loop, target updates — all handled

  DiscoTrainer encapsulates the full training loop: replay buffer, 32x inner gradient steps, per-element gradient
  clipping, Polyak target network updates, and meta-state management. See
  https://github.com/asystemoffields/disco-torch/blob/main/examples/catch_disco.py for a complete working example that
  reaches 99% catch rate in 1000 steps.

  Advanced: Low-level API

  from disco_torch import DiscoUpdateRule, load_disco103_weights

  rule = DiscoUpdateRule()
  load_disco103_weights(rule)  # auto-downloads from this repo

  # Generate loss targets from a rollout
  meta_out, new_state = rule.unroll_meta_net(
      rollout, agent_params, meta_state, unroll_fn, hyper_params
  )
  loss, logs = rule.agent_loss(rollout, meta_out, hyper_params)

  File

  disco_103.npz — NumPy archive with 42 parameters (754,778 values total), converted from the original JAX checkpoint.

  PyTorch Port

  See https://github.com/asystemoffields/disco-torch for the full PyTorch implementation, examples, and experiment
  results.

  Citation

  @article{oh2025disco,
    title={Discovering State-of-the-art Reinforcement Learning Algorithms},
    author={Oh, Junhyuk and Farquhar, Greg and Kemaev, Iurii and Calian, Dan A. and Hessel, Matteo and Zintgraf, Luisa
  and Singh, Satinder and van Hasselt, Hado and Silver, David},
    journal={Nature},
    volume={648},
    pages={312--319},
    year={2025},
    doi={10.1038/s41586-025-09761-x}
  }

  License

  Apache 2.0 — same as the original https://github.com/google-deepmind/disco_rl.