lainwired commited on
Commit
ec924fd
·
verified ·
1 Parent(s): 1ad965e

upload teammate_generation for ckpt-eval support

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. teammate_generation/BRDiv.py +832 -0
  2. teammate_generation/CoMeDi.py +1161 -0
  3. teammate_generation/LBRDiv.py +1098 -0
  4. teammate_generation/__init__.py +0 -0
  5. teammate_generation/configs/algorithm/brdiv/_base_.yaml +40 -0
  6. teammate_generation/configs/algorithm/brdiv/hanabi.yaml +27 -0
  7. teammate_generation/configs/algorithm/brdiv/lbf/lbf_12x12.yaml +18 -0
  8. teammate_generation/configs/algorithm/brdiv/lbf/lbf_7x7_nolevels.yaml +18 -0
  9. teammate_generation/configs/algorithm/brdiv/mini-hanabi.yaml +28 -0
  10. teammate_generation/configs/algorithm/brdiv/overcooked-v1/asymm_advantages.yaml +18 -0
  11. teammate_generation/configs/algorithm/brdiv/overcooked-v1/coord_ring.yaml +18 -0
  12. teammate_generation/configs/algorithm/brdiv/overcooked-v1/counter_circuit.yaml +18 -0
  13. teammate_generation/configs/algorithm/brdiv/overcooked-v1/cramped_room.yaml +21 -0
  14. teammate_generation/configs/algorithm/brdiv/overcooked-v1/forced_coord.yaml +18 -0
  15. teammate_generation/configs/algorithm/comedi/_base_.yaml +36 -0
  16. teammate_generation/configs/algorithm/comedi/hanabi.yaml +26 -0
  17. teammate_generation/configs/algorithm/comedi/lbf/lbf_12x12.yaml +18 -0
  18. teammate_generation/configs/algorithm/comedi/lbf/lbf_7x7_nolevels.yaml +18 -0
  19. teammate_generation/configs/algorithm/comedi/mini-hanabi.yaml +27 -0
  20. teammate_generation/configs/algorithm/comedi/overcooked-v1/asymm_advantages.yaml +16 -0
  21. teammate_generation/configs/algorithm/comedi/overcooked-v1/coord_ring.yaml +16 -0
  22. teammate_generation/configs/algorithm/comedi/overcooked-v1/counter_circuit.yaml +16 -0
  23. teammate_generation/configs/algorithm/comedi/overcooked-v1/cramped_room.yaml +17 -0
  24. teammate_generation/configs/algorithm/comedi/overcooked-v1/forced_coord.yaml +16 -0
  25. teammate_generation/configs/algorithm/fcp/_base_.yaml +37 -0
  26. teammate_generation/configs/algorithm/fcp/hanabi.yaml +32 -0
  27. teammate_generation/configs/algorithm/fcp/lbf/lbf_12x12.yaml +17 -0
  28. teammate_generation/configs/algorithm/fcp/lbf/lbf_7x7_nolevels.yaml +17 -0
  29. teammate_generation/configs/algorithm/fcp/mini-hanabi.yaml +26 -0
  30. teammate_generation/configs/algorithm/fcp/overcooked-v1/asymm_advantages.yaml +17 -0
  31. teammate_generation/configs/algorithm/fcp/overcooked-v1/coord_ring.yaml +16 -0
  32. teammate_generation/configs/algorithm/fcp/overcooked-v1/counter_circuit.yaml +16 -0
  33. teammate_generation/configs/algorithm/fcp/overcooked-v1/cramped_room.yaml +16 -0
  34. teammate_generation/configs/algorithm/fcp/overcooked-v1/forced_coord.yaml +16 -0
  35. teammate_generation/configs/algorithm/lbrdiv/_base_.yaml +38 -0
  36. teammate_generation/configs/algorithm/lbrdiv/hanabi.yaml +26 -0
  37. teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_12x12.yaml +17 -0
  38. teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_7x7_nolevels.yaml +17 -0
  39. teammate_generation/configs/algorithm/lbrdiv/mini-hanabi.yaml +27 -0
  40. teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/asymm_advantages.yaml +18 -0
  41. teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/coord_ring.yaml +18 -0
  42. teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/counter_circuit.yaml +18 -0
  43. teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/cramped_room.yaml +18 -0
  44. teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/forced_coord.yaml +18 -0
  45. teammate_generation/configs/base_config_teammate.yaml +54 -0
  46. teammate_generation/configs/hydra/hydra_simple.yaml +7 -0
  47. teammate_generation/configs/task/hanabi.yaml +16 -0
  48. teammate_generation/configs/task/lbf/lbf_12x12.yaml +7 -0
  49. teammate_generation/configs/task/lbf/lbf_7x7_nolevels.yaml +4 -0
  50. teammate_generation/configs/task/mini-hanabi.yaml +13 -0
teammate_generation/BRDiv.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Implementation of the BRDiv teammate generation algorithm (Rahman et al., TMLR 2023)
2
+ https://arxiv.org/abs/2207.14138
3
+
4
+ Command to run BRDiv only on LBF:
5
+ python teammate_generation/run.py algorithm=brdiv/lbf/lbf_7x7_nolevels task=lbf/lbf_7x7_nolevels label=test_brdiv run_heldout_eval=false train_ego=false
6
+
7
+ Limitations: does not support recurrent actors.
8
+ '''
9
+ import shutil
10
+ import time
11
+ import logging
12
+ from typing import NamedTuple
13
+ from functools import partial
14
+
15
+ import hydra
16
+ import jax
17
+ import jax.numpy as jnp
18
+ import numpy as np
19
+ import optax
20
+ from flax.training.train_state import TrainState
21
+ import wandb
22
+
23
+ from agents.mlp_actor_critic_agent import ActorWithConditionalCriticPolicy
24
+ from agents.population_interface import AgentPopulation
25
+ from common.plot_utils import get_metric_names
26
+ from common.run_episodes import run_episodes
27
+ from common.save_load_utils import save_train_run
28
+ from envs import make_env
29
+ from envs.log_wrapper import LogWrapper
30
+ from marl.ppo_utils import unbatchify, _create_minibatches
31
+
32
+ log = logging.getLogger(__name__)
33
+ logging.basicConfig(level=logging.INFO)
34
+
35
+ class XPTransition(NamedTuple):
36
+ done: jnp.ndarray
37
+ action: jnp.ndarray
38
+ value: jnp.ndarray
39
+ self_onehot_id: jnp.ndarray
40
+ oppo_onehot_id: jnp.ndarray
41
+ reward: jnp.ndarray
42
+ log_prob: jnp.ndarray
43
+ obs: jnp.ndarray
44
+ info: jnp.ndarray
45
+ avail_actions: jnp.ndarray
46
+
47
+ def _get_all_ids(pop_size):
48
+ cross_product = np.meshgrid(
49
+ np.arange(pop_size),
50
+ np.arange(pop_size)
51
+ )
52
+ agent_id_cartesian_product = np.stack([g.ravel() for g in cross_product], axis=-1)
53
+ all_conf_ids = agent_id_cartesian_product[:, 1]
54
+ all_br_ids = agent_id_cartesian_product[:, 0]
55
+ return all_conf_ids, all_br_ids
56
+
57
+ def gather_params(partner_params_pytree, idx_vec):
58
+ """
59
+ partner_params_pytree: pytree with all partner params. Each leaf has shape (n_seeds, m_ckpts, ...).
60
+ idx_vec: a vector of indices with shape (num_envs,) each in [0, n_seeds*m_ckpts).
61
+
62
+ Return a new pytree where each leaf has shape (num_envs, ...). Each leaf has a sampled
63
+ partner's parameters for each environment.
64
+ """
65
+ # We'll define a function that gathers from each leaf
66
+ # where leaf has shape (n_seeds, m_ckpts, ...), we want [idx_vec[i]] for each i.
67
+ # We'll vmap a slicing function.
68
+ def gather_leaf(leaf):
69
+ def slice_one(idx):
70
+ return leaf[idx] # shape (...)
71
+ return jax.vmap(slice_one)(idx_vec)
72
+
73
+ return jax.tree.map(gather_leaf, partner_params_pytree)
74
+
75
+ def train_brdiv_partners(train_rng, env, config, conf_policy, br_policy):
76
+ num_agents = env.num_agents
77
+ assert num_agents == 2, "This code assumes the environment has exactly 2 agents."
78
+
79
+ # Define different minibatch sizes for interactions with ego agent and one with BR agent
80
+ config["NUM_GAME_AGENTS"] = num_agents
81
+ config["NUM_CONF_ACTORS"] = config["NUM_ENVS"]
82
+ config["NUM_BR_ACTORS"] = config["NUM_ENVS"]
83
+ config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // (config["ROLLOUT_LENGTH"] * config["NUM_ENVS"])
84
+
85
+ def make_brdiv_agents(config):
86
+ def linear_schedule(count):
87
+ frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
88
+ return config["LR"] * frac
89
+
90
+ def train(rng):
91
+ rng, init_conf_rng, init_br_rng = jax.random.split(rng, 3)
92
+ all_conf_init_rngs = jax.random.split(init_conf_rng, config["PARTNER_POP_SIZE"])
93
+ all_br_init_rngs = jax.random.split(init_br_rng, config["PARTNER_POP_SIZE"])
94
+ identity_matrix = jnp.eye(config["PARTNER_POP_SIZE"])
95
+
96
+ init_conf_hstate = conf_policy.init_hstate(config["NUM_CONF_ACTORS"])
97
+ init_br_hstate = br_policy.init_hstate(config["NUM_BR_ACTORS"])
98
+
99
+ def init_train_states(rng_agents, rng_brs):
100
+ def init_single_pair_optimizers(rng_agent, rng_br):
101
+ init_params_conf = conf_policy.init_params(rng_agent)
102
+ init_params_br = br_policy.init_params(rng_br)
103
+ return init_params_conf, init_params_br
104
+
105
+ init_all_networks_and_optimizers = jax.vmap(init_single_pair_optimizers)
106
+ all_conf_params, all_br_params = init_all_networks_and_optimizers(rng_agents, rng_brs)
107
+
108
+ # Define optimizers for both confederate and BR policy
109
+ tx = optax.chain(
110
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
111
+ optax.adam(learning_rate=linear_schedule if config["ANNEAL_LR"] else config["LR"],
112
+ eps=1e-5),
113
+ )
114
+ tx_br = optax.chain(
115
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
116
+ optax.adam(learning_rate=linear_schedule if config["ANNEAL_LR"] else config["LR"],
117
+ eps=1e-5),
118
+ )
119
+
120
+ train_state_conf = TrainState.create(
121
+ apply_fn=conf_policy.network.apply,
122
+ params=all_conf_params,
123
+ tx=tx,
124
+ )
125
+
126
+ train_state_br = TrainState.create(
127
+ apply_fn=br_policy.network.apply,
128
+ params=all_br_params,
129
+ tx=tx_br,
130
+ )
131
+
132
+ return train_state_conf, train_state_br
133
+
134
+ all_conf_optims, all_br_optims = init_train_states(
135
+ all_conf_init_rngs, all_br_init_rngs
136
+ )
137
+
138
+ def forward_pass_conf(params, obs, id, done, avail_actions, hstate, rng):
139
+ act, val, pi, new_hstate = conf_policy.get_action_value_policy(
140
+ params=params,
141
+ obs=obs[jnp.newaxis, ...],
142
+ done=done[jnp.newaxis, ...],
143
+ avail_actions=avail_actions,
144
+ hstate=hstate,
145
+ rng=rng,
146
+ aux_obs=id[jnp.newaxis, ...]
147
+ )
148
+ return act, val, pi, new_hstate
149
+
150
+ def forward_pass_br(params, obs, id, done, avail_actions, hstate, rng):
151
+ act, val, pi, new_hstate = br_policy.get_action_value_policy(
152
+ params=params,
153
+ obs=obs[jnp.newaxis, ...],
154
+ done=done[jnp.newaxis, ...],
155
+ avail_actions=avail_actions,
156
+ hstate=hstate,
157
+ rng=rng,
158
+ aux_obs=id[jnp.newaxis, ...]
159
+ )
160
+ return act, val, pi, new_hstate
161
+
162
+ def _env_step(runner_state, unused):
163
+ """
164
+ agent_0 = confederate, agent_1 = br
165
+ Returns updated runner_state, and Transitions for agent_0 and agent_1
166
+ """
167
+ (
168
+ all_train_state_conf, all_train_state_br, last_conf_ids, last_br_ids,
169
+ env_state, last_obs, last_done, last_conf_h, last_br_h, rng
170
+ ) = runner_state
171
+ rng, act0_rng, act1_rng, step_rng, conf_sampling_rng, br_sampling_rng = jax.random.split(rng, 6)
172
+
173
+ # For done envs, resample both conf and brs
174
+ needs_resample = last_done["__all__"]
175
+ resampled_conf_ids = jax.random.randint(conf_sampling_rng, (config["NUM_CONF_ACTORS"],), 0, config["PARTNER_POP_SIZE"])
176
+ resampled_br_ids = jax.random.randint(br_sampling_rng, (config["NUM_BR_ACTORS"],), 0, config["PARTNER_POP_SIZE"])
177
+
178
+ # Determine final indices based on whether resampling was needed for each env
179
+ updated_conf_ids = jnp.where(
180
+ needs_resample,
181
+ resampled_conf_ids, # Use newly sampled index if True
182
+ last_conf_ids # Else, keep index from previous step
183
+ )
184
+
185
+ updated_br_ids = jnp.where(
186
+ needs_resample,
187
+ resampled_br_ids, # Use newly sampled index if True
188
+ last_br_ids # Else, keep index from previous step
189
+ )
190
+
191
+ # Reset the hidden states for resampled conf and br if they are not None
192
+ # WARNING: BRDiv was not tested with recurrent actors, so the code for if the hstate is not None may not work
193
+ if last_conf_h is not None:
194
+ updated_conf_h = jnp.where(
195
+ needs_resample,
196
+ init_conf_hstate,
197
+ last_conf_h
198
+ )
199
+ else:
200
+ updated_conf_h = last_conf_h
201
+
202
+ if last_br_h is not None:
203
+ updated_br_h = jnp.where(
204
+ needs_resample,
205
+ init_br_hstate,
206
+ last_br_h
207
+ )
208
+ else:
209
+ updated_br_h = last_br_h
210
+
211
+ # Get the corresponding conf and br params
212
+ updated_conf_params = gather_params(all_train_state_conf.params, updated_conf_ids)
213
+ updated_br_params = gather_params(all_train_state_br.params, updated_br_ids)
214
+
215
+ updated_conf_onehot_ids = identity_matrix[updated_conf_ids]
216
+ updated_br_onehot_ids = identity_matrix[updated_br_ids]
217
+
218
+ # Get available actions for agent 0 from environment state
219
+ avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
220
+ avail_actions = jax.lax.stop_gradient(avail_actions)
221
+ avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32)
222
+ avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32)
223
+
224
+ # Agent_0 action
225
+ act0_rng = jax.random.split(act0_rng, config["NUM_ENVS"])
226
+ act_0, val_0, pi_0, new_conf_h = jax.vmap(forward_pass_conf)(updated_conf_params,
227
+ last_obs["agent_0"], updated_br_onehot_ids, last_done["agent_0"], avail_actions_0,
228
+ updated_conf_h, act0_rng)
229
+ logp_0 = pi_0.log_prob(act_0)
230
+ act_0, val_0, logp_0 = act_0.squeeze(), val_0.squeeze(), logp_0.squeeze()
231
+
232
+ # Agent_1 action
233
+ act1_rng = jax.random.split(act1_rng, config["NUM_ENVS"])
234
+ act_1, val_1, pi_1, new_br_h = jax.vmap(forward_pass_br)(updated_br_params,
235
+ last_obs["agent_1"], updated_conf_onehot_ids, last_done["agent_1"], avail_actions_1,
236
+ updated_br_h, act1_rng)
237
+ logp_1 = pi_1.log_prob(act_1)
238
+ act_1, val_1, logp_1 = act_1.squeeze(), val_1.squeeze(), logp_1.squeeze()
239
+
240
+ # Combine actions into the env format
241
+ combined_actions = jnp.concatenate([act_0, act_1], axis=0)
242
+ env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents)
243
+ env_act = {k: v.flatten() for k, v in env_act.items()}
244
+
245
+ # Step env
246
+ step_rngs = jax.random.split(step_rng, config["NUM_ENVS"])
247
+ obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))(
248
+ step_rngs, env_state, env_act
249
+ )
250
+ # note that num_actors = num_envs * num_agents
251
+ info_0 = jax.tree.map(lambda x: x[:, 0], info)
252
+ info_1 = jax.tree.map(lambda x: x[:, 1], info)
253
+
254
+ def _compute_rewards(conf_id, br_id, agent_rew):
255
+ return jax.lax.cond(jnp.equal(
256
+ jnp.argmax(conf_id, axis=-1), jnp.argmax(br_id, axis=-1)
257
+ ),
258
+ lambda x: x,
259
+ lambda x: -x,
260
+ agent_rew
261
+ )
262
+
263
+ agent_0_rews = jax.vmap(_compute_rewards)(updated_conf_onehot_ids, updated_br_onehot_ids, reward["agent_1"])
264
+ agent_1_rews = jax.vmap(_compute_rewards)(updated_conf_onehot_ids, updated_br_onehot_ids, reward["agent_0"])
265
+
266
+ # Store agent_0 data in transition
267
+ transition_0 = XPTransition(
268
+ done=done["agent_0"],
269
+ action=act_0,
270
+ value=val_0,
271
+ self_onehot_id=updated_conf_onehot_ids,
272
+ oppo_onehot_id=updated_br_onehot_ids,
273
+ reward=agent_0_rews,
274
+ log_prob=logp_0,
275
+ obs=last_obs["agent_0"],
276
+ info=info_0,
277
+ avail_actions=avail_actions_0
278
+ )
279
+
280
+ transition_1 = XPTransition(
281
+ done=done["agent_1"],
282
+ action=act_1,
283
+ value=val_1,
284
+ self_onehot_id=updated_br_onehot_ids,
285
+ oppo_onehot_id=updated_conf_onehot_ids,
286
+ reward=agent_1_rews,
287
+ log_prob=logp_1,
288
+ obs=last_obs["agent_1"],
289
+ info=info_1,
290
+ avail_actions=avail_actions_1
291
+ )
292
+ new_runner_state = (all_train_state_conf, all_train_state_br, updated_conf_ids, updated_br_ids,
293
+ env_state_next, obs_next, done, new_conf_h, new_br_h, rng)
294
+ return new_runner_state, (transition_0, transition_1)
295
+
296
+ def _calculate_gae(traj_batch, last_val):
297
+ def _get_advantages(gae_and_next_value, transition):
298
+ gae, next_value = gae_and_next_value
299
+ done, value, reward = (
300
+ transition.done,
301
+ transition.value,
302
+ transition.reward,
303
+ )
304
+ delta = reward + config["GAMMA"] * next_value * (1 - done) - value
305
+ gae = (
306
+ delta
307
+ + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
308
+ )
309
+ return (gae, value), gae
310
+
311
+ _, advantages = jax.lax.scan(
312
+ _get_advantages,
313
+ (jnp.zeros_like(last_val), last_val),
314
+ traj_batch,
315
+ reverse=True,
316
+ unroll=16,
317
+ )
318
+ return advantages, advantages + traj_batch.value
319
+
320
+ def run_all_episodes(rng, train_state_conf, train_state_br):
321
+ conf_ids, br_ids = _get_all_ids(config["PARTNER_POP_SIZE"])
322
+ gathered_conf_model_params = gather_params(train_state_conf.params, conf_ids)
323
+ gathered_br_model_params = gather_params(train_state_br.params, br_ids)
324
+
325
+ rng, eval_rng = jax.random.split(rng)
326
+ def run_episodes_fixed_rng(conf_param, br_param):
327
+ return run_episodes(
328
+ eval_rng, env,
329
+ conf_param, conf_policy,
330
+ br_param, br_policy,
331
+ config["ROLLOUT_LENGTH"], config["NUM_EVAL_EPISODES"],
332
+ )
333
+ ep_infos = jax.vmap(run_episodes_fixed_rng)(
334
+ gathered_conf_model_params, gathered_br_model_params, # leaves where shape is (pop_size*pop_size, ...)
335
+ )
336
+ return ep_infos
337
+
338
+ def _update_epoch(update_state, unused):
339
+ def _update_minbatch(all_train_states, all_data):
340
+ train_state_conf, train_state_br = all_train_states
341
+ minbatch_conf, minbatch_br = all_data
342
+
343
+ def _loss_fn(param, agent_policy, minbatch, agent_id):
344
+ '''Compute loss for agent corresponding to agent_id.
345
+ '''
346
+ init_hstate, traj_batch, gae, target_v = minbatch
347
+ # get policy and value of confederate versus ego and best response agents respectively
348
+ squeezed_param = jax.tree.map(lambda x: jnp.squeeze(x, 0), param)
349
+ _, value, pi, _ = agent_policy.get_action_value_policy(
350
+ params=squeezed_param,
351
+ obs=traj_batch.obs,
352
+ done=traj_batch.done,
353
+ avail_actions=traj_batch.avail_actions,
354
+ hstate=init_hstate,
355
+ rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
356
+ aux_obs=traj_batch.oppo_onehot_id
357
+ )
358
+ log_prob = pi.log_prob(traj_batch.action)
359
+
360
+ is_relevant = jnp.equal(
361
+ jnp.argmax(traj_batch.self_onehot_id, axis=-1),
362
+ agent_id
363
+ )
364
+ loss_weights = jnp.where(is_relevant, 1, 0).astype(jnp.float32)
365
+
366
+ # Value loss
367
+ value_pred_clipped = traj_batch.value + (
368
+ value - traj_batch.value
369
+ ).clip(
370
+ -config["CLIP_EPS"], config["CLIP_EPS"])
371
+ value_losses = jnp.square(value - target_v)
372
+ value_losses_clipped = jnp.square(value_pred_clipped - target_v)
373
+ value_loss = jax.lax.cond(
374
+ loss_weights.sum() == 0,
375
+ lambda x: jnp.zeros_like(x).astype(jnp.float32),
376
+ lambda x: x,
377
+ (loss_weights * jnp.maximum(value_losses, value_losses_clipped)).sum() / (loss_weights.sum() + 1e-8)
378
+ )
379
+
380
+ n = config["PARTNER_POP_SIZE"]
381
+ # Apply different loss weights for SP and XP data
382
+ # Loss weights consist of two parts: the first term is the weighting from the BRDiv loss fucntion
383
+ # The second term is a reweighting term to compensate for the data collection process, which uniformly and independently
384
+ # samples the conf and br ids from 1, ..., n, resulting in P(SP) = 1/n and P(XP) = (n-1)/n.
385
+ # To prevent the XP loss term from dominating the SP loss term, we would like P(SP) = P(XP) = 1/2.
386
+ # Thus, we set the 2nd term of the SP weight to n/2, and the 2nd term of the XP weight to n/(2 * (n-1)).
387
+
388
+ is_sp = jnp.equal(jnp.argmax(traj_batch.self_onehot_id, axis=-1), jnp.argmax(traj_batch.oppo_onehot_id, axis=-1))
389
+ sp_weight = (1 + 2*config["XP_LOSS_WEIGHTS"]) * (n/2)
390
+ xp_weight = config["XP_LOSS_WEIGHTS"] * (n / (2 * (n-1)))
391
+ actor_weights = jnp.where(is_sp, sp_weight, xp_weight)
392
+
393
+ # Policy gradient loss
394
+ ratio = jnp.exp(log_prob - traj_batch.log_prob)
395
+ gae_norm = (gae - gae.mean()) / (gae.std() + 1e-8)
396
+ pg_loss_1 = ratio * gae_norm * actor_weights
397
+ pg_loss_2 = jnp.clip(
398
+ ratio,
399
+ 1.0 - config["CLIP_EPS"],
400
+ 1.0 + config["CLIP_EPS"]) * gae_norm * actor_weights
401
+ pg_loss = jax.lax.cond(
402
+ loss_weights.sum() == 0,
403
+ lambda x: jnp.zeros_like(x).astype(jnp.float32),
404
+ lambda x: x,
405
+ -(
406
+ loss_weights*jnp.minimum(pg_loss_1, pg_loss_2)
407
+ ).sum()/(loss_weights.sum() + 1e-8)
408
+ )
409
+
410
+ # Entropy
411
+ entropy = jax.lax.cond(
412
+ loss_weights.sum() == 0,
413
+ lambda x: jnp.zeros_like(x).astype(jnp.float32),
414
+ lambda x: x,
415
+ (loss_weights * pi.entropy()).sum()/(loss_weights.sum() + 1e-8)
416
+ )
417
+
418
+ total_loss = pg_loss + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy
419
+ return total_loss, (value_loss, pg_loss, entropy)
420
+
421
+ possible_agent_ids = jnp.expand_dims(jnp.arange(config["PARTNER_POP_SIZE"]), 1)
422
+ grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
423
+
424
+ def gather_conf_params_and_return_grads(agent_id):
425
+ param_vector = gather_params(train_state_conf.params, agent_id)
426
+ (loss_val_conf, aux_vals_conf), grads_conf = grad_fn(
427
+ param_vector, conf_policy, minbatch_conf, agent_id
428
+ )
429
+ return (loss_val_conf, aux_vals_conf), grads_conf
430
+
431
+ def gather_br_params_and_return_grads(agent_id):
432
+ param_vector = gather_params(train_state_br.params, agent_id)
433
+ (loss_val_br, aux_vals_br), grads_br = grad_fn(
434
+ param_vector, br_policy, minbatch_br, agent_id
435
+ )
436
+ return (loss_val_br, aux_vals_br), grads_br
437
+
438
+ (loss_val_conf, aux_vals_conf), grads_conf = jax.vmap(gather_conf_params_and_return_grads)(possible_agent_ids)
439
+ (loss_val_br, aux_vals_br), grads_br = jax.vmap(gather_br_params_and_return_grads)(possible_agent_ids)
440
+
441
+ grads_conf_new = jax.tree.map(lambda x: jnp.squeeze(x, 1), grads_conf)
442
+ grads_br_new = jax.tree.map(lambda x: jnp.squeeze(x, 1), grads_br)
443
+ train_state_conf = train_state_conf.apply_gradients(grads=grads_conf_new)
444
+ train_state_br = train_state_br.apply_gradients(grads=grads_br_new)
445
+ return (train_state_conf, train_state_br), ((loss_val_conf, aux_vals_conf), (loss_val_br, aux_vals_br))
446
+
447
+ (
448
+ train_state_conf, train_state_br,
449
+ traj_batch_conf, traj_batch_br,
450
+ advantages_conf, advantages_br,
451
+ targets_conf, targets_br,
452
+ rng
453
+ ) = update_state
454
+ rng, perm_rng_conf, perm_rng_br = jax.random.split(rng, 3)
455
+
456
+ minibatches_conf = _create_minibatches(traj_batch_conf, advantages_conf, targets_conf, init_conf_hstate,
457
+ config["NUM_CONF_ACTORS"], config["NUM_MINIBATCHES"], perm_rng_conf)
458
+ minibatches_br = _create_minibatches(traj_batch_br, advantages_br, targets_br, init_br_hstate,
459
+ config["NUM_BR_ACTORS"], config["NUM_MINIBATCHES"], perm_rng_br)
460
+
461
+ # Update both policies
462
+ (train_state_conf, train_state_br), all_losses = jax.lax.scan(
463
+ _update_minbatch, (train_state_conf, train_state_br), (minibatches_conf, minibatches_br)
464
+ )
465
+
466
+ update_state = (train_state_conf, train_state_br,
467
+ traj_batch_conf, traj_batch_br,
468
+ advantages_conf, advantages_br,
469
+ targets_conf, targets_br,
470
+ rng
471
+ )
472
+ return update_state, all_losses
473
+
474
+ def _update_step(update_runner_state, unused):
475
+ """
476
+ 1. Collect rollouts
477
+ 2. Compute advantage
478
+ 3. PPO updates
479
+ """
480
+ (
481
+ all_train_state_conf, all_train_state_br,
482
+ last_env_state, last_obs, last_done, last_conf_h, last_br_h,
483
+ rng, update_steps
484
+ ) = update_runner_state
485
+
486
+ rng, conf_sampling_rng, br_sampling_rng = jax.random.split(rng, 3)
487
+
488
+ conf_ids = jax.random.randint(conf_sampling_rng, (config["NUM_ENVS"],), 0, config["PARTNER_POP_SIZE"])
489
+ br_ids = jax.random.randint(br_sampling_rng, (config["NUM_ENVS"],), 0, config["PARTNER_POP_SIZE"])
490
+
491
+ runner_state = (
492
+ all_train_state_conf, all_train_state_br, conf_ids, br_ids,
493
+ last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng
494
+ )
495
+ runner_state, traj_batch = jax.lax.scan(
496
+ _env_step, runner_state, None, config["ROLLOUT_LENGTH"])
497
+ (all_train_state_conf, all_train_state_br, last_conf_ids, last_br_ids,
498
+ last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng) = runner_state
499
+
500
+ # Get the last conf and br params and ids
501
+ last_conf_params = gather_params(all_train_state_conf.params, last_conf_ids)
502
+ last_br_params = gather_params(all_train_state_br.params, last_br_ids)
503
+
504
+ last_conf_one_hots = identity_matrix[last_conf_ids]
505
+ last_br_one_hots = identity_matrix[last_br_ids]
506
+
507
+ # Get agent 0 and agent 1 trajectories from interaction between conf policy and its BR policy.
508
+ traj_batch_conf, traj_batch_br = traj_batch
509
+
510
+ # Compute advantage for confederate agent from interaction with br policy
511
+ avail_actions_0 = jax.vmap(env.get_avail_actions)(last_env_state.env_state)["agent_0"].astype(jnp.float32)
512
+ _, last_val_conf, _, _ = jax.vmap(forward_pass_conf)(
513
+ params=last_conf_params,
514
+ obs=last_obs["agent_0"],
515
+ id=last_br_one_hots,
516
+ done=last_done["agent_0"],
517
+ avail_actions=avail_actions_0,
518
+ hstate=last_conf_h,
519
+ rng=jax.random.split(jax.random.PRNGKey(0), config["NUM_ENVS"]) # Dummy key since we're just extracting the value
520
+ )
521
+ last_val_conf = last_val_conf.squeeze()
522
+ advantages_conf, targets_conf = _calculate_gae(traj_batch_conf, last_val_conf)
523
+
524
+ # Compute advantage for br policy from interaction with confederate agent
525
+ avail_actions_1 = jax.vmap(env.get_avail_actions)(last_env_state.env_state)["agent_1"].astype(jnp.float32)
526
+ _, last_val_br, _, _ = jax.vmap(forward_pass_br)(
527
+ params=last_br_params,
528
+ obs=last_obs["agent_1"],
529
+ id=last_conf_one_hots,
530
+ done=last_done["agent_1"],
531
+ avail_actions=avail_actions_1,
532
+ hstate=last_br_h,
533
+ rng=jax.random.split(jax.random.PRNGKey(0), config["NUM_ENVS"]) # Dummy key since we're just extracting the value
534
+ )
535
+ last_val_br = last_val_br.squeeze()
536
+ advantages_br, targets_br = _calculate_gae(traj_batch_br, last_val_br)
537
+
538
+ # 3) PPO update
539
+ rng, update_rng = jax.random.split(rng, 2)
540
+ update_state = (
541
+ all_train_state_conf, all_train_state_br,
542
+ traj_batch_conf, traj_batch_br,
543
+ advantages_conf, advantages_br,
544
+ targets_conf, targets_br,
545
+ update_rng
546
+ )
547
+
548
+ update_state, all_losses = jax.lax.scan(
549
+ _update_epoch, update_state, None, config["UPDATE_EPOCHS"])
550
+ all_train_state_conf, all_train_state_br = update_state[:2]
551
+ (_, (value_loss_conf, pg_loss_conf, entropy_conf)), (_, (value_loss_br, pg_loss_br, entropy_br)) = all_losses
552
+
553
+ # Metrics
554
+ def mask_and_mean(x, mask):
555
+ return jnp.where(mask, x, 0).sum() / jnp.maximum(1, mask.sum())
556
+
557
+ mask = traj_batch_conf.info.get("returned_episode", jnp.ones_like(traj_batch_conf.reward))
558
+ metric = jax.tree.map(lambda x: mask_and_mean(x, mask), traj_batch_conf.info)
559
+ metric["update_steps"] = update_steps
560
+ metric["value_loss_conf_agent"] = value_loss_conf.mean(axis=(0, 1))
561
+ metric["value_loss_br_agent"] = value_loss_br.mean(axis=(0, 1))
562
+
563
+ metric["pg_loss_conf_agent"] = pg_loss_conf.mean(axis=(0, 1))
564
+ metric["pg_loss_br_agent"] = pg_loss_br.mean(axis=(0, 1))
565
+
566
+ metric["entropy_conf"] = entropy_conf.mean(axis=(0, 1))
567
+ metric["entropy_br"] = entropy_br.mean(axis=(0, 1))
568
+
569
+ new_runner_state = (
570
+ all_train_state_conf, all_train_state_br,
571
+ last_env_state, last_obs, last_done, last_conf_h, last_br_h,
572
+ rng, update_steps + 1
573
+ )
574
+ return (new_runner_state, metric)
575
+
576
+ # --------------------------
577
+ # PPO Update and Checkpoint saving
578
+ # --------------------------
579
+ ckpt_and_eval_interval = config["NUM_UPDATES"] // max(1, config["NUM_CHECKPOINTS"] - 1) # -1 because we store a ckpt at the last update
580
+ num_ckpts = config["NUM_CHECKPOINTS"]
581
+
582
+ # Build a PyTree that holds parameters for all conf agent checkpoints
583
+ def init_ckpt_array(params_pytree):
584
+ return jax.tree.map(
585
+ lambda x: jnp.zeros((num_ckpts,) + x.shape, x.dtype),
586
+ params_pytree)
587
+
588
+ def _update_step_with_ckpt(state_with_ckpt, unused):
589
+ (update_runner_state, checkpoint_array_conf, checkpoint_array_br, ckpt_idx,
590
+ eval_info) = state_with_ckpt
591
+
592
+ # Single PPO update
593
+ new_runner_state, metric = _update_step(update_runner_state, None)
594
+
595
+ train_state_conf, train_state_br, last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng, update_steps = new_runner_state
596
+
597
+ # Decide if we store a checkpoint
598
+ # update steps is 1-indexed because it was incremented at the end of the update step
599
+ to_store = jnp.logical_or(jnp.equal(jnp.mod(update_steps-1, ckpt_and_eval_interval), 0),
600
+ jnp.equal(update_steps, config["NUM_UPDATES"]))
601
+
602
+ def store_and_eval_ckpt(args):
603
+ ckpt_arr_and_ep_infos, rng, cidx = args
604
+ ckpt_arr_conf, ckpt_arr_br, _ = ckpt_arr_and_ep_infos
605
+ new_ckpt_arr_conf = jax.tree.map(
606
+ lambda c_arr, p: c_arr.at[cidx].set(p),
607
+ ckpt_arr_conf, train_state_conf.params
608
+ )
609
+ new_ckpt_arr_br = jax.tree.map(
610
+ lambda c_arr, p: c_arr.at[cidx].set(p),
611
+ ckpt_arr_br, train_state_br.params
612
+ )
613
+
614
+ rng, eval_rng = jax.random.split(rng)
615
+ ep_last_info = jax.tree.map(lambda x: x.mean(axis=(-2, -1)),
616
+ run_all_episodes(eval_rng, train_state_conf, train_state_br))
617
+
618
+ return ((new_ckpt_arr_conf, new_ckpt_arr_br, ep_last_info), rng, cidx + 1)
619
+
620
+ def skip_ckpt(args):
621
+ return args
622
+
623
+ (checkpoint_array_and_infos, rng, ckpt_idx) = jax.lax.cond(
624
+ to_store,
625
+ store_and_eval_ckpt,
626
+ skip_ckpt,
627
+ ((checkpoint_array_conf, checkpoint_array_br, eval_info), rng, ckpt_idx)
628
+ )
629
+ checkpoint_array_conf, checkpoint_array_br, eval_ep_last_info = checkpoint_array_and_infos
630
+
631
+ metric["eval_ep_last_info"] = eval_ep_last_info # return of confederate
632
+
633
+ return ((train_state_conf, train_state_br,
634
+ last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng, update_steps),
635
+ checkpoint_array_conf, checkpoint_array_br, ckpt_idx,
636
+ eval_ep_last_info), metric
637
+
638
+ # Initialize checkpoint array
639
+ checkpoint_array_conf = init_ckpt_array(all_conf_optims.params)
640
+ checkpoint_array_br = init_ckpt_array(all_br_optims.params)
641
+ ckpt_idx = 0
642
+
643
+ # Initialize state for scan over _update_step_with_ckpt
644
+ update_steps = 0
645
+
646
+ rng, rng_eval = jax.random.split(rng, 2)
647
+ eval_ep_last_info = jax.tree.map(lambda x: x.mean(axis=(-2, -1)),
648
+ run_all_episodes(rng_eval, all_conf_optims, all_br_optims))
649
+
650
+ # Initialize environment
651
+ rng, reset_rng = jax.random.split(rng)
652
+ reset_rngs = jax.random.split(reset_rng, config["NUM_ENVS"])
653
+ init_obs, init_env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rngs)
654
+ init_done = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]}
655
+
656
+ # Initialize conf and br hstates
657
+ init_conf_h = conf_policy.init_hstate(config["NUM_CONF_ACTORS"])
658
+ init_br_h = br_policy.init_hstate(config["NUM_BR_ACTORS"])
659
+
660
+ update_runner_state = (
661
+ all_conf_optims, all_br_optims,
662
+ init_env_state, init_obs, init_done, init_conf_h, init_br_h,
663
+ rng, update_steps
664
+ )
665
+
666
+ state_with_ckpt = (
667
+ update_runner_state, checkpoint_array_conf,
668
+ checkpoint_array_br, ckpt_idx, eval_ep_last_info
669
+ )
670
+
671
+ # run training
672
+ state_with_ckpt, metrics = jax.lax.scan(
673
+ _update_step_with_ckpt,
674
+ state_with_ckpt,
675
+ xs=None,
676
+ length=config["NUM_UPDATES"]
677
+ )
678
+
679
+ (
680
+ final_runner_state, checkpoint_array_conf, checkpoint_array_br,
681
+ final_ckpt_idx, all_ep_infos
682
+ ) = state_with_ckpt
683
+
684
+ out = {
685
+ "final_params_conf": final_runner_state[0].params,
686
+ "final_params_br": final_runner_state[1].params,
687
+ "checkpoints_conf": checkpoint_array_conf,
688
+ "checkpoints_br": checkpoint_array_br,
689
+ "metrics": metrics, # metrics is from the perspective of the confederate agent (averaged over population)
690
+ "all_pair_returns": all_ep_infos
691
+ }
692
+ return out
693
+
694
+ return train
695
+ # ------------------------------
696
+ # Actually run the adversarial teammate training
697
+ # ------------------------------
698
+ train_fn = make_brdiv_agents(config)
699
+ out = train_fn(train_rng)
700
+ return out
701
+
702
+ def get_brdiv_population(config, out, env):
703
+ '''
704
+ Get the partner params and partner population for ego training.
705
+ '''
706
+ brdiv_pop_size = config["algorithm"]["PARTNER_POP_SIZE"]
707
+
708
+ # partner_params has shape (num_seeds, brdiv_pop_size, ...)
709
+ partner_params = out['final_params_conf']
710
+
711
+ partner_policy = ActorWithConditionalCriticPolicy(
712
+ action_dim=env.action_space(env.agents[1]).n,
713
+ obs_dim=env.observation_space(env.agents[1]).shape[0],
714
+ pop_size=brdiv_pop_size, # used to create onehot agent id
715
+ activation=config["algorithm"].get("ACTIVATION", "tanh")
716
+ )
717
+
718
+ # Create partner population
719
+ partner_population = AgentPopulation(
720
+ pop_size=brdiv_pop_size,
721
+ policy_cls=partner_policy
722
+ )
723
+
724
+ return partner_params, partner_population
725
+
726
+ def run_brdiv(config, wandb_logger):
727
+ algorithm_config = dict(config["algorithm"])
728
+
729
+ env = make_env(algorithm_config["ENV_NAME"], algorithm_config["ENV_KWARGS"])
730
+ env = LogWrapper(env)
731
+
732
+ log.info("Starting BRDiv training...")
733
+ start = time.time()
734
+
735
+ # Generate multiple random seeds from the base seed
736
+ rng = jax.random.PRNGKey(algorithm_config["TRAIN_SEED"])
737
+ rngs = jax.random.split(rng, algorithm_config["NUM_SEEDS"])
738
+
739
+ # Initialize br and conf policies
740
+ conf_policy = ActorWithConditionalCriticPolicy(
741
+ action_dim=env.action_space(env.agents[0]).n,
742
+ obs_dim=env.observation_space(env.agents[0]).shape[0],
743
+ pop_size=algorithm_config["PARTNER_POP_SIZE"],
744
+ )
745
+ br_policy = ActorWithConditionalCriticPolicy(
746
+ action_dim=env.action_space(env.agents[0]).n,
747
+ obs_dim=env.observation_space(env.agents[0]).shape[0],
748
+ pop_size=algorithm_config["PARTNER_POP_SIZE"],
749
+ )
750
+
751
+ # Create a vmapped version of train_brdiv_partners
752
+ with jax.disable_jit(False):
753
+ vmapped_train_fn = jax.jit(
754
+ jax.vmap(
755
+ partial(train_brdiv_partners, env=env, config=algorithm_config, conf_policy=conf_policy, br_policy=br_policy)
756
+ )
757
+ )
758
+ out = vmapped_train_fn(rngs)
759
+
760
+ end = time.time()
761
+ log.info(f"BRDiv training complete in {end - start} seconds")
762
+
763
+ metric_names = get_metric_names(algorithm_config["ENV_NAME"])
764
+ log_metrics(config, out, wandb_logger, metric_names)
765
+
766
+ partner_params, partner_population = get_brdiv_population(config, out, env)
767
+
768
+ return partner_params, partner_population
769
+
770
+
771
+ def log_metrics(config, outs, logger, metric_names: tuple):
772
+ metrics = outs["metrics"]
773
+ # metrics now has shape (num_seeds, num_updates, pop_size)
774
+ num_seeds, num_updates, pop_size = metrics["pg_loss_conf_agent"].shape # number of trained pairs
775
+
776
+ ### Log evaluation metrics
777
+ # we plot XP return curves separately from SP return curves
778
+ # shape (num_seeds, num_updates, (pop_size)^2) [pre-scalarized: mean over eval eps and agents taken inside scan]
779
+ all_returns = np.asarray(metrics["eval_ep_last_info"]["returned_episode_returns"])
780
+ xs = list(range(num_updates))
781
+
782
+ all_conf_ids, all_br_ids = _get_all_ids(pop_size)
783
+ sp_mask = (all_conf_ids == all_br_ids)
784
+ sp_returns = all_returns[:, :, sp_mask]
785
+ xp_returns = all_returns[:, :, ~sp_mask]
786
+
787
+ # Average over seeds and agent pairs (eval episodes and agents already averaged inside scan)
788
+ sp_return_curve = sp_returns.mean(axis=(0, 2))
789
+ xp_return_curve = xp_returns.mean(axis=(0, 2))
790
+
791
+ for step in range(num_updates):
792
+ logger.log_item("Eval/AvgSPReturnCurve", sp_return_curve[step], train_step=step)
793
+ logger.log_item("Eval/AvgXPReturnCurve", xp_return_curve[step], train_step=step)
794
+ logger.commit()
795
+
796
+ # log final XP matrix to wandb - average over seeds
797
+ last_returns_array = all_returns[:, -1].mean(axis=0)
798
+ last_returns_array = np.reshape(last_returns_array, (pop_size, pop_size))
799
+ logger.log_xp_matrix("Eval/LastXPMatrix", last_returns_array)
800
+
801
+ ### Log population loss as multi-line plots, where each line is a different population member
802
+ # shape (num_seeds, num_updates, update_epochs, num_minibatches, pop_size)
803
+ # Average over seeds
804
+ processed_losses = {
805
+ "ConfPGLoss": np.asarray(metrics["pg_loss_conf_agent"]).mean(axis=0).transpose(),
806
+ "BRPGLoss": np.asarray(metrics["pg_loss_br_agent"]).mean(axis=0).transpose(),
807
+ "ConfValLoss": np.asarray(metrics["value_loss_conf_agent"]).mean(axis=0).transpose(),
808
+ "BRValLoss": np.asarray(metrics["value_loss_br_agent"]).mean(axis=0).transpose(),
809
+ "ConfEntropy": np.asarray(metrics["entropy_conf"]).mean(axis=0).transpose(),
810
+ "BREntropy": np.asarray(metrics["entropy_br"]).mean(axis=0).transpose(),
811
+ }
812
+
813
+ xs = list(range(num_updates))
814
+ keys = [f"pair {i}" for i in range(pop_size)]
815
+ for loss_name, loss_data in processed_losses.items():
816
+ if np.isnan(loss_data).any():
817
+ raise ValueError(f"Found nan in loss {loss_name}")
818
+ logger.log_item(f"Losses/{loss_name}",
819
+ wandb.plot.line_series(xs=xs, ys=loss_data, keys=keys,
820
+ title=loss_name, xname="train_step")
821
+ )
822
+
823
+ ### Log artifacts
824
+ savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
825
+ # Save train run output and log to wandb as artifact
826
+ out_savepath = save_train_run(outs, savedir, savename="saved_train_run")
827
+ if config["logger"]["log_train_out"]:
828
+ logger.log_artifact(name="saved_train_run", path=out_savepath, type_name="train_run")
829
+
830
+ # Cleanup locally logged out files
831
+ if not config["local_logger"]["save_train_out"]:
832
+ shutil.rmtree(out_savepath)
teammate_generation/CoMeDi.py ADDED
@@ -0,0 +1,1161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Implementation of the CoMeDi teammate generation algorithm (Sarkar et al. NeurIPS 2023)
2
+ https://openreview.net/forum?id=MljeRycu9s
3
+
4
+ Command to run CoMeDi only on LBF:
5
+ python teammate_generation/run.py algorithm=comedi/lbf/lbf_7x7_nolevels task=lbf/lbf_7x7_nolevels label=test_comedi run_heldout_eval=false train_ego=false
6
+
7
+ Limitations: does not support recurrent actors.
8
+ '''
9
+ from functools import partial
10
+ import logging
11
+ import shutil
12
+ import time
13
+ from typing import NamedTuple
14
+
15
+ from flax.training.train_state import TrainState
16
+ import hydra
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import optax
21
+ import wandb
22
+
23
+ from agents.mlp_actor_critic_agent import ActorWithConditionalCriticPolicy
24
+ from agents.initialize_agents import initialize_actor_with_conditional_critic
25
+ from agents.population_interface import AgentPopulation
26
+ from agents.population_buffer import BufferedPopulation
27
+ from common.save_load_utils import save_train_run
28
+ from common.plot_utils import get_metric_names
29
+ from common.run_episodes import run_episodes
30
+ from envs import make_env
31
+ from envs.log_wrapper import LogWrapper, LogEnvState
32
+ from marl.ippo import make_train as make_ppo_train
33
+ from marl.ppo_utils import Transition, unbatchify, _create_minibatches
34
+
35
+ log = logging.getLogger(__name__)
36
+ logging.basicConfig(level=logging.INFO)
37
+
38
+ class ResetTransition(NamedTuple):
39
+ '''Stores extra information for resetting agents to a point in some trajectory.'''
40
+ env_state: LogEnvState
41
+ conf_obs: jnp.ndarray
42
+ partner_obs: jnp.ndarray
43
+ conf_done: jnp.ndarray
44
+ partner_done: jnp.ndarray
45
+ conf_hstate: jnp.ndarray
46
+ partner_hstate: jnp.ndarray
47
+
48
+ def train_comedi_partners(train_rng, wandb_logger, env, config):
49
+ num_agents = env.num_agents
50
+ assert num_agents == 2, "This code assumes the environment has exactly 2 agents."
51
+
52
+ # Define 4 types of rollouts: SP, XP, MP, MP2
53
+ config["NUM_GAME_AGENTS"] = num_agents
54
+
55
+ config["NUM_ACTORS"] = num_agents * config["NUM_ENVS"]
56
+ # Right now assume control of both agent and its BR
57
+ config["NUM_CONTROLLED_ACTORS"] = config["NUM_ACTORS"]
58
+
59
+ # Compute numbber of updates PER outermost iteration
60
+ # Calculate timesteps per update
61
+ # 1. Overhead from population selection rollouts
62
+ # We divide by 2 because for ease in Jax, this implementation uses a vmap over PARTNER_POP_SIZE to
63
+ # evaluate the agent generated at each outermost iteration against all previously
64
+ # generated agents, but a non-Jax implementation would only need to evaluate against
65
+ # *previously* generated agents.
66
+ selection_steps = config["PARTNER_POP_SIZE"] * config["NUM_ARGMAX_ROLLOUT_EPS"] * config["ROLLOUT_LENGTH"] // 2
67
+ # 2. Training rollouts: 4 distinct rollout phases (SP, XP, MP, MP2) each using NUM_ENVS
68
+ training_steps = 4 * config["ROLLOUT_LENGTH"] * config["NUM_ENVS"]
69
+
70
+ steps_per_update = selection_steps + training_steps
71
+ config["NUM_UPDATES"] = int(config["TOTAL_TIMESTEPS_PER_ITERATION"] // steps_per_update)
72
+
73
+ def make_comedi_agents(config):
74
+ def linear_schedule(count):
75
+ frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
76
+ return config["LR"] * frac
77
+
78
+ def train_init_ippo_partners(config, partner_rng, env):
79
+ '''
80
+ Train a pool IPPO agents w/parameter sharing.
81
+ Returns out, a dictionary of the model checkpoints, final parameters, and metrics.
82
+ '''
83
+ # POP_SIZE is referenced throughout the CoMeDi training loops
84
+ config["POP_SIZE"] = config["PARTNER_POP_SIZE"]
85
+ # Use a local copy for warmup-specific overrides to avoid
86
+ # mutating the shared config (ACTOR_TYPE, TOTAL_TIMESTEPS)
87
+ warmup_config = dict(config)
88
+ warmup_config["TOTAL_TIMESTEPS"] = config["TOTAL_TIMESTEPS_PER_ITERATION"]
89
+ warmup_config["ACTOR_TYPE"] = "pseudo_actor_with_conditional_critic"
90
+ out = make_ppo_train(warmup_config, env, wandb_logger)(partner_rng)
91
+ return out
92
+
93
+ def train(rng):
94
+ # Start by training a single PPO agent via self-play
95
+ rng, init_ppo_rng, init_conf_rng = jax.random.split(rng, 3)
96
+
97
+ init_ppo_partner = train_init_ippo_partners(config, init_ppo_rng, env)
98
+
99
+ # Initialize a population buffer
100
+ dummy_policy, dummy_init_params = initialize_actor_with_conditional_critic(config, env, init_conf_rng)
101
+ partner_population = BufferedPopulation(
102
+ max_pop_size=config["PARTNER_POP_SIZE"],
103
+ policy_cls=dummy_policy,
104
+ )
105
+
106
+ population_buffer = partner_population.reset_buffer(dummy_init_params)
107
+ population_buffer = partner_population.add_agent(population_buffer, init_ppo_partner["final_params"])
108
+
109
+ def add_conf_policy(pop_buffer, func_input):
110
+ num_existing_agents, rng = func_input
111
+ rng, init_conf_rng = jax.random.split(rng)
112
+
113
+ # Create new confederate agent policy and critic
114
+ policy, init_params = initialize_actor_with_conditional_critic(
115
+ config, env, init_conf_rng
116
+ )
117
+
118
+ # Create a train_state and optimizer for the newly initialzied model
119
+ if config["ANNEAL_LR"]:
120
+ tx = optax.chain(
121
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
122
+ optax.adam(learning_rate=linear_schedule, eps=1e-5),
123
+ )
124
+ else:
125
+ tx = optax.chain(
126
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
127
+ optax.adam(config["LR"], eps=1e-5))
128
+
129
+ train_state = TrainState.create(
130
+ apply_fn=policy.network.apply,
131
+ params=init_params,
132
+ tx=tx,
133
+ )
134
+
135
+ # Reset envs for SP, XP, and MP
136
+ rng, reset_rng_eval, reset_rng_sp, reset_rng_xp, reset_rng_mp, reset_rng_mp2 = jax.random.split(rng, 6)
137
+
138
+ reset_rngs_sps = jax.random.split(reset_rng_sp, config["NUM_ENVS"])
139
+ reset_rngs_xps = jax.random.split(reset_rng_xp, config["NUM_ENVS"])
140
+ reset_rngs_mps = jax.random.split(reset_rng_mp, config["NUM_ENVS"])
141
+ reset_rngs_mps2 = jax.random.split(reset_rng_mp2, config["NUM_ENVS"])
142
+
143
+ obsv_xp, env_state_xp = jax.vmap(env.reset, in_axes=(0,))(reset_rngs_sps)
144
+ obsv_sp, env_state_sp = jax.vmap(env.reset, in_axes=(0,))(reset_rngs_xps)
145
+ obsv_mp, env_state_mp = jax.vmap(env.reset, in_axes=(0,))(reset_rngs_mps)
146
+ obsv_mp2, env_state_mp2 = jax.vmap(env.reset, in_axes=(0,))(reset_rngs_mps2)
147
+
148
+ # build a pytree that can hold the parameters for all checkpoints.
149
+ ckpt_and_eval_interval = config["NUM_UPDATES"] // max(1, config["NUM_CHECKPOINTS"] - 1)
150
+ num_ckpts = config["NUM_CHECKPOINTS"]
151
+ def init_ckpt_array(params_pytree):
152
+ return jax.tree.map(
153
+ lambda x: jnp.zeros((num_ckpts,) + x.shape, x.dtype),
154
+ params_pytree
155
+ )
156
+
157
+ # define evaluation function
158
+ rng, eval_rng = jax.random.split(rng, 2)
159
+ def per_id_run_episode_fixed_rng(agent0_param, agent1_id):
160
+ agent1_param = partner_population.gather_agent_params(pop_buffer,
161
+ agent_indices=agent1_id * jnp.ones((1,), dtype=np.int32))
162
+ agent1_param = jax.tree_map(lambda y: jnp.squeeze(y, 0), agent1_param)
163
+ all_outs = run_episodes(
164
+ rng=eval_rng, env=env,
165
+ agent_0_param=agent0_param, agent_0_policy=policy,
166
+ agent_1_param=agent1_param, agent_1_policy=policy,
167
+ max_episode_steps=config["ROLLOUT_LENGTH"],
168
+ num_eps=config["NUM_ARGMAX_ROLLOUT_EPS"]
169
+ )
170
+ return all_outs
171
+
172
+ def _update_step(update_with_ckpt_runner_state, unused):
173
+ update_runner_state, checkpoint_array, ckpt_idx = update_with_ckpt_runner_state
174
+ (
175
+ train_state, pop_buffer,
176
+ env_state_sp, obsv_sp,
177
+ env_state_xp, obsv_xp,
178
+ env_state_mp, obsv_mp,
179
+ env_state_mp2, obsv_mp2,
180
+ last_dones_xp,
181
+ last_dones_sp,
182
+ last_dones_mp,
183
+ last_dones_mp2,
184
+ rng, update_steps,
185
+ num_prev_trained_conf
186
+ ) = update_runner_state
187
+
188
+ # Identify the expected returns from the newly trained policy
189
+ # when interacting with the previously generated confederate
190
+ # policies
191
+ valid_sampling_indices = jnp.arange(config["POP_SIZE"])
192
+ run_all_rollouts = jax.vmap(per_id_run_episode_fixed_rng, in_axes=(None, 0))(
193
+ train_state.params,valid_sampling_indices)
194
+
195
+ # Mask out the XP returns against invalid policies
196
+ # resulting from IDs that are yet set to a specific
197
+ # confederate params
198
+ all_mean_returns = run_all_rollouts["returned_episode_returns"][:, :, 0].mean(axis=-1)
199
+ masked_mean_returns = jnp.where(
200
+ valid_sampling_indices >= num_prev_trained_conf, -jnp.inf, all_mean_returns
201
+ )
202
+
203
+ # Pick the right confederate params to act as the XP agent
204
+ max_means_id = masked_mean_returns.argmax()
205
+ xp_param = jax.tree_map(
206
+ lambda x: jnp.squeeze(x, 0),
207
+ partner_population.gather_agent_params(pop_buffer,
208
+ agent_indices=max_means_id * jnp.ones((1,), dtype=np.int32))
209
+ )
210
+
211
+ rng, rng_xp, rng_sp, rng_mp, rng_mp2 = jax.random.split(rng, 5)
212
+
213
+ def _env_step_conf_ego(runner_state, unused):
214
+ """
215
+ agent_0 = confederate, agent_1 = ego
216
+ Returns updated runner_state and a Transition for the confederate.
217
+ """
218
+ train_state, xp_param, xp_id, env_state, last_obs, last_dones, rng = runner_state
219
+ rng, act_rng, partner_rng, step_rng = jax.random.split(rng, 4)
220
+
221
+ obs_0 = last_obs["agent_0"]
222
+ obs_1 = last_obs["agent_1"]
223
+
224
+ # Get available actions for agent 0 from environment state
225
+ avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
226
+ avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32)
227
+ avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32)
228
+
229
+ # Add one-hot ID of XP teammate
230
+ xp_one_hot_id = jnp.eye(config["POP_SIZE"])[xp_id]
231
+ xp_one_hot_id = jnp.expand_dims(
232
+ jnp.expand_dims(
233
+ xp_one_hot_id, 0
234
+ ), 0
235
+ )
236
+
237
+ # Agent_0 (confederate) action using policy interface
238
+ aux_obs = jnp.repeat(xp_one_hot_id, config["NUM_ENVS"], axis=1)
239
+ act_0, val_0, pi_0, _ = policy.get_action_value_policy(
240
+ params=train_state.params,
241
+ obs=obs_0.reshape(1, config["NUM_ENVS"], -1),
242
+ done=last_dones["agent_0"].reshape(1, config["NUM_ENVS"]),
243
+ avail_actions=jax.lax.stop_gradient(avail_actions_0),
244
+ hstate=None,
245
+ rng=act_rng,
246
+ aux_obs=aux_obs
247
+ )
248
+ logp_0 = pi_0.log_prob(act_0)
249
+
250
+ act_0 = act_0.squeeze()
251
+ logp_0 = logp_0.squeeze()
252
+ val_0 = val_0.squeeze()
253
+
254
+ # Agent_1 (ego) action using policy interface
255
+ act_1, _, _, _ = policy.get_action_value_policy(
256
+ params=xp_param,
257
+ obs=obs_1.reshape(1, config["NUM_ENVS"], -1),
258
+ done=last_dones["agent_1"].reshape(1, config["NUM_ENVS"]),
259
+ avail_actions=jax.lax.stop_gradient(avail_actions_1),
260
+ hstate=None,
261
+ rng=partner_rng,
262
+ aux_obs=aux_obs
263
+ )
264
+ act_1 = act_1.squeeze()
265
+
266
+ # Combine actions into the env format
267
+ combined_actions = jnp.concatenate([act_0, act_1], axis=0) # shape (2*num_envs,)
268
+ env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents)
269
+ env_act = {k: v.flatten() for k, v in env_act.items()}
270
+
271
+ # Step env
272
+ step_rngs = jax.random.split(step_rng, config["NUM_ENVS"])
273
+ obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))(
274
+ step_rngs, env_state, env_act
275
+ )
276
+ # note that num_actors = num_envs * num_agents
277
+ info_0 = jax.tree.map(lambda x: x[:, 0], info)
278
+
279
+ # Store agent_0 data in transition
280
+ transition = Transition(
281
+ done=done["agent_0"],
282
+ action=act_0,
283
+ value=val_0,
284
+ reward=reward["agent_1"],
285
+ log_prob=logp_0,
286
+ obs=obs_0,
287
+ info=info_0,
288
+ avail_actions=avail_actions_0
289
+ )
290
+ new_runner_state = (train_state, xp_param, xp_id, env_state_next, obs_next, done, rng)
291
+ return new_runner_state, transition
292
+
293
+ def _env_step_conf_br(runner_state, unused):
294
+ """
295
+ agent_0 = confederate, agent_1 = best response
296
+ Returns updated runner_state, and Transitions for the confederate and best response.
297
+ """
298
+ train_state, env_state, last_obs, last_dones, rng, current_trained_pop_id, reset_traj_batch = runner_state
299
+ rng, conf_rng, br_rng, step_rng = jax.random.split(rng, 4)
300
+
301
+ def gather_sampled(data_pytree, flat_indices, first_nonbatch_dim: int):
302
+ '''Will treat all dimensions up to the first_nonbatch_dim as batch dimensions. '''
303
+ batch_size = config["ROLLOUT_LENGTH"] * config["NUM_ENVS"]
304
+ flat_data = jax.tree.map(lambda x: x.reshape(batch_size, *x.shape[first_nonbatch_dim:]), data_pytree)
305
+ sampled_data = jax.tree.map(lambda x: x[flat_indices], flat_data) # Shape (N, ...)
306
+ return sampled_data
307
+
308
+ if reset_traj_batch is not None:
309
+ rng, sample_rng = jax.random.split(rng)
310
+ needs_resample = last_dones["__all__"] # shape (N,) bool
311
+
312
+ total_reset_states = config["ROLLOUT_LENGTH"] * config["NUM_ENVS"]
313
+ sampled_indices = jax.random.randint(sample_rng, shape=(config["NUM_ENVS"],), minval=0,
314
+ maxval=total_reset_states)
315
+
316
+ # Gather sampled leaves from each data pytree
317
+ sampled_env_state = gather_sampled(reset_traj_batch.env_state, sampled_indices, first_nonbatch_dim=2)
318
+ sampled_conf_obs = gather_sampled(reset_traj_batch.conf_obs, sampled_indices, first_nonbatch_dim=2)
319
+ sampled_br_obs = gather_sampled(reset_traj_batch.partner_obs, sampled_indices, first_nonbatch_dim=2)
320
+ sampled_conf_done = gather_sampled(reset_traj_batch.conf_done, sampled_indices, first_nonbatch_dim=2)
321
+ sampled_br_done = gather_sampled(reset_traj_batch.partner_done, sampled_indices, first_nonbatch_dim=2)
322
+
323
+ # for done environments, select data corresponding to the reset_traj_batch states
324
+ env_state = jax.tree.map(
325
+ lambda sampled, original: jnp.where(
326
+ needs_resample.reshape((-1,) + (1,) * (original.ndim - 1)),
327
+ sampled, original
328
+ ),
329
+ sampled_env_state,
330
+ env_state
331
+ )
332
+ obs_0 = jnp.where(needs_resample[:, jnp.newaxis], sampled_conf_obs, last_obs["agent_0"])
333
+ obs_1 = jnp.where(needs_resample[:, jnp.newaxis], sampled_br_obs, last_obs["agent_1"])
334
+
335
+ dones_0 = jnp.where(needs_resample, sampled_conf_done, last_dones["agent_0"])
336
+ dones_1 = jnp.where(needs_resample, sampled_br_done, last_dones["agent_1"])
337
+
338
+ else:
339
+
340
+ # Reset conf-br data collection from conf-ego states
341
+ obs_0, obs_1 = last_obs["agent_0"], last_obs["agent_1"]
342
+ dones_0, dones_1 = last_dones["agent_0"], last_dones["agent_1"]
343
+
344
+ # Get available actions for agent 0 from environment state
345
+ avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
346
+ avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32)
347
+ avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32)
348
+
349
+ # Agent_0 (confederate) action
350
+ # Add one-hot ID of XP teammate
351
+ sp_one_hot_id = jnp.eye(config["POP_SIZE"])[current_trained_pop_id]
352
+ sp_one_hot_id = jnp.expand_dims(
353
+ jnp.expand_dims(
354
+ sp_one_hot_id, 0
355
+ ), 0
356
+ )
357
+
358
+ aux_obs = jnp.repeat(sp_one_hot_id, config["NUM_ENVS"], 1)
359
+ act_0, val_0, pi_0, _ = policy.get_action_value_policy(
360
+ params=train_state.params,
361
+ obs=obs_0.reshape(1, config["NUM_ENVS"], -1),
362
+ done=dones_0.reshape(1, config["NUM_ENVS"]),
363
+ avail_actions=jax.lax.stop_gradient(avail_actions_0),
364
+ hstate=None,
365
+ rng=conf_rng,
366
+ aux_obs=aux_obs
367
+ )
368
+ logp_0 = pi_0.log_prob(act_0)
369
+
370
+ act_0 = act_0.squeeze()
371
+ logp_0 = logp_0.squeeze()
372
+ val_0 = val_0.squeeze()
373
+
374
+ # Agent 1 (best response) action
375
+ act_1, val_1, pi_1, _ = policy.get_action_value_policy(
376
+ params=train_state.params,
377
+ obs=obs_1.reshape(1, config["NUM_ENVS"], -1),
378
+ done=dones_1.reshape(1, config["NUM_ENVS"]),
379
+ avail_actions=jax.lax.stop_gradient(avail_actions_1),
380
+ hstate=None,
381
+ rng=br_rng,
382
+ aux_obs=aux_obs
383
+ )
384
+ logp_1 = pi_1.log_prob(act_1)
385
+
386
+ act_1 = act_1.squeeze()
387
+ logp_1 = logp_1.squeeze()
388
+ val_1 = val_1.squeeze()
389
+
390
+ # Combine actions into the env format
391
+ combined_actions = jnp.concatenate([act_0, act_1], axis=0) # shape (2*num_envs,)
392
+ env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents)
393
+ env_act = {k: v.flatten() for k, v in env_act.items()}
394
+
395
+ # Step env
396
+ step_rngs = jax.random.split(step_rng, config["NUM_ENVS"])
397
+ obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))(
398
+ step_rngs, env_state, env_act
399
+ )
400
+ info_0 = jax.tree.map(lambda x: x[:, 0], info)
401
+ info_1 = jax.tree.map(lambda x: x[:, 1], info)
402
+
403
+ # Store agent_0 (confederate) data in transition
404
+ transition_0 = Transition(
405
+ done=done["agent_0"],
406
+ action=act_0,
407
+ value=val_0,
408
+ reward=reward["agent_0"],
409
+ log_prob=logp_0,
410
+ obs=obs_0,
411
+ info=info_0,
412
+ avail_actions=avail_actions_0
413
+ )
414
+ # Store agent_1 (best response) data in transition
415
+ transition_1 = Transition(
416
+ done=done["agent_1"],
417
+ action=act_1,
418
+ value=val_1,
419
+ reward=reward["agent_1"],
420
+ log_prob=logp_1,
421
+ obs=obs_1,
422
+ info=info_1,
423
+ avail_actions=avail_actions_1
424
+ )
425
+ # Pass reset_traj_batch and init_br_hstate through unchanged in the state tuple
426
+ new_runner_state = (train_state, env_state_next, obs_next, done, rng, current_trained_pop_id, reset_traj_batch)
427
+ return new_runner_state, (transition_0, transition_1)
428
+
429
+ def _env_step_mixed(runner_state, unused):
430
+ """
431
+ agent_0 = confederate, agent_1 = ego OR best response
432
+ Returns a ResetTransition for resetting to env states encountered here.
433
+ """
434
+ train_state_conf, ego_param, env_state, last_obs, last_dones, rng, current_trained_pop_id = runner_state
435
+ rng, act_rng, ego_act_rng, br_act_rng, partner_choice_rng, step_rng = jax.random.split(rng, 6)
436
+
437
+ obs_0 = last_obs["agent_0"]
438
+ obs_1 = last_obs["agent_1"]
439
+
440
+ # Get available actions for agent 0 from environment state
441
+ avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
442
+ avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32)
443
+ avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32)
444
+
445
+ xp_one_hot_id = jnp.eye(config["POP_SIZE"])[current_trained_pop_id]
446
+ xp_one_hot_id = jnp.expand_dims(
447
+ jnp.expand_dims(
448
+ xp_one_hot_id, 0
449
+ ), 0
450
+ )
451
+
452
+ # Agent_0 (confederate) action using policy interface
453
+ aux_obs = jnp.repeat(xp_one_hot_id, config["NUM_ENVS"], axis=1)
454
+
455
+ # Agent_0 (confederate) action using policy interface
456
+ act_0, val_0, pi_0, _ = policy.get_action_value_policy(
457
+ params=train_state_conf.params,
458
+ obs=obs_0.reshape(1, config["NUM_ENVS"], -1),
459
+ done=last_dones["agent_0"].reshape(1, config["NUM_ENVS"]),
460
+ avail_actions=jax.lax.stop_gradient(avail_actions_0),
461
+ hstate=None,
462
+ rng=act_rng,
463
+ aux_obs=aux_obs
464
+ )
465
+ logp_0 = pi_0.log_prob(act_0)
466
+
467
+ act_0 = act_0.squeeze()
468
+ logp_0 = logp_0.squeeze()
469
+ val_0 = val_0.squeeze()
470
+
471
+ ### Compute both the ego action and the best response action
472
+ act_ego, _, _, _ = policy.get_action_value_policy(
473
+ params=ego_param,
474
+ obs=obs_1.reshape(1, config["NUM_ENVS"], -1),
475
+ done=last_dones["agent_1"].reshape(1, config["NUM_ENVS"]),
476
+ avail_actions=jax.lax.stop_gradient(avail_actions_1),
477
+ hstate=None,
478
+ rng=ego_act_rng,
479
+ aux_obs=aux_obs
480
+ )
481
+ act_br, _, _, _ = policy.get_action_value_policy(
482
+ params=train_state.params,
483
+ obs=obs_1.reshape(1, config["NUM_ENVS"], -1),
484
+ done=last_dones["agent_1"].reshape(1, config["NUM_ENVS"]),
485
+ avail_actions=jax.lax.stop_gradient(avail_actions_1),
486
+ hstate=None,
487
+ rng=br_act_rng,
488
+ aux_obs=aux_obs
489
+ )
490
+
491
+ act_ego = act_ego.squeeze()
492
+ act_br = act_br.squeeze()
493
+ # Agent 1 (ego or best response) action - choose between ego and best response
494
+ partner_choice = jax.random.randint(partner_choice_rng, shape=(config["NUM_ENVS"],), minval=0, maxval=2)
495
+ act_1 = jnp.where(partner_choice == 0, act_ego, act_br)
496
+
497
+ # Combine actions into the env format
498
+ combined_actions = jnp.concatenate([act_0, act_1], axis=0)
499
+ env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents)
500
+ env_act = {k: v.flatten() for k, v in env_act.items()}
501
+
502
+ # Step env
503
+ step_rngs = jax.random.split(step_rng, config["NUM_ENVS"])
504
+ obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))(
505
+ step_rngs, env_state, env_act
506
+ )
507
+
508
+ reset_transition = ResetTransition(
509
+ # all of these are from before env step
510
+ env_state=env_state,
511
+ conf_obs=obs_0,
512
+ partner_obs=obs_1,
513
+ conf_done=last_dones["agent_0"],
514
+ partner_done=last_dones["agent_1"],
515
+ conf_hstate=None,
516
+ # we record the best response hstate because we use it to reset the best response
517
+ partner_hstate=None
518
+ )
519
+ new_runner_state = (train_state_conf, ego_param, env_state_next, obs_next, done, rng, current_trained_pop_id)
520
+ return new_runner_state, reset_transition
521
+
522
+ # Do XP rollout (based on train_state params and the param in pop_buffer identified in Step 1)
523
+ runner_state_xp = (train_state, xp_param, max_means_id, env_state_xp, obsv_xp, last_dones_xp, rng_xp)
524
+ runner_state_xp, traj_batch_xp = jax.lax.scan(
525
+ _env_step_conf_ego, runner_state_xp, None, config["ROLLOUT_LENGTH"])
526
+ (train_state, xp_param, max_means_id, env_state_xp, last_obs_xp, last_dones_xp, rng_xp) = runner_state_xp
527
+
528
+ # Do self-play (based on train_state params) rollout like in the IPPO code
529
+ runner_state_sp = (train_state, env_state_sp, obsv_sp, last_dones_sp, rng_sp, num_prev_trained_conf, None)
530
+ runner_state_sp, (traj_batch_sp_agent0, traj_batch_sp_agent1) = jax.lax.scan(
531
+ _env_step_conf_br, runner_state_sp, None, config["ROLLOUT_LENGTH"])
532
+ (train_state, env_state_sp, last_obs_sp, last_dones_sp, rng_sp, num_prev_trained_conf, mp_traj_batch) = runner_state_sp
533
+
534
+ # Step 4
535
+ # Do MP rollout (based on train_state params and the param in pop_buffer identified in Step 1)
536
+ runner_state_mp = (train_state, xp_param, env_state_mp, obsv_mp, last_dones_mp, rng_mp, num_prev_trained_conf)
537
+ runner_state_mp, traj_batch_mp = jax.lax.scan(
538
+ _env_step_mixed, runner_state_mp, None, config["ROLLOUT_LENGTH"])
539
+ (train_state, xp_param, env_state_mp, last_obs_mp, last_dones_mp, rng_mp, num_prev_trained_conf) = runner_state_mp
540
+
541
+ runner_state_smp = (train_state, env_state_mp2, obsv_mp2, last_dones_mp2, rng_mp2, num_prev_trained_conf, traj_batch_mp)
542
+ runner_state_smp, (traj_batch_smp0, traj_batch_smp1) = jax.lax.scan(
543
+ _env_step_conf_br, runner_state_smp, None, config["ROLLOUT_LENGTH"])
544
+ (train_state, env_state_mp2, last_obs_mp2, last_dones_mp2, rng_mp2, num_prev_trained_conf, mp2_traj_batch) = runner_state_smp
545
+
546
+ def _calculate_gae(traj_batch, last_val):
547
+ def _get_advantages(gae_and_next_value, transition):
548
+ gae, next_value = gae_and_next_value
549
+ done, value, reward = (
550
+ transition.done,
551
+ transition.value,
552
+ transition.reward,
553
+ )
554
+ delta = reward + config["GAMMA"] * next_value * (1 - done) - value
555
+ gae = (
556
+ delta
557
+ + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
558
+ )
559
+ return (gae, value), gae
560
+
561
+ _, advantages = jax.lax.scan(
562
+ _get_advantages,
563
+ (jnp.zeros_like(last_val), last_val),
564
+ traj_batch,
565
+ reverse=True,
566
+ unroll=16,
567
+ )
568
+ return advantages, advantages + traj_batch.value
569
+
570
+ def _compute_advantages_and_targets(env_state, policy, policy_params, policy_hstate,
571
+ last_obs, last_dones, traj_batch, agent_name, value_idx=None):
572
+ '''Value_idx argument is to support the ActorWithDoubleCritic (confederate) policy, which
573
+ has two value heads. Value head 0 models the ego agent while value head 1 models the best response.'''
574
+ avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)[agent_name].astype(jnp.float32)
575
+
576
+ # Add one-hot ID of interaction teammate
577
+ xp_one_hot_id = jnp.eye(config["POP_SIZE"])[value_idx]
578
+ xp_one_hot_id = jnp.expand_dims(
579
+ jnp.expand_dims(
580
+ xp_one_hot_id, 0
581
+ ), 0
582
+ )
583
+
584
+ # Agent_0 (confederate) action using policy interface
585
+ aux_obs = jnp.repeat(xp_one_hot_id, last_obs[agent_name].shape[0], axis=1)
586
+
587
+ _, vals, _, _ = policy.get_action_value_policy(
588
+ params=policy_params,
589
+ obs=last_obs[agent_name].reshape(1, last_obs[agent_name].shape[0], -1),
590
+ done=last_dones[agent_name].reshape(1, last_obs[agent_name].shape[0]),
591
+ avail_actions=jax.lax.stop_gradient(avail_actions),
592
+ hstate=policy_hstate,
593
+ rng=jax.random.PRNGKey(0), # dummy key as we don't sample actions
594
+ aux_obs=aux_obs
595
+ )
596
+ last_val = vals.squeeze()
597
+ advantages, targets = _calculate_gae(traj_batch, last_val)
598
+ return advantages, targets
599
+
600
+ # 5a) Compute conf advantages for XP (conf-ego) interaction
601
+ advantages_xp_conf, targets_xp_conf = _compute_advantages_and_targets(
602
+ env_state_xp, policy, train_state.params, None,
603
+ last_obs_xp, last_dones_xp, traj_batch_xp, "agent_0", value_idx=max_means_id)
604
+
605
+ # 5b) Compute conf and br advantages for SP (conf-br) interaction
606
+ advantages_sp_conf, targets_sp_conf = _compute_advantages_and_targets(
607
+ env_state_sp, policy, train_state.params, None,
608
+ last_obs_sp, last_dones_sp, traj_batch_sp_agent0, "agent_0", value_idx=num_prev_trained_conf)
609
+
610
+ advantages_sp_br, targets_sp_br = _compute_advantages_and_targets(
611
+ env_state_sp, policy, train_state.params, None,
612
+ last_obs_sp, last_dones_sp, traj_batch_sp_agent1, "agent_1", value_idx=num_prev_trained_conf)
613
+
614
+ # 5c) Compute advantages from MP interactions
615
+ advantages_mp_conf, targets_mp_conf = _compute_advantages_and_targets(
616
+ env_state_mp2, policy, train_state.params, None,
617
+ last_obs_mp2, last_dones_mp2, traj_batch_smp0, "agent_0", value_idx=num_prev_trained_conf)
618
+
619
+ advantages_mp_br, targets_mp_br = _compute_advantages_and_targets(
620
+ env_state_mp2, policy, train_state.params, None,
621
+ last_obs_mp2, last_dones_mp2, traj_batch_smp1, "agent_1", value_idx=num_prev_trained_conf)
622
+
623
+ def _update_epoch(update_state, unused):
624
+ def _compute_ppo_value_loss(pred_value, traj_batch, target_v):
625
+ '''Value loss function for PPO'''
626
+ value_pred_clipped = traj_batch.value + (
627
+ pred_value - traj_batch.value
628
+ ).clip(
629
+ -config["CLIP_EPS"], config["CLIP_EPS"])
630
+ value_losses = jnp.square(pred_value - target_v)
631
+ value_losses_clipped = jnp.square(value_pred_clipped - target_v)
632
+ value_loss = (
633
+ jnp.maximum(value_losses, value_losses_clipped).mean()
634
+ )
635
+ return value_loss
636
+
637
+ def _compute_ppo_pg_loss(log_prob, traj_batch, gae):
638
+ '''Policy gradient loss function for PPO'''
639
+ ratio = jnp.exp(log_prob - traj_batch.log_prob)
640
+ gae_norm = (gae - gae.mean()) / (gae.std() + 1e-8)
641
+ pg_loss_1 = ratio * gae_norm
642
+ pg_loss_2 = jnp.clip(
643
+ ratio,
644
+ 1.0 - config["CLIP_EPS"],
645
+ 1.0 + config["CLIP_EPS"]) * gae_norm
646
+ pg_loss = -jnp.mean(jnp.minimum(pg_loss_1, pg_loss_2))
647
+ return pg_loss
648
+
649
+ def _update_minbatch_conf(train_state_conf, batch_infos):
650
+ minbatch_xp, minbatch_sp1, minbatch_sp2, minbatch_mp1, minbatch_mp2, xp_id, sp_id = batch_infos
651
+ _, traj_batch_xp, advantages_xp, returns_xp = minbatch_xp
652
+ _, traj_batch_sp1, advantages_sp1, returns_sp1 = minbatch_sp1
653
+ _, traj_batch_sp2, advantages_sp2, returns_sp2 = minbatch_sp2
654
+ _, traj_batch_mp1, advantages_mp1, returns_mp1 = minbatch_mp1
655
+ _, traj_batch_mp2, advantages_mp2, returns_mp2 = minbatch_mp2
656
+
657
+ def _loss_fn_conf(params, traj_batch_xp, gae_xp, target_v_xp,
658
+ traj_batch_sp, gae_sp, target_v_sp,
659
+ traj_batch_sp2, gae_sp2, target_v_sp2,
660
+ traj_batch_mp, gae_mp, target_v_mp,
661
+ traj_batch_mp2, gae_mp2, target_v_mp2):
662
+ # get policy and value of confederate versus ego and best response agents respectively
663
+ xp_one_hot_id = jnp.eye(config["POP_SIZE"])[xp_id]
664
+ xp_one_hot_id = jnp.expand_dims(
665
+ jnp.expand_dims(
666
+ xp_one_hot_id, 0
667
+ ), 0
668
+ )
669
+
670
+ sp_one_hot_id = jnp.eye(config["POP_SIZE"])[sp_id]
671
+ sp_one_hot_id = jnp.expand_dims(
672
+ jnp.expand_dims(
673
+ sp_one_hot_id, 0
674
+ ), 0
675
+ )
676
+
677
+ # Agent_0 (confederate) action using policy interface
678
+ aux_obs_xp = jnp.repeat(xp_one_hot_id, traj_batch_xp.obs.shape[1], axis=1)
679
+ aux_obs_xp = jnp.repeat(aux_obs_xp, traj_batch_xp.obs.shape[0], axis=0)
680
+
681
+ _, value_xp, pi_xp, _ = policy.get_action_value_policy(
682
+ params=params,
683
+ obs=traj_batch_xp.obs,
684
+ done=traj_batch_xp.done,
685
+ avail_actions=traj_batch_xp.avail_actions,
686
+ hstate=None,
687
+ rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
688
+ aux_obs=aux_obs_xp
689
+ )
690
+
691
+ aux_obs_sp = jnp.repeat(xp_one_hot_id, traj_batch_sp.obs.shape[1], axis=1)
692
+ aux_obs_sp = jnp.repeat(aux_obs_sp, traj_batch_sp.obs.shape[0], axis=0)
693
+ _, value_sp, pi_sp, _ = policy.get_action_value_policy(
694
+ params=params,
695
+ obs=traj_batch_sp.obs,
696
+ done=traj_batch_sp.done,
697
+ avail_actions=traj_batch_sp.avail_actions,
698
+ hstate=None,
699
+ rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
700
+ aux_obs=aux_obs_sp
701
+ )
702
+
703
+ _, value_sp2, pi_sp2, _ = policy.get_action_value_policy(
704
+ params=params,
705
+ obs=traj_batch_sp2.obs,
706
+ done=traj_batch_sp2.done,
707
+ avail_actions=traj_batch_sp2.avail_actions,
708
+ hstate=None,
709
+ rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
710
+ aux_obs=aux_obs_sp
711
+ )
712
+
713
+ _, value_mp, pi_mp, _ = policy.get_action_value_policy(
714
+ params=params,
715
+ obs=traj_batch_mp.obs,
716
+ done=traj_batch_mp.done,
717
+ avail_actions=traj_batch_mp.avail_actions,
718
+ hstate=None,
719
+ rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
720
+ aux_obs=aux_obs_sp
721
+ )
722
+
723
+ _, value_mp2, pi_mp2, _ = policy.get_action_value_policy(
724
+ params=params,
725
+ obs=traj_batch_mp2.obs,
726
+ done=traj_batch_mp2.done,
727
+ avail_actions=traj_batch_mp2.avail_actions,
728
+ hstate=None,
729
+ rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
730
+ aux_obs=aux_obs_sp
731
+ )
732
+
733
+ log_prob_xp = pi_xp.log_prob(traj_batch_xp.action)
734
+ log_prob_sp = pi_sp.log_prob(traj_batch_sp.action)
735
+ log_prob_sp2 = pi_sp2.log_prob(traj_batch_sp2.action)
736
+ log_prob_mp = pi_mp.log_prob(traj_batch_mp.action)
737
+ log_prob_mp2 = pi_mp2.log_prob(traj_batch_mp2.action)
738
+
739
+
740
+ value_loss_xp = _compute_ppo_value_loss(value_xp, traj_batch_xp, target_v_xp)
741
+ value_loss_sp = _compute_ppo_value_loss(value_sp, traj_batch_sp, target_v_sp)
742
+ value_loss_sp2 = _compute_ppo_value_loss(value_sp2, traj_batch_sp2, target_v_sp2)
743
+ value_loss_mp = _compute_ppo_value_loss(value_mp, traj_batch_mp, target_v_mp)
744
+ value_loss_mp2 = _compute_ppo_value_loss(value_mp2, traj_batch_mp2, target_v_mp2)
745
+
746
+ pg_loss_xp = _compute_ppo_pg_loss(log_prob_xp, traj_batch_xp, gae_xp)
747
+ pg_loss_sp = _compute_ppo_pg_loss(log_prob_sp, traj_batch_sp, gae_sp)
748
+ pg_loss_sp2 = _compute_ppo_pg_loss(log_prob_sp2, traj_batch_sp2, gae_sp2)
749
+ pg_loss_mp = _compute_ppo_pg_loss(log_prob_mp, traj_batch_mp, gae_mp)
750
+ pg_loss_mp2 = _compute_ppo_pg_loss(log_prob_mp2, traj_batch_mp2, gae_mp2)
751
+
752
+
753
+ # Entropy for interaction with ego agent
754
+ entropy_xp = jnp.mean(pi_xp.entropy())
755
+ entropy_sp = jnp.mean(pi_sp.entropy())
756
+ entropy_sp2 = jnp.mean(pi_sp2.entropy())
757
+ entropy_mp = jnp.mean(pi_mp.entropy())
758
+ entropy_mp2 = jnp.mean(pi_mp2.entropy())
759
+
760
+ xp_pg_weight = -config["COMEDI_ALPHA"] # negate to minimize the ego agent's PG objective
761
+ sp_pg_weight = 1.0
762
+ mp2_pg_weight = config["COMEDI_BETA"]
763
+
764
+ xp_loss = xp_pg_weight * pg_loss_xp + config["VF_COEF"] * value_loss_xp - config["ENT_COEF"] * entropy_xp
765
+ sp_loss = sp_pg_weight * pg_loss_sp + config["VF_COEF"] * value_loss_sp - config["ENT_COEF"] * entropy_sp
766
+ sp2_loss = sp_pg_weight * pg_loss_sp2 + config["VF_COEF"] * value_loss_sp2 - config["ENT_COEF"] * entropy_sp2
767
+ mp_loss = mp2_pg_weight * pg_loss_mp + config["VF_COEF"] * value_loss_mp - config["ENT_COEF"] * entropy_mp
768
+ mp2_loss = mp2_pg_weight * pg_loss_mp2 + config["VF_COEF"] * value_loss_mp2 - config["ENT_COEF"] * entropy_mp2
769
+
770
+ total_loss = sp_loss + sp2_loss + xp_loss + mp2_loss + mp_loss
771
+ return total_loss, (value_loss_xp, value_loss_sp + value_loss_sp2, value_loss_mp + value_loss_mp2,
772
+ pg_loss_xp, pg_loss_sp + pg_loss_sp2, pg_loss_mp + pg_loss_mp2,
773
+ entropy_xp, entropy_sp + entropy_sp2, entropy_mp + entropy_mp2)
774
+
775
+ grad_fn = jax.value_and_grad(_loss_fn_conf, has_aux=True)
776
+ (loss_val, aux_vals), grads = grad_fn(
777
+ train_state_conf.params,
778
+ traj_batch_xp, advantages_xp, returns_xp,
779
+ traj_batch_sp1, advantages_sp1, returns_sp1,
780
+ traj_batch_sp2, advantages_sp2, returns_sp2,
781
+ traj_batch_mp1, advantages_mp1, returns_mp1,
782
+ traj_batch_mp2, advantages_mp2, returns_mp2)
783
+ train_state_conf = train_state_conf.apply_gradients(grads=grads)
784
+ return train_state_conf, (loss_val, aux_vals)
785
+
786
+ (
787
+ train_state_conf, traj_batch_xp,
788
+ traj_batch_sp_conf, traj_batch_sp_br,
789
+ traj_batch_mp_conf, traj_batch_mp_br,
790
+ advantages_xp_conf, advantages_sp_conf,
791
+ advantages_sp_br, advantages_mp_conf,
792
+ advantages_mp_br, targets_xp_conf,
793
+ targets_sp_conf, targets_sp_br,
794
+ targets_mp_conf, targets_mp_br,
795
+ rng, xp_id, sp_id
796
+ ) = update_state
797
+
798
+ rng, perm_rng_xp, perm_rng_sp_conf, perm_rng_sp_br, perm_rng_mp2_conf, perm_rng_mp2_br = jax.random.split(rng, 6)
799
+
800
+ # Create minibatches for each agent and interaction type
801
+ minibatches_xp = _create_minibatches(
802
+ traj_batch_xp, advantages_xp_conf, targets_xp_conf, None,
803
+ config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_xp
804
+ )
805
+ minibatches_sp_conf = _create_minibatches(
806
+ traj_batch_sp_conf, advantages_sp_conf, targets_sp_conf, None,
807
+ config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_sp_conf
808
+ )
809
+ minibatches_sp_br = _create_minibatches(
810
+ traj_batch_sp_br, advantages_sp_br, targets_sp_br, None,
811
+ config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_sp_br
812
+ )
813
+ minibatches_mp_conf = _create_minibatches(
814
+ traj_batch_mp_conf, advantages_mp_conf, targets_mp_conf, None,
815
+ config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_mp2_conf
816
+ )
817
+ minibatches_mp_br = _create_minibatches(
818
+ traj_batch_mp_br, advantages_mp_br, targets_mp_br, None,
819
+ config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_mp2_br
820
+ )
821
+
822
+ # Update confederate
823
+ repeated_xp_id = jnp.repeat(xp_id, minibatches_xp[1].obs.shape[0], axis=0)
824
+ repeated_sp_id = jnp.repeat(sp_id, minibatches_sp_br[1].obs.shape[0], axis=0)
825
+ train_state_conf, total_loss_conf = jax.lax.scan(
826
+ _update_minbatch_conf, train_state_conf, (
827
+ minibatches_xp, minibatches_sp_conf, minibatches_sp_br,
828
+ minibatches_mp_conf, minibatches_mp_br, repeated_xp_id, repeated_sp_id
829
+ )
830
+ )
831
+
832
+ update_state = (train_state_conf,
833
+ traj_batch_xp, traj_batch_sp_conf, traj_batch_sp_br, traj_batch_mp_conf, traj_batch_mp_br,
834
+ advantages_xp_conf, advantages_sp_conf, advantages_sp_br, advantages_mp_conf, advantages_mp_br,
835
+ targets_xp_conf, targets_sp_conf, targets_sp_br, targets_mp_conf, targets_mp_br,
836
+ rng, xp_id, sp_id
837
+ )
838
+ return update_state, total_loss_conf
839
+
840
+ # 3) PPO update
841
+ rng, sub_rng = jax.random.split(rng, 2)
842
+ update_state = (
843
+ train_state,
844
+ traj_batch_xp, traj_batch_sp_agent0,
845
+ traj_batch_sp_agent1,
846
+ traj_batch_smp0, traj_batch_smp1,
847
+ advantages_xp_conf,
848
+ advantages_sp_conf, advantages_sp_br,
849
+ advantages_mp_conf, advantages_mp_br,
850
+ targets_xp_conf, targets_sp_conf,
851
+ targets_sp_br, targets_mp_conf,
852
+ targets_mp_br, sub_rng,
853
+ max_means_id, num_prev_trained_conf
854
+ )
855
+ update_state, conf_losses = jax.lax.scan(
856
+ _update_epoch, update_state, None, config["UPDATE_EPOCHS"])
857
+ train_state = update_state[0]
858
+
859
+ (
860
+ conf_value_loss_xp, conf_value_loss_sp, conf_value_loss_mp,
861
+ conf_pg_loss_xp, conf_pg_loss_sp, conf_pg_loss_mp,
862
+ conf_entropy_xp, conf_entropy_sp, conf_entropy_mp
863
+ ) = conf_losses[1]
864
+
865
+ new_update_runner_state = (
866
+ train_state, pop_buffer,
867
+ env_state_sp, last_obs_sp,
868
+ env_state_xp, last_obs_xp,
869
+ env_state_mp, last_obs_mp,
870
+ env_state_mp2, last_obs_mp2,
871
+ last_dones_xp, last_dones_sp,
872
+ last_dones_mp, last_dones_mp2,
873
+ rng, update_steps+1, num_prev_trained_conf
874
+ )
875
+
876
+ # Metrics
877
+ def mask_and_mean(x, mask):
878
+ return jnp.where(mask, x, 0).sum() / jnp.maximum(1, mask.sum())
879
+
880
+ mask = traj_batch_xp.info.get("returned_episode", jnp.ones_like(traj_batch_xp.reward))
881
+ metric = jax.tree.map(lambda x: mask_and_mean(x, mask), traj_batch_xp.info)
882
+ metric["update_steps"] = update_steps
883
+ metric["value_loss_conf_xp"] = conf_value_loss_xp.mean()
884
+ metric["value_loss_conf_sp"] = conf_value_loss_sp.mean()
885
+ metric["value_loss_conf_mp"] = conf_value_loss_mp.mean()
886
+
887
+ metric["pg_loss_conf_xp"] = conf_pg_loss_xp.mean()
888
+ metric["pg_loss_conf_sp"] = conf_pg_loss_sp.mean()
889
+ metric["pg_loss_conf_mp"] = conf_pg_loss_mp.mean()
890
+
891
+ metric["entropy_conf_xp"] = conf_entropy_xp.mean()
892
+ metric["entropy_conf_sp"] = conf_entropy_sp.mean()
893
+ metric["entropy_conf_mp"] = conf_entropy_mp.mean()
894
+
895
+ metric["average_rewards_ego"] = jnp.mean(traj_batch_xp.reward)
896
+ metric["average_rewards_br_sp"] = jnp.mean(traj_batch_sp_agent1.reward)
897
+ metric["average_rewards_br_mp2"] = jnp.mean(traj_batch_smp1.reward)
898
+
899
+ return (new_update_runner_state, checkpoint_array, ckpt_idx+1), metric
900
+
901
+ # XP eval against all policies in the buffer
902
+ xp_eval_returns = jax.tree.map(lambda x: x.mean(axis=(-2, -1)),
903
+ jax.vmap(per_id_run_episode_fixed_rng, in_axes=(None, 0))(
904
+ train_state.params,jnp.arange(config["POP_SIZE"])))
905
+
906
+ # SP performance against itself
907
+ sp_eval_returns = jax.tree.map(lambda x: x.mean(), run_episodes(
908
+ eval_rng, env,
909
+ agent_0_param=train_state.params, agent_0_policy=policy,
910
+ agent_1_param=train_state.params, agent_1_policy=policy,
911
+ max_episode_steps=config["ROLLOUT_LENGTH"],
912
+ num_eps=config["NUM_EVAL_EPISODES"]
913
+ ))
914
+
915
+
916
+ update_steps = 0
917
+ init_done_xp = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]}
918
+ init_done_sp = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]}
919
+ init_done_mp = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]}
920
+ init_done_mp2 = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]}
921
+
922
+ update_runner_state = (
923
+ train_state, pop_buffer,
924
+ env_state_sp, obsv_sp,
925
+ env_state_xp, obsv_xp,
926
+ env_state_mp, obsv_mp,
927
+ env_state_mp2, obsv_mp2,
928
+ init_done_xp, init_done_sp,
929
+ init_done_mp, init_done_mp2,
930
+ rng, update_steps,
931
+ num_existing_agents
932
+ )
933
+
934
+ checkpoint_array = init_ckpt_array(train_state.params)
935
+ ckpt_idx = 0
936
+ update_with_ckpt_runner_state = (update_runner_state, checkpoint_array, ckpt_idx, xp_eval_returns, sp_eval_returns)
937
+
938
+ def _update_step_with_ckpt(state_with_ckpt, unused):
939
+
940
+ (update_runner_state, checkpoint_array, ckpt_idx, xp_eval_returns, sp_eval_returns) = state_with_ckpt
941
+ train_state = update_runner_state[0]
942
+
943
+ # Single PPO update
944
+ new_state_with_ckpt, metric = _update_step(
945
+ (update_runner_state, checkpoint_array, ckpt_idx),
946
+ None
947
+ )
948
+ new_update_runner_state = new_state_with_ckpt[0]
949
+ rng, update_steps = new_update_runner_state[-3], new_update_runner_state[-2]
950
+
951
+ # Decide if we store a checkpoint
952
+ # update steps is 1-indexed because it was incremented at the end of the update step
953
+ to_store = jnp.logical_or(jnp.equal(jnp.mod(update_steps-1, ckpt_and_eval_interval), 0),
954
+ jnp.equal(update_steps, config["NUM_UPDATES"]))
955
+
956
+ def store_and_eval_ckpt(args):
957
+ ckpt_arr_conf, rng, cidx, _, _ = args
958
+ new_ckpt_arr_conf = jax.tree.map(
959
+ lambda c_arr, p: c_arr.at[cidx].set(p),
960
+ ckpt_arr_conf, train_state.params
961
+ )
962
+
963
+ # Eval trained agent against all params in the pool
964
+ xp_eval_returns = jax.tree.map(lambda x: x.mean(axis=(-2, -1)),
965
+ jax.vmap(per_id_run_episode_fixed_rng, in_axes=(None, 0))(
966
+ train_state.params, jnp.arange(config["POP_SIZE"])))
967
+ # Eval trained agent against itself
968
+ sp_eval_returns = jax.tree.map(lambda x: x.mean(), run_episodes(
969
+ eval_rng, env,
970
+ agent_0_param=train_state.params, agent_0_policy=policy,
971
+ agent_1_param=train_state.params, agent_1_policy=policy,
972
+ max_episode_steps=config["ROLLOUT_LENGTH"],
973
+ num_eps=config["NUM_EVAL_EPISODES"]
974
+ ))
975
+
976
+ return (new_ckpt_arr_conf, rng, cidx + 1, xp_eval_returns, sp_eval_returns)
977
+
978
+ def skip_ckpt(args):
979
+ return args
980
+
981
+ rng, store_and_eval_rng = jax.random.split(rng, 2)
982
+ (checkpoint_array, store_and_eval_rng, ckpt_idx, xp_eval_returns, sp_eval_returns) = jax.lax.cond(
983
+ to_store,
984
+ store_and_eval_ckpt,
985
+ skip_ckpt,
986
+ (checkpoint_array, store_and_eval_rng, ckpt_idx, xp_eval_returns, sp_eval_returns)
987
+ )
988
+
989
+ return (new_update_runner_state, checkpoint_array,
990
+ ckpt_idx, xp_eval_returns, sp_eval_returns), (metric, xp_eval_returns, sp_eval_returns)
991
+
992
+ new_update_with_ckpt_runner_state, (metric, xp_eval_returns, sp_eval_returns) = jax.lax.scan(
993
+ _update_step_with_ckpt,
994
+ update_with_ckpt_runner_state,
995
+ xs=None, # No per-step input data
996
+ length=config["NUM_UPDATES"],
997
+ )
998
+ new_update_runner_state, new_checkpoint_array, _, _ ,_ = new_update_with_ckpt_runner_state
999
+ final_train_state = new_update_runner_state[0]
1000
+
1001
+ updated_pop_buffer = partner_population.add_agent(pop_buffer, final_train_state.params)
1002
+ conf_checkpoints = new_checkpoint_array
1003
+ return updated_pop_buffer, (conf_checkpoints, metric, xp_eval_returns, sp_eval_returns)
1004
+
1005
+ rngs = jax.random.split(rng, config["PARTNER_POP_SIZE"])
1006
+ rng, add_conf_iter_rngs = rngs[0], rngs[1:]
1007
+
1008
+ iter_ids = jnp.arange(1, config["PARTNER_POP_SIZE"])
1009
+ final_population_buffer, (conf_checkpoints, metric, xp_eval_returns, sp_eval_returns) = jax.lax.scan(
1010
+ add_conf_policy, population_buffer, (iter_ids, add_conf_iter_rngs)
1011
+ )
1012
+
1013
+ out = {
1014
+ "final_params_conf": final_population_buffer.params,
1015
+ "checkpoints_conf": conf_checkpoints,
1016
+ "metrics": metric,
1017
+ "last_ep_infos_xp": xp_eval_returns,
1018
+ "last_ep_infos_sp": sp_eval_returns
1019
+ }
1020
+
1021
+ return out
1022
+ return train
1023
+
1024
+ train_fn = make_comedi_agents(config)
1025
+ out = train_fn(train_rng)
1026
+ return out
1027
+
1028
+ def get_comedi_population(config, out, env):
1029
+ '''
1030
+ Get the partner params and partner population for ego training.
1031
+ '''
1032
+ comedi_pop_size = config["algorithm"]["PARTNER_POP_SIZE"]
1033
+
1034
+ # partner_params has shape (num_seeds, comedi_pop_size, ...)
1035
+ partner_params = out['final_params_conf']
1036
+
1037
+ partner_policy = ActorWithConditionalCriticPolicy(
1038
+ action_dim=env.action_space(env.agents[1]).n,
1039
+ obs_dim=env.observation_space(env.agents[1]).shape[0],
1040
+ pop_size=comedi_pop_size, # used to create onehot agent id
1041
+ activation=config["algorithm"].get("ACTIVATION", "tanh")
1042
+ )
1043
+
1044
+ # Create partner population
1045
+ partner_population = AgentPopulation(
1046
+ pop_size=comedi_pop_size,
1047
+ policy_cls=partner_policy
1048
+ )
1049
+
1050
+ return partner_params, partner_population
1051
+
1052
+ def run_comedi(config, wandb_logger):
1053
+ algorithm_config = dict(config["algorithm"])
1054
+
1055
+ env = make_env(algorithm_config["ENV_NAME"], algorithm_config["ENV_KWARGS"])
1056
+ env = LogWrapper(env)
1057
+
1058
+ log.info("Starting CoMeDi training...")
1059
+ start = time.time()
1060
+
1061
+ # Generate multiple random seeds from the base seed
1062
+ rng = jax.random.PRNGKey(algorithm_config["TRAIN_SEED"])
1063
+ rngs = jax.random.split(rng, algorithm_config["NUM_SEEDS"])
1064
+
1065
+ # Create a vmapped version of train_comedi_partners
1066
+ with jax.disable_jit(False):
1067
+ vmapped_train_fn = jax.jit(
1068
+ jax.vmap(
1069
+ partial(train_comedi_partners,
1070
+ wandb_logger=wandb_logger,
1071
+ env=env,
1072
+ config=algorithm_config)
1073
+ )
1074
+ )
1075
+ out = vmapped_train_fn(rngs)
1076
+
1077
+ end = time.time()
1078
+ log.info(f"CoMeDi training complete in {end - start} seconds")
1079
+
1080
+ metric_names = get_metric_names(algorithm_config["ENV_NAME"])
1081
+
1082
+ # Save FIRST so the checkpoint survives even if metric logging OOMs.
1083
+ savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
1084
+ out_savepath = save_train_run(out, savedir, savename="saved_train_run")
1085
+ log_metrics(config, out, wandb_logger, metric_names, out_savepath)
1086
+ partner_params, partner_population = get_comedi_population(config, out, env)
1087
+ return partner_params, partner_population
1088
+
1089
+ def compute_sp_mask_and_ids(pop_size):
1090
+ cross_product = np.meshgrid(
1091
+ np.arange(pop_size),
1092
+ np.arange(pop_size)
1093
+ )
1094
+ agent_id_cartesian_product = np.stack([g.ravel() for g in cross_product], axis=-1)
1095
+ conf_ids = agent_id_cartesian_product[:, 0]
1096
+ ego_ids = agent_id_cartesian_product[:, 1]
1097
+ sp_mask = (conf_ids == ego_ids)
1098
+ return sp_mask, agent_id_cartesian_product
1099
+
1100
+ def log_metrics(config, outs, logger, metric_names: tuple, out_savepath):
1101
+ metrics = outs["metrics"]
1102
+ # trained_pop_size excludes the initial policy
1103
+ num_seeds, pop_size, num_updates = metrics["pg_loss_conf_sp"].shape
1104
+ # TODO: add the eval_ep_last_info metrics
1105
+
1106
+ ### Log evaluation metrics
1107
+ # xp_eval_returns and sp_eval_returns logged at each evaluation only.
1108
+ algorithm_config = config["algorithm"]
1109
+ ckpt_and_eval_interval = max(1, num_updates // max(1, algorithm_config["NUM_CHECKPOINTS"] - 1))
1110
+ # Steps at which store_and_eval_ckpt fires (0-indexed, matching the update_step logged below)
1111
+ eval_steps = list(range(0, num_updates, ckpt_and_eval_interval))
1112
+ if (num_updates - 1) not in eval_steps:
1113
+ eval_steps.append(num_updates - 1)
1114
+
1115
+ # shape (num_seeds, pop_size - 1, num_updates) [pre-scalarized: mean over eval eps and agents taken inside scan]
1116
+ all_returns_sp = np.asarray(outs["last_ep_infos_sp"]["returned_episode_returns"])
1117
+ # shape (num_seeds, pop_size - 1, num_updates, pop_size) [pre-scalarized: mean over eval eps and agents taken inside scan]
1118
+ all_returns_xp = np.asarray(outs["last_ep_infos_xp"]["returned_episode_returns"])
1119
+
1120
+ # Average over seeds only (eval episodes and agents already averaged inside scan)
1121
+ sp_return_curve = all_returns_sp.mean(axis=0) # shape (pop_size - 1, num_updates)
1122
+ xp_return_curve = all_returns_xp.mean(axis=0) # shape (pop_size - 1, num_updates, pop_size)
1123
+
1124
+ for num_add_policies in range(pop_size):
1125
+ for update_step in eval_steps:
1126
+ logger.log_item("Eval/AvgSPReturnCurve", sp_return_curve[num_add_policies, update_step], train_step=update_step)
1127
+ mean_xp_returns = xp_return_curve[num_add_policies, :, :(num_add_policies+1)].mean(axis=-1)
1128
+ logger.log_item("Eval/AvgXPReturnCurve", mean_xp_returns[update_step], train_step=update_step)
1129
+ logger.commit()
1130
+
1131
+ ### Log population loss as multi-line plots, where each line is a different population member
1132
+ # both xp and xp metrics has shape (num_seeds, pop_size - 1, num_updates, update_epochs, num_minibatches)
1133
+ # Average over seeds
1134
+ processed_losses = {
1135
+ "ConfPGLossSP": np.asarray(metrics["pg_loss_conf_sp"]).mean(axis=0), # desired shape (pop_size - 1, num_updates)
1136
+ "ConfPGLossXP": np.asarray(metrics["pg_loss_conf_xp"]).mean(axis=0),
1137
+ "ConfPGLossMP": np.asarray(metrics["pg_loss_conf_mp"]).mean(axis=0),
1138
+ "ConfValLossSP": np.asarray(metrics["value_loss_conf_sp"]).mean(axis=0),
1139
+ "ConfValLossXP": np.asarray(metrics["value_loss_conf_xp"]).mean(axis=0),
1140
+ "ConfValLossMP": np.asarray(metrics["value_loss_conf_mp"]).mean(axis=0),
1141
+ "EntropySP": np.asarray(metrics["entropy_conf_sp"]).mean(axis=0),
1142
+ "EntropyXP": np.asarray(metrics["entropy_conf_xp"]).mean(axis=0),
1143
+ "EntropyMP": np.asarray(metrics["entropy_conf_mp"]).mean(axis=0),
1144
+ }
1145
+
1146
+ xs = list(range(num_updates))
1147
+ keys = [f"pair {i}" for i in range(pop_size)]
1148
+
1149
+ for loss_name, loss_data in processed_losses.items():
1150
+ logger.log_item(f"Losses/{loss_name}",
1151
+ wandb.plot.line_series(xs=xs, ys=loss_data, keys=keys,
1152
+ title=loss_name, xname="train_step")
1153
+ )
1154
+
1155
+ ### Log artifacts (already saved by caller; just publish to wandb)
1156
+ if config["logger"]["log_train_out"]:
1157
+ logger.log_artifact(name="saved_train_run", path=out_savepath, type_name="train_run")
1158
+
1159
+ # Cleanup locally logged out files
1160
+ if not config["local_logger"]["save_train_out"]:
1161
+ shutil.rmtree(out_savepath)
teammate_generation/LBRDiv.py ADDED
@@ -0,0 +1,1098 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Implementation of the LBRDiv teammate generation algorithm (Rahman et al., AAAI 2024)
2
+ https://ojs.aaai.org/index.php/AAAI/article/view/29702
3
+
4
+ Command to run LBRDiv only on LBF:
5
+ python teammate_generation/run.py algorithm=lbrdiv/lbf/lbf_7x7_nolevels task=lbf/lbf_7x7_nolevels label=test_lbrdiv run_heldout_eval=false train_ego=false
6
+
7
+ Suggested Debug command:
8
+ python teammate_generation/run.py algorithm=lbrdiv/lbf/lbf_7x7_nolevels task=lbf/lbf_7x7_nolevels logger.mode=disabled label=debug algorithm.TOTAL_TIMESTEPS=1e5 algorithm.PARTNER_POP_SIZE=2 train_ego=false run_heldout_eval=false
9
+
10
+ Limitations: does not support recurrent actors.
11
+ '''
12
+ import shutil
13
+ import time
14
+ import logging
15
+ from functools import partial
16
+
17
+ import hydra
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ import optax
22
+ from flax.training.train_state import TrainState
23
+ import wandb
24
+
25
+ from agents.mlp_actor_critic_agent import ActorWithConditionalCriticPolicy
26
+ from agents.population_interface import AgentPopulation
27
+ from common.plot_utils import get_metric_names
28
+ from common.run_episodes import run_episodes
29
+ from common.save_load_utils import save_train_run
30
+ from envs import make_env
31
+ from envs.log_wrapper import LogWrapper
32
+ from marl.ppo_utils import unbatchify, _create_minibatches
33
+ from teammate_generation.BRDiv import _get_all_ids, XPTransition, gather_params
34
+
35
+ log = logging.getLogger(__name__)
36
+ logging.basicConfig(level=logging.INFO)
37
+
38
+
39
+ def train_lbrdiv_partners(train_rng, env, config, conf_policy, br_policy):
40
+ num_agents = env.num_agents
41
+ assert num_agents == 2, "This code assumes the environment has exactly 2 agents."
42
+
43
+ # Define different minibatch sizes for interactions with ego agent and one with BR agent
44
+ config["NUM_GAME_AGENTS"] = num_agents
45
+ config["NUM_CONF_ACTORS"] = config["NUM_ENVS"]
46
+ config["NUM_BR_ACTORS"] = config["NUM_ENVS"]
47
+ config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // (config["ROLLOUT_LENGTH"] * config["NUM_ENVS"])
48
+
49
+ def make_lbrdiv_agents(config):
50
+ def linear_schedule(count):
51
+ frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
52
+ return config["LR"] * frac
53
+
54
+ def train(rng):
55
+ rng, init_conf_rng, init_br_rng = jax.random.split(rng, 3)
56
+ all_conf_init_rngs = jax.random.split(init_conf_rng, config["PARTNER_POP_SIZE"])
57
+ all_br_init_rngs = jax.random.split(init_br_rng, config["PARTNER_POP_SIZE"])
58
+ identity_matrix = jnp.eye(config["PARTNER_POP_SIZE"])
59
+
60
+ init_conf_hstate = conf_policy.init_hstate(config["NUM_CONF_ACTORS"])
61
+ init_br_hstate = br_policy.init_hstate(config["NUM_BR_ACTORS"])
62
+
63
+ def init_train_states(rng_agents, rng_brs):
64
+ def init_single_pair_optimizers(rng_agent, rng_br):
65
+ init_params_conf = conf_policy.init_params(rng_agent)
66
+ init_params_br = br_policy.init_params(rng_br)
67
+ return init_params_conf, init_params_br
68
+
69
+ init_all_networks_and_optimizers = jax.vmap(init_single_pair_optimizers)
70
+ all_conf_params, all_br_params = init_all_networks_and_optimizers(rng_agents, rng_brs)
71
+
72
+ # Define optimizers for both confederate and BR policy
73
+ tx = optax.chain(
74
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
75
+ optax.adam(learning_rate=linear_schedule if config["ANNEAL_LR"] else config["LR"],
76
+ eps=1e-5),
77
+ )
78
+ tx_br = optax.chain(
79
+ optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
80
+ optax.adam(learning_rate=linear_schedule if config["ANNEAL_LR"] else config["LR"],
81
+ eps=1e-5),
82
+ )
83
+
84
+ train_state_conf = TrainState.create(
85
+ apply_fn=conf_policy.network.apply,
86
+ params=all_conf_params,
87
+ tx=tx,
88
+ )
89
+
90
+ train_state_br = TrainState.create(
91
+ apply_fn=br_policy.network.apply,
92
+ params=all_br_params,
93
+ tx=tx_br,
94
+ )
95
+
96
+ return train_state_conf, train_state_br
97
+
98
+ all_conf_optims, all_br_optims = init_train_states(
99
+ all_conf_init_rngs, all_br_init_rngs
100
+ )
101
+
102
+ def forward_pass_conf(params, obs, id, done, avail_actions, hstate, rng):
103
+ act, val, pi, new_hstate = conf_policy.get_action_value_policy(
104
+ params=params,
105
+ obs=obs[jnp.newaxis, ...],
106
+ done=done[jnp.newaxis, ...],
107
+ avail_actions=avail_actions,
108
+ hstate=hstate,
109
+ rng=rng,
110
+ aux_obs=id[jnp.newaxis, ...]
111
+ )
112
+ return act, val, pi, new_hstate
113
+
114
+ def forward_pass_br(params, obs, id, done, avail_actions, hstate, rng):
115
+ act, val, pi, new_hstate = br_policy.get_action_value_policy(
116
+ params=params,
117
+ obs=obs[jnp.newaxis, ...],
118
+ done=done[jnp.newaxis, ...],
119
+ avail_actions=avail_actions,
120
+ hstate=hstate,
121
+ rng=rng,
122
+ aux_obs=id[jnp.newaxis, ...]
123
+ )
124
+ return act, val, pi, new_hstate
125
+
126
+ def _env_step(runner_state, unused):
127
+ """
128
+ agent_0 = confederate, agent_1 = br
129
+ Returns updated runner_state, and Transitions for agent_0 and agent_1
130
+ """
131
+ (
132
+ all_train_state_conf, all_train_state_br, last_conf_ids, last_br_ids,
133
+ env_state, last_obs, last_done, last_conf_h, last_br_h, rng
134
+ ) = runner_state
135
+ rng, act0_rng, act1_rng, step_rng, conf_sampling_rng, br_sampling_rng = jax.random.split(rng, 6)
136
+
137
+ # For done envs, resample both conf and brs
138
+ needs_resample = last_done["__all__"]
139
+ resampled_conf_ids = jax.random.randint(conf_sampling_rng, (config["NUM_CONF_ACTORS"],), 0, config["PARTNER_POP_SIZE"])
140
+ resampled_br_ids = jax.random.randint(br_sampling_rng, (config["NUM_BR_ACTORS"],), 0, config["PARTNER_POP_SIZE"])
141
+
142
+ # Determine final indices based on whether resampling was needed for each env
143
+ updated_conf_ids = jnp.where(
144
+ needs_resample,
145
+ resampled_conf_ids, # Use newly sampled index if True
146
+ last_conf_ids # Else, keep index from previous step
147
+ )
148
+
149
+ updated_br_ids = jnp.where(
150
+ needs_resample,
151
+ resampled_br_ids, # Use newly sampled index if True
152
+ last_br_ids # Else, keep index from previous step
153
+ )
154
+
155
+ # Reset the hidden states for resampled conf and br if they are not None
156
+ # WARNING: (L)BRDiv was not tested with recurrent actors, so the code for if the hstate is not None may not work
157
+ if last_conf_h is not None:
158
+ updated_conf_h = jnp.where(
159
+ needs_resample,
160
+ init_conf_hstate,
161
+ last_conf_h
162
+ )
163
+ else:
164
+ updated_conf_h = last_conf_h
165
+
166
+ if last_br_h is not None:
167
+ updated_br_h = jnp.where(
168
+ needs_resample,
169
+ init_br_hstate,
170
+ last_br_h
171
+ )
172
+ else:
173
+ updated_br_h = last_br_h
174
+
175
+ # Get the corresponding conf and br params
176
+ updated_conf_params = gather_params(all_train_state_conf.params, updated_conf_ids)
177
+ updated_br_params = gather_params(all_train_state_br.params, updated_br_ids)
178
+
179
+ updated_conf_onehot_ids = identity_matrix[updated_conf_ids]
180
+ updated_br_onehot_ids = identity_matrix[updated_br_ids]
181
+
182
+ # Get available actions for agent 0 from environment state
183
+ avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
184
+ avail_actions = jax.lax.stop_gradient(avail_actions)
185
+ avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32)
186
+ avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32)
187
+
188
+ # Agent_0 action
189
+ act0_rng = jax.random.split(act0_rng, config["NUM_ENVS"])
190
+ act_0, val_0, pi_0, new_conf_h = jax.vmap(forward_pass_conf)(updated_conf_params,
191
+ last_obs["agent_0"], updated_br_onehot_ids, last_done["agent_0"], avail_actions_0,
192
+ updated_conf_h, act0_rng)
193
+ logp_0 = pi_0.log_prob(act_0)
194
+ act_0, val_0, logp_0 = act_0.squeeze(), val_0.squeeze(), logp_0.squeeze()
195
+
196
+ # Agent_1 action
197
+ act1_rng = jax.random.split(act1_rng, config["NUM_ENVS"])
198
+ act_1, val_1, pi_1, new_br_h = jax.vmap(forward_pass_br)(updated_br_params,
199
+ last_obs["agent_1"], updated_conf_onehot_ids, last_done["agent_1"], avail_actions_1,
200
+ updated_br_h, act1_rng)
201
+ logp_1 = pi_1.log_prob(act_1)
202
+ act_1, val_1, logp_1 = act_1.squeeze(), val_1.squeeze(), logp_1.squeeze()
203
+
204
+ # Combine actions into the env format
205
+ combined_actions = jnp.concatenate([act_0, act_1], axis=0)
206
+ env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents)
207
+ env_act = {k: v.flatten() for k, v in env_act.items()}
208
+
209
+ # Step env
210
+ step_rngs = jax.random.split(step_rng, config["NUM_ENVS"])
211
+ obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))(
212
+ step_rngs, env_state, env_act
213
+ )
214
+ # note that num_actors = num_envs * num_agents
215
+ info_0 = jax.tree.map(lambda x: x[:, 0], info)
216
+ info_1 = jax.tree.map(lambda x: x[:, 1], info)
217
+
218
+ # Store agent_0 data in transition
219
+ transition_0 = XPTransition(
220
+ done=done["agent_0"],
221
+ action=act_0,
222
+ value=val_0,
223
+ self_onehot_id=updated_conf_onehot_ids,
224
+ oppo_onehot_id=updated_br_onehot_ids,
225
+ reward=reward["agent_1"],
226
+ log_prob=logp_0,
227
+ obs=last_obs["agent_0"],
228
+ info=info_0,
229
+ avail_actions=avail_actions_0
230
+ )
231
+
232
+ transition_1 = XPTransition(
233
+ done=done["agent_1"],
234
+ action=act_1,
235
+ value=val_1,
236
+ self_onehot_id=updated_br_onehot_ids,
237
+ oppo_onehot_id=updated_conf_onehot_ids,
238
+ reward=reward["agent_1"],
239
+ log_prob=logp_1,
240
+ obs=last_obs["agent_1"],
241
+ info=info_1,
242
+ avail_actions=avail_actions_1
243
+ )
244
+ new_runner_state = (all_train_state_conf, all_train_state_br, updated_conf_ids, updated_br_ids,
245
+ env_state_next, obs_next, done, new_conf_h, new_br_h, rng)
246
+ return new_runner_state, (transition_0, transition_1)
247
+
248
+ def _calculate_gae(traj_batch, last_val):
249
+ def _get_advantages(gae_and_next_value, transition):
250
+ gae, next_value = gae_and_next_value
251
+ done, value, reward = (
252
+ transition.done,
253
+ transition.value,
254
+ transition.reward,
255
+ )
256
+ delta = reward + config["GAMMA"] * next_value * (1 - done) - value
257
+ gae = (
258
+ delta
259
+ + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
260
+ )
261
+ return (gae, value), gae
262
+
263
+ _, advantages = jax.lax.scan(
264
+ _get_advantages,
265
+ (jnp.zeros_like(last_val), last_val),
266
+ traj_batch,
267
+ reverse=True,
268
+ unroll=16,
269
+ )
270
+ return advantages, advantages + traj_batch.value
271
+
272
+ def run_all_episodes(rng, train_state_conf, train_state_br):
273
+ conf_ids, br_ids = _get_all_ids(config["PARTNER_POP_SIZE"])
274
+ gathered_conf_model_params = gather_params(train_state_conf.params, conf_ids)
275
+ gathered_br_model_params = gather_params(train_state_br.params, br_ids)
276
+
277
+ rng, eval_rng = jax.random.split(rng)
278
+ def run_episodes_fixed_rng(conf_param, br_param):
279
+ return run_episodes(
280
+ eval_rng, env,
281
+ conf_param, conf_policy,
282
+ br_param, br_policy,
283
+ config["ROLLOUT_LENGTH"], config["NUM_EVAL_EPISODES"],
284
+ )
285
+ ep_infos = jax.vmap(run_episodes_fixed_rng)(
286
+ gathered_conf_model_params, gathered_br_model_params, # leaves where shape is (pop_size*pop_size, ...)
287
+ )
288
+ return ep_infos
289
+
290
+ def _update_epoch(update_state, unused):
291
+ def _update_minbatch(all_train_states, all_data):
292
+ train_state_conf, train_state_br = all_train_states
293
+ minbatch_conf, minbatch_br, lms_vertical, lms_horizontal = all_data
294
+
295
+ def _loss_fn(param, agent_policy, minbatch, agent_id, lms_vertical, lms_horizontal):
296
+ '''Compute loss for agent corresponding to agent_id.
297
+ '''
298
+ init_hstate, traj_batch, gae, target_v = minbatch
299
+ # get policy and value of confederate versus ego and best response agents respectively
300
+ squeezed_param = jax.tree.map(lambda x: jnp.squeeze(x, 0), param)
301
+ _, value, pi, _ = agent_policy.get_action_value_policy(
302
+ params=squeezed_param,
303
+ obs=traj_batch.obs,
304
+ done=traj_batch.done,
305
+ avail_actions=traj_batch.avail_actions,
306
+ hstate=init_hstate,
307
+ rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
308
+ aux_obs=traj_batch.oppo_onehot_id
309
+ )
310
+ log_prob = pi.log_prob(traj_batch.action)
311
+
312
+ is_relevant = jnp.equal(
313
+ jnp.argmax(traj_batch.self_onehot_id, axis=-1),
314
+ agent_id
315
+ )
316
+ loss_weights = jnp.where(is_relevant, 1, 0).astype(jnp.float32)
317
+ int_self_id = jnp.argmax(traj_batch.self_onehot_id, axis=-1)
318
+ int_oppo_id = jnp.argmax(traj_batch.oppo_onehot_id, axis=-1)
319
+
320
+ # Given a pair of policies that generate SP trajectories,
321
+ # compute the pair's total Lagrange multiplier in the Lagrange dual.
322
+ # Assuming the SP data is generated by population i, the total LMs
323
+ # amounts to \sum_{j}*lms_vertical[i][j] + \sum_{j}*lms_horizontal[i][j]
324
+
325
+ def _gather_sp_weights(ids):
326
+ s_id, _ = ids
327
+ return jnp.sum(lms_vertical, axis=-1)[s_id], jnp.sum(lms_horizontal, axis=-1)[s_id]
328
+
329
+ # Given a pair of policies that generate XP trajectories,
330
+ # compute the pair's total Lagrange multiplier in the Lagrange dual.
331
+ # Assuming the XP data is generated by the i^th conf policy and the j^th BR policy,
332
+ # the total LMs amounts to
333
+ # -lms_vertical[j][i] -lms_horizontal[i][j]
334
+
335
+ def _gather_xp_weights(ids):
336
+ s_id, o_id = ids
337
+ return -lms_vertical[s_id][o_id], -lms_horizontal[o_id][s_id]
338
+
339
+ def _get_weights(s_id, o_id):
340
+ return jax.lax.cond(
341
+ jnp.equal(s_id, o_id),
342
+ _gather_sp_weights,
343
+ _gather_xp_weights,
344
+ (s_id, o_id)
345
+ )
346
+
347
+ # Value loss
348
+ value_pred_clipped = traj_batch.value + (
349
+ value - traj_batch.value
350
+ ).clip(
351
+ -config["CLIP_EPS"], config["CLIP_EPS"])
352
+ value_losses = jnp.square(value - target_v)
353
+ value_losses_clipped = jnp.square(value_pred_clipped - target_v)
354
+ value_loss = jax.lax.cond(
355
+ loss_weights.sum() == 0,
356
+ lambda x: jnp.zeros_like(x).astype(jnp.float32),
357
+ lambda x: x,
358
+ (loss_weights * jnp.maximum(value_losses, value_losses_clipped)).sum() / (loss_weights.sum() + 1e-8)
359
+ )
360
+
361
+ # # Apply different loss weights for SP and XP data
362
+ # # Loss weights consist of two parts: the first term is the weighting from the (L)BRDiv loss fucntion
363
+ # # which is based on the sum of Lagrange multipliers for a given confederate-ego pair expected returns
364
+ # # in the Lagrange dual formulation. This is indicated by weights1 + weights2 in the code below.
365
+
366
+ # # The second term is a reweighting term to compensate for the data collection process, which uniformly and independently
367
+ # # samples the conf and br ids from 1, ..., n, resulting in P(SP) = 1/n and P(XP) = (n-1)/n.
368
+ # # To prevent the XP loss term from dominating the SP loss term, we would like P(SP) = P(XP) = 1/2.
369
+ # # Thus, we set the 2nd term of the SP weight to n/2, and the 2nd term of the XP weight to n/(2 * (n-1)).
370
+
371
+ n = config["PARTNER_POP_SIZE"]
372
+ is_sp = jnp.equal(jnp.argmax(traj_batch.self_onehot_id, axis=-1), jnp.argmax(traj_batch.oppo_onehot_id, axis=-1))
373
+ weights1, weights2 = jax.vmap(jax.vmap(_get_weights))(int_self_id, int_oppo_id)
374
+ actor_weights_sp = (weights1 + weights2) * (n/2)
375
+ actor_weights_xp = (weights1 + weights2) * (n / (2 * (n-1)))
376
+ actor_weights = jnp.where(is_sp, actor_weights_sp, actor_weights_xp)
377
+
378
+ # Policy gradient loss
379
+ ratio = jnp.exp(log_prob - traj_batch.log_prob)
380
+ gae_norm = (gae - gae.mean()) / (gae.std() + 1e-8)
381
+ pg_loss_1 = ratio * actor_weights * gae_norm
382
+ pg_loss_2 = jnp.clip(
383
+ ratio,
384
+ 1.0 - config["CLIP_EPS"],
385
+ 1.0 + config["CLIP_EPS"]) * actor_weights * gae_norm
386
+ pg_loss = jax.lax.cond(
387
+ loss_weights.sum() == 0,
388
+ lambda x: jnp.zeros_like(x).astype(jnp.float32),
389
+ lambda x: x,
390
+ -(
391
+ loss_weights * jnp.minimum(pg_loss_1, pg_loss_2)
392
+ ).sum()/(loss_weights.sum() + 1e-8)
393
+ )
394
+
395
+ # Weight entropy based on actor weights
396
+ all_sp_weights1, all_sp_weights2 = jax.vmap(_gather_sp_weights)((int_self_id, int_self_id))
397
+ entropy_scaler = jnp.maximum(all_sp_weights1, all_sp_weights2)
398
+
399
+ # Compute entropy loss
400
+ entropy = jax.lax.cond(
401
+ loss_weights.sum() == 0,
402
+ lambda x: jnp.zeros_like(x).astype(jnp.float32),
403
+ lambda x: x,
404
+ (loss_weights * entropy_scaler * pi.entropy()).sum()/(loss_weights.sum() + 1e-8)
405
+ )
406
+
407
+ total_loss = pg_loss + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy
408
+ return total_loss, (value_loss, pg_loss, entropy)
409
+
410
+ possible_agent_ids = jnp.expand_dims(jnp.arange(config["PARTNER_POP_SIZE"]), 1)
411
+ grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
412
+
413
+ def gather_conf_params_and_return_grads(agent_id):
414
+ # transposing the lm matrices only on the confederate agent side
415
+ # ensures that both the confederate and br policy that interact
416
+ # to generate a trajectory have the same weights when computing
417
+ # the policy gradient loss.
418
+ param_vector = gather_params(train_state_conf.params, agent_id)
419
+ (loss_val_conf, aux_vals_conf), grads_conf = grad_fn(
420
+ param_vector, conf_policy, minbatch_conf, agent_id,
421
+ jnp.transpose(lms_vertical), jnp.transpose(lms_horizontal)
422
+ )
423
+ return (loss_val_conf, aux_vals_conf), grads_conf
424
+
425
+ def gather_br_params_and_return_grads(agent_id):
426
+ param_vector = gather_params(train_state_br.params, agent_id)
427
+ (loss_val_br, aux_vals_br), grads_br = grad_fn(
428
+ param_vector, br_policy, minbatch_br, agent_id,
429
+ lms_vertical, lms_horizontal
430
+ )
431
+ return (loss_val_br, aux_vals_br), grads_br
432
+
433
+ (loss_val_conf, aux_vals_conf), grads_conf = jax.vmap(gather_conf_params_and_return_grads)(possible_agent_ids)
434
+ (loss_val_br, aux_vals_br), grads_br = jax.vmap(gather_br_params_and_return_grads)(possible_agent_ids)
435
+
436
+ grads_conf_new = jax.tree.map(lambda x: jnp.squeeze(x, 1), grads_conf)
437
+ grads_br_new = jax.tree.map(lambda x: jnp.squeeze(x, 1), grads_br)
438
+ train_state_conf = train_state_conf.apply_gradients(grads=grads_conf_new)
439
+ train_state_br = train_state_br.apply_gradients(grads=grads_br_new)
440
+ return (train_state_conf, train_state_br), ((loss_val_conf, aux_vals_conf), (loss_val_br, aux_vals_br))
441
+
442
+ (
443
+ train_state_conf, train_state_br,
444
+ traj_batch_conf, traj_batch_br,
445
+ advantages_conf, advantages_br,
446
+ targets_conf, targets_br,
447
+ rng, lms_vertical, lms_horizontal
448
+ ) = update_state
449
+ rng, perm_rng_conf, perm_rng_br = jax.random.split(rng, 3)
450
+
451
+ minibatches_conf = _create_minibatches(traj_batch_conf, advantages_conf, targets_conf, init_conf_hstate,
452
+ config["NUM_CONF_ACTORS"], config["NUM_MINIBATCHES"], perm_rng_conf)
453
+ minibatches_br = _create_minibatches(traj_batch_br, advantages_br, targets_br, init_br_hstate,
454
+ config["NUM_BR_ACTORS"], config["NUM_MINIBATCHES"], perm_rng_br)
455
+
456
+ # Update both policies
457
+ num_minibatches = minibatches_br[1].obs.shape[0]
458
+
459
+ repeated_lms_vertical = lms_vertical[jnp.newaxis, ...].repeat(num_minibatches, axis=0)
460
+ repeated_lms_horizontal = lms_horizontal[jnp.newaxis, ...].repeat(num_minibatches, axis=0)
461
+
462
+ (train_state_conf, train_state_br), all_losses = jax.lax.scan(
463
+ _update_minbatch, (train_state_conf, train_state_br),
464
+ (minibatches_conf, minibatches_br, repeated_lms_vertical, repeated_lms_horizontal)
465
+ )
466
+
467
+ update_state = (train_state_conf, train_state_br,
468
+ traj_batch_conf, traj_batch_br,
469
+ advantages_conf, advantages_br,
470
+ targets_conf, targets_br,
471
+ rng, lms_vertical, lms_horizontal
472
+ )
473
+ return update_state, all_losses
474
+
475
+ def _update_step(update_runner_state, unused):
476
+ """
477
+ 1. Collect rollouts
478
+ 2. Compute advantage
479
+ 3. PPO updates (UPDATE_EPOCHS epochs)
480
+ 4. Lagrange multiplier update (once, after all PPO epochs)
481
+ """
482
+ (
483
+ all_train_state_conf, all_train_state_br,
484
+ last_env_state, last_obs, last_done, last_conf_h, last_br_h,
485
+ rng, update_steps, lms_vertical, lms_horizontal
486
+ ) = update_runner_state
487
+
488
+ rng, conf_sampling_rng, br_sampling_rng = jax.random.split(rng, 3)
489
+
490
+ conf_ids = jax.random.randint(conf_sampling_rng, (config["NUM_ENVS"],), 0, config["PARTNER_POP_SIZE"])
491
+ br_ids = jax.random.randint(br_sampling_rng, (config["NUM_ENVS"],), 0, config["PARTNER_POP_SIZE"])
492
+
493
+ runner_state = (
494
+ all_train_state_conf, all_train_state_br, conf_ids, br_ids,
495
+ last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng
496
+ )
497
+ runner_state, traj_batch = jax.lax.scan(
498
+ _env_step, runner_state, None, config["ROLLOUT_LENGTH"])
499
+ (all_train_state_conf, all_train_state_br, last_conf_ids, last_br_ids,
500
+ last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng) = runner_state
501
+
502
+ # Get the last conf and br params and ids
503
+ last_conf_params = gather_params(all_train_state_conf.params, last_conf_ids)
504
+ last_br_params = gather_params(all_train_state_br.params, last_br_ids)
505
+
506
+ last_conf_one_hots = identity_matrix[last_conf_ids]
507
+ last_br_one_hots = identity_matrix[last_br_ids]
508
+
509
+ # Get agent 0 and agent 1 trajectories from interaction between conf policy and its BR policy.
510
+ traj_batch_conf, traj_batch_br = traj_batch
511
+
512
+ # Compute advantage for confederate agent from interaction with br policy
513
+ avail_actions_0 = jax.vmap(env.get_avail_actions)(last_env_state.env_state)["agent_0"].astype(jnp.float32)
514
+ _, last_val_conf, _, _ = jax.vmap(forward_pass_conf)(
515
+ params=last_conf_params,
516
+ obs=last_obs["agent_0"],
517
+ id=last_br_one_hots,
518
+ done=last_done["agent_0"],
519
+ avail_actions=avail_actions_0,
520
+ hstate=last_conf_h,
521
+ rng=jax.random.split(jax.random.PRNGKey(0), config["NUM_ENVS"]) # Dummy key since we're just extracting the value
522
+ )
523
+ last_val_conf = last_val_conf.squeeze()
524
+ advantages_conf, targets_conf = _calculate_gae(traj_batch_conf, last_val_conf)
525
+
526
+ # Compute advantage for br policy from interaction with confederate agent
527
+ avail_actions_1 = jax.vmap(env.get_avail_actions)(last_env_state.env_state)["agent_1"].astype(jnp.float32)
528
+ _, last_val_br, _, _ = jax.vmap(forward_pass_br)(
529
+ params=last_br_params,
530
+ obs=last_obs["agent_1"],
531
+ id=last_conf_one_hots,
532
+ done=last_done["agent_1"],
533
+ avail_actions=avail_actions_1,
534
+ hstate=last_br_h,
535
+ rng=jax.random.split(jax.random.PRNGKey(0), config["NUM_ENVS"]) # Dummy key since we're just extracting the value
536
+ )
537
+ last_val_br = last_val_br.squeeze()
538
+ advantages_br, targets_br = _calculate_gae(traj_batch_br, last_val_br)
539
+
540
+ # 3) PPO update
541
+ rng, update_rng = jax.random.split(rng, 2)
542
+ update_state = (
543
+ all_train_state_conf, all_train_state_br,
544
+ traj_batch_conf, traj_batch_br,
545
+ advantages_conf, advantages_br,
546
+ targets_conf, targets_br,
547
+ update_rng, lms_vertical, lms_horizontal
548
+ )
549
+
550
+ update_state, all_losses = jax.lax.scan(
551
+ _update_epoch, update_state, None, config["UPDATE_EPOCHS"])
552
+ all_train_state_conf, all_train_state_br = update_state[:2]
553
+ lms_vertical, lms_horizontal = update_state[-2:]
554
+
555
+ # Compute Lagrange gradient updates once per update step (after all PPO epochs).
556
+ # Diagonal and off-diagonal pairs use separate vmaps to avoid evaluating both
557
+ # branches of lax.cond for all pop_size^2 elements under vmap.
558
+ def compute_lagrange_grads_same(params_br, batch, target_value, ids):
559
+ conf_id, br_id = ids
560
+
561
+ all_target_value = jnp.reshape(target_value, (-1, 1))
562
+ repeated_value_sp = jnp.repeat(
563
+ jnp.reshape(all_target_value, (1, -1)),
564
+ config["PARTNER_POP_SIZE"],
565
+ axis=0
566
+ )
567
+
568
+ relevant_conf_params = gather_params(params_br, jnp.reshape(conf_id, (1,)))
569
+ relevant_conf_params = jax.tree.map(lambda x: jnp.squeeze(x, 0), relevant_conf_params)
570
+ def _get_value_xp_vary_conf(param, agent_onehot_id):
571
+ ts, bs = batch.obs.shape[:2]
572
+ agent_onehot_id = agent_onehot_id[jnp.newaxis, jnp.newaxis, ...].repeat(ts, axis=0).repeat(bs, axis=1)
573
+ _, value_xp_vary_conf, _, _ = br_policy.get_action_value_policy(
574
+ params=param,
575
+ obs=batch.obs,
576
+ done=batch.done,
577
+ avail_actions=batch.avail_actions,
578
+ hstate=init_br_hstate,
579
+ rng=jax.random.PRNGKey(0),
580
+ aux_obs=agent_onehot_id
581
+ )
582
+ return value_xp_vary_conf.reshape(ts*bs)
583
+
584
+ all_possible_value_xp_vary_conf = jax.vmap(
585
+ lambda agent_id: _get_value_xp_vary_conf(relevant_conf_params, agent_id)
586
+ )(jnp.eye(config["PARTNER_POP_SIZE"]))
587
+ all_possible_value_xp_vary_conf = all_possible_value_xp_vary_conf.at[conf_id].set(
588
+ repeated_value_sp[conf_id]
589
+ )
590
+
591
+ offsetting_thresholds = jnp.zeros_like(repeated_value_sp)
592
+ offsetting_thresholds = offsetting_thresholds.at[conf_id].set(
593
+ config["TOLERANCE_FACTOR"] * jnp.ones_like(offsetting_thresholds[conf_id])
594
+ )
595
+ grad_sp_vary_conf = repeated_value_sp + offsetting_thresholds - (
596
+ all_possible_value_xp_vary_conf + config["TOLERANCE_FACTOR"] * jnp.ones_like(offsetting_thresholds)
597
+ )
598
+
599
+ ##### Compute grad_sp_vary_br
600
+ # This code tries to measure the expected returns of the ego agent had the BR policy been
601
+ # substituted by another BR policy
602
+
603
+ # Lets say that R_{i,-j} is the ego agent's returns when following the BR policy of the i^th pair
604
+ # againts the confederate policy of the j^th pair.
605
+
606
+ # Then grad_sp_vary_conf computes R_{i,-i} - R_{i,-j} - tolerance factor
607
+ # for all possible j (note for j=i, we sub in <repeated_value_sp + offsetting_thresholds above>
608
+ # R_{i,-i} with the target returns + tolerance factor so that R_{i,-i} - R_{i,-j} = 0)
609
+
610
+ # Meanwhile grad_sp_vary_br below computes R_{i,-i} - R_{j,-i} - tolerance factor
611
+ # for all possible j.
612
+
613
+ # Vary the BR policy parameters (j) used in value computation
614
+ # Use the experience generating pop id (batch.self_onehot_id) <i> as the conf ID.
615
+
616
+ relevant_params = gather_params(params_br, jnp.arange(config["PARTNER_POP_SIZE"]))
617
+ def _get_value_xp_vary_br(param):
618
+ ts, bs = batch.obs.shape[:2]
619
+ conf_one_hot = jnp.eye(config["PARTNER_POP_SIZE"])[conf_id]
620
+ conf_one_hot = conf_one_hot[jnp.newaxis, jnp.newaxis, ...].repeat(ts, axis=0).repeat(bs, axis=1)
621
+ _, value_xp_vary_br, _, _ = br_policy.get_action_value_policy(
622
+ params=param,
623
+ obs=batch.obs,
624
+ done=batch.done,
625
+ avail_actions=batch.avail_actions,
626
+ hstate=init_br_hstate,
627
+ rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
628
+ aux_obs=conf_one_hot
629
+ )
630
+ return value_xp_vary_br.reshape(ts*bs)
631
+
632
+ all_possible_value_xp_vary_br = jax.vmap(
633
+ lambda param: _get_value_xp_vary_br(param)
634
+ )(relevant_params)
635
+ all_possible_value_xp_vary_br = jnp.reshape(
636
+ all_possible_value_xp_vary_br, (config["PARTNER_POP_SIZE"], -1)
637
+ )
638
+ all_possible_value_xp_vary_br = all_possible_value_xp_vary_br.at[conf_id].set(
639
+ repeated_value_sp[conf_id]
640
+ )
641
+
642
+ grad_sp_vary_br = repeated_value_sp + offsetting_thresholds - (
643
+ all_possible_value_xp_vary_br + config["TOLERANCE_FACTOR"] * jnp.ones_like(offsetting_thresholds)
644
+ )
645
+
646
+ all_self_id_int = jnp.reshape(
647
+ batch.self_onehot_id, (-1, jnp.shape(batch.self_onehot_id)[-1])
648
+ ).argmax(axis=-1)
649
+ all_oppo_id_int = jnp.reshape(
650
+ batch.oppo_onehot_id, (-1, jnp.shape(batch.oppo_onehot_id)[-1])
651
+ ).argmax(axis=-1)
652
+
653
+ self_is_conf = jnp.equal(all_self_id_int, conf_id).astype(jnp.float32)
654
+ oppo_is_conf = jnp.equal(all_oppo_id_int, conf_id).astype(jnp.float32)
655
+ loss_weights = self_is_conf * oppo_is_conf
656
+ repeated_loss_weights = jnp.repeat(
657
+ jnp.expand_dims(loss_weights, axis=0),
658
+ config["PARTNER_POP_SIZE"],
659
+ axis=0
660
+ )
661
+
662
+ # Compute vertical and horizontal gradient
663
+ vertical_grads = jnp.sum(grad_sp_vary_conf * repeated_loss_weights, axis=-1) / (jnp.sum(loss_weights) + 1e-8)
664
+ horizontal_grads = jnp.sum(grad_sp_vary_br * repeated_loss_weights, axis=-1) / (jnp.sum(loss_weights) + 1e-8)
665
+
666
+ output_grad_matrix_vertical = jnp.zeros((config["PARTNER_POP_SIZE"], config["PARTNER_POP_SIZE"]))
667
+ output_grad_matrix_horizontal = jnp.zeros((config["PARTNER_POP_SIZE"], config["PARTNER_POP_SIZE"]))
668
+ output_grad_matrix_vertical = output_grad_matrix_vertical.at[conf_id].set(vertical_grads)
669
+ output_grad_matrix_horizontal = output_grad_matrix_horizontal.at[conf_id].set(horizontal_grads)
670
+ return output_grad_matrix_vertical, output_grad_matrix_horizontal
671
+
672
+ def compute_lagrange_grads_diff(params_br, batch, target_returns, ids):
673
+ conf_id, br_id = ids
674
+ param_conf_id = gather_params(params_br, jnp.reshape(conf_id, (1,)))
675
+ param_br_id = gather_params(params_br, jnp.reshape(br_id, (1,)))
676
+ param_br_id = jax.tree.map(lambda x: jnp.squeeze(x, 0), param_br_id)
677
+ param_conf_id = jax.tree.map(lambda x: jnp.squeeze(x, 0), param_conf_id)
678
+
679
+ all_self_id_int = jnp.reshape(
680
+ batch.self_onehot_id, (-1, jnp.shape(batch.self_onehot_id)[-1])
681
+ ).argmax(axis=-1)
682
+ all_oppo_id_int = jnp.reshape(
683
+ batch.oppo_onehot_id, (-1, jnp.shape(batch.oppo_onehot_id)[-1])
684
+ ).argmax(axis=-1)
685
+ all_target_returns = jnp.reshape(target_returns, (-1))
686
+
687
+ # Compute data weights based on whether selected ID
688
+ # is relevant for the gradient computation process
689
+ oppo_is_conf = jnp.equal(all_oppo_id_int, conf_id).astype(jnp.float32)
690
+ self_is_br = jnp.equal(all_self_id_int, br_id).astype(jnp.float32)
691
+ loss_weights = oppo_is_conf * self_is_br
692
+
693
+ ts, bs = batch.obs.shape[:2]
694
+ conf_one_hot = jnp.eye(config["PARTNER_POP_SIZE"])[conf_id]
695
+ conf_one_hot = conf_one_hot[jnp.newaxis, jnp.newaxis, ...].repeat(ts, axis=0).repeat(bs, axis=1)
696
+ br_one_hot = jnp.eye(config["PARTNER_POP_SIZE"])[br_id]
697
+ br_one_hot = br_one_hot[jnp.newaxis, jnp.newaxis, ...].repeat(ts, axis=0).repeat(bs, axis=1)
698
+
699
+ _, value_sp_pop_is_br, _, _ = br_policy.get_action_value_policy(
700
+ params=param_br_id,
701
+ obs=batch.obs,
702
+ done=batch.done,
703
+ avail_actions=batch.avail_actions,
704
+ hstate=init_br_hstate,
705
+ rng=jax.random.PRNGKey(0),
706
+ aux_obs=br_one_hot
707
+ )
708
+ value_sp_pop_is_br = value_sp_pop_is_br.reshape(bs*ts)
709
+
710
+ _, value_sp_pop_is_not_br, _, _ = br_policy.get_action_value_policy(
711
+ params=param_conf_id,
712
+ obs=batch.obs,
713
+ done=batch.done,
714
+ avail_actions=batch.avail_actions,
715
+ hstate=init_br_hstate,
716
+ rng=jax.random.PRNGKey(0),
717
+ aux_obs=conf_one_hot
718
+ )
719
+ value_sp_pop_is_not_br = value_sp_pop_is_not_br.reshape(bs*ts)
720
+
721
+ vertical_diff = value_sp_pop_is_br - all_target_returns - config["TOLERANCE_FACTOR"]
722
+ horizontal_diff = value_sp_pop_is_not_br - all_target_returns - config["TOLERANCE_FACTOR"]
723
+
724
+ total_grad_vertical = (loss_weights * vertical_diff).sum() / (loss_weights.sum() + 1e-8)
725
+ total_grad_horizontal = (loss_weights * horizontal_diff).sum() / (loss_weights.sum() + 1e-8)
726
+
727
+ output_grad_matrix_vertical = jnp.zeros((config["PARTNER_POP_SIZE"], config["PARTNER_POP_SIZE"]))
728
+ output_grad_matrix_horizontal = jnp.zeros((config["PARTNER_POP_SIZE"], config["PARTNER_POP_SIZE"]))
729
+ output_grad_matrix_vertical = output_grad_matrix_vertical.at[br_id, conf_id].set(total_grad_vertical)
730
+ output_grad_matrix_horizontal = output_grad_matrix_horizontal.at[conf_id, br_id].set(total_grad_horizontal)
731
+ return output_grad_matrix_vertical, output_grad_matrix_horizontal
732
+
733
+ # Diagonal pairs (conf_id == br_id): vmap over pop_size elements only
734
+ diag_ids = np.arange(config["PARTNER_POP_SIZE"])
735
+ diag_lagrange_grads = jax.vmap(
736
+ lambda conf_id, br_id: compute_lagrange_grads_same(
737
+ all_train_state_br.params, traj_batch_br, targets_br, (conf_id, br_id)
738
+ )
739
+ )(diag_ids, diag_ids)
740
+
741
+ # Off-diagonal pairs (conf_id != br_id): vmap over pop_size*(pop_size-1) elements only
742
+ all_conf_ids_np, all_br_ids_np = _get_all_ids(config["PARTNER_POP_SIZE"])
743
+ off_diag_mask = all_conf_ids_np != all_br_ids_np
744
+ off_diag_conf_ids = all_conf_ids_np[off_diag_mask]
745
+ off_diag_br_ids = all_br_ids_np[off_diag_mask]
746
+ off_diag_lagrange_grads = jax.vmap(
747
+ lambda conf_id, br_id: compute_lagrange_grads_diff(
748
+ all_train_state_br.params, traj_batch_br, targets_br, (conf_id, br_id)
749
+ )
750
+ )(off_diag_conf_ids, off_diag_br_ids)
751
+
752
+ averaged_grad_vertical = (
753
+ jnp.sum(diag_lagrange_grads[0], axis=0) +
754
+ jnp.sum(off_diag_lagrange_grads[0], axis=0)
755
+ )
756
+ averaged_grad_horizontal = (
757
+ jnp.sum(diag_lagrange_grads[1], axis=0) +
758
+ jnp.sum(off_diag_lagrange_grads[1], axis=0)
759
+ )
760
+
761
+ lms_vertical = jnp.maximum(
762
+ lms_vertical - config["LAGRANGE_LR"] * averaged_grad_vertical,
763
+ 0.5 * jnp.eye(config["PARTNER_POP_SIZE"])
764
+ )
765
+ lms_vertical = jnp.fill_diagonal(
766
+ lms_vertical, 0.5 * jnp.ones((config["PARTNER_POP_SIZE"]), dtype=jnp.float32),
767
+ inplace=False
768
+ )
769
+ lms_horizontal = jnp.maximum(
770
+ lms_horizontal - config["LAGRANGE_LR"] * averaged_grad_horizontal,
771
+ 0.5 * jnp.eye(config["PARTNER_POP_SIZE"]),
772
+ )
773
+ lms_horizontal = jnp.fill_diagonal(
774
+ lms_horizontal, 0.5 * jnp.ones((config["PARTNER_POP_SIZE"]), dtype=jnp.float32),
775
+ inplace=False
776
+ )
777
+
778
+ (_, (value_loss_conf, pg_loss_conf, entropy_conf)), (_, (value_loss_br, pg_loss_br, entropy_br)) = all_losses
779
+
780
+ # Metrics
781
+ def mask_and_mean(x, mask):
782
+ return jnp.where(mask, x, 0).sum() / jnp.maximum(1, mask.sum())
783
+
784
+ mask = traj_batch_conf.info.get("returned_episode", jnp.ones_like(traj_batch_conf.reward))
785
+ metric = jax.tree.map(lambda x: mask_and_mean(x, mask), traj_batch_conf.info)
786
+ metric["lms_vertical"] = lms_vertical
787
+ metric["lms_horizontal"] = lms_horizontal
788
+ metric["update_steps"] = update_steps
789
+ metric["value_loss_conf_agent"] = value_loss_conf.mean(axis=(0, 1))
790
+ metric["value_loss_br_agent"] = value_loss_br.mean(axis=(0, 1))
791
+
792
+ metric["pg_loss_conf_agent"] = pg_loss_conf.mean(axis=(0, 1))
793
+ metric["pg_loss_br_agent"] = pg_loss_br.mean(axis=(0, 1))
794
+
795
+ metric["entropy_conf"] = entropy_conf.mean(axis=(0, 1))
796
+ metric["entropy_br"] = entropy_br.mean(axis=(0, 1))
797
+
798
+ new_runner_state = (
799
+ all_train_state_conf, all_train_state_br,
800
+ last_env_state, last_obs, last_done, last_conf_h, last_br_h,
801
+ rng, update_steps + 1,
802
+ lms_vertical, lms_horizontal
803
+ )
804
+ return (new_runner_state, metric)
805
+
806
+ # --------------------------
807
+ # PPO Update and Checkpoint saving
808
+ # --------------------------
809
+ ckpt_and_eval_interval = config["NUM_UPDATES"] // max(1, config["NUM_CHECKPOINTS"] - 1) # -1 because we store a ckpt at the last update
810
+ num_ckpts = config["NUM_CHECKPOINTS"]
811
+
812
+ # Build a PyTree that holds parameters for all conf agent checkpoints
813
+ def init_ckpt_array(params_pytree):
814
+ return jax.tree.map(
815
+ lambda x: jnp.zeros((num_ckpts,) + x.shape, x.dtype),
816
+ params_pytree)
817
+
818
+ def _update_step_with_ckpt(state_with_ckpt, unused):
819
+ (update_runner_state, checkpoint_array_conf, checkpoint_array_br, ckpt_idx,
820
+ eval_info) = state_with_ckpt
821
+
822
+ # Single PPO update
823
+ new_runner_state, metric = _update_step(update_runner_state, None)
824
+
825
+ (
826
+ train_state_conf, train_state_br,
827
+ last_env_state, last_obs, last_done, last_conf_h, last_br_h,
828
+ rng, update_steps, lms_vertical, lms_horizontal
829
+ ) = new_runner_state
830
+
831
+ # Decide if we store a checkpoint
832
+ # update steps is 1-indexed because it was incremented at the end of the update step
833
+ to_store = jnp.logical_or(jnp.equal(jnp.mod(update_steps-1, ckpt_and_eval_interval), 0),
834
+ jnp.equal(update_steps, config["NUM_UPDATES"]))
835
+
836
+ def store_and_eval_ckpt(args):
837
+ ckpt_arr_and_ep_infos, rng, cidx = args
838
+ ckpt_arr_conf, ckpt_arr_br, _ = ckpt_arr_and_ep_infos
839
+ new_ckpt_arr_conf = jax.tree.map(
840
+ lambda c_arr, p: c_arr.at[cidx].set(p),
841
+ ckpt_arr_conf, train_state_conf.params
842
+ )
843
+ new_ckpt_arr_br = jax.tree.map(
844
+ lambda c_arr, p: c_arr.at[cidx].set(p),
845
+ ckpt_arr_br, train_state_br.params
846
+ )
847
+
848
+ rng, eval_rng = jax.random.split(rng)
849
+ ep_last_info = jax.tree.map(lambda x: x.mean(axis=(-2, -1)),
850
+ run_all_episodes(eval_rng, train_state_conf, train_state_br))
851
+
852
+ return ((new_ckpt_arr_conf, new_ckpt_arr_br, ep_last_info), rng, cidx + 1)
853
+
854
+ def skip_ckpt(args):
855
+ return args
856
+
857
+ (checkpoint_array_and_infos, rng, ckpt_idx) = jax.lax.cond(
858
+ to_store,
859
+ store_and_eval_ckpt,
860
+ skip_ckpt,
861
+ ((checkpoint_array_conf, checkpoint_array_br, eval_info), rng, ckpt_idx)
862
+ )
863
+ checkpoint_array_conf, checkpoint_array_br, eval_ep_last_info = checkpoint_array_and_infos
864
+
865
+ metric["eval_ep_last_info"] = eval_ep_last_info # return of confederate
866
+
867
+ return ((train_state_conf, train_state_br,
868
+ last_env_state, last_obs, last_done, last_conf_h, last_br_h,
869
+ rng, update_steps, lms_vertical, lms_horizontal),
870
+ checkpoint_array_conf, checkpoint_array_br, ckpt_idx,
871
+ eval_ep_last_info), metric
872
+
873
+ # Initialize checkpoint array
874
+ checkpoint_array_conf = init_ckpt_array(all_conf_optims.params)
875
+ checkpoint_array_br = init_ckpt_array(all_br_optims.params)
876
+ ckpt_idx = 0
877
+
878
+ # Initialize state for scan over _update_step_with_ckpt
879
+ update_steps = 0
880
+
881
+ rng, rng_eval = jax.random.split(rng, 2)
882
+ eval_ep_last_info = jax.tree.map(lambda x: x.mean(axis=(-2, -1)),
883
+ run_all_episodes(rng_eval, all_conf_optims, all_br_optims))
884
+
885
+ # Initialize environment
886
+ rng, reset_rng = jax.random.split(rng)
887
+ reset_rngs = jax.random.split(reset_rng, config["NUM_ENVS"])
888
+ init_obs, init_env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rngs)
889
+ init_done = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]}
890
+
891
+ # Initialize conf and br hstates
892
+ init_conf_h = conf_policy.init_hstate(config["NUM_CONF_ACTORS"])
893
+ init_br_h = br_policy.init_hstate(config["NUM_BR_ACTORS"])
894
+
895
+ # Initialize LMs
896
+ # lm_vertical[i, j] stores the lagrange multiplier for upholding
897
+ # R_{conf(i), BR(i)} >= R_{conf(j), BR(i)} + tolerance_factor
898
+
899
+ # lm_horizontal[i, j] stores the lagrange multiplier for upholding
900
+ # R_{conf(i), BR(i)} >= R_{conf(i), BR(j)} + tolerance_factor
901
+
902
+ # Diagonal elements of both matrices sum up to 1.
903
+ # Providing a weight of 1 to maximize the SP return from any population
904
+ lagrange_multipliers_vertical = 0.5 * jnp.eye(config["PARTNER_POP_SIZE"])
905
+ lagrange_multipliers_horizontal = 0.5 * jnp.eye(config["PARTNER_POP_SIZE"])
906
+
907
+ update_runner_state = (
908
+ all_conf_optims, all_br_optims,
909
+ init_env_state, init_obs, init_done, init_conf_h, init_br_h,
910
+ rng, update_steps,
911
+ lagrange_multipliers_vertical, lagrange_multipliers_horizontal
912
+ )
913
+
914
+ state_with_ckpt = (
915
+ update_runner_state, checkpoint_array_conf,
916
+ checkpoint_array_br, ckpt_idx, eval_ep_last_info
917
+ )
918
+
919
+ # run training
920
+ state_with_ckpt, metrics = jax.lax.scan(
921
+ _update_step_with_ckpt,
922
+ state_with_ckpt,
923
+ xs=None,
924
+ length=config["NUM_UPDATES"]
925
+ )
926
+
927
+ (
928
+ final_runner_state, checkpoint_array_conf, checkpoint_array_br,
929
+ final_ckpt_idx, all_ep_infos
930
+ ) = state_with_ckpt
931
+
932
+ out = {
933
+ "final_params_conf": final_runner_state[0].params,
934
+ "final_params_br": final_runner_state[1].params,
935
+ "checkpoints_conf": checkpoint_array_conf,
936
+ "checkpoints_br": checkpoint_array_br,
937
+ "metrics": metrics, # metrics is from the perspective of the confederate agent (averaged over population)
938
+ "all_pair_returns": all_ep_infos
939
+ }
940
+ return out
941
+
942
+ return train
943
+ # ------------------------------
944
+ # Actually run the adversarial teammate training
945
+ # ------------------------------
946
+ train_fn = make_lbrdiv_agents(config)
947
+ out = train_fn(train_rng)
948
+ return out
949
+
950
+ def get_lbrdiv_population(config, out, env):
951
+ '''
952
+ Get the partner params and partner population for ego training.
953
+ '''
954
+ pop_size = config["algorithm"]["PARTNER_POP_SIZE"]
955
+
956
+ # partner_params has shape (num_seeds, pop_size, ...)
957
+ partner_params = out['final_params_conf']
958
+
959
+ partner_policy = ActorWithConditionalCriticPolicy(
960
+ action_dim=env.action_space(env.agents[1]).n,
961
+ obs_dim=env.observation_space(env.agents[1]).shape[0],
962
+ pop_size=pop_size, # used to create onehot agent id
963
+ activation=config["algorithm"].get("ACTIVATION", "tanh")
964
+ )
965
+
966
+ # Create partner population
967
+ partner_population = AgentPopulation(
968
+ pop_size=pop_size,
969
+ policy_cls=partner_policy
970
+ )
971
+
972
+ return partner_params, partner_population
973
+
974
+ def run_lbrdiv(config, wandb_logger):
975
+ algorithm_config = dict(config["algorithm"])
976
+
977
+ env = make_env(algorithm_config["ENV_NAME"], algorithm_config["ENV_KWARGS"])
978
+ env = LogWrapper(env)
979
+
980
+ log.info("Starting LBRDiv training...")
981
+ start = time.time()
982
+
983
+ # Generate multiple random seeds from the base seed
984
+ rng = jax.random.PRNGKey(algorithm_config["TRAIN_SEED"])
985
+ rngs = jax.random.split(rng, algorithm_config["NUM_SEEDS"])
986
+
987
+ # Initialize br and conf policies
988
+ conf_policy = ActorWithConditionalCriticPolicy(
989
+ action_dim=env.action_space(env.agents[0]).n,
990
+ obs_dim=env.observation_space(env.agents[0]).shape[0],
991
+ pop_size=algorithm_config["PARTNER_POP_SIZE"],
992
+ )
993
+ br_policy = ActorWithConditionalCriticPolicy(
994
+ action_dim=env.action_space(env.agents[0]).n,
995
+ obs_dim=env.observation_space(env.agents[0]).shape[0],
996
+ pop_size=algorithm_config["PARTNER_POP_SIZE"],
997
+ )
998
+
999
+ # Create a vmapped version of train_lbrdiv_partners
1000
+ with jax.disable_jit(False):
1001
+ vmapped_train_fn = jax.jit(
1002
+ jax.vmap(
1003
+ partial(train_lbrdiv_partners, env=env, config=algorithm_config, conf_policy=conf_policy, br_policy=br_policy)
1004
+ )
1005
+ )
1006
+ out = vmapped_train_fn(rngs)
1007
+
1008
+ end = time.time()
1009
+ log.info(f"LBRDiv training complete in {end - start} seconds")
1010
+
1011
+ metric_names = get_metric_names(algorithm_config["ENV_NAME"])
1012
+ log_metrics(config, out, wandb_logger, metric_names)
1013
+
1014
+ partner_params, partner_population = get_lbrdiv_population(config, out, env)
1015
+
1016
+ return partner_params, partner_population
1017
+
1018
+
1019
+ def log_metrics(config, outs, logger, metric_names: tuple):
1020
+ metrics = outs["metrics"]
1021
+ # metrics now has shape (num_seeds, num_updates, pop_size)
1022
+ num_seeds, num_updates, pop_size = metrics["pg_loss_conf_agent"].shape # number of trained pairs
1023
+
1024
+ ### Log evaluation metrics
1025
+ # shape (num_seeds, num_updates, (pop_size)^2) [pre-scalarized: mean over eval eps and agents taken inside scan]
1026
+ all_returns = np.asarray(metrics["eval_ep_last_info"]["returned_episode_returns"])
1027
+ xs = list(range(num_updates))
1028
+
1029
+ all_conf_ids, all_br_ids = _get_all_ids(pop_size)
1030
+ sp_mask = (all_conf_ids == all_br_ids)
1031
+ sp_returns = all_returns[:, :, sp_mask]
1032
+ xp_returns = all_returns[:, :, ~sp_mask]
1033
+
1034
+ # Average over seeds and agent pairs (eval episodes and agents already averaged inside scan)
1035
+ sp_return_curve = sp_returns.mean(axis=(0, 2))
1036
+ xp_return_curve = xp_returns.mean(axis=(0, 2))
1037
+
1038
+ for step in range(num_updates):
1039
+ logger.log_item("Eval/AvgSPReturnCurve", sp_return_curve[step], train_step=step)
1040
+ logger.log_item("Eval/AvgXPReturnCurve", xp_return_curve[step], train_step=step)
1041
+ logger.commit()
1042
+
1043
+ # log final XP matrix to wandb - average over seeds
1044
+ last_returns_array = all_returns[:, -1].mean(axis=0)
1045
+ last_returns_array = np.reshape(last_returns_array, (pop_size, pop_size))
1046
+ logger.log_xp_matrix("Eval/LastXPMatrix", last_returns_array)
1047
+
1048
+ ### Log population loss as multi-line plots, where each line is a different population member
1049
+ # shape (num_seeds, num_updates, update_epochs, num_minibatches, pop_size)
1050
+ # Average over seeds
1051
+ processed_losses = {
1052
+ "ConfPGLoss": np.asarray(metrics["pg_loss_conf_agent"]).mean(axis=0).transpose(),
1053
+ "BRPGLoss": np.asarray(metrics["pg_loss_br_agent"]).mean(axis=0).transpose(),
1054
+ "ConfValLoss": np.asarray(metrics["value_loss_conf_agent"]).mean(axis=0).transpose(),
1055
+ "BRValLoss": np.asarray(metrics["value_loss_br_agent"]).mean(axis=0).transpose(),
1056
+ "ConfEntropy": np.asarray(metrics["entropy_conf"]).mean(axis=0).transpose(),
1057
+ "BREntropy": np.asarray(metrics["entropy_br"]).mean(axis=0).transpose(),
1058
+ }
1059
+
1060
+ xs = list(range(num_updates))
1061
+ keys = [f"pair {i}" for i in range(pop_size)]
1062
+ for loss_name, loss_data in processed_losses.items():
1063
+ if np.isnan(loss_data).any():
1064
+ raise ValueError(f"Found nan in loss {loss_name}")
1065
+ logger.log_item(f"Losses/{loss_name}",
1066
+ wandb.plot.line_series(xs=xs, ys=loss_data, keys=keys,
1067
+ title=loss_name, xname="train_step")
1068
+ )
1069
+
1070
+ # Average over seeds for Lagrange multipliers
1071
+ lm_keys = [f"pair {i}, {j}" for i in range(pop_size) for j in range(pop_size)]
1072
+ lm_horizontal = np.asarray(metrics["lms_horizontal"]).mean(axis=0)
1073
+ lm_vertical = np.asarray(metrics["lms_vertical"]).mean(axis=0)
1074
+ lagrange_multipliers = {
1075
+ "LMs_Horizontal": np.reshape(lm_horizontal, (lm_horizontal.shape[0], -1)).transpose(),
1076
+ "LMs_Vertical": np.reshape(lm_vertical, (lm_vertical.shape[0], -1)).transpose()
1077
+ }
1078
+
1079
+ for array_name, array_data in lagrange_multipliers.items():
1080
+ if np.isnan(array_data).any():
1081
+ raise ValueError(f"Found nan in loss {array_name}")
1082
+ logger.log_item(
1083
+ f"Losses/{array_name}",
1084
+ wandb.plot.line_series(xs=xs, ys=array_data, keys=lm_keys,
1085
+ title=array_name, xname="train_step")
1086
+ )
1087
+ logger.commit()
1088
+
1089
+ ### Log artifacts
1090
+ savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
1091
+ # Save train run output and log to wandb as artifact
1092
+ out_savepath = save_train_run(outs, savedir, savename="saved_train_run")
1093
+ if config["logger"]["log_train_out"]:
1094
+ logger.log_artifact(name="saved_train_run", path=out_savepath, type_name="train_run")
1095
+
1096
+ # Cleanup locally logged out files
1097
+ if not config["local_logger"]["save_train_out"]:
1098
+ shutil.rmtree(out_savepath)
teammate_generation/__init__.py ADDED
File without changes
teammate_generation/configs/algorithm/brdiv/_base_.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package algorithm
2
+ # ^ tells hydra to place these value directly under algorithm key
3
+ ALG: brdiv
4
+ TOTAL_TIMESTEPS: 4.5e7 # divided among each pair of BR and Conf agents
5
+ NUM_CHECKPOINTS: 5
6
+ PARTNER_POP_SIZE: 4
7
+ NUM_ENVS: 64
8
+ # SP weight = 1 + 2*XP weight.
9
+ # Thus, as XP weight -> 0, SP/(SP+XP) -> 1.
10
+ # If XP weight -> infinity, XP/(SP+XP) -> 1/3, and SP/(SP+XP) -> 2/3.
11
+ XP_LOSS_WEIGHTS: 1
12
+ LR: 1e-4
13
+ UPDATE_EPOCHS: 15
14
+ NUM_MINIBATCHES: 4
15
+ GAMMA: 0.99
16
+ GAE_LAMBDA: 0.95
17
+ CLIP_EPS: 0.05
18
+ ENT_COEF: 0.01
19
+ VF_COEF: 0.5
20
+ MAX_GRAD_NORM: 1.0
21
+ ANNEAL_LR: false
22
+ ego_train_algorithm:
23
+ EGO_ACTOR_TYPE: s5
24
+ S5_D_MODEL: 16
25
+ S5_SSM_SIZE: 16
26
+ S5_ACTOR_CRITIC_HIDDEN_DIM: 64
27
+ FC_N_LAYERS: 2
28
+ TOTAL_TIMESTEPS: 1e7
29
+ NUM_CHECKPOINTS: 5
30
+ NUM_ENVS: 8
31
+ LR: 1e-4
32
+ UPDATE_EPOCHS: 15
33
+ NUM_MINIBATCHES: 4
34
+ GAMMA: 0.99
35
+ GAE_LAMBDA: 0.95
36
+ CLIP_EPS: 0.05
37
+ ENT_COEF: 0.01
38
+ VF_COEF: 0.5
39
+ MAX_GRAD_NORM: 1.0
40
+ ANNEAL_LR: true
teammate_generation/configs/algorithm/brdiv/hanabi.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - brdiv/_base_
3
+ - _self_
4
+
5
+ TOTAL_TIMESTEPS: 5e8
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 128
8
+ XP_LOSS_WEIGHTS: 0.05
9
+ LR: 5e-4
10
+ UPDATE_EPOCHS: 4
11
+ NUM_MINIBATCHES: 4
12
+ CLIP_EPS: 0.2
13
+ ENT_COEF: 0.01
14
+ ANNEAL_LR: true
15
+ GAMMA: 0.999
16
+ GAE_LAMBDA: 0.95
17
+ MAX_GRAD_NORM: 0.5
18
+ ego_train_algorithm:
19
+ TOTAL_TIMESTEPS: 1e8
20
+ LR: 5e-4
21
+ ENT_COEF: 0.01
22
+ CLIP_EPS: 0.2
23
+ ANNEAL_LR: true
24
+ UPDATE_EPOCHS: 4
25
+ GAMMA: 0.999
26
+ GAE_LAMBDA: 0.95
27
+ MAX_GRAD_NORM: 0.5
teammate_generation/configs/algorithm/brdiv/lbf/lbf_12x12.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - brdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 4.5e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 64
8
+ XP_LOSS_WEIGHTS: 0.05 # 0.1
9
+ LR: 5e-4
10
+ UPDATE_EPOCHS: 15
11
+ NUM_MINIBATCHES: 2 # 4
12
+ CLIP_EPS: 0.05
13
+ ENT_COEF: 0.01
14
+ ego_train_algorithm:
15
+ TOTAL_TIMESTEPS: 3e7
16
+ LR: 1e-4
17
+ ENT_COEF: 0.01
18
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/brdiv/lbf/lbf_7x7_nolevels.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - brdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 4.5e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 64
8
+ XP_LOSS_WEIGHTS: 0.05 # 0.1
9
+ LR: 5e-4
10
+ UPDATE_EPOCHS: 15
11
+ NUM_MINIBATCHES: 2 # 4
12
+ CLIP_EPS: 0.05
13
+ ENT_COEF: 0.01
14
+ ego_train_algorithm:
15
+ TOTAL_TIMESTEPS: 3e7
16
+ LR: 1e-4
17
+ ENT_COEF: 0.01
18
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/brdiv/mini-hanabi.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - brdiv/_base_
3
+ - _self_
4
+
5
+ # Mini-Hanabi (3c/3r/hand3) BRDiv config.
6
+ TOTAL_TIMESTEPS: 1e8
7
+ PARTNER_POP_SIZE: 3
8
+ NUM_ENVS: 128
9
+ XP_LOSS_WEIGHTS: 0.05
10
+ LR: 5e-4
11
+ UPDATE_EPOCHS: 4
12
+ NUM_MINIBATCHES: 4
13
+ CLIP_EPS: 0.2
14
+ ENT_COEF: 0.01
15
+ ANNEAL_LR: true
16
+ GAMMA: 0.999
17
+ GAE_LAMBDA: 0.95
18
+ MAX_GRAD_NORM: 0.5
19
+ ego_train_algorithm:
20
+ TOTAL_TIMESTEPS: 1e8
21
+ LR: 5e-4
22
+ ENT_COEF: 0.01
23
+ CLIP_EPS: 0.2
24
+ ANNEAL_LR: true
25
+ UPDATE_EPOCHS: 4
26
+ GAMMA: 0.999
27
+ GAE_LAMBDA: 0.95
28
+ MAX_GRAD_NORM: 0.5
teammate_generation/configs/algorithm/brdiv/overcooked-v1/asymm_advantages.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - brdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 4.5e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 64
8
+ XP_LOSS_WEIGHTS: 1
9
+ LR: .0001
10
+ UPDATE_EPOCHS: 15
11
+ NUM_MINIBATCHES: 16
12
+ CLIP_EPS: 0.3
13
+ ENT_COEF: 0.01
14
+ ego_train_algorithm:
15
+ TOTAL_TIMESTEPS: 3e7
16
+ LR: 1e-4
17
+ ENT_COEF: 0.01
18
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/brdiv/overcooked-v1/coord_ring.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - brdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 9e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 128
8
+ XP_LOSS_WEIGHTS: 0.007
9
+ LR: 5e-4
10
+ UPDATE_EPOCHS: 15
11
+ NUM_MINIBATCHES: 4
12
+ CLIP_EPS: 0.1
13
+ ENT_COEF: 0.05
14
+ ego_train_algorithm:
15
+ TOTAL_TIMESTEPS: 6e7
16
+ LR: 1e-3
17
+ ENT_COEF: 0.01
18
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/brdiv/overcooked-v1/counter_circuit.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - brdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 9e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 128
8
+ XP_LOSS_WEIGHTS: 0.005
9
+ LR: 1e-3
10
+ UPDATE_EPOCHS: 15
11
+ NUM_MINIBATCHES: 8
12
+ CLIP_EPS: 0.01
13
+ ENT_COEF: 0.05
14
+ ego_train_algorithm:
15
+ TOTAL_TIMESTEPS: 6e7
16
+ LR: 1e-3
17
+ ENT_COEF: 0.01
18
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/brdiv/overcooked-v1/cramped_room.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - brdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 4.5e7
6
+ PARTNER_POP_SIZE: 4
7
+ NUM_ENVS: 64
8
+ # SP weight = 1 + 2*XP weight.
9
+ # Thus, as XP weight -> 0, SP/(SP+XP) -> 1.
10
+ # If XP weight -> infinity, XP/(SP+XP) -> 1/3, and SP/(SP+XP) -> 2/3.
11
+ XP_LOSS_WEIGHTS: 0.5 # 10
12
+ LR: 1e-4
13
+ UPDATE_EPOCHS: 15
14
+ NUM_MINIBATCHES: 16
15
+ CLIP_EPS: 0.05
16
+ ENT_COEF: 0.01
17
+ ego_train_algorithm:
18
+ TOTAL_TIMESTEPS: 3e7
19
+ LR: 1e-4
20
+ ENT_COEF: 0.01
21
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/brdiv/overcooked-v1/forced_coord.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - brdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 9e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 128
8
+ XP_LOSS_WEIGHTS: 0.01
9
+ LR: 5e-4
10
+ UPDATE_EPOCHS: 15
11
+ NUM_MINIBATCHES: 16
12
+ CLIP_EPS: 0.05
13
+ ENT_COEF: 0.01
14
+ ego_train_algorithm:
15
+ TOTAL_TIMESTEPS: 6e7
16
+ LR: 1e-3
17
+ ENT_COEF: 0.01
18
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/comedi/_base_.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package algorithm
2
+ # ^ tells hydra to place these value directly under algorithm key
3
+ ALG: comedi
4
+ TOTAL_TIMESTEPS_PER_ITERATION: 1.2e7 # number of steps used to train each comedi agent at each iteration
5
+ NUM_CHECKPOINTS: 5
6
+ PARTNER_POP_SIZE: 4
7
+ NUM_ENVS: 48
8
+ LR: 1e-4
9
+ UPDATE_EPOCHS: 15
10
+ NUM_MINIBATCHES: 8
11
+ GAMMA: 0.99
12
+ GAE_LAMBDA: 0.95
13
+ CLIP_EPS: 0.05
14
+ ENT_COEF: 0.01
15
+ VF_COEF: 0.5
16
+ MAX_GRAD_NORM: 1.0
17
+ ANNEAL_LR: false
18
+ ACTOR_TYPE: actor_with_conditional_critic
19
+ NUM_ARGMAX_ROLLOUT_EPS: 20
20
+ COMEDI_ALPHA: 1.0
21
+ COMEDI_BETA: 0.5
22
+ ego_train_algorithm:
23
+ EGO_ACTOR_TYPE: s5
24
+ TOTAL_TIMESTEPS: 1e7
25
+ NUM_CHECKPOINTS: 5
26
+ NUM_ENVS: 8
27
+ LR: 1e-4
28
+ UPDATE_EPOCHS: 15
29
+ NUM_MINIBATCHES: 4
30
+ GAMMA: 0.99
31
+ GAE_LAMBDA: 0.95
32
+ CLIP_EPS: 0.05
33
+ ENT_COEF: 0.01
34
+ VF_COEF: 0.5
35
+ MAX_GRAD_NORM: 1.0
36
+ ANNEAL_LR: true
teammate_generation/configs/algorithm/comedi/hanabi.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - comedi/_base_
3
+ - _self_
4
+
5
+ TOTAL_TIMESTEPS_PER_ITERATION: 2e7
6
+ PARTNER_POP_SIZE: 5
7
+ LR: 5e-4
8
+ UPDATE_EPOCHS: 4
9
+ CLIP_EPS: 0.2
10
+ ENT_COEF: 0.01
11
+ ANNEAL_LR: true
12
+ GAMMA: 0.999
13
+ GAE_LAMBDA: 0.95
14
+ MAX_GRAD_NORM: 0.5
15
+ COMEDI_ALPHA: 0.2
16
+ COMEDI_BETA: 0.4
17
+ ego_train_algorithm:
18
+ TOTAL_TIMESTEPS: 1e8
19
+ LR: 5e-4
20
+ ENT_COEF: 0.01
21
+ CLIP_EPS: 0.2
22
+ ANNEAL_LR: true
23
+ UPDATE_EPOCHS: 4
24
+ GAMMA: 0.999
25
+ GAE_LAMBDA: 0.95
26
+ MAX_GRAD_NORM: 0.5
teammate_generation/configs/algorithm/comedi/lbf/lbf_12x12.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - comedi/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS_PER_ITERATION: 6e6
6
+ PARTNER_POP_SIZE: 10
7
+ LR: 5e-4
8
+ UPDATE_EPOCHS: 15
9
+ CLIP_EPS: 0.05
10
+ ENT_COEF: 0.001
11
+ COMEDI_ALPHA: 0.2 # weight on XP return
12
+ COMEDI_BETA: 0.4 # weight on SXP return
13
+ ego_train_algorithm:
14
+ TOTAL_TIMESTEPS: 3e7
15
+ LR: 5e-5
16
+ ENT_COEF: 1e-4
17
+ CLIP_EPS: 0.1
18
+ ANNEAL_LR: false
teammate_generation/configs/algorithm/comedi/lbf/lbf_7x7_nolevels.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - comedi/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS_PER_ITERATION: 6e6
6
+ PARTNER_POP_SIZE: 10
7
+ LR: 5e-4
8
+ UPDATE_EPOCHS: 15
9
+ CLIP_EPS: 0.05
10
+ ENT_COEF: 0.001
11
+ COMEDI_ALPHA: 0.2 # weight on XP return
12
+ COMEDI_BETA: 0.4 # weight on SXP return
13
+ ego_train_algorithm:
14
+ TOTAL_TIMESTEPS: 3e7
15
+ LR: 5e-5
16
+ ENT_COEF: 1e-4
17
+ CLIP_EPS: 0.1
18
+ ANNEAL_LR: false
teammate_generation/configs/algorithm/comedi/mini-hanabi.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - comedi/_base_
3
+ - _self_
4
+
5
+ # Mini-Hanabi (3c/3r/hand3) CoMeDi config.
6
+ TOTAL_TIMESTEPS_PER_ITERATION: 2e6
7
+ PARTNER_POP_SIZE: 5
8
+ LR: 5e-4
9
+ UPDATE_EPOCHS: 4
10
+ CLIP_EPS: 0.2
11
+ ENT_COEF: 0.01
12
+ ANNEAL_LR: true
13
+ GAMMA: 0.999
14
+ GAE_LAMBDA: 0.95
15
+ MAX_GRAD_NORM: 0.5
16
+ COMEDI_ALPHA: 0.2
17
+ COMEDI_BETA: 0.4
18
+ ego_train_algorithm:
19
+ TOTAL_TIMESTEPS: 1e8
20
+ LR: 5e-4
21
+ ENT_COEF: 0.01
22
+ CLIP_EPS: 0.2
23
+ ANNEAL_LR: true
24
+ UPDATE_EPOCHS: 4
25
+ GAMMA: 0.999
26
+ GAE_LAMBDA: 0.95
27
+ MAX_GRAD_NORM: 0.5
teammate_generation/configs/algorithm/comedi/overcooked-v1/asymm_advantages.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - comedi/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 6e6
6
+ PARTNER_POP_SIZE: 10
7
+ LR: .0001
8
+ UPDATE_EPOCHS: 15
9
+ CLIP_EPS: 0.3
10
+ ENT_COEF: 0.01
11
+ ego_train_algorithm:
12
+ TOTAL_TIMESTEPS: 3e7
13
+ LR: 5e-5
14
+ ENT_COEF: .001
15
+ CLIP_EPS: 0.1
16
+ UPDATE_EPOCHS: 10
teammate_generation/configs/algorithm/comedi/overcooked-v1/coord_ring.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - comedi/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 1e7
6
+ PARTNER_POP_SIZE: 10
7
+ LR: 5e-4
8
+ UPDATE_EPOCHS: 15
9
+ CLIP_EPS: 0.1
10
+ ENT_COEF: 0.05
11
+ ego_train_algorithm:
12
+ TOTAL_TIMESTEPS: 6e7
13
+ LR: 3e-5
14
+ ENT_COEF: .001
15
+ CLIP_EPS: 0.1
16
+ UPDATE_EPOCHS: 10
teammate_generation/configs/algorithm/comedi/overcooked-v1/counter_circuit.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - comedi/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 1e7
6
+ PARTNER_POP_SIZE: 10
7
+ LR: 1e-3
8
+ UPDATE_EPOCHS: 15
9
+ CLIP_EPS: 0.01 # 0.1
10
+ ENT_COEF: 0.05
11
+ ego_train_algorithm:
12
+ TOTAL_TIMESTEPS: 6e7
13
+ LR: 5e-5
14
+ ENT_COEF: .001
15
+ CLIP_EPS: 0.1
16
+ UPDATE_EPOCHS: 10
teammate_generation/configs/algorithm/comedi/overcooked-v1/cramped_room.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - comedi/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 6e6
6
+ PARTNER_POP_SIZE: 10
7
+ LR: 1e-4
8
+ UPDATE_EPOCHS: 15
9
+ CLIP_EPS: 0.05
10
+ ENT_COEF: 0.01
11
+ ego_train_algorithm:
12
+ TOTAL_TIMESTEPS: 3e7
13
+ LR: 5e-5
14
+ ANNEAL_LR: false
15
+ ENT_COEF: .001
16
+ CLIP_EPS: 0.1
17
+ UPDATE_EPOCHS: 10
teammate_generation/configs/algorithm/comedi/overcooked-v1/forced_coord.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - comedi/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 1e7
6
+ PARTNER_POP_SIZE: 10
7
+ LR: 5e-4
8
+ UPDATE_EPOCHS: 15
9
+ CLIP_EPS: 0.05
10
+ ENT_COEF: 0.01
11
+ ego_train_algorithm:
12
+ TOTAL_TIMESTEPS: 6e7
13
+ LR: 1e-5
14
+ ENT_COEF: 1e-4
15
+ CLIP_EPS: 0.1
16
+ UPDATE_EPOCHS: 5
teammate_generation/configs/algorithm/fcp/_base_.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package algorithm
2
+ # ^ tells hydra to place these value directly under algorithm key
3
+ ALG: fcp
4
+ ACTOR_TYPE: mlp
5
+ TOTAL_TIMESTEPS: 1e6 # per PARTNER_POP_SIZE trained
6
+ NUM_CHECKPOINTS: 5
7
+ PARTNER_POP_SIZE: 20 # true partner pop size is PARTNER_POP_SIZE * NUM_CHECKPOINTS
8
+ NUM_ENVS: 8
9
+ LR: 1e-4
10
+ UPDATE_EPOCHS: 15
11
+ NUM_MINIBATCHES: 4
12
+ GAMMA: 0.99
13
+ GAE_LAMBDA: 0.95
14
+ CLIP_EPS: 0.05
15
+ ENT_COEF: 0.01
16
+ VF_COEF: 0.5
17
+ MAX_GRAD_NORM: 1.0
18
+ ANNEAL_LR: true
19
+ ego_train_algorithm:
20
+ EGO_ACTOR_TYPE: s5
21
+ S5_D_MODEL: 16
22
+ S5_SSM_SIZE: 16
23
+ S5_ACTOR_CRITIC_HIDDEN_DIM: 64
24
+ FC_N_LAYERS: 2
25
+ TOTAL_TIMESTEPS: 1e7
26
+ NUM_CHECKPOINTS: 5
27
+ NUM_ENVS: 8
28
+ LR: 1e-4
29
+ UPDATE_EPOCHS: 15
30
+ NUM_MINIBATCHES: 4
31
+ GAMMA: 0.99
32
+ GAE_LAMBDA: 0.95
33
+ CLIP_EPS: 0.05
34
+ ENT_COEF: 0.01
35
+ VF_COEF: 0.5
36
+ MAX_GRAD_NORM: 1.0
37
+ ANNEAL_LR: true
teammate_generation/configs/algorithm/fcp/hanabi.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - fcp/_base_
3
+ - _self_
4
+
5
+ # Full 2-player Hanabi FCP config. Trains IPPO partners then ego.
6
+ # Hyperparameters aligned with JaxMARL Hanabi consensus.
7
+ #
8
+ # PARTNER_POP_SIZE=3 (not 10): FCP vmaps across pop size, so 10
9
+ # parallel IPPO instances with 658-dim obs OOMs on 48GB. 3 partners
10
+ # x 5 checkpoints = 15 total partners, enough for diversity.
11
+ TOTAL_TIMESTEPS: 1e9
12
+ PARTNER_POP_SIZE: 3
13
+ LR: 5e-4
14
+ NUM_ENVS: 32
15
+ UPDATE_EPOCHS: 4
16
+ NUM_MINIBATCHES: 4
17
+ CLIP_EPS: 0.2
18
+ ENT_COEF: 0.01
19
+ ANNEAL_LR: true
20
+ GAMMA: 0.999
21
+ GAE_LAMBDA: 0.95
22
+ MAX_GRAD_NORM: 0.5
23
+ ego_train_algorithm:
24
+ TOTAL_TIMESTEPS: 1e9
25
+ LR: 5e-4
26
+ ENT_COEF: 0.01
27
+ CLIP_EPS: 0.2
28
+ ANNEAL_LR: true
29
+ UPDATE_EPOCHS: 4
30
+ GAMMA: 0.999
31
+ GAE_LAMBDA: 0.95
32
+ MAX_GRAD_NORM: 0.5
teammate_generation/configs/algorithm/fcp/lbf/lbf_12x12.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - fcp/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 1e6
6
+ LR: .0001
7
+ NUM_ENVS: 8
8
+ UPDATE_EPOCHS: 15
9
+ NUM_MINIBATCHES: 4
10
+ CLIP_EPS: 0.03
11
+ ENT_COEF: 0.01
12
+ ego_train_algorithm:
13
+ TOTAL_TIMESTEPS: 3e7
14
+ LR: 1e-4
15
+ ENT_COEF: 0.01
16
+ CLIP_EPS: 0.05
17
+
teammate_generation/configs/algorithm/fcp/lbf/lbf_7x7_nolevels.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - fcp/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 1e6
6
+ LR: .0001
7
+ NUM_ENVS: 8
8
+ UPDATE_EPOCHS: 15
9
+ NUM_MINIBATCHES: 4
10
+ CLIP_EPS: 0.03
11
+ ENT_COEF: 0.01
12
+ ego_train_algorithm:
13
+ TOTAL_TIMESTEPS: 3e7
14
+ LR: 1e-4
15
+ ENT_COEF: 0.01
16
+ CLIP_EPS: 0.05
17
+
teammate_generation/configs/algorithm/fcp/mini-hanabi.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - fcp/_base_
3
+ - _self_
4
+
5
+ # Mini-Hanabi (3c/3r/hand3) FCP config.
6
+ TOTAL_TIMESTEPS: 1e8
7
+ LR: 5e-4
8
+ NUM_ENVS: 128
9
+ UPDATE_EPOCHS: 4
10
+ NUM_MINIBATCHES: 4
11
+ CLIP_EPS: 0.2
12
+ ENT_COEF: 0.01
13
+ ANNEAL_LR: true
14
+ GAMMA: 0.999
15
+ GAE_LAMBDA: 0.95
16
+ MAX_GRAD_NORM: 0.5
17
+ ego_train_algorithm:
18
+ TOTAL_TIMESTEPS: 1e8
19
+ LR: 5e-4
20
+ ENT_COEF: 0.01
21
+ CLIP_EPS: 0.2
22
+ ANNEAL_LR: true
23
+ UPDATE_EPOCHS: 4
24
+ GAMMA: 0.999
25
+ GAE_LAMBDA: 0.95
26
+ MAX_GRAD_NORM: 0.5
teammate_generation/configs/algorithm/fcp/overcooked-v1/asymm_advantages.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - fcp/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 2e6
6
+ LR: .0001
7
+ NUM_ENVS: 8
8
+ UPDATE_EPOCHS: 15
9
+ NUM_MINIBATCHES: 16
10
+ CLIP_EPS: 0.3
11
+ ENT_COEF: 0.01
12
+ ego_train_algorithm:
13
+ TOTAL_TIMESTEPS: 3e7
14
+ LR: 1e-4
15
+ ENT_COEF: 0.01
16
+ CLIP_EPS: 0.05
17
+
teammate_generation/configs/algorithm/fcp/overcooked-v1/coord_ring.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - fcp/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 4e6
6
+ LR: 1e-3
7
+ NUM_ENVS: 8
8
+ UPDATE_EPOCHS: 15
9
+ NUM_MINIBATCHES: 16
10
+ CLIP_EPS: 0.1
11
+ ENT_COEF: 0.05
12
+ ego_train_algorithm:
13
+ TOTAL_TIMESTEPS: 6e7
14
+ LR: 1e-3
15
+ ENT_COEF: 0.01
16
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/fcp/overcooked-v1/counter_circuit.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - fcp/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 4e6
6
+ LR: 1e-3
7
+ NUM_ENVS: 8
8
+ UPDATE_EPOCHS: 15
9
+ NUM_MINIBATCHES: 16
10
+ CLIP_EPS: 0.1
11
+ ENT_COEF: 0.05
12
+ ego_train_algorithm:
13
+ TOTAL_TIMESTEPS: 6e7
14
+ LR: 1e-3
15
+ ENT_COEF: 0.01
16
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/fcp/overcooked-v1/cramped_room.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - fcp/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 2e6
6
+ LR: .0001
7
+ NUM_ENVS: 8
8
+ UPDATE_EPOCHS: 15
9
+ NUM_MINIBATCHES: 16
10
+ CLIP_EPS: 0.2
11
+ ENT_COEF: 0.01
12
+ ego_train_algorithm:
13
+ TOTAL_TIMESTEPS: 3e7
14
+ LR: 1e-4
15
+ ENT_COEF: 0.01
16
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/fcp/overcooked-v1/forced_coord.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - fcp/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 4e6
6
+ LR: 1e-3
7
+ NUM_ENVS: 8
8
+ UPDATE_EPOCHS: 15
9
+ NUM_MINIBATCHES: 16
10
+ CLIP_EPS: 0.1
11
+ ENT_COEF: 0.05
12
+ ego_train_algorithm:
13
+ TOTAL_TIMESTEPS: 6e7
14
+ LR: 1e-3
15
+ ENT_COEF: 0.01
16
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/lbrdiv/_base_.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package algorithm
2
+ # ^ tells hydra to place these value directly under algorithm key
3
+ ALG: lbrdiv
4
+ TOTAL_TIMESTEPS: 4.5e7 # divided among each pair of BR and Conf agents
5
+ NUM_CHECKPOINTS: 5
6
+ PARTNER_POP_SIZE: 4
7
+ NUM_ENVS: 64
8
+ TOLERANCE_FACTOR: 0.1 # require that SP - XP > TOLERANCE_FACTOR
9
+ LAGRANGE_LR: 0.01 # specific to L-BRDiv
10
+ LR: 1e-4
11
+ UPDATE_EPOCHS: 15
12
+ NUM_MINIBATCHES: 4
13
+ GAMMA: 0.99
14
+ GAE_LAMBDA: 0.95
15
+ CLIP_EPS: 0.05
16
+ ENT_COEF: 0.01
17
+ VF_COEF: 0.5
18
+ MAX_GRAD_NORM: 1.0
19
+ ANNEAL_LR: false
20
+ ego_train_algorithm:
21
+ EGO_ACTOR_TYPE: s5
22
+ S5_D_MODEL: 16
23
+ S5_SSM_SIZE: 16
24
+ S5_ACTOR_CRITIC_HIDDEN_DIM: 64
25
+ FC_N_LAYERS: 2
26
+ TOTAL_TIMESTEPS: 1e7
27
+ NUM_CHECKPOINTS: 5
28
+ NUM_ENVS: 8
29
+ LR: 1e-4
30
+ UPDATE_EPOCHS: 15
31
+ NUM_MINIBATCHES: 4
32
+ GAMMA: 0.99
33
+ GAE_LAMBDA: 0.95
34
+ CLIP_EPS: 0.05
35
+ ENT_COEF: 0.01
36
+ VF_COEF: 0.5
37
+ MAX_GRAD_NORM: 1.0
38
+ ANNEAL_LR: true
teammate_generation/configs/algorithm/lbrdiv/hanabi.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - lbrdiv/_base_
3
+ - _self_
4
+
5
+ TOTAL_TIMESTEPS: 5e8
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 128
8
+ LR: 5e-4
9
+ UPDATE_EPOCHS: 4
10
+ NUM_MINIBATCHES: 4
11
+ CLIP_EPS: 0.2
12
+ ENT_COEF: 0.01
13
+ ANNEAL_LR: true
14
+ GAMMA: 0.999
15
+ GAE_LAMBDA: 0.95
16
+ MAX_GRAD_NORM: 0.5
17
+ ego_train_algorithm:
18
+ TOTAL_TIMESTEPS: 1e8
19
+ LR: 5e-4
20
+ ENT_COEF: 0.01
21
+ CLIP_EPS: 0.2
22
+ ANNEAL_LR: true
23
+ UPDATE_EPOCHS: 4
24
+ GAMMA: 0.999
25
+ GAE_LAMBDA: 0.95
26
+ MAX_GRAD_NORM: 0.5
teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_12x12.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - lbrdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 4.5e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 64
8
+ LR: 5e-4
9
+ UPDATE_EPOCHS: 15
10
+ NUM_MINIBATCHES: 4
11
+ CLIP_EPS: 0.05
12
+ ENT_COEF: 0.01
13
+ ego_train_algorithm:
14
+ TOTAL_TIMESTEPS: 3e7
15
+ LR: 1e-4
16
+ ENT_COEF: 0.01
17
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_7x7_nolevels.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - lbrdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 4.5e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 64
8
+ LR: 5e-4
9
+ UPDATE_EPOCHS: 15
10
+ NUM_MINIBATCHES: 4
11
+ CLIP_EPS: 0.05
12
+ ENT_COEF: 0.01
13
+ ego_train_algorithm:
14
+ TOTAL_TIMESTEPS: 3e7
15
+ LR: 1e-4
16
+ ENT_COEF: 0.01
17
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/lbrdiv/mini-hanabi.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - lbrdiv/_base_
3
+ - _self_
4
+
5
+ # Mini-Hanabi (3c/3r/hand3) LBRDiv config.
6
+ TOTAL_TIMESTEPS: 1e8
7
+ PARTNER_POP_SIZE: 3
8
+ NUM_ENVS: 128
9
+ LR: 5e-4
10
+ UPDATE_EPOCHS: 4
11
+ NUM_MINIBATCHES: 4
12
+ CLIP_EPS: 0.2
13
+ ENT_COEF: 0.01
14
+ ANNEAL_LR: true
15
+ GAMMA: 0.999
16
+ GAE_LAMBDA: 0.95
17
+ MAX_GRAD_NORM: 0.5
18
+ ego_train_algorithm:
19
+ TOTAL_TIMESTEPS: 1e8
20
+ LR: 5e-4
21
+ ENT_COEF: 0.01
22
+ CLIP_EPS: 0.2
23
+ ANNEAL_LR: true
24
+ UPDATE_EPOCHS: 4
25
+ GAMMA: 0.999
26
+ GAE_LAMBDA: 0.95
27
+ MAX_GRAD_NORM: 0.5
teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/asymm_advantages.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - lbrdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 4.5e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 64
8
+ TOLERANCE_FACTOR: 10.0 # require that SP - XP > TOLERANCE_FACTOR
9
+ LR: .0001
10
+ UPDATE_EPOCHS: 15
11
+ NUM_MINIBATCHES: 16
12
+ CLIP_EPS: 0.3
13
+ ENT_COEF: 0.01
14
+ ego_train_algorithm:
15
+ TOTAL_TIMESTEPS: 3e7
16
+ LR: 1e-4
17
+ ENT_COEF: 0.01
18
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/coord_ring.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - lbrdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 9e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 128
8
+ TOLERANCE_FACTOR: 10.0 # require that SP - XP > TOLERANCE_FACTOR
9
+ LR: 5e-4
10
+ UPDATE_EPOCHS: 15
11
+ NUM_MINIBATCHES: 4
12
+ CLIP_EPS: 0.1
13
+ ENT_COEF: 0.05
14
+ ego_train_algorithm:
15
+ TOTAL_TIMESTEPS: 6e7
16
+ LR: 1e-3
17
+ ENT_COEF: 0.01
18
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/counter_circuit.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - lbrdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 9e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 128
8
+ TOLERANCE_FACTOR: 10.0 # require that SP - XP > TOLERANCE_FACTOR
9
+ LR: 1e-3
10
+ UPDATE_EPOCHS: 15
11
+ NUM_MINIBATCHES: 8
12
+ CLIP_EPS: 0.01
13
+ ENT_COEF: 0.05
14
+ ego_train_algorithm:
15
+ TOTAL_TIMESTEPS: 6e7
16
+ LR: 1e-3
17
+ ENT_COEF: 0.01
18
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/cramped_room.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - lbrdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 4.5e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 64
8
+ TOLERANCE_FACTOR: 10.0 # require that SP - XP > TOLERANCE_FACTOR
9
+ LR: 1e-4
10
+ UPDATE_EPOCHS: 15
11
+ NUM_MINIBATCHES: 16
12
+ CLIP_EPS: 0.05
13
+ ENT_COEF: 0.01
14
+ ego_train_algorithm:
15
+ TOTAL_TIMESTEPS: 3e7
16
+ LR: 1e-4
17
+ ENT_COEF: 0.01
18
+ CLIP_EPS: 0.05
teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/forced_coord.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - lbrdiv/_base_
3
+ - _self_ # values from this file override the values from the base file
4
+
5
+ TOTAL_TIMESTEPS: 9e7
6
+ PARTNER_POP_SIZE: 3
7
+ NUM_ENVS: 128
8
+ TOLERANCE_FACTOR: 5.0 # require that SP - XP > TOLERANCE_FACTOR
9
+ LR: 5e-4
10
+ UPDATE_EPOCHS: 15
11
+ NUM_MINIBATCHES: 16
12
+ CLIP_EPS: 0.05
13
+ ENT_COEF: 0.01
14
+ ego_train_algorithm:
15
+ TOTAL_TIMESTEPS: 6e7
16
+ LR: 1e-3
17
+ ENT_COEF: 0.01
18
+ CLIP_EPS: 0.05
teammate_generation/configs/base_config_teammate.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - task: lbf/lbf_7x7_nolevels # task configs
3
+ - algorithm@algorithm: fcp/${task} # task-specific algorithm configs
4
+ - hydra: hydra_simple
5
+ - ../../evaluation/configs/global_heldout_settings
6
+ - _self_
7
+
8
+ ENV_NAME: ${task.ENV_NAME}
9
+ ENV_KWARGS: ${task.ENV_KWARGS}
10
+ ROLLOUT_LENGTH: ${task.ROLLOUT_LENGTH}
11
+ TASK_NAME: ${task.TASK_NAME}
12
+
13
+ # training settings
14
+ train_ego: true # whether to train the ego agent
15
+ run_heldout_eval: true # whether to run a heldout evaluation of the ego agent
16
+
17
+ # teammate generation settings
18
+ algorithm:
19
+ NUM_EVAL_EPISODES: 20 # used during training
20
+ TRAIN_SEED: 20374 # 112358 # 20374
21
+ NUM_SEEDS: 1
22
+ ENV_NAME: ${ENV_NAME}
23
+ ENV_KWARGS: ${ENV_KWARGS}
24
+ ROLLOUT_LENGTH: ${ROLLOUT_LENGTH}
25
+ # ego training settings
26
+ ego_train_algorithm:
27
+ NUM_EGO_TRAIN_SEEDS: 1 # per seed of teammate generation
28
+ NUM_EVAL_EPISODES: 20
29
+ TRAIN_SEED: 204829
30
+ ENV_NAME: ${ENV_NAME}
31
+ ENV_KWARGS: ${ENV_KWARGS}
32
+ ROLLOUT_LENGTH: ${ROLLOUT_LENGTH}
33
+
34
+ label: "default_label"
35
+ name: ${TASK_NAME}/${algorithm.ALG}/${label}
36
+
37
+ # wandb settings
38
+ logger:
39
+ project: aht-benchmark
40
+ entity: aht-project
41
+ tags:
42
+ - ${algorithm.ALG}
43
+ - ${TASK_NAME}
44
+ - seed=${algorithm.TRAIN_SEED}
45
+ - ${label}
46
+ mode: offline # options: online, offline, disabled
47
+ verbose: false
48
+ log_train_out: true # whether to log the out dictionary
49
+ log_eval_out: true # whether to log the eval metrics
50
+
51
+ # Local logger
52
+ local_logger:
53
+ save_train_out: true
54
+ save_eval_out: true
teammate_generation/configs/hydra/hydra_simple.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ job:
2
+ chdir: true
3
+ run:
4
+ dir: results/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
5
+ sweep:
6
+ dir: results_sweep/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
7
+ subdir: ${run.seed}
teammate_generation/configs/task/hanabi.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hanabi: teammate generation task config.
2
+ # Mirrors ego_agent_training/configs/task/hanabi.yaml because
3
+ # teammate_generation methods (FCP, BRDiv, LBRDiv, CoMeDi) call into
4
+ # ego_agent_training as a subroutine, which asserts num_agents == 2.
5
+ # Hanabi is natively 2-player so this is satisfied by default.
6
+ ENV_NAME: hanabi
7
+ ROLLOUT_LENGTH: 128
8
+ ENV_KWARGS:
9
+ num_agents: 2
10
+ num_colors: 5
11
+ num_ranks: 5
12
+ hand_size: 5
13
+ max_info_tokens: 8
14
+ max_life_tokens: 3
15
+ num_cards_of_rank: [3, 2, 2, 2, 1]
16
+ TASK_NAME: hanabi
teammate_generation/configs/task/lbf/lbf_12x12.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ENV_NAME: lbf
2
+ ROLLOUT_LENGTH: 128
3
+ ENV_KWARGS:
4
+ grid_size: 12
5
+ num_food: 6
6
+ different_levels: true
7
+ TASK_NAME: lbf/lbf_12x12
teammate_generation/configs/task/lbf/lbf_7x7_nolevels.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ENV_NAME: lbf
2
+ ROLLOUT_LENGTH: 128
3
+ ENV_KWARGS: {}
4
+ TASK_NAME: lbf/lbf_7x7_nolevels
teammate_generation/configs/task/mini-hanabi.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mini-Hanabi: teammate generation task config.
2
+ # Mirrors ego_agent_training/configs/task/mini-hanabi.yaml.
3
+ ENV_NAME: hanabi
4
+ ROLLOUT_LENGTH: 128
5
+ ENV_KWARGS:
6
+ num_agents: 2
7
+ num_colors: 3
8
+ num_ranks: 3
9
+ hand_size: 3
10
+ max_info_tokens: 5
11
+ max_life_tokens: 3
12
+ num_cards_of_rank: [2, 2, 1]
13
+ TASK_NAME: mini-hanabi