Spaces:
Running
Running
upload teammate_generation for ckpt-eval support
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- teammate_generation/BRDiv.py +832 -0
- teammate_generation/CoMeDi.py +1161 -0
- teammate_generation/LBRDiv.py +1098 -0
- teammate_generation/__init__.py +0 -0
- teammate_generation/configs/algorithm/brdiv/_base_.yaml +40 -0
- teammate_generation/configs/algorithm/brdiv/hanabi.yaml +27 -0
- teammate_generation/configs/algorithm/brdiv/lbf/lbf_12x12.yaml +18 -0
- teammate_generation/configs/algorithm/brdiv/lbf/lbf_7x7_nolevels.yaml +18 -0
- teammate_generation/configs/algorithm/brdiv/mini-hanabi.yaml +28 -0
- teammate_generation/configs/algorithm/brdiv/overcooked-v1/asymm_advantages.yaml +18 -0
- teammate_generation/configs/algorithm/brdiv/overcooked-v1/coord_ring.yaml +18 -0
- teammate_generation/configs/algorithm/brdiv/overcooked-v1/counter_circuit.yaml +18 -0
- teammate_generation/configs/algorithm/brdiv/overcooked-v1/cramped_room.yaml +21 -0
- teammate_generation/configs/algorithm/brdiv/overcooked-v1/forced_coord.yaml +18 -0
- teammate_generation/configs/algorithm/comedi/_base_.yaml +36 -0
- teammate_generation/configs/algorithm/comedi/hanabi.yaml +26 -0
- teammate_generation/configs/algorithm/comedi/lbf/lbf_12x12.yaml +18 -0
- teammate_generation/configs/algorithm/comedi/lbf/lbf_7x7_nolevels.yaml +18 -0
- teammate_generation/configs/algorithm/comedi/mini-hanabi.yaml +27 -0
- teammate_generation/configs/algorithm/comedi/overcooked-v1/asymm_advantages.yaml +16 -0
- teammate_generation/configs/algorithm/comedi/overcooked-v1/coord_ring.yaml +16 -0
- teammate_generation/configs/algorithm/comedi/overcooked-v1/counter_circuit.yaml +16 -0
- teammate_generation/configs/algorithm/comedi/overcooked-v1/cramped_room.yaml +17 -0
- teammate_generation/configs/algorithm/comedi/overcooked-v1/forced_coord.yaml +16 -0
- teammate_generation/configs/algorithm/fcp/_base_.yaml +37 -0
- teammate_generation/configs/algorithm/fcp/hanabi.yaml +32 -0
- teammate_generation/configs/algorithm/fcp/lbf/lbf_12x12.yaml +17 -0
- teammate_generation/configs/algorithm/fcp/lbf/lbf_7x7_nolevels.yaml +17 -0
- teammate_generation/configs/algorithm/fcp/mini-hanabi.yaml +26 -0
- teammate_generation/configs/algorithm/fcp/overcooked-v1/asymm_advantages.yaml +17 -0
- teammate_generation/configs/algorithm/fcp/overcooked-v1/coord_ring.yaml +16 -0
- teammate_generation/configs/algorithm/fcp/overcooked-v1/counter_circuit.yaml +16 -0
- teammate_generation/configs/algorithm/fcp/overcooked-v1/cramped_room.yaml +16 -0
- teammate_generation/configs/algorithm/fcp/overcooked-v1/forced_coord.yaml +16 -0
- teammate_generation/configs/algorithm/lbrdiv/_base_.yaml +38 -0
- teammate_generation/configs/algorithm/lbrdiv/hanabi.yaml +26 -0
- teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_12x12.yaml +17 -0
- teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_7x7_nolevels.yaml +17 -0
- teammate_generation/configs/algorithm/lbrdiv/mini-hanabi.yaml +27 -0
- teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/asymm_advantages.yaml +18 -0
- teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/coord_ring.yaml +18 -0
- teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/counter_circuit.yaml +18 -0
- teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/cramped_room.yaml +18 -0
- teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/forced_coord.yaml +18 -0
- teammate_generation/configs/base_config_teammate.yaml +54 -0
- teammate_generation/configs/hydra/hydra_simple.yaml +7 -0
- teammate_generation/configs/task/hanabi.yaml +16 -0
- teammate_generation/configs/task/lbf/lbf_12x12.yaml +7 -0
- teammate_generation/configs/task/lbf/lbf_7x7_nolevels.yaml +4 -0
- teammate_generation/configs/task/mini-hanabi.yaml +13 -0
teammate_generation/BRDiv.py
ADDED
|
@@ -0,0 +1,832 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''Implementation of the BRDiv teammate generation algorithm (Rahman et al., TMLR 2023)
|
| 2 |
+
https://arxiv.org/abs/2207.14138
|
| 3 |
+
|
| 4 |
+
Command to run BRDiv only on LBF:
|
| 5 |
+
python teammate_generation/run.py algorithm=brdiv/lbf/lbf_7x7_nolevels task=lbf/lbf_7x7_nolevels label=test_brdiv run_heldout_eval=false train_ego=false
|
| 6 |
+
|
| 7 |
+
Limitations: does not support recurrent actors.
|
| 8 |
+
'''
|
| 9 |
+
import shutil
|
| 10 |
+
import time
|
| 11 |
+
import logging
|
| 12 |
+
from typing import NamedTuple
|
| 13 |
+
from functools import partial
|
| 14 |
+
|
| 15 |
+
import hydra
|
| 16 |
+
import jax
|
| 17 |
+
import jax.numpy as jnp
|
| 18 |
+
import numpy as np
|
| 19 |
+
import optax
|
| 20 |
+
from flax.training.train_state import TrainState
|
| 21 |
+
import wandb
|
| 22 |
+
|
| 23 |
+
from agents.mlp_actor_critic_agent import ActorWithConditionalCriticPolicy
|
| 24 |
+
from agents.population_interface import AgentPopulation
|
| 25 |
+
from common.plot_utils import get_metric_names
|
| 26 |
+
from common.run_episodes import run_episodes
|
| 27 |
+
from common.save_load_utils import save_train_run
|
| 28 |
+
from envs import make_env
|
| 29 |
+
from envs.log_wrapper import LogWrapper
|
| 30 |
+
from marl.ppo_utils import unbatchify, _create_minibatches
|
| 31 |
+
|
| 32 |
+
log = logging.getLogger(__name__)
|
| 33 |
+
logging.basicConfig(level=logging.INFO)
|
| 34 |
+
|
| 35 |
+
class XPTransition(NamedTuple):
|
| 36 |
+
done: jnp.ndarray
|
| 37 |
+
action: jnp.ndarray
|
| 38 |
+
value: jnp.ndarray
|
| 39 |
+
self_onehot_id: jnp.ndarray
|
| 40 |
+
oppo_onehot_id: jnp.ndarray
|
| 41 |
+
reward: jnp.ndarray
|
| 42 |
+
log_prob: jnp.ndarray
|
| 43 |
+
obs: jnp.ndarray
|
| 44 |
+
info: jnp.ndarray
|
| 45 |
+
avail_actions: jnp.ndarray
|
| 46 |
+
|
| 47 |
+
def _get_all_ids(pop_size):
|
| 48 |
+
cross_product = np.meshgrid(
|
| 49 |
+
np.arange(pop_size),
|
| 50 |
+
np.arange(pop_size)
|
| 51 |
+
)
|
| 52 |
+
agent_id_cartesian_product = np.stack([g.ravel() for g in cross_product], axis=-1)
|
| 53 |
+
all_conf_ids = agent_id_cartesian_product[:, 1]
|
| 54 |
+
all_br_ids = agent_id_cartesian_product[:, 0]
|
| 55 |
+
return all_conf_ids, all_br_ids
|
| 56 |
+
|
| 57 |
+
def gather_params(partner_params_pytree, idx_vec):
|
| 58 |
+
"""
|
| 59 |
+
partner_params_pytree: pytree with all partner params. Each leaf has shape (n_seeds, m_ckpts, ...).
|
| 60 |
+
idx_vec: a vector of indices with shape (num_envs,) each in [0, n_seeds*m_ckpts).
|
| 61 |
+
|
| 62 |
+
Return a new pytree where each leaf has shape (num_envs, ...). Each leaf has a sampled
|
| 63 |
+
partner's parameters for each environment.
|
| 64 |
+
"""
|
| 65 |
+
# We'll define a function that gathers from each leaf
|
| 66 |
+
# where leaf has shape (n_seeds, m_ckpts, ...), we want [idx_vec[i]] for each i.
|
| 67 |
+
# We'll vmap a slicing function.
|
| 68 |
+
def gather_leaf(leaf):
|
| 69 |
+
def slice_one(idx):
|
| 70 |
+
return leaf[idx] # shape (...)
|
| 71 |
+
return jax.vmap(slice_one)(idx_vec)
|
| 72 |
+
|
| 73 |
+
return jax.tree.map(gather_leaf, partner_params_pytree)
|
| 74 |
+
|
| 75 |
+
def train_brdiv_partners(train_rng, env, config, conf_policy, br_policy):
|
| 76 |
+
num_agents = env.num_agents
|
| 77 |
+
assert num_agents == 2, "This code assumes the environment has exactly 2 agents."
|
| 78 |
+
|
| 79 |
+
# Define different minibatch sizes for interactions with ego agent and one with BR agent
|
| 80 |
+
config["NUM_GAME_AGENTS"] = num_agents
|
| 81 |
+
config["NUM_CONF_ACTORS"] = config["NUM_ENVS"]
|
| 82 |
+
config["NUM_BR_ACTORS"] = config["NUM_ENVS"]
|
| 83 |
+
config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // (config["ROLLOUT_LENGTH"] * config["NUM_ENVS"])
|
| 84 |
+
|
| 85 |
+
def make_brdiv_agents(config):
|
| 86 |
+
def linear_schedule(count):
|
| 87 |
+
frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
|
| 88 |
+
return config["LR"] * frac
|
| 89 |
+
|
| 90 |
+
def train(rng):
|
| 91 |
+
rng, init_conf_rng, init_br_rng = jax.random.split(rng, 3)
|
| 92 |
+
all_conf_init_rngs = jax.random.split(init_conf_rng, config["PARTNER_POP_SIZE"])
|
| 93 |
+
all_br_init_rngs = jax.random.split(init_br_rng, config["PARTNER_POP_SIZE"])
|
| 94 |
+
identity_matrix = jnp.eye(config["PARTNER_POP_SIZE"])
|
| 95 |
+
|
| 96 |
+
init_conf_hstate = conf_policy.init_hstate(config["NUM_CONF_ACTORS"])
|
| 97 |
+
init_br_hstate = br_policy.init_hstate(config["NUM_BR_ACTORS"])
|
| 98 |
+
|
| 99 |
+
def init_train_states(rng_agents, rng_brs):
|
| 100 |
+
def init_single_pair_optimizers(rng_agent, rng_br):
|
| 101 |
+
init_params_conf = conf_policy.init_params(rng_agent)
|
| 102 |
+
init_params_br = br_policy.init_params(rng_br)
|
| 103 |
+
return init_params_conf, init_params_br
|
| 104 |
+
|
| 105 |
+
init_all_networks_and_optimizers = jax.vmap(init_single_pair_optimizers)
|
| 106 |
+
all_conf_params, all_br_params = init_all_networks_and_optimizers(rng_agents, rng_brs)
|
| 107 |
+
|
| 108 |
+
# Define optimizers for both confederate and BR policy
|
| 109 |
+
tx = optax.chain(
|
| 110 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 111 |
+
optax.adam(learning_rate=linear_schedule if config["ANNEAL_LR"] else config["LR"],
|
| 112 |
+
eps=1e-5),
|
| 113 |
+
)
|
| 114 |
+
tx_br = optax.chain(
|
| 115 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 116 |
+
optax.adam(learning_rate=linear_schedule if config["ANNEAL_LR"] else config["LR"],
|
| 117 |
+
eps=1e-5),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
train_state_conf = TrainState.create(
|
| 121 |
+
apply_fn=conf_policy.network.apply,
|
| 122 |
+
params=all_conf_params,
|
| 123 |
+
tx=tx,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
train_state_br = TrainState.create(
|
| 127 |
+
apply_fn=br_policy.network.apply,
|
| 128 |
+
params=all_br_params,
|
| 129 |
+
tx=tx_br,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return train_state_conf, train_state_br
|
| 133 |
+
|
| 134 |
+
all_conf_optims, all_br_optims = init_train_states(
|
| 135 |
+
all_conf_init_rngs, all_br_init_rngs
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def forward_pass_conf(params, obs, id, done, avail_actions, hstate, rng):
|
| 139 |
+
act, val, pi, new_hstate = conf_policy.get_action_value_policy(
|
| 140 |
+
params=params,
|
| 141 |
+
obs=obs[jnp.newaxis, ...],
|
| 142 |
+
done=done[jnp.newaxis, ...],
|
| 143 |
+
avail_actions=avail_actions,
|
| 144 |
+
hstate=hstate,
|
| 145 |
+
rng=rng,
|
| 146 |
+
aux_obs=id[jnp.newaxis, ...]
|
| 147 |
+
)
|
| 148 |
+
return act, val, pi, new_hstate
|
| 149 |
+
|
| 150 |
+
def forward_pass_br(params, obs, id, done, avail_actions, hstate, rng):
|
| 151 |
+
act, val, pi, new_hstate = br_policy.get_action_value_policy(
|
| 152 |
+
params=params,
|
| 153 |
+
obs=obs[jnp.newaxis, ...],
|
| 154 |
+
done=done[jnp.newaxis, ...],
|
| 155 |
+
avail_actions=avail_actions,
|
| 156 |
+
hstate=hstate,
|
| 157 |
+
rng=rng,
|
| 158 |
+
aux_obs=id[jnp.newaxis, ...]
|
| 159 |
+
)
|
| 160 |
+
return act, val, pi, new_hstate
|
| 161 |
+
|
| 162 |
+
def _env_step(runner_state, unused):
|
| 163 |
+
"""
|
| 164 |
+
agent_0 = confederate, agent_1 = br
|
| 165 |
+
Returns updated runner_state, and Transitions for agent_0 and agent_1
|
| 166 |
+
"""
|
| 167 |
+
(
|
| 168 |
+
all_train_state_conf, all_train_state_br, last_conf_ids, last_br_ids,
|
| 169 |
+
env_state, last_obs, last_done, last_conf_h, last_br_h, rng
|
| 170 |
+
) = runner_state
|
| 171 |
+
rng, act0_rng, act1_rng, step_rng, conf_sampling_rng, br_sampling_rng = jax.random.split(rng, 6)
|
| 172 |
+
|
| 173 |
+
# For done envs, resample both conf and brs
|
| 174 |
+
needs_resample = last_done["__all__"]
|
| 175 |
+
resampled_conf_ids = jax.random.randint(conf_sampling_rng, (config["NUM_CONF_ACTORS"],), 0, config["PARTNER_POP_SIZE"])
|
| 176 |
+
resampled_br_ids = jax.random.randint(br_sampling_rng, (config["NUM_BR_ACTORS"],), 0, config["PARTNER_POP_SIZE"])
|
| 177 |
+
|
| 178 |
+
# Determine final indices based on whether resampling was needed for each env
|
| 179 |
+
updated_conf_ids = jnp.where(
|
| 180 |
+
needs_resample,
|
| 181 |
+
resampled_conf_ids, # Use newly sampled index if True
|
| 182 |
+
last_conf_ids # Else, keep index from previous step
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
updated_br_ids = jnp.where(
|
| 186 |
+
needs_resample,
|
| 187 |
+
resampled_br_ids, # Use newly sampled index if True
|
| 188 |
+
last_br_ids # Else, keep index from previous step
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Reset the hidden states for resampled conf and br if they are not None
|
| 192 |
+
# WARNING: BRDiv was not tested with recurrent actors, so the code for if the hstate is not None may not work
|
| 193 |
+
if last_conf_h is not None:
|
| 194 |
+
updated_conf_h = jnp.where(
|
| 195 |
+
needs_resample,
|
| 196 |
+
init_conf_hstate,
|
| 197 |
+
last_conf_h
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
updated_conf_h = last_conf_h
|
| 201 |
+
|
| 202 |
+
if last_br_h is not None:
|
| 203 |
+
updated_br_h = jnp.where(
|
| 204 |
+
needs_resample,
|
| 205 |
+
init_br_hstate,
|
| 206 |
+
last_br_h
|
| 207 |
+
)
|
| 208 |
+
else:
|
| 209 |
+
updated_br_h = last_br_h
|
| 210 |
+
|
| 211 |
+
# Get the corresponding conf and br params
|
| 212 |
+
updated_conf_params = gather_params(all_train_state_conf.params, updated_conf_ids)
|
| 213 |
+
updated_br_params = gather_params(all_train_state_br.params, updated_br_ids)
|
| 214 |
+
|
| 215 |
+
updated_conf_onehot_ids = identity_matrix[updated_conf_ids]
|
| 216 |
+
updated_br_onehot_ids = identity_matrix[updated_br_ids]
|
| 217 |
+
|
| 218 |
+
# Get available actions for agent 0 from environment state
|
| 219 |
+
avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
|
| 220 |
+
avail_actions = jax.lax.stop_gradient(avail_actions)
|
| 221 |
+
avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32)
|
| 222 |
+
avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32)
|
| 223 |
+
|
| 224 |
+
# Agent_0 action
|
| 225 |
+
act0_rng = jax.random.split(act0_rng, config["NUM_ENVS"])
|
| 226 |
+
act_0, val_0, pi_0, new_conf_h = jax.vmap(forward_pass_conf)(updated_conf_params,
|
| 227 |
+
last_obs["agent_0"], updated_br_onehot_ids, last_done["agent_0"], avail_actions_0,
|
| 228 |
+
updated_conf_h, act0_rng)
|
| 229 |
+
logp_0 = pi_0.log_prob(act_0)
|
| 230 |
+
act_0, val_0, logp_0 = act_0.squeeze(), val_0.squeeze(), logp_0.squeeze()
|
| 231 |
+
|
| 232 |
+
# Agent_1 action
|
| 233 |
+
act1_rng = jax.random.split(act1_rng, config["NUM_ENVS"])
|
| 234 |
+
act_1, val_1, pi_1, new_br_h = jax.vmap(forward_pass_br)(updated_br_params,
|
| 235 |
+
last_obs["agent_1"], updated_conf_onehot_ids, last_done["agent_1"], avail_actions_1,
|
| 236 |
+
updated_br_h, act1_rng)
|
| 237 |
+
logp_1 = pi_1.log_prob(act_1)
|
| 238 |
+
act_1, val_1, logp_1 = act_1.squeeze(), val_1.squeeze(), logp_1.squeeze()
|
| 239 |
+
|
| 240 |
+
# Combine actions into the env format
|
| 241 |
+
combined_actions = jnp.concatenate([act_0, act_1], axis=0)
|
| 242 |
+
env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents)
|
| 243 |
+
env_act = {k: v.flatten() for k, v in env_act.items()}
|
| 244 |
+
|
| 245 |
+
# Step env
|
| 246 |
+
step_rngs = jax.random.split(step_rng, config["NUM_ENVS"])
|
| 247 |
+
obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))(
|
| 248 |
+
step_rngs, env_state, env_act
|
| 249 |
+
)
|
| 250 |
+
# note that num_actors = num_envs * num_agents
|
| 251 |
+
info_0 = jax.tree.map(lambda x: x[:, 0], info)
|
| 252 |
+
info_1 = jax.tree.map(lambda x: x[:, 1], info)
|
| 253 |
+
|
| 254 |
+
def _compute_rewards(conf_id, br_id, agent_rew):
|
| 255 |
+
return jax.lax.cond(jnp.equal(
|
| 256 |
+
jnp.argmax(conf_id, axis=-1), jnp.argmax(br_id, axis=-1)
|
| 257 |
+
),
|
| 258 |
+
lambda x: x,
|
| 259 |
+
lambda x: -x,
|
| 260 |
+
agent_rew
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
agent_0_rews = jax.vmap(_compute_rewards)(updated_conf_onehot_ids, updated_br_onehot_ids, reward["agent_1"])
|
| 264 |
+
agent_1_rews = jax.vmap(_compute_rewards)(updated_conf_onehot_ids, updated_br_onehot_ids, reward["agent_0"])
|
| 265 |
+
|
| 266 |
+
# Store agent_0 data in transition
|
| 267 |
+
transition_0 = XPTransition(
|
| 268 |
+
done=done["agent_0"],
|
| 269 |
+
action=act_0,
|
| 270 |
+
value=val_0,
|
| 271 |
+
self_onehot_id=updated_conf_onehot_ids,
|
| 272 |
+
oppo_onehot_id=updated_br_onehot_ids,
|
| 273 |
+
reward=agent_0_rews,
|
| 274 |
+
log_prob=logp_0,
|
| 275 |
+
obs=last_obs["agent_0"],
|
| 276 |
+
info=info_0,
|
| 277 |
+
avail_actions=avail_actions_0
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
transition_1 = XPTransition(
|
| 281 |
+
done=done["agent_1"],
|
| 282 |
+
action=act_1,
|
| 283 |
+
value=val_1,
|
| 284 |
+
self_onehot_id=updated_br_onehot_ids,
|
| 285 |
+
oppo_onehot_id=updated_conf_onehot_ids,
|
| 286 |
+
reward=agent_1_rews,
|
| 287 |
+
log_prob=logp_1,
|
| 288 |
+
obs=last_obs["agent_1"],
|
| 289 |
+
info=info_1,
|
| 290 |
+
avail_actions=avail_actions_1
|
| 291 |
+
)
|
| 292 |
+
new_runner_state = (all_train_state_conf, all_train_state_br, updated_conf_ids, updated_br_ids,
|
| 293 |
+
env_state_next, obs_next, done, new_conf_h, new_br_h, rng)
|
| 294 |
+
return new_runner_state, (transition_0, transition_1)
|
| 295 |
+
|
| 296 |
+
def _calculate_gae(traj_batch, last_val):
|
| 297 |
+
def _get_advantages(gae_and_next_value, transition):
|
| 298 |
+
gae, next_value = gae_and_next_value
|
| 299 |
+
done, value, reward = (
|
| 300 |
+
transition.done,
|
| 301 |
+
transition.value,
|
| 302 |
+
transition.reward,
|
| 303 |
+
)
|
| 304 |
+
delta = reward + config["GAMMA"] * next_value * (1 - done) - value
|
| 305 |
+
gae = (
|
| 306 |
+
delta
|
| 307 |
+
+ config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
|
| 308 |
+
)
|
| 309 |
+
return (gae, value), gae
|
| 310 |
+
|
| 311 |
+
_, advantages = jax.lax.scan(
|
| 312 |
+
_get_advantages,
|
| 313 |
+
(jnp.zeros_like(last_val), last_val),
|
| 314 |
+
traj_batch,
|
| 315 |
+
reverse=True,
|
| 316 |
+
unroll=16,
|
| 317 |
+
)
|
| 318 |
+
return advantages, advantages + traj_batch.value
|
| 319 |
+
|
| 320 |
+
def run_all_episodes(rng, train_state_conf, train_state_br):
|
| 321 |
+
conf_ids, br_ids = _get_all_ids(config["PARTNER_POP_SIZE"])
|
| 322 |
+
gathered_conf_model_params = gather_params(train_state_conf.params, conf_ids)
|
| 323 |
+
gathered_br_model_params = gather_params(train_state_br.params, br_ids)
|
| 324 |
+
|
| 325 |
+
rng, eval_rng = jax.random.split(rng)
|
| 326 |
+
def run_episodes_fixed_rng(conf_param, br_param):
|
| 327 |
+
return run_episodes(
|
| 328 |
+
eval_rng, env,
|
| 329 |
+
conf_param, conf_policy,
|
| 330 |
+
br_param, br_policy,
|
| 331 |
+
config["ROLLOUT_LENGTH"], config["NUM_EVAL_EPISODES"],
|
| 332 |
+
)
|
| 333 |
+
ep_infos = jax.vmap(run_episodes_fixed_rng)(
|
| 334 |
+
gathered_conf_model_params, gathered_br_model_params, # leaves where shape is (pop_size*pop_size, ...)
|
| 335 |
+
)
|
| 336 |
+
return ep_infos
|
| 337 |
+
|
| 338 |
+
def _update_epoch(update_state, unused):
|
| 339 |
+
def _update_minbatch(all_train_states, all_data):
|
| 340 |
+
train_state_conf, train_state_br = all_train_states
|
| 341 |
+
minbatch_conf, minbatch_br = all_data
|
| 342 |
+
|
| 343 |
+
def _loss_fn(param, agent_policy, minbatch, agent_id):
|
| 344 |
+
'''Compute loss for agent corresponding to agent_id.
|
| 345 |
+
'''
|
| 346 |
+
init_hstate, traj_batch, gae, target_v = minbatch
|
| 347 |
+
# get policy and value of confederate versus ego and best response agents respectively
|
| 348 |
+
squeezed_param = jax.tree.map(lambda x: jnp.squeeze(x, 0), param)
|
| 349 |
+
_, value, pi, _ = agent_policy.get_action_value_policy(
|
| 350 |
+
params=squeezed_param,
|
| 351 |
+
obs=traj_batch.obs,
|
| 352 |
+
done=traj_batch.done,
|
| 353 |
+
avail_actions=traj_batch.avail_actions,
|
| 354 |
+
hstate=init_hstate,
|
| 355 |
+
rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
|
| 356 |
+
aux_obs=traj_batch.oppo_onehot_id
|
| 357 |
+
)
|
| 358 |
+
log_prob = pi.log_prob(traj_batch.action)
|
| 359 |
+
|
| 360 |
+
is_relevant = jnp.equal(
|
| 361 |
+
jnp.argmax(traj_batch.self_onehot_id, axis=-1),
|
| 362 |
+
agent_id
|
| 363 |
+
)
|
| 364 |
+
loss_weights = jnp.where(is_relevant, 1, 0).astype(jnp.float32)
|
| 365 |
+
|
| 366 |
+
# Value loss
|
| 367 |
+
value_pred_clipped = traj_batch.value + (
|
| 368 |
+
value - traj_batch.value
|
| 369 |
+
).clip(
|
| 370 |
+
-config["CLIP_EPS"], config["CLIP_EPS"])
|
| 371 |
+
value_losses = jnp.square(value - target_v)
|
| 372 |
+
value_losses_clipped = jnp.square(value_pred_clipped - target_v)
|
| 373 |
+
value_loss = jax.lax.cond(
|
| 374 |
+
loss_weights.sum() == 0,
|
| 375 |
+
lambda x: jnp.zeros_like(x).astype(jnp.float32),
|
| 376 |
+
lambda x: x,
|
| 377 |
+
(loss_weights * jnp.maximum(value_losses, value_losses_clipped)).sum() / (loss_weights.sum() + 1e-8)
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
n = config["PARTNER_POP_SIZE"]
|
| 381 |
+
# Apply different loss weights for SP and XP data
|
| 382 |
+
# Loss weights consist of two parts: the first term is the weighting from the BRDiv loss fucntion
|
| 383 |
+
# The second term is a reweighting term to compensate for the data collection process, which uniformly and independently
|
| 384 |
+
# samples the conf and br ids from 1, ..., n, resulting in P(SP) = 1/n and P(XP) = (n-1)/n.
|
| 385 |
+
# To prevent the XP loss term from dominating the SP loss term, we would like P(SP) = P(XP) = 1/2.
|
| 386 |
+
# Thus, we set the 2nd term of the SP weight to n/2, and the 2nd term of the XP weight to n/(2 * (n-1)).
|
| 387 |
+
|
| 388 |
+
is_sp = jnp.equal(jnp.argmax(traj_batch.self_onehot_id, axis=-1), jnp.argmax(traj_batch.oppo_onehot_id, axis=-1))
|
| 389 |
+
sp_weight = (1 + 2*config["XP_LOSS_WEIGHTS"]) * (n/2)
|
| 390 |
+
xp_weight = config["XP_LOSS_WEIGHTS"] * (n / (2 * (n-1)))
|
| 391 |
+
actor_weights = jnp.where(is_sp, sp_weight, xp_weight)
|
| 392 |
+
|
| 393 |
+
# Policy gradient loss
|
| 394 |
+
ratio = jnp.exp(log_prob - traj_batch.log_prob)
|
| 395 |
+
gae_norm = (gae - gae.mean()) / (gae.std() + 1e-8)
|
| 396 |
+
pg_loss_1 = ratio * gae_norm * actor_weights
|
| 397 |
+
pg_loss_2 = jnp.clip(
|
| 398 |
+
ratio,
|
| 399 |
+
1.0 - config["CLIP_EPS"],
|
| 400 |
+
1.0 + config["CLIP_EPS"]) * gae_norm * actor_weights
|
| 401 |
+
pg_loss = jax.lax.cond(
|
| 402 |
+
loss_weights.sum() == 0,
|
| 403 |
+
lambda x: jnp.zeros_like(x).astype(jnp.float32),
|
| 404 |
+
lambda x: x,
|
| 405 |
+
-(
|
| 406 |
+
loss_weights*jnp.minimum(pg_loss_1, pg_loss_2)
|
| 407 |
+
).sum()/(loss_weights.sum() + 1e-8)
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# Entropy
|
| 411 |
+
entropy = jax.lax.cond(
|
| 412 |
+
loss_weights.sum() == 0,
|
| 413 |
+
lambda x: jnp.zeros_like(x).astype(jnp.float32),
|
| 414 |
+
lambda x: x,
|
| 415 |
+
(loss_weights * pi.entropy()).sum()/(loss_weights.sum() + 1e-8)
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
total_loss = pg_loss + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy
|
| 419 |
+
return total_loss, (value_loss, pg_loss, entropy)
|
| 420 |
+
|
| 421 |
+
possible_agent_ids = jnp.expand_dims(jnp.arange(config["PARTNER_POP_SIZE"]), 1)
|
| 422 |
+
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
| 423 |
+
|
| 424 |
+
def gather_conf_params_and_return_grads(agent_id):
|
| 425 |
+
param_vector = gather_params(train_state_conf.params, agent_id)
|
| 426 |
+
(loss_val_conf, aux_vals_conf), grads_conf = grad_fn(
|
| 427 |
+
param_vector, conf_policy, minbatch_conf, agent_id
|
| 428 |
+
)
|
| 429 |
+
return (loss_val_conf, aux_vals_conf), grads_conf
|
| 430 |
+
|
| 431 |
+
def gather_br_params_and_return_grads(agent_id):
|
| 432 |
+
param_vector = gather_params(train_state_br.params, agent_id)
|
| 433 |
+
(loss_val_br, aux_vals_br), grads_br = grad_fn(
|
| 434 |
+
param_vector, br_policy, minbatch_br, agent_id
|
| 435 |
+
)
|
| 436 |
+
return (loss_val_br, aux_vals_br), grads_br
|
| 437 |
+
|
| 438 |
+
(loss_val_conf, aux_vals_conf), grads_conf = jax.vmap(gather_conf_params_and_return_grads)(possible_agent_ids)
|
| 439 |
+
(loss_val_br, aux_vals_br), grads_br = jax.vmap(gather_br_params_and_return_grads)(possible_agent_ids)
|
| 440 |
+
|
| 441 |
+
grads_conf_new = jax.tree.map(lambda x: jnp.squeeze(x, 1), grads_conf)
|
| 442 |
+
grads_br_new = jax.tree.map(lambda x: jnp.squeeze(x, 1), grads_br)
|
| 443 |
+
train_state_conf = train_state_conf.apply_gradients(grads=grads_conf_new)
|
| 444 |
+
train_state_br = train_state_br.apply_gradients(grads=grads_br_new)
|
| 445 |
+
return (train_state_conf, train_state_br), ((loss_val_conf, aux_vals_conf), (loss_val_br, aux_vals_br))
|
| 446 |
+
|
| 447 |
+
(
|
| 448 |
+
train_state_conf, train_state_br,
|
| 449 |
+
traj_batch_conf, traj_batch_br,
|
| 450 |
+
advantages_conf, advantages_br,
|
| 451 |
+
targets_conf, targets_br,
|
| 452 |
+
rng
|
| 453 |
+
) = update_state
|
| 454 |
+
rng, perm_rng_conf, perm_rng_br = jax.random.split(rng, 3)
|
| 455 |
+
|
| 456 |
+
minibatches_conf = _create_minibatches(traj_batch_conf, advantages_conf, targets_conf, init_conf_hstate,
|
| 457 |
+
config["NUM_CONF_ACTORS"], config["NUM_MINIBATCHES"], perm_rng_conf)
|
| 458 |
+
minibatches_br = _create_minibatches(traj_batch_br, advantages_br, targets_br, init_br_hstate,
|
| 459 |
+
config["NUM_BR_ACTORS"], config["NUM_MINIBATCHES"], perm_rng_br)
|
| 460 |
+
|
| 461 |
+
# Update both policies
|
| 462 |
+
(train_state_conf, train_state_br), all_losses = jax.lax.scan(
|
| 463 |
+
_update_minbatch, (train_state_conf, train_state_br), (minibatches_conf, minibatches_br)
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
update_state = (train_state_conf, train_state_br,
|
| 467 |
+
traj_batch_conf, traj_batch_br,
|
| 468 |
+
advantages_conf, advantages_br,
|
| 469 |
+
targets_conf, targets_br,
|
| 470 |
+
rng
|
| 471 |
+
)
|
| 472 |
+
return update_state, all_losses
|
| 473 |
+
|
| 474 |
+
def _update_step(update_runner_state, unused):
|
| 475 |
+
"""
|
| 476 |
+
1. Collect rollouts
|
| 477 |
+
2. Compute advantage
|
| 478 |
+
3. PPO updates
|
| 479 |
+
"""
|
| 480 |
+
(
|
| 481 |
+
all_train_state_conf, all_train_state_br,
|
| 482 |
+
last_env_state, last_obs, last_done, last_conf_h, last_br_h,
|
| 483 |
+
rng, update_steps
|
| 484 |
+
) = update_runner_state
|
| 485 |
+
|
| 486 |
+
rng, conf_sampling_rng, br_sampling_rng = jax.random.split(rng, 3)
|
| 487 |
+
|
| 488 |
+
conf_ids = jax.random.randint(conf_sampling_rng, (config["NUM_ENVS"],), 0, config["PARTNER_POP_SIZE"])
|
| 489 |
+
br_ids = jax.random.randint(br_sampling_rng, (config["NUM_ENVS"],), 0, config["PARTNER_POP_SIZE"])
|
| 490 |
+
|
| 491 |
+
runner_state = (
|
| 492 |
+
all_train_state_conf, all_train_state_br, conf_ids, br_ids,
|
| 493 |
+
last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng
|
| 494 |
+
)
|
| 495 |
+
runner_state, traj_batch = jax.lax.scan(
|
| 496 |
+
_env_step, runner_state, None, config["ROLLOUT_LENGTH"])
|
| 497 |
+
(all_train_state_conf, all_train_state_br, last_conf_ids, last_br_ids,
|
| 498 |
+
last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng) = runner_state
|
| 499 |
+
|
| 500 |
+
# Get the last conf and br params and ids
|
| 501 |
+
last_conf_params = gather_params(all_train_state_conf.params, last_conf_ids)
|
| 502 |
+
last_br_params = gather_params(all_train_state_br.params, last_br_ids)
|
| 503 |
+
|
| 504 |
+
last_conf_one_hots = identity_matrix[last_conf_ids]
|
| 505 |
+
last_br_one_hots = identity_matrix[last_br_ids]
|
| 506 |
+
|
| 507 |
+
# Get agent 0 and agent 1 trajectories from interaction between conf policy and its BR policy.
|
| 508 |
+
traj_batch_conf, traj_batch_br = traj_batch
|
| 509 |
+
|
| 510 |
+
# Compute advantage for confederate agent from interaction with br policy
|
| 511 |
+
avail_actions_0 = jax.vmap(env.get_avail_actions)(last_env_state.env_state)["agent_0"].astype(jnp.float32)
|
| 512 |
+
_, last_val_conf, _, _ = jax.vmap(forward_pass_conf)(
|
| 513 |
+
params=last_conf_params,
|
| 514 |
+
obs=last_obs["agent_0"],
|
| 515 |
+
id=last_br_one_hots,
|
| 516 |
+
done=last_done["agent_0"],
|
| 517 |
+
avail_actions=avail_actions_0,
|
| 518 |
+
hstate=last_conf_h,
|
| 519 |
+
rng=jax.random.split(jax.random.PRNGKey(0), config["NUM_ENVS"]) # Dummy key since we're just extracting the value
|
| 520 |
+
)
|
| 521 |
+
last_val_conf = last_val_conf.squeeze()
|
| 522 |
+
advantages_conf, targets_conf = _calculate_gae(traj_batch_conf, last_val_conf)
|
| 523 |
+
|
| 524 |
+
# Compute advantage for br policy from interaction with confederate agent
|
| 525 |
+
avail_actions_1 = jax.vmap(env.get_avail_actions)(last_env_state.env_state)["agent_1"].astype(jnp.float32)
|
| 526 |
+
_, last_val_br, _, _ = jax.vmap(forward_pass_br)(
|
| 527 |
+
params=last_br_params,
|
| 528 |
+
obs=last_obs["agent_1"],
|
| 529 |
+
id=last_conf_one_hots,
|
| 530 |
+
done=last_done["agent_1"],
|
| 531 |
+
avail_actions=avail_actions_1,
|
| 532 |
+
hstate=last_br_h,
|
| 533 |
+
rng=jax.random.split(jax.random.PRNGKey(0), config["NUM_ENVS"]) # Dummy key since we're just extracting the value
|
| 534 |
+
)
|
| 535 |
+
last_val_br = last_val_br.squeeze()
|
| 536 |
+
advantages_br, targets_br = _calculate_gae(traj_batch_br, last_val_br)
|
| 537 |
+
|
| 538 |
+
# 3) PPO update
|
| 539 |
+
rng, update_rng = jax.random.split(rng, 2)
|
| 540 |
+
update_state = (
|
| 541 |
+
all_train_state_conf, all_train_state_br,
|
| 542 |
+
traj_batch_conf, traj_batch_br,
|
| 543 |
+
advantages_conf, advantages_br,
|
| 544 |
+
targets_conf, targets_br,
|
| 545 |
+
update_rng
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
update_state, all_losses = jax.lax.scan(
|
| 549 |
+
_update_epoch, update_state, None, config["UPDATE_EPOCHS"])
|
| 550 |
+
all_train_state_conf, all_train_state_br = update_state[:2]
|
| 551 |
+
(_, (value_loss_conf, pg_loss_conf, entropy_conf)), (_, (value_loss_br, pg_loss_br, entropy_br)) = all_losses
|
| 552 |
+
|
| 553 |
+
# Metrics
|
| 554 |
+
def mask_and_mean(x, mask):
|
| 555 |
+
return jnp.where(mask, x, 0).sum() / jnp.maximum(1, mask.sum())
|
| 556 |
+
|
| 557 |
+
mask = traj_batch_conf.info.get("returned_episode", jnp.ones_like(traj_batch_conf.reward))
|
| 558 |
+
metric = jax.tree.map(lambda x: mask_and_mean(x, mask), traj_batch_conf.info)
|
| 559 |
+
metric["update_steps"] = update_steps
|
| 560 |
+
metric["value_loss_conf_agent"] = value_loss_conf.mean(axis=(0, 1))
|
| 561 |
+
metric["value_loss_br_agent"] = value_loss_br.mean(axis=(0, 1))
|
| 562 |
+
|
| 563 |
+
metric["pg_loss_conf_agent"] = pg_loss_conf.mean(axis=(0, 1))
|
| 564 |
+
metric["pg_loss_br_agent"] = pg_loss_br.mean(axis=(0, 1))
|
| 565 |
+
|
| 566 |
+
metric["entropy_conf"] = entropy_conf.mean(axis=(0, 1))
|
| 567 |
+
metric["entropy_br"] = entropy_br.mean(axis=(0, 1))
|
| 568 |
+
|
| 569 |
+
new_runner_state = (
|
| 570 |
+
all_train_state_conf, all_train_state_br,
|
| 571 |
+
last_env_state, last_obs, last_done, last_conf_h, last_br_h,
|
| 572 |
+
rng, update_steps + 1
|
| 573 |
+
)
|
| 574 |
+
return (new_runner_state, metric)
|
| 575 |
+
|
| 576 |
+
# --------------------------
|
| 577 |
+
# PPO Update and Checkpoint saving
|
| 578 |
+
# --------------------------
|
| 579 |
+
ckpt_and_eval_interval = config["NUM_UPDATES"] // max(1, config["NUM_CHECKPOINTS"] - 1) # -1 because we store a ckpt at the last update
|
| 580 |
+
num_ckpts = config["NUM_CHECKPOINTS"]
|
| 581 |
+
|
| 582 |
+
# Build a PyTree that holds parameters for all conf agent checkpoints
|
| 583 |
+
def init_ckpt_array(params_pytree):
|
| 584 |
+
return jax.tree.map(
|
| 585 |
+
lambda x: jnp.zeros((num_ckpts,) + x.shape, x.dtype),
|
| 586 |
+
params_pytree)
|
| 587 |
+
|
| 588 |
+
def _update_step_with_ckpt(state_with_ckpt, unused):
|
| 589 |
+
(update_runner_state, checkpoint_array_conf, checkpoint_array_br, ckpt_idx,
|
| 590 |
+
eval_info) = state_with_ckpt
|
| 591 |
+
|
| 592 |
+
# Single PPO update
|
| 593 |
+
new_runner_state, metric = _update_step(update_runner_state, None)
|
| 594 |
+
|
| 595 |
+
train_state_conf, train_state_br, last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng, update_steps = new_runner_state
|
| 596 |
+
|
| 597 |
+
# Decide if we store a checkpoint
|
| 598 |
+
# update steps is 1-indexed because it was incremented at the end of the update step
|
| 599 |
+
to_store = jnp.logical_or(jnp.equal(jnp.mod(update_steps-1, ckpt_and_eval_interval), 0),
|
| 600 |
+
jnp.equal(update_steps, config["NUM_UPDATES"]))
|
| 601 |
+
|
| 602 |
+
def store_and_eval_ckpt(args):
|
| 603 |
+
ckpt_arr_and_ep_infos, rng, cidx = args
|
| 604 |
+
ckpt_arr_conf, ckpt_arr_br, _ = ckpt_arr_and_ep_infos
|
| 605 |
+
new_ckpt_arr_conf = jax.tree.map(
|
| 606 |
+
lambda c_arr, p: c_arr.at[cidx].set(p),
|
| 607 |
+
ckpt_arr_conf, train_state_conf.params
|
| 608 |
+
)
|
| 609 |
+
new_ckpt_arr_br = jax.tree.map(
|
| 610 |
+
lambda c_arr, p: c_arr.at[cidx].set(p),
|
| 611 |
+
ckpt_arr_br, train_state_br.params
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
rng, eval_rng = jax.random.split(rng)
|
| 615 |
+
ep_last_info = jax.tree.map(lambda x: x.mean(axis=(-2, -1)),
|
| 616 |
+
run_all_episodes(eval_rng, train_state_conf, train_state_br))
|
| 617 |
+
|
| 618 |
+
return ((new_ckpt_arr_conf, new_ckpt_arr_br, ep_last_info), rng, cidx + 1)
|
| 619 |
+
|
| 620 |
+
def skip_ckpt(args):
|
| 621 |
+
return args
|
| 622 |
+
|
| 623 |
+
(checkpoint_array_and_infos, rng, ckpt_idx) = jax.lax.cond(
|
| 624 |
+
to_store,
|
| 625 |
+
store_and_eval_ckpt,
|
| 626 |
+
skip_ckpt,
|
| 627 |
+
((checkpoint_array_conf, checkpoint_array_br, eval_info), rng, ckpt_idx)
|
| 628 |
+
)
|
| 629 |
+
checkpoint_array_conf, checkpoint_array_br, eval_ep_last_info = checkpoint_array_and_infos
|
| 630 |
+
|
| 631 |
+
metric["eval_ep_last_info"] = eval_ep_last_info # return of confederate
|
| 632 |
+
|
| 633 |
+
return ((train_state_conf, train_state_br,
|
| 634 |
+
last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng, update_steps),
|
| 635 |
+
checkpoint_array_conf, checkpoint_array_br, ckpt_idx,
|
| 636 |
+
eval_ep_last_info), metric
|
| 637 |
+
|
| 638 |
+
# Initialize checkpoint array
|
| 639 |
+
checkpoint_array_conf = init_ckpt_array(all_conf_optims.params)
|
| 640 |
+
checkpoint_array_br = init_ckpt_array(all_br_optims.params)
|
| 641 |
+
ckpt_idx = 0
|
| 642 |
+
|
| 643 |
+
# Initialize state for scan over _update_step_with_ckpt
|
| 644 |
+
update_steps = 0
|
| 645 |
+
|
| 646 |
+
rng, rng_eval = jax.random.split(rng, 2)
|
| 647 |
+
eval_ep_last_info = jax.tree.map(lambda x: x.mean(axis=(-2, -1)),
|
| 648 |
+
run_all_episodes(rng_eval, all_conf_optims, all_br_optims))
|
| 649 |
+
|
| 650 |
+
# Initialize environment
|
| 651 |
+
rng, reset_rng = jax.random.split(rng)
|
| 652 |
+
reset_rngs = jax.random.split(reset_rng, config["NUM_ENVS"])
|
| 653 |
+
init_obs, init_env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rngs)
|
| 654 |
+
init_done = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]}
|
| 655 |
+
|
| 656 |
+
# Initialize conf and br hstates
|
| 657 |
+
init_conf_h = conf_policy.init_hstate(config["NUM_CONF_ACTORS"])
|
| 658 |
+
init_br_h = br_policy.init_hstate(config["NUM_BR_ACTORS"])
|
| 659 |
+
|
| 660 |
+
update_runner_state = (
|
| 661 |
+
all_conf_optims, all_br_optims,
|
| 662 |
+
init_env_state, init_obs, init_done, init_conf_h, init_br_h,
|
| 663 |
+
rng, update_steps
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
state_with_ckpt = (
|
| 667 |
+
update_runner_state, checkpoint_array_conf,
|
| 668 |
+
checkpoint_array_br, ckpt_idx, eval_ep_last_info
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# run training
|
| 672 |
+
state_with_ckpt, metrics = jax.lax.scan(
|
| 673 |
+
_update_step_with_ckpt,
|
| 674 |
+
state_with_ckpt,
|
| 675 |
+
xs=None,
|
| 676 |
+
length=config["NUM_UPDATES"]
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
(
|
| 680 |
+
final_runner_state, checkpoint_array_conf, checkpoint_array_br,
|
| 681 |
+
final_ckpt_idx, all_ep_infos
|
| 682 |
+
) = state_with_ckpt
|
| 683 |
+
|
| 684 |
+
out = {
|
| 685 |
+
"final_params_conf": final_runner_state[0].params,
|
| 686 |
+
"final_params_br": final_runner_state[1].params,
|
| 687 |
+
"checkpoints_conf": checkpoint_array_conf,
|
| 688 |
+
"checkpoints_br": checkpoint_array_br,
|
| 689 |
+
"metrics": metrics, # metrics is from the perspective of the confederate agent (averaged over population)
|
| 690 |
+
"all_pair_returns": all_ep_infos
|
| 691 |
+
}
|
| 692 |
+
return out
|
| 693 |
+
|
| 694 |
+
return train
|
| 695 |
+
# ------------------------------
|
| 696 |
+
# Actually run the adversarial teammate training
|
| 697 |
+
# ------------------------------
|
| 698 |
+
train_fn = make_brdiv_agents(config)
|
| 699 |
+
out = train_fn(train_rng)
|
| 700 |
+
return out
|
| 701 |
+
|
| 702 |
+
def get_brdiv_population(config, out, env):
|
| 703 |
+
'''
|
| 704 |
+
Get the partner params and partner population for ego training.
|
| 705 |
+
'''
|
| 706 |
+
brdiv_pop_size = config["algorithm"]["PARTNER_POP_SIZE"]
|
| 707 |
+
|
| 708 |
+
# partner_params has shape (num_seeds, brdiv_pop_size, ...)
|
| 709 |
+
partner_params = out['final_params_conf']
|
| 710 |
+
|
| 711 |
+
partner_policy = ActorWithConditionalCriticPolicy(
|
| 712 |
+
action_dim=env.action_space(env.agents[1]).n,
|
| 713 |
+
obs_dim=env.observation_space(env.agents[1]).shape[0],
|
| 714 |
+
pop_size=brdiv_pop_size, # used to create onehot agent id
|
| 715 |
+
activation=config["algorithm"].get("ACTIVATION", "tanh")
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# Create partner population
|
| 719 |
+
partner_population = AgentPopulation(
|
| 720 |
+
pop_size=brdiv_pop_size,
|
| 721 |
+
policy_cls=partner_policy
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
return partner_params, partner_population
|
| 725 |
+
|
| 726 |
+
def run_brdiv(config, wandb_logger):
|
| 727 |
+
algorithm_config = dict(config["algorithm"])
|
| 728 |
+
|
| 729 |
+
env = make_env(algorithm_config["ENV_NAME"], algorithm_config["ENV_KWARGS"])
|
| 730 |
+
env = LogWrapper(env)
|
| 731 |
+
|
| 732 |
+
log.info("Starting BRDiv training...")
|
| 733 |
+
start = time.time()
|
| 734 |
+
|
| 735 |
+
# Generate multiple random seeds from the base seed
|
| 736 |
+
rng = jax.random.PRNGKey(algorithm_config["TRAIN_SEED"])
|
| 737 |
+
rngs = jax.random.split(rng, algorithm_config["NUM_SEEDS"])
|
| 738 |
+
|
| 739 |
+
# Initialize br and conf policies
|
| 740 |
+
conf_policy = ActorWithConditionalCriticPolicy(
|
| 741 |
+
action_dim=env.action_space(env.agents[0]).n,
|
| 742 |
+
obs_dim=env.observation_space(env.agents[0]).shape[0],
|
| 743 |
+
pop_size=algorithm_config["PARTNER_POP_SIZE"],
|
| 744 |
+
)
|
| 745 |
+
br_policy = ActorWithConditionalCriticPolicy(
|
| 746 |
+
action_dim=env.action_space(env.agents[0]).n,
|
| 747 |
+
obs_dim=env.observation_space(env.agents[0]).shape[0],
|
| 748 |
+
pop_size=algorithm_config["PARTNER_POP_SIZE"],
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
# Create a vmapped version of train_brdiv_partners
|
| 752 |
+
with jax.disable_jit(False):
|
| 753 |
+
vmapped_train_fn = jax.jit(
|
| 754 |
+
jax.vmap(
|
| 755 |
+
partial(train_brdiv_partners, env=env, config=algorithm_config, conf_policy=conf_policy, br_policy=br_policy)
|
| 756 |
+
)
|
| 757 |
+
)
|
| 758 |
+
out = vmapped_train_fn(rngs)
|
| 759 |
+
|
| 760 |
+
end = time.time()
|
| 761 |
+
log.info(f"BRDiv training complete in {end - start} seconds")
|
| 762 |
+
|
| 763 |
+
metric_names = get_metric_names(algorithm_config["ENV_NAME"])
|
| 764 |
+
log_metrics(config, out, wandb_logger, metric_names)
|
| 765 |
+
|
| 766 |
+
partner_params, partner_population = get_brdiv_population(config, out, env)
|
| 767 |
+
|
| 768 |
+
return partner_params, partner_population
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
def log_metrics(config, outs, logger, metric_names: tuple):
|
| 772 |
+
metrics = outs["metrics"]
|
| 773 |
+
# metrics now has shape (num_seeds, num_updates, pop_size)
|
| 774 |
+
num_seeds, num_updates, pop_size = metrics["pg_loss_conf_agent"].shape # number of trained pairs
|
| 775 |
+
|
| 776 |
+
### Log evaluation metrics
|
| 777 |
+
# we plot XP return curves separately from SP return curves
|
| 778 |
+
# shape (num_seeds, num_updates, (pop_size)^2) [pre-scalarized: mean over eval eps and agents taken inside scan]
|
| 779 |
+
all_returns = np.asarray(metrics["eval_ep_last_info"]["returned_episode_returns"])
|
| 780 |
+
xs = list(range(num_updates))
|
| 781 |
+
|
| 782 |
+
all_conf_ids, all_br_ids = _get_all_ids(pop_size)
|
| 783 |
+
sp_mask = (all_conf_ids == all_br_ids)
|
| 784 |
+
sp_returns = all_returns[:, :, sp_mask]
|
| 785 |
+
xp_returns = all_returns[:, :, ~sp_mask]
|
| 786 |
+
|
| 787 |
+
# Average over seeds and agent pairs (eval episodes and agents already averaged inside scan)
|
| 788 |
+
sp_return_curve = sp_returns.mean(axis=(0, 2))
|
| 789 |
+
xp_return_curve = xp_returns.mean(axis=(0, 2))
|
| 790 |
+
|
| 791 |
+
for step in range(num_updates):
|
| 792 |
+
logger.log_item("Eval/AvgSPReturnCurve", sp_return_curve[step], train_step=step)
|
| 793 |
+
logger.log_item("Eval/AvgXPReturnCurve", xp_return_curve[step], train_step=step)
|
| 794 |
+
logger.commit()
|
| 795 |
+
|
| 796 |
+
# log final XP matrix to wandb - average over seeds
|
| 797 |
+
last_returns_array = all_returns[:, -1].mean(axis=0)
|
| 798 |
+
last_returns_array = np.reshape(last_returns_array, (pop_size, pop_size))
|
| 799 |
+
logger.log_xp_matrix("Eval/LastXPMatrix", last_returns_array)
|
| 800 |
+
|
| 801 |
+
### Log population loss as multi-line plots, where each line is a different population member
|
| 802 |
+
# shape (num_seeds, num_updates, update_epochs, num_minibatches, pop_size)
|
| 803 |
+
# Average over seeds
|
| 804 |
+
processed_losses = {
|
| 805 |
+
"ConfPGLoss": np.asarray(metrics["pg_loss_conf_agent"]).mean(axis=0).transpose(),
|
| 806 |
+
"BRPGLoss": np.asarray(metrics["pg_loss_br_agent"]).mean(axis=0).transpose(),
|
| 807 |
+
"ConfValLoss": np.asarray(metrics["value_loss_conf_agent"]).mean(axis=0).transpose(),
|
| 808 |
+
"BRValLoss": np.asarray(metrics["value_loss_br_agent"]).mean(axis=0).transpose(),
|
| 809 |
+
"ConfEntropy": np.asarray(metrics["entropy_conf"]).mean(axis=0).transpose(),
|
| 810 |
+
"BREntropy": np.asarray(metrics["entropy_br"]).mean(axis=0).transpose(),
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
xs = list(range(num_updates))
|
| 814 |
+
keys = [f"pair {i}" for i in range(pop_size)]
|
| 815 |
+
for loss_name, loss_data in processed_losses.items():
|
| 816 |
+
if np.isnan(loss_data).any():
|
| 817 |
+
raise ValueError(f"Found nan in loss {loss_name}")
|
| 818 |
+
logger.log_item(f"Losses/{loss_name}",
|
| 819 |
+
wandb.plot.line_series(xs=xs, ys=loss_data, keys=keys,
|
| 820 |
+
title=loss_name, xname="train_step")
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
### Log artifacts
|
| 824 |
+
savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
|
| 825 |
+
# Save train run output and log to wandb as artifact
|
| 826 |
+
out_savepath = save_train_run(outs, savedir, savename="saved_train_run")
|
| 827 |
+
if config["logger"]["log_train_out"]:
|
| 828 |
+
logger.log_artifact(name="saved_train_run", path=out_savepath, type_name="train_run")
|
| 829 |
+
|
| 830 |
+
# Cleanup locally logged out files
|
| 831 |
+
if not config["local_logger"]["save_train_out"]:
|
| 832 |
+
shutil.rmtree(out_savepath)
|
teammate_generation/CoMeDi.py
ADDED
|
@@ -0,0 +1,1161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''Implementation of the CoMeDi teammate generation algorithm (Sarkar et al. NeurIPS 2023)
|
| 2 |
+
https://openreview.net/forum?id=MljeRycu9s
|
| 3 |
+
|
| 4 |
+
Command to run CoMeDi only on LBF:
|
| 5 |
+
python teammate_generation/run.py algorithm=comedi/lbf/lbf_7x7_nolevels task=lbf/lbf_7x7_nolevels label=test_comedi run_heldout_eval=false train_ego=false
|
| 6 |
+
|
| 7 |
+
Limitations: does not support recurrent actors.
|
| 8 |
+
'''
|
| 9 |
+
from functools import partial
|
| 10 |
+
import logging
|
| 11 |
+
import shutil
|
| 12 |
+
import time
|
| 13 |
+
from typing import NamedTuple
|
| 14 |
+
|
| 15 |
+
from flax.training.train_state import TrainState
|
| 16 |
+
import hydra
|
| 17 |
+
import jax
|
| 18 |
+
import jax.numpy as jnp
|
| 19 |
+
import numpy as np
|
| 20 |
+
import optax
|
| 21 |
+
import wandb
|
| 22 |
+
|
| 23 |
+
from agents.mlp_actor_critic_agent import ActorWithConditionalCriticPolicy
|
| 24 |
+
from agents.initialize_agents import initialize_actor_with_conditional_critic
|
| 25 |
+
from agents.population_interface import AgentPopulation
|
| 26 |
+
from agents.population_buffer import BufferedPopulation
|
| 27 |
+
from common.save_load_utils import save_train_run
|
| 28 |
+
from common.plot_utils import get_metric_names
|
| 29 |
+
from common.run_episodes import run_episodes
|
| 30 |
+
from envs import make_env
|
| 31 |
+
from envs.log_wrapper import LogWrapper, LogEnvState
|
| 32 |
+
from marl.ippo import make_train as make_ppo_train
|
| 33 |
+
from marl.ppo_utils import Transition, unbatchify, _create_minibatches
|
| 34 |
+
|
| 35 |
+
log = logging.getLogger(__name__)
|
| 36 |
+
logging.basicConfig(level=logging.INFO)
|
| 37 |
+
|
| 38 |
+
class ResetTransition(NamedTuple):
|
| 39 |
+
'''Stores extra information for resetting agents to a point in some trajectory.'''
|
| 40 |
+
env_state: LogEnvState
|
| 41 |
+
conf_obs: jnp.ndarray
|
| 42 |
+
partner_obs: jnp.ndarray
|
| 43 |
+
conf_done: jnp.ndarray
|
| 44 |
+
partner_done: jnp.ndarray
|
| 45 |
+
conf_hstate: jnp.ndarray
|
| 46 |
+
partner_hstate: jnp.ndarray
|
| 47 |
+
|
| 48 |
+
def train_comedi_partners(train_rng, wandb_logger, env, config):
|
| 49 |
+
num_agents = env.num_agents
|
| 50 |
+
assert num_agents == 2, "This code assumes the environment has exactly 2 agents."
|
| 51 |
+
|
| 52 |
+
# Define 4 types of rollouts: SP, XP, MP, MP2
|
| 53 |
+
config["NUM_GAME_AGENTS"] = num_agents
|
| 54 |
+
|
| 55 |
+
config["NUM_ACTORS"] = num_agents * config["NUM_ENVS"]
|
| 56 |
+
# Right now assume control of both agent and its BR
|
| 57 |
+
config["NUM_CONTROLLED_ACTORS"] = config["NUM_ACTORS"]
|
| 58 |
+
|
| 59 |
+
# Compute numbber of updates PER outermost iteration
|
| 60 |
+
# Calculate timesteps per update
|
| 61 |
+
# 1. Overhead from population selection rollouts
|
| 62 |
+
# We divide by 2 because for ease in Jax, this implementation uses a vmap over PARTNER_POP_SIZE to
|
| 63 |
+
# evaluate the agent generated at each outermost iteration against all previously
|
| 64 |
+
# generated agents, but a non-Jax implementation would only need to evaluate against
|
| 65 |
+
# *previously* generated agents.
|
| 66 |
+
selection_steps = config["PARTNER_POP_SIZE"] * config["NUM_ARGMAX_ROLLOUT_EPS"] * config["ROLLOUT_LENGTH"] // 2
|
| 67 |
+
# 2. Training rollouts: 4 distinct rollout phases (SP, XP, MP, MP2) each using NUM_ENVS
|
| 68 |
+
training_steps = 4 * config["ROLLOUT_LENGTH"] * config["NUM_ENVS"]
|
| 69 |
+
|
| 70 |
+
steps_per_update = selection_steps + training_steps
|
| 71 |
+
config["NUM_UPDATES"] = int(config["TOTAL_TIMESTEPS_PER_ITERATION"] // steps_per_update)
|
| 72 |
+
|
| 73 |
+
def make_comedi_agents(config):
|
| 74 |
+
def linear_schedule(count):
|
| 75 |
+
frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
|
| 76 |
+
return config["LR"] * frac
|
| 77 |
+
|
| 78 |
+
def train_init_ippo_partners(config, partner_rng, env):
|
| 79 |
+
'''
|
| 80 |
+
Train a pool IPPO agents w/parameter sharing.
|
| 81 |
+
Returns out, a dictionary of the model checkpoints, final parameters, and metrics.
|
| 82 |
+
'''
|
| 83 |
+
# POP_SIZE is referenced throughout the CoMeDi training loops
|
| 84 |
+
config["POP_SIZE"] = config["PARTNER_POP_SIZE"]
|
| 85 |
+
# Use a local copy for warmup-specific overrides to avoid
|
| 86 |
+
# mutating the shared config (ACTOR_TYPE, TOTAL_TIMESTEPS)
|
| 87 |
+
warmup_config = dict(config)
|
| 88 |
+
warmup_config["TOTAL_TIMESTEPS"] = config["TOTAL_TIMESTEPS_PER_ITERATION"]
|
| 89 |
+
warmup_config["ACTOR_TYPE"] = "pseudo_actor_with_conditional_critic"
|
| 90 |
+
out = make_ppo_train(warmup_config, env, wandb_logger)(partner_rng)
|
| 91 |
+
return out
|
| 92 |
+
|
| 93 |
+
def train(rng):
|
| 94 |
+
# Start by training a single PPO agent via self-play
|
| 95 |
+
rng, init_ppo_rng, init_conf_rng = jax.random.split(rng, 3)
|
| 96 |
+
|
| 97 |
+
init_ppo_partner = train_init_ippo_partners(config, init_ppo_rng, env)
|
| 98 |
+
|
| 99 |
+
# Initialize a population buffer
|
| 100 |
+
dummy_policy, dummy_init_params = initialize_actor_with_conditional_critic(config, env, init_conf_rng)
|
| 101 |
+
partner_population = BufferedPopulation(
|
| 102 |
+
max_pop_size=config["PARTNER_POP_SIZE"],
|
| 103 |
+
policy_cls=dummy_policy,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
population_buffer = partner_population.reset_buffer(dummy_init_params)
|
| 107 |
+
population_buffer = partner_population.add_agent(population_buffer, init_ppo_partner["final_params"])
|
| 108 |
+
|
| 109 |
+
def add_conf_policy(pop_buffer, func_input):
|
| 110 |
+
num_existing_agents, rng = func_input
|
| 111 |
+
rng, init_conf_rng = jax.random.split(rng)
|
| 112 |
+
|
| 113 |
+
# Create new confederate agent policy and critic
|
| 114 |
+
policy, init_params = initialize_actor_with_conditional_critic(
|
| 115 |
+
config, env, init_conf_rng
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Create a train_state and optimizer for the newly initialzied model
|
| 119 |
+
if config["ANNEAL_LR"]:
|
| 120 |
+
tx = optax.chain(
|
| 121 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 122 |
+
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
| 123 |
+
)
|
| 124 |
+
else:
|
| 125 |
+
tx = optax.chain(
|
| 126 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 127 |
+
optax.adam(config["LR"], eps=1e-5))
|
| 128 |
+
|
| 129 |
+
train_state = TrainState.create(
|
| 130 |
+
apply_fn=policy.network.apply,
|
| 131 |
+
params=init_params,
|
| 132 |
+
tx=tx,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Reset envs for SP, XP, and MP
|
| 136 |
+
rng, reset_rng_eval, reset_rng_sp, reset_rng_xp, reset_rng_mp, reset_rng_mp2 = jax.random.split(rng, 6)
|
| 137 |
+
|
| 138 |
+
reset_rngs_sps = jax.random.split(reset_rng_sp, config["NUM_ENVS"])
|
| 139 |
+
reset_rngs_xps = jax.random.split(reset_rng_xp, config["NUM_ENVS"])
|
| 140 |
+
reset_rngs_mps = jax.random.split(reset_rng_mp, config["NUM_ENVS"])
|
| 141 |
+
reset_rngs_mps2 = jax.random.split(reset_rng_mp2, config["NUM_ENVS"])
|
| 142 |
+
|
| 143 |
+
obsv_xp, env_state_xp = jax.vmap(env.reset, in_axes=(0,))(reset_rngs_sps)
|
| 144 |
+
obsv_sp, env_state_sp = jax.vmap(env.reset, in_axes=(0,))(reset_rngs_xps)
|
| 145 |
+
obsv_mp, env_state_mp = jax.vmap(env.reset, in_axes=(0,))(reset_rngs_mps)
|
| 146 |
+
obsv_mp2, env_state_mp2 = jax.vmap(env.reset, in_axes=(0,))(reset_rngs_mps2)
|
| 147 |
+
|
| 148 |
+
# build a pytree that can hold the parameters for all checkpoints.
|
| 149 |
+
ckpt_and_eval_interval = config["NUM_UPDATES"] // max(1, config["NUM_CHECKPOINTS"] - 1)
|
| 150 |
+
num_ckpts = config["NUM_CHECKPOINTS"]
|
| 151 |
+
def init_ckpt_array(params_pytree):
|
| 152 |
+
return jax.tree.map(
|
| 153 |
+
lambda x: jnp.zeros((num_ckpts,) + x.shape, x.dtype),
|
| 154 |
+
params_pytree
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# define evaluation function
|
| 158 |
+
rng, eval_rng = jax.random.split(rng, 2)
|
| 159 |
+
def per_id_run_episode_fixed_rng(agent0_param, agent1_id):
|
| 160 |
+
agent1_param = partner_population.gather_agent_params(pop_buffer,
|
| 161 |
+
agent_indices=agent1_id * jnp.ones((1,), dtype=np.int32))
|
| 162 |
+
agent1_param = jax.tree_map(lambda y: jnp.squeeze(y, 0), agent1_param)
|
| 163 |
+
all_outs = run_episodes(
|
| 164 |
+
rng=eval_rng, env=env,
|
| 165 |
+
agent_0_param=agent0_param, agent_0_policy=policy,
|
| 166 |
+
agent_1_param=agent1_param, agent_1_policy=policy,
|
| 167 |
+
max_episode_steps=config["ROLLOUT_LENGTH"],
|
| 168 |
+
num_eps=config["NUM_ARGMAX_ROLLOUT_EPS"]
|
| 169 |
+
)
|
| 170 |
+
return all_outs
|
| 171 |
+
|
| 172 |
+
def _update_step(update_with_ckpt_runner_state, unused):
|
| 173 |
+
update_runner_state, checkpoint_array, ckpt_idx = update_with_ckpt_runner_state
|
| 174 |
+
(
|
| 175 |
+
train_state, pop_buffer,
|
| 176 |
+
env_state_sp, obsv_sp,
|
| 177 |
+
env_state_xp, obsv_xp,
|
| 178 |
+
env_state_mp, obsv_mp,
|
| 179 |
+
env_state_mp2, obsv_mp2,
|
| 180 |
+
last_dones_xp,
|
| 181 |
+
last_dones_sp,
|
| 182 |
+
last_dones_mp,
|
| 183 |
+
last_dones_mp2,
|
| 184 |
+
rng, update_steps,
|
| 185 |
+
num_prev_trained_conf
|
| 186 |
+
) = update_runner_state
|
| 187 |
+
|
| 188 |
+
# Identify the expected returns from the newly trained policy
|
| 189 |
+
# when interacting with the previously generated confederate
|
| 190 |
+
# policies
|
| 191 |
+
valid_sampling_indices = jnp.arange(config["POP_SIZE"])
|
| 192 |
+
run_all_rollouts = jax.vmap(per_id_run_episode_fixed_rng, in_axes=(None, 0))(
|
| 193 |
+
train_state.params,valid_sampling_indices)
|
| 194 |
+
|
| 195 |
+
# Mask out the XP returns against invalid policies
|
| 196 |
+
# resulting from IDs that are yet set to a specific
|
| 197 |
+
# confederate params
|
| 198 |
+
all_mean_returns = run_all_rollouts["returned_episode_returns"][:, :, 0].mean(axis=-1)
|
| 199 |
+
masked_mean_returns = jnp.where(
|
| 200 |
+
valid_sampling_indices >= num_prev_trained_conf, -jnp.inf, all_mean_returns
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Pick the right confederate params to act as the XP agent
|
| 204 |
+
max_means_id = masked_mean_returns.argmax()
|
| 205 |
+
xp_param = jax.tree_map(
|
| 206 |
+
lambda x: jnp.squeeze(x, 0),
|
| 207 |
+
partner_population.gather_agent_params(pop_buffer,
|
| 208 |
+
agent_indices=max_means_id * jnp.ones((1,), dtype=np.int32))
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
rng, rng_xp, rng_sp, rng_mp, rng_mp2 = jax.random.split(rng, 5)
|
| 212 |
+
|
| 213 |
+
def _env_step_conf_ego(runner_state, unused):
|
| 214 |
+
"""
|
| 215 |
+
agent_0 = confederate, agent_1 = ego
|
| 216 |
+
Returns updated runner_state and a Transition for the confederate.
|
| 217 |
+
"""
|
| 218 |
+
train_state, xp_param, xp_id, env_state, last_obs, last_dones, rng = runner_state
|
| 219 |
+
rng, act_rng, partner_rng, step_rng = jax.random.split(rng, 4)
|
| 220 |
+
|
| 221 |
+
obs_0 = last_obs["agent_0"]
|
| 222 |
+
obs_1 = last_obs["agent_1"]
|
| 223 |
+
|
| 224 |
+
# Get available actions for agent 0 from environment state
|
| 225 |
+
avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
|
| 226 |
+
avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32)
|
| 227 |
+
avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32)
|
| 228 |
+
|
| 229 |
+
# Add one-hot ID of XP teammate
|
| 230 |
+
xp_one_hot_id = jnp.eye(config["POP_SIZE"])[xp_id]
|
| 231 |
+
xp_one_hot_id = jnp.expand_dims(
|
| 232 |
+
jnp.expand_dims(
|
| 233 |
+
xp_one_hot_id, 0
|
| 234 |
+
), 0
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Agent_0 (confederate) action using policy interface
|
| 238 |
+
aux_obs = jnp.repeat(xp_one_hot_id, config["NUM_ENVS"], axis=1)
|
| 239 |
+
act_0, val_0, pi_0, _ = policy.get_action_value_policy(
|
| 240 |
+
params=train_state.params,
|
| 241 |
+
obs=obs_0.reshape(1, config["NUM_ENVS"], -1),
|
| 242 |
+
done=last_dones["agent_0"].reshape(1, config["NUM_ENVS"]),
|
| 243 |
+
avail_actions=jax.lax.stop_gradient(avail_actions_0),
|
| 244 |
+
hstate=None,
|
| 245 |
+
rng=act_rng,
|
| 246 |
+
aux_obs=aux_obs
|
| 247 |
+
)
|
| 248 |
+
logp_0 = pi_0.log_prob(act_0)
|
| 249 |
+
|
| 250 |
+
act_0 = act_0.squeeze()
|
| 251 |
+
logp_0 = logp_0.squeeze()
|
| 252 |
+
val_0 = val_0.squeeze()
|
| 253 |
+
|
| 254 |
+
# Agent_1 (ego) action using policy interface
|
| 255 |
+
act_1, _, _, _ = policy.get_action_value_policy(
|
| 256 |
+
params=xp_param,
|
| 257 |
+
obs=obs_1.reshape(1, config["NUM_ENVS"], -1),
|
| 258 |
+
done=last_dones["agent_1"].reshape(1, config["NUM_ENVS"]),
|
| 259 |
+
avail_actions=jax.lax.stop_gradient(avail_actions_1),
|
| 260 |
+
hstate=None,
|
| 261 |
+
rng=partner_rng,
|
| 262 |
+
aux_obs=aux_obs
|
| 263 |
+
)
|
| 264 |
+
act_1 = act_1.squeeze()
|
| 265 |
+
|
| 266 |
+
# Combine actions into the env format
|
| 267 |
+
combined_actions = jnp.concatenate([act_0, act_1], axis=0) # shape (2*num_envs,)
|
| 268 |
+
env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents)
|
| 269 |
+
env_act = {k: v.flatten() for k, v in env_act.items()}
|
| 270 |
+
|
| 271 |
+
# Step env
|
| 272 |
+
step_rngs = jax.random.split(step_rng, config["NUM_ENVS"])
|
| 273 |
+
obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))(
|
| 274 |
+
step_rngs, env_state, env_act
|
| 275 |
+
)
|
| 276 |
+
# note that num_actors = num_envs * num_agents
|
| 277 |
+
info_0 = jax.tree.map(lambda x: x[:, 0], info)
|
| 278 |
+
|
| 279 |
+
# Store agent_0 data in transition
|
| 280 |
+
transition = Transition(
|
| 281 |
+
done=done["agent_0"],
|
| 282 |
+
action=act_0,
|
| 283 |
+
value=val_0,
|
| 284 |
+
reward=reward["agent_1"],
|
| 285 |
+
log_prob=logp_0,
|
| 286 |
+
obs=obs_0,
|
| 287 |
+
info=info_0,
|
| 288 |
+
avail_actions=avail_actions_0
|
| 289 |
+
)
|
| 290 |
+
new_runner_state = (train_state, xp_param, xp_id, env_state_next, obs_next, done, rng)
|
| 291 |
+
return new_runner_state, transition
|
| 292 |
+
|
| 293 |
+
def _env_step_conf_br(runner_state, unused):
|
| 294 |
+
"""
|
| 295 |
+
agent_0 = confederate, agent_1 = best response
|
| 296 |
+
Returns updated runner_state, and Transitions for the confederate and best response.
|
| 297 |
+
"""
|
| 298 |
+
train_state, env_state, last_obs, last_dones, rng, current_trained_pop_id, reset_traj_batch = runner_state
|
| 299 |
+
rng, conf_rng, br_rng, step_rng = jax.random.split(rng, 4)
|
| 300 |
+
|
| 301 |
+
def gather_sampled(data_pytree, flat_indices, first_nonbatch_dim: int):
|
| 302 |
+
'''Will treat all dimensions up to the first_nonbatch_dim as batch dimensions. '''
|
| 303 |
+
batch_size = config["ROLLOUT_LENGTH"] * config["NUM_ENVS"]
|
| 304 |
+
flat_data = jax.tree.map(lambda x: x.reshape(batch_size, *x.shape[first_nonbatch_dim:]), data_pytree)
|
| 305 |
+
sampled_data = jax.tree.map(lambda x: x[flat_indices], flat_data) # Shape (N, ...)
|
| 306 |
+
return sampled_data
|
| 307 |
+
|
| 308 |
+
if reset_traj_batch is not None:
|
| 309 |
+
rng, sample_rng = jax.random.split(rng)
|
| 310 |
+
needs_resample = last_dones["__all__"] # shape (N,) bool
|
| 311 |
+
|
| 312 |
+
total_reset_states = config["ROLLOUT_LENGTH"] * config["NUM_ENVS"]
|
| 313 |
+
sampled_indices = jax.random.randint(sample_rng, shape=(config["NUM_ENVS"],), minval=0,
|
| 314 |
+
maxval=total_reset_states)
|
| 315 |
+
|
| 316 |
+
# Gather sampled leaves from each data pytree
|
| 317 |
+
sampled_env_state = gather_sampled(reset_traj_batch.env_state, sampled_indices, first_nonbatch_dim=2)
|
| 318 |
+
sampled_conf_obs = gather_sampled(reset_traj_batch.conf_obs, sampled_indices, first_nonbatch_dim=2)
|
| 319 |
+
sampled_br_obs = gather_sampled(reset_traj_batch.partner_obs, sampled_indices, first_nonbatch_dim=2)
|
| 320 |
+
sampled_conf_done = gather_sampled(reset_traj_batch.conf_done, sampled_indices, first_nonbatch_dim=2)
|
| 321 |
+
sampled_br_done = gather_sampled(reset_traj_batch.partner_done, sampled_indices, first_nonbatch_dim=2)
|
| 322 |
+
|
| 323 |
+
# for done environments, select data corresponding to the reset_traj_batch states
|
| 324 |
+
env_state = jax.tree.map(
|
| 325 |
+
lambda sampled, original: jnp.where(
|
| 326 |
+
needs_resample.reshape((-1,) + (1,) * (original.ndim - 1)),
|
| 327 |
+
sampled, original
|
| 328 |
+
),
|
| 329 |
+
sampled_env_state,
|
| 330 |
+
env_state
|
| 331 |
+
)
|
| 332 |
+
obs_0 = jnp.where(needs_resample[:, jnp.newaxis], sampled_conf_obs, last_obs["agent_0"])
|
| 333 |
+
obs_1 = jnp.where(needs_resample[:, jnp.newaxis], sampled_br_obs, last_obs["agent_1"])
|
| 334 |
+
|
| 335 |
+
dones_0 = jnp.where(needs_resample, sampled_conf_done, last_dones["agent_0"])
|
| 336 |
+
dones_1 = jnp.where(needs_resample, sampled_br_done, last_dones["agent_1"])
|
| 337 |
+
|
| 338 |
+
else:
|
| 339 |
+
|
| 340 |
+
# Reset conf-br data collection from conf-ego states
|
| 341 |
+
obs_0, obs_1 = last_obs["agent_0"], last_obs["agent_1"]
|
| 342 |
+
dones_0, dones_1 = last_dones["agent_0"], last_dones["agent_1"]
|
| 343 |
+
|
| 344 |
+
# Get available actions for agent 0 from environment state
|
| 345 |
+
avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
|
| 346 |
+
avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32)
|
| 347 |
+
avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32)
|
| 348 |
+
|
| 349 |
+
# Agent_0 (confederate) action
|
| 350 |
+
# Add one-hot ID of XP teammate
|
| 351 |
+
sp_one_hot_id = jnp.eye(config["POP_SIZE"])[current_trained_pop_id]
|
| 352 |
+
sp_one_hot_id = jnp.expand_dims(
|
| 353 |
+
jnp.expand_dims(
|
| 354 |
+
sp_one_hot_id, 0
|
| 355 |
+
), 0
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
aux_obs = jnp.repeat(sp_one_hot_id, config["NUM_ENVS"], 1)
|
| 359 |
+
act_0, val_0, pi_0, _ = policy.get_action_value_policy(
|
| 360 |
+
params=train_state.params,
|
| 361 |
+
obs=obs_0.reshape(1, config["NUM_ENVS"], -1),
|
| 362 |
+
done=dones_0.reshape(1, config["NUM_ENVS"]),
|
| 363 |
+
avail_actions=jax.lax.stop_gradient(avail_actions_0),
|
| 364 |
+
hstate=None,
|
| 365 |
+
rng=conf_rng,
|
| 366 |
+
aux_obs=aux_obs
|
| 367 |
+
)
|
| 368 |
+
logp_0 = pi_0.log_prob(act_0)
|
| 369 |
+
|
| 370 |
+
act_0 = act_0.squeeze()
|
| 371 |
+
logp_0 = logp_0.squeeze()
|
| 372 |
+
val_0 = val_0.squeeze()
|
| 373 |
+
|
| 374 |
+
# Agent 1 (best response) action
|
| 375 |
+
act_1, val_1, pi_1, _ = policy.get_action_value_policy(
|
| 376 |
+
params=train_state.params,
|
| 377 |
+
obs=obs_1.reshape(1, config["NUM_ENVS"], -1),
|
| 378 |
+
done=dones_1.reshape(1, config["NUM_ENVS"]),
|
| 379 |
+
avail_actions=jax.lax.stop_gradient(avail_actions_1),
|
| 380 |
+
hstate=None,
|
| 381 |
+
rng=br_rng,
|
| 382 |
+
aux_obs=aux_obs
|
| 383 |
+
)
|
| 384 |
+
logp_1 = pi_1.log_prob(act_1)
|
| 385 |
+
|
| 386 |
+
act_1 = act_1.squeeze()
|
| 387 |
+
logp_1 = logp_1.squeeze()
|
| 388 |
+
val_1 = val_1.squeeze()
|
| 389 |
+
|
| 390 |
+
# Combine actions into the env format
|
| 391 |
+
combined_actions = jnp.concatenate([act_0, act_1], axis=0) # shape (2*num_envs,)
|
| 392 |
+
env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents)
|
| 393 |
+
env_act = {k: v.flatten() for k, v in env_act.items()}
|
| 394 |
+
|
| 395 |
+
# Step env
|
| 396 |
+
step_rngs = jax.random.split(step_rng, config["NUM_ENVS"])
|
| 397 |
+
obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))(
|
| 398 |
+
step_rngs, env_state, env_act
|
| 399 |
+
)
|
| 400 |
+
info_0 = jax.tree.map(lambda x: x[:, 0], info)
|
| 401 |
+
info_1 = jax.tree.map(lambda x: x[:, 1], info)
|
| 402 |
+
|
| 403 |
+
# Store agent_0 (confederate) data in transition
|
| 404 |
+
transition_0 = Transition(
|
| 405 |
+
done=done["agent_0"],
|
| 406 |
+
action=act_0,
|
| 407 |
+
value=val_0,
|
| 408 |
+
reward=reward["agent_0"],
|
| 409 |
+
log_prob=logp_0,
|
| 410 |
+
obs=obs_0,
|
| 411 |
+
info=info_0,
|
| 412 |
+
avail_actions=avail_actions_0
|
| 413 |
+
)
|
| 414 |
+
# Store agent_1 (best response) data in transition
|
| 415 |
+
transition_1 = Transition(
|
| 416 |
+
done=done["agent_1"],
|
| 417 |
+
action=act_1,
|
| 418 |
+
value=val_1,
|
| 419 |
+
reward=reward["agent_1"],
|
| 420 |
+
log_prob=logp_1,
|
| 421 |
+
obs=obs_1,
|
| 422 |
+
info=info_1,
|
| 423 |
+
avail_actions=avail_actions_1
|
| 424 |
+
)
|
| 425 |
+
# Pass reset_traj_batch and init_br_hstate through unchanged in the state tuple
|
| 426 |
+
new_runner_state = (train_state, env_state_next, obs_next, done, rng, current_trained_pop_id, reset_traj_batch)
|
| 427 |
+
return new_runner_state, (transition_0, transition_1)
|
| 428 |
+
|
| 429 |
+
def _env_step_mixed(runner_state, unused):
|
| 430 |
+
"""
|
| 431 |
+
agent_0 = confederate, agent_1 = ego OR best response
|
| 432 |
+
Returns a ResetTransition for resetting to env states encountered here.
|
| 433 |
+
"""
|
| 434 |
+
train_state_conf, ego_param, env_state, last_obs, last_dones, rng, current_trained_pop_id = runner_state
|
| 435 |
+
rng, act_rng, ego_act_rng, br_act_rng, partner_choice_rng, step_rng = jax.random.split(rng, 6)
|
| 436 |
+
|
| 437 |
+
obs_0 = last_obs["agent_0"]
|
| 438 |
+
obs_1 = last_obs["agent_1"]
|
| 439 |
+
|
| 440 |
+
# Get available actions for agent 0 from environment state
|
| 441 |
+
avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
|
| 442 |
+
avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32)
|
| 443 |
+
avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32)
|
| 444 |
+
|
| 445 |
+
xp_one_hot_id = jnp.eye(config["POP_SIZE"])[current_trained_pop_id]
|
| 446 |
+
xp_one_hot_id = jnp.expand_dims(
|
| 447 |
+
jnp.expand_dims(
|
| 448 |
+
xp_one_hot_id, 0
|
| 449 |
+
), 0
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Agent_0 (confederate) action using policy interface
|
| 453 |
+
aux_obs = jnp.repeat(xp_one_hot_id, config["NUM_ENVS"], axis=1)
|
| 454 |
+
|
| 455 |
+
# Agent_0 (confederate) action using policy interface
|
| 456 |
+
act_0, val_0, pi_0, _ = policy.get_action_value_policy(
|
| 457 |
+
params=train_state_conf.params,
|
| 458 |
+
obs=obs_0.reshape(1, config["NUM_ENVS"], -1),
|
| 459 |
+
done=last_dones["agent_0"].reshape(1, config["NUM_ENVS"]),
|
| 460 |
+
avail_actions=jax.lax.stop_gradient(avail_actions_0),
|
| 461 |
+
hstate=None,
|
| 462 |
+
rng=act_rng,
|
| 463 |
+
aux_obs=aux_obs
|
| 464 |
+
)
|
| 465 |
+
logp_0 = pi_0.log_prob(act_0)
|
| 466 |
+
|
| 467 |
+
act_0 = act_0.squeeze()
|
| 468 |
+
logp_0 = logp_0.squeeze()
|
| 469 |
+
val_0 = val_0.squeeze()
|
| 470 |
+
|
| 471 |
+
### Compute both the ego action and the best response action
|
| 472 |
+
act_ego, _, _, _ = policy.get_action_value_policy(
|
| 473 |
+
params=ego_param,
|
| 474 |
+
obs=obs_1.reshape(1, config["NUM_ENVS"], -1),
|
| 475 |
+
done=last_dones["agent_1"].reshape(1, config["NUM_ENVS"]),
|
| 476 |
+
avail_actions=jax.lax.stop_gradient(avail_actions_1),
|
| 477 |
+
hstate=None,
|
| 478 |
+
rng=ego_act_rng,
|
| 479 |
+
aux_obs=aux_obs
|
| 480 |
+
)
|
| 481 |
+
act_br, _, _, _ = policy.get_action_value_policy(
|
| 482 |
+
params=train_state.params,
|
| 483 |
+
obs=obs_1.reshape(1, config["NUM_ENVS"], -1),
|
| 484 |
+
done=last_dones["agent_1"].reshape(1, config["NUM_ENVS"]),
|
| 485 |
+
avail_actions=jax.lax.stop_gradient(avail_actions_1),
|
| 486 |
+
hstate=None,
|
| 487 |
+
rng=br_act_rng,
|
| 488 |
+
aux_obs=aux_obs
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
act_ego = act_ego.squeeze()
|
| 492 |
+
act_br = act_br.squeeze()
|
| 493 |
+
# Agent 1 (ego or best response) action - choose between ego and best response
|
| 494 |
+
partner_choice = jax.random.randint(partner_choice_rng, shape=(config["NUM_ENVS"],), minval=0, maxval=2)
|
| 495 |
+
act_1 = jnp.where(partner_choice == 0, act_ego, act_br)
|
| 496 |
+
|
| 497 |
+
# Combine actions into the env format
|
| 498 |
+
combined_actions = jnp.concatenate([act_0, act_1], axis=0)
|
| 499 |
+
env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents)
|
| 500 |
+
env_act = {k: v.flatten() for k, v in env_act.items()}
|
| 501 |
+
|
| 502 |
+
# Step env
|
| 503 |
+
step_rngs = jax.random.split(step_rng, config["NUM_ENVS"])
|
| 504 |
+
obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))(
|
| 505 |
+
step_rngs, env_state, env_act
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
reset_transition = ResetTransition(
|
| 509 |
+
# all of these are from before env step
|
| 510 |
+
env_state=env_state,
|
| 511 |
+
conf_obs=obs_0,
|
| 512 |
+
partner_obs=obs_1,
|
| 513 |
+
conf_done=last_dones["agent_0"],
|
| 514 |
+
partner_done=last_dones["agent_1"],
|
| 515 |
+
conf_hstate=None,
|
| 516 |
+
# we record the best response hstate because we use it to reset the best response
|
| 517 |
+
partner_hstate=None
|
| 518 |
+
)
|
| 519 |
+
new_runner_state = (train_state_conf, ego_param, env_state_next, obs_next, done, rng, current_trained_pop_id)
|
| 520 |
+
return new_runner_state, reset_transition
|
| 521 |
+
|
| 522 |
+
# Do XP rollout (based on train_state params and the param in pop_buffer identified in Step 1)
|
| 523 |
+
runner_state_xp = (train_state, xp_param, max_means_id, env_state_xp, obsv_xp, last_dones_xp, rng_xp)
|
| 524 |
+
runner_state_xp, traj_batch_xp = jax.lax.scan(
|
| 525 |
+
_env_step_conf_ego, runner_state_xp, None, config["ROLLOUT_LENGTH"])
|
| 526 |
+
(train_state, xp_param, max_means_id, env_state_xp, last_obs_xp, last_dones_xp, rng_xp) = runner_state_xp
|
| 527 |
+
|
| 528 |
+
# Do self-play (based on train_state params) rollout like in the IPPO code
|
| 529 |
+
runner_state_sp = (train_state, env_state_sp, obsv_sp, last_dones_sp, rng_sp, num_prev_trained_conf, None)
|
| 530 |
+
runner_state_sp, (traj_batch_sp_agent0, traj_batch_sp_agent1) = jax.lax.scan(
|
| 531 |
+
_env_step_conf_br, runner_state_sp, None, config["ROLLOUT_LENGTH"])
|
| 532 |
+
(train_state, env_state_sp, last_obs_sp, last_dones_sp, rng_sp, num_prev_trained_conf, mp_traj_batch) = runner_state_sp
|
| 533 |
+
|
| 534 |
+
# Step 4
|
| 535 |
+
# Do MP rollout (based on train_state params and the param in pop_buffer identified in Step 1)
|
| 536 |
+
runner_state_mp = (train_state, xp_param, env_state_mp, obsv_mp, last_dones_mp, rng_mp, num_prev_trained_conf)
|
| 537 |
+
runner_state_mp, traj_batch_mp = jax.lax.scan(
|
| 538 |
+
_env_step_mixed, runner_state_mp, None, config["ROLLOUT_LENGTH"])
|
| 539 |
+
(train_state, xp_param, env_state_mp, last_obs_mp, last_dones_mp, rng_mp, num_prev_trained_conf) = runner_state_mp
|
| 540 |
+
|
| 541 |
+
runner_state_smp = (train_state, env_state_mp2, obsv_mp2, last_dones_mp2, rng_mp2, num_prev_trained_conf, traj_batch_mp)
|
| 542 |
+
runner_state_smp, (traj_batch_smp0, traj_batch_smp1) = jax.lax.scan(
|
| 543 |
+
_env_step_conf_br, runner_state_smp, None, config["ROLLOUT_LENGTH"])
|
| 544 |
+
(train_state, env_state_mp2, last_obs_mp2, last_dones_mp2, rng_mp2, num_prev_trained_conf, mp2_traj_batch) = runner_state_smp
|
| 545 |
+
|
| 546 |
+
def _calculate_gae(traj_batch, last_val):
|
| 547 |
+
def _get_advantages(gae_and_next_value, transition):
|
| 548 |
+
gae, next_value = gae_and_next_value
|
| 549 |
+
done, value, reward = (
|
| 550 |
+
transition.done,
|
| 551 |
+
transition.value,
|
| 552 |
+
transition.reward,
|
| 553 |
+
)
|
| 554 |
+
delta = reward + config["GAMMA"] * next_value * (1 - done) - value
|
| 555 |
+
gae = (
|
| 556 |
+
delta
|
| 557 |
+
+ config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
|
| 558 |
+
)
|
| 559 |
+
return (gae, value), gae
|
| 560 |
+
|
| 561 |
+
_, advantages = jax.lax.scan(
|
| 562 |
+
_get_advantages,
|
| 563 |
+
(jnp.zeros_like(last_val), last_val),
|
| 564 |
+
traj_batch,
|
| 565 |
+
reverse=True,
|
| 566 |
+
unroll=16,
|
| 567 |
+
)
|
| 568 |
+
return advantages, advantages + traj_batch.value
|
| 569 |
+
|
| 570 |
+
def _compute_advantages_and_targets(env_state, policy, policy_params, policy_hstate,
|
| 571 |
+
last_obs, last_dones, traj_batch, agent_name, value_idx=None):
|
| 572 |
+
'''Value_idx argument is to support the ActorWithDoubleCritic (confederate) policy, which
|
| 573 |
+
has two value heads. Value head 0 models the ego agent while value head 1 models the best response.'''
|
| 574 |
+
avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)[agent_name].astype(jnp.float32)
|
| 575 |
+
|
| 576 |
+
# Add one-hot ID of interaction teammate
|
| 577 |
+
xp_one_hot_id = jnp.eye(config["POP_SIZE"])[value_idx]
|
| 578 |
+
xp_one_hot_id = jnp.expand_dims(
|
| 579 |
+
jnp.expand_dims(
|
| 580 |
+
xp_one_hot_id, 0
|
| 581 |
+
), 0
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
# Agent_0 (confederate) action using policy interface
|
| 585 |
+
aux_obs = jnp.repeat(xp_one_hot_id, last_obs[agent_name].shape[0], axis=1)
|
| 586 |
+
|
| 587 |
+
_, vals, _, _ = policy.get_action_value_policy(
|
| 588 |
+
params=policy_params,
|
| 589 |
+
obs=last_obs[agent_name].reshape(1, last_obs[agent_name].shape[0], -1),
|
| 590 |
+
done=last_dones[agent_name].reshape(1, last_obs[agent_name].shape[0]),
|
| 591 |
+
avail_actions=jax.lax.stop_gradient(avail_actions),
|
| 592 |
+
hstate=policy_hstate,
|
| 593 |
+
rng=jax.random.PRNGKey(0), # dummy key as we don't sample actions
|
| 594 |
+
aux_obs=aux_obs
|
| 595 |
+
)
|
| 596 |
+
last_val = vals.squeeze()
|
| 597 |
+
advantages, targets = _calculate_gae(traj_batch, last_val)
|
| 598 |
+
return advantages, targets
|
| 599 |
+
|
| 600 |
+
# 5a) Compute conf advantages for XP (conf-ego) interaction
|
| 601 |
+
advantages_xp_conf, targets_xp_conf = _compute_advantages_and_targets(
|
| 602 |
+
env_state_xp, policy, train_state.params, None,
|
| 603 |
+
last_obs_xp, last_dones_xp, traj_batch_xp, "agent_0", value_idx=max_means_id)
|
| 604 |
+
|
| 605 |
+
# 5b) Compute conf and br advantages for SP (conf-br) interaction
|
| 606 |
+
advantages_sp_conf, targets_sp_conf = _compute_advantages_and_targets(
|
| 607 |
+
env_state_sp, policy, train_state.params, None,
|
| 608 |
+
last_obs_sp, last_dones_sp, traj_batch_sp_agent0, "agent_0", value_idx=num_prev_trained_conf)
|
| 609 |
+
|
| 610 |
+
advantages_sp_br, targets_sp_br = _compute_advantages_and_targets(
|
| 611 |
+
env_state_sp, policy, train_state.params, None,
|
| 612 |
+
last_obs_sp, last_dones_sp, traj_batch_sp_agent1, "agent_1", value_idx=num_prev_trained_conf)
|
| 613 |
+
|
| 614 |
+
# 5c) Compute advantages from MP interactions
|
| 615 |
+
advantages_mp_conf, targets_mp_conf = _compute_advantages_and_targets(
|
| 616 |
+
env_state_mp2, policy, train_state.params, None,
|
| 617 |
+
last_obs_mp2, last_dones_mp2, traj_batch_smp0, "agent_0", value_idx=num_prev_trained_conf)
|
| 618 |
+
|
| 619 |
+
advantages_mp_br, targets_mp_br = _compute_advantages_and_targets(
|
| 620 |
+
env_state_mp2, policy, train_state.params, None,
|
| 621 |
+
last_obs_mp2, last_dones_mp2, traj_batch_smp1, "agent_1", value_idx=num_prev_trained_conf)
|
| 622 |
+
|
| 623 |
+
def _update_epoch(update_state, unused):
|
| 624 |
+
def _compute_ppo_value_loss(pred_value, traj_batch, target_v):
|
| 625 |
+
'''Value loss function for PPO'''
|
| 626 |
+
value_pred_clipped = traj_batch.value + (
|
| 627 |
+
pred_value - traj_batch.value
|
| 628 |
+
).clip(
|
| 629 |
+
-config["CLIP_EPS"], config["CLIP_EPS"])
|
| 630 |
+
value_losses = jnp.square(pred_value - target_v)
|
| 631 |
+
value_losses_clipped = jnp.square(value_pred_clipped - target_v)
|
| 632 |
+
value_loss = (
|
| 633 |
+
jnp.maximum(value_losses, value_losses_clipped).mean()
|
| 634 |
+
)
|
| 635 |
+
return value_loss
|
| 636 |
+
|
| 637 |
+
def _compute_ppo_pg_loss(log_prob, traj_batch, gae):
|
| 638 |
+
'''Policy gradient loss function for PPO'''
|
| 639 |
+
ratio = jnp.exp(log_prob - traj_batch.log_prob)
|
| 640 |
+
gae_norm = (gae - gae.mean()) / (gae.std() + 1e-8)
|
| 641 |
+
pg_loss_1 = ratio * gae_norm
|
| 642 |
+
pg_loss_2 = jnp.clip(
|
| 643 |
+
ratio,
|
| 644 |
+
1.0 - config["CLIP_EPS"],
|
| 645 |
+
1.0 + config["CLIP_EPS"]) * gae_norm
|
| 646 |
+
pg_loss = -jnp.mean(jnp.minimum(pg_loss_1, pg_loss_2))
|
| 647 |
+
return pg_loss
|
| 648 |
+
|
| 649 |
+
def _update_minbatch_conf(train_state_conf, batch_infos):
|
| 650 |
+
minbatch_xp, minbatch_sp1, minbatch_sp2, minbatch_mp1, minbatch_mp2, xp_id, sp_id = batch_infos
|
| 651 |
+
_, traj_batch_xp, advantages_xp, returns_xp = minbatch_xp
|
| 652 |
+
_, traj_batch_sp1, advantages_sp1, returns_sp1 = minbatch_sp1
|
| 653 |
+
_, traj_batch_sp2, advantages_sp2, returns_sp2 = minbatch_sp2
|
| 654 |
+
_, traj_batch_mp1, advantages_mp1, returns_mp1 = minbatch_mp1
|
| 655 |
+
_, traj_batch_mp2, advantages_mp2, returns_mp2 = minbatch_mp2
|
| 656 |
+
|
| 657 |
+
def _loss_fn_conf(params, traj_batch_xp, gae_xp, target_v_xp,
|
| 658 |
+
traj_batch_sp, gae_sp, target_v_sp,
|
| 659 |
+
traj_batch_sp2, gae_sp2, target_v_sp2,
|
| 660 |
+
traj_batch_mp, gae_mp, target_v_mp,
|
| 661 |
+
traj_batch_mp2, gae_mp2, target_v_mp2):
|
| 662 |
+
# get policy and value of confederate versus ego and best response agents respectively
|
| 663 |
+
xp_one_hot_id = jnp.eye(config["POP_SIZE"])[xp_id]
|
| 664 |
+
xp_one_hot_id = jnp.expand_dims(
|
| 665 |
+
jnp.expand_dims(
|
| 666 |
+
xp_one_hot_id, 0
|
| 667 |
+
), 0
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
sp_one_hot_id = jnp.eye(config["POP_SIZE"])[sp_id]
|
| 671 |
+
sp_one_hot_id = jnp.expand_dims(
|
| 672 |
+
jnp.expand_dims(
|
| 673 |
+
sp_one_hot_id, 0
|
| 674 |
+
), 0
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# Agent_0 (confederate) action using policy interface
|
| 678 |
+
aux_obs_xp = jnp.repeat(xp_one_hot_id, traj_batch_xp.obs.shape[1], axis=1)
|
| 679 |
+
aux_obs_xp = jnp.repeat(aux_obs_xp, traj_batch_xp.obs.shape[0], axis=0)
|
| 680 |
+
|
| 681 |
+
_, value_xp, pi_xp, _ = policy.get_action_value_policy(
|
| 682 |
+
params=params,
|
| 683 |
+
obs=traj_batch_xp.obs,
|
| 684 |
+
done=traj_batch_xp.done,
|
| 685 |
+
avail_actions=traj_batch_xp.avail_actions,
|
| 686 |
+
hstate=None,
|
| 687 |
+
rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
|
| 688 |
+
aux_obs=aux_obs_xp
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
aux_obs_sp = jnp.repeat(xp_one_hot_id, traj_batch_sp.obs.shape[1], axis=1)
|
| 692 |
+
aux_obs_sp = jnp.repeat(aux_obs_sp, traj_batch_sp.obs.shape[0], axis=0)
|
| 693 |
+
_, value_sp, pi_sp, _ = policy.get_action_value_policy(
|
| 694 |
+
params=params,
|
| 695 |
+
obs=traj_batch_sp.obs,
|
| 696 |
+
done=traj_batch_sp.done,
|
| 697 |
+
avail_actions=traj_batch_sp.avail_actions,
|
| 698 |
+
hstate=None,
|
| 699 |
+
rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
|
| 700 |
+
aux_obs=aux_obs_sp
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
_, value_sp2, pi_sp2, _ = policy.get_action_value_policy(
|
| 704 |
+
params=params,
|
| 705 |
+
obs=traj_batch_sp2.obs,
|
| 706 |
+
done=traj_batch_sp2.done,
|
| 707 |
+
avail_actions=traj_batch_sp2.avail_actions,
|
| 708 |
+
hstate=None,
|
| 709 |
+
rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
|
| 710 |
+
aux_obs=aux_obs_sp
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
_, value_mp, pi_mp, _ = policy.get_action_value_policy(
|
| 714 |
+
params=params,
|
| 715 |
+
obs=traj_batch_mp.obs,
|
| 716 |
+
done=traj_batch_mp.done,
|
| 717 |
+
avail_actions=traj_batch_mp.avail_actions,
|
| 718 |
+
hstate=None,
|
| 719 |
+
rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
|
| 720 |
+
aux_obs=aux_obs_sp
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
_, value_mp2, pi_mp2, _ = policy.get_action_value_policy(
|
| 724 |
+
params=params,
|
| 725 |
+
obs=traj_batch_mp2.obs,
|
| 726 |
+
done=traj_batch_mp2.done,
|
| 727 |
+
avail_actions=traj_batch_mp2.avail_actions,
|
| 728 |
+
hstate=None,
|
| 729 |
+
rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
|
| 730 |
+
aux_obs=aux_obs_sp
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
log_prob_xp = pi_xp.log_prob(traj_batch_xp.action)
|
| 734 |
+
log_prob_sp = pi_sp.log_prob(traj_batch_sp.action)
|
| 735 |
+
log_prob_sp2 = pi_sp2.log_prob(traj_batch_sp2.action)
|
| 736 |
+
log_prob_mp = pi_mp.log_prob(traj_batch_mp.action)
|
| 737 |
+
log_prob_mp2 = pi_mp2.log_prob(traj_batch_mp2.action)
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
value_loss_xp = _compute_ppo_value_loss(value_xp, traj_batch_xp, target_v_xp)
|
| 741 |
+
value_loss_sp = _compute_ppo_value_loss(value_sp, traj_batch_sp, target_v_sp)
|
| 742 |
+
value_loss_sp2 = _compute_ppo_value_loss(value_sp2, traj_batch_sp2, target_v_sp2)
|
| 743 |
+
value_loss_mp = _compute_ppo_value_loss(value_mp, traj_batch_mp, target_v_mp)
|
| 744 |
+
value_loss_mp2 = _compute_ppo_value_loss(value_mp2, traj_batch_mp2, target_v_mp2)
|
| 745 |
+
|
| 746 |
+
pg_loss_xp = _compute_ppo_pg_loss(log_prob_xp, traj_batch_xp, gae_xp)
|
| 747 |
+
pg_loss_sp = _compute_ppo_pg_loss(log_prob_sp, traj_batch_sp, gae_sp)
|
| 748 |
+
pg_loss_sp2 = _compute_ppo_pg_loss(log_prob_sp2, traj_batch_sp2, gae_sp2)
|
| 749 |
+
pg_loss_mp = _compute_ppo_pg_loss(log_prob_mp, traj_batch_mp, gae_mp)
|
| 750 |
+
pg_loss_mp2 = _compute_ppo_pg_loss(log_prob_mp2, traj_batch_mp2, gae_mp2)
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# Entropy for interaction with ego agent
|
| 754 |
+
entropy_xp = jnp.mean(pi_xp.entropy())
|
| 755 |
+
entropy_sp = jnp.mean(pi_sp.entropy())
|
| 756 |
+
entropy_sp2 = jnp.mean(pi_sp2.entropy())
|
| 757 |
+
entropy_mp = jnp.mean(pi_mp.entropy())
|
| 758 |
+
entropy_mp2 = jnp.mean(pi_mp2.entropy())
|
| 759 |
+
|
| 760 |
+
xp_pg_weight = -config["COMEDI_ALPHA"] # negate to minimize the ego agent's PG objective
|
| 761 |
+
sp_pg_weight = 1.0
|
| 762 |
+
mp2_pg_weight = config["COMEDI_BETA"]
|
| 763 |
+
|
| 764 |
+
xp_loss = xp_pg_weight * pg_loss_xp + config["VF_COEF"] * value_loss_xp - config["ENT_COEF"] * entropy_xp
|
| 765 |
+
sp_loss = sp_pg_weight * pg_loss_sp + config["VF_COEF"] * value_loss_sp - config["ENT_COEF"] * entropy_sp
|
| 766 |
+
sp2_loss = sp_pg_weight * pg_loss_sp2 + config["VF_COEF"] * value_loss_sp2 - config["ENT_COEF"] * entropy_sp2
|
| 767 |
+
mp_loss = mp2_pg_weight * pg_loss_mp + config["VF_COEF"] * value_loss_mp - config["ENT_COEF"] * entropy_mp
|
| 768 |
+
mp2_loss = mp2_pg_weight * pg_loss_mp2 + config["VF_COEF"] * value_loss_mp2 - config["ENT_COEF"] * entropy_mp2
|
| 769 |
+
|
| 770 |
+
total_loss = sp_loss + sp2_loss + xp_loss + mp2_loss + mp_loss
|
| 771 |
+
return total_loss, (value_loss_xp, value_loss_sp + value_loss_sp2, value_loss_mp + value_loss_mp2,
|
| 772 |
+
pg_loss_xp, pg_loss_sp + pg_loss_sp2, pg_loss_mp + pg_loss_mp2,
|
| 773 |
+
entropy_xp, entropy_sp + entropy_sp2, entropy_mp + entropy_mp2)
|
| 774 |
+
|
| 775 |
+
grad_fn = jax.value_and_grad(_loss_fn_conf, has_aux=True)
|
| 776 |
+
(loss_val, aux_vals), grads = grad_fn(
|
| 777 |
+
train_state_conf.params,
|
| 778 |
+
traj_batch_xp, advantages_xp, returns_xp,
|
| 779 |
+
traj_batch_sp1, advantages_sp1, returns_sp1,
|
| 780 |
+
traj_batch_sp2, advantages_sp2, returns_sp2,
|
| 781 |
+
traj_batch_mp1, advantages_mp1, returns_mp1,
|
| 782 |
+
traj_batch_mp2, advantages_mp2, returns_mp2)
|
| 783 |
+
train_state_conf = train_state_conf.apply_gradients(grads=grads)
|
| 784 |
+
return train_state_conf, (loss_val, aux_vals)
|
| 785 |
+
|
| 786 |
+
(
|
| 787 |
+
train_state_conf, traj_batch_xp,
|
| 788 |
+
traj_batch_sp_conf, traj_batch_sp_br,
|
| 789 |
+
traj_batch_mp_conf, traj_batch_mp_br,
|
| 790 |
+
advantages_xp_conf, advantages_sp_conf,
|
| 791 |
+
advantages_sp_br, advantages_mp_conf,
|
| 792 |
+
advantages_mp_br, targets_xp_conf,
|
| 793 |
+
targets_sp_conf, targets_sp_br,
|
| 794 |
+
targets_mp_conf, targets_mp_br,
|
| 795 |
+
rng, xp_id, sp_id
|
| 796 |
+
) = update_state
|
| 797 |
+
|
| 798 |
+
rng, perm_rng_xp, perm_rng_sp_conf, perm_rng_sp_br, perm_rng_mp2_conf, perm_rng_mp2_br = jax.random.split(rng, 6)
|
| 799 |
+
|
| 800 |
+
# Create minibatches for each agent and interaction type
|
| 801 |
+
minibatches_xp = _create_minibatches(
|
| 802 |
+
traj_batch_xp, advantages_xp_conf, targets_xp_conf, None,
|
| 803 |
+
config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_xp
|
| 804 |
+
)
|
| 805 |
+
minibatches_sp_conf = _create_minibatches(
|
| 806 |
+
traj_batch_sp_conf, advantages_sp_conf, targets_sp_conf, None,
|
| 807 |
+
config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_sp_conf
|
| 808 |
+
)
|
| 809 |
+
minibatches_sp_br = _create_minibatches(
|
| 810 |
+
traj_batch_sp_br, advantages_sp_br, targets_sp_br, None,
|
| 811 |
+
config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_sp_br
|
| 812 |
+
)
|
| 813 |
+
minibatches_mp_conf = _create_minibatches(
|
| 814 |
+
traj_batch_mp_conf, advantages_mp_conf, targets_mp_conf, None,
|
| 815 |
+
config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_mp2_conf
|
| 816 |
+
)
|
| 817 |
+
minibatches_mp_br = _create_minibatches(
|
| 818 |
+
traj_batch_mp_br, advantages_mp_br, targets_mp_br, None,
|
| 819 |
+
config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_mp2_br
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
# Update confederate
|
| 823 |
+
repeated_xp_id = jnp.repeat(xp_id, minibatches_xp[1].obs.shape[0], axis=0)
|
| 824 |
+
repeated_sp_id = jnp.repeat(sp_id, minibatches_sp_br[1].obs.shape[0], axis=0)
|
| 825 |
+
train_state_conf, total_loss_conf = jax.lax.scan(
|
| 826 |
+
_update_minbatch_conf, train_state_conf, (
|
| 827 |
+
minibatches_xp, minibatches_sp_conf, minibatches_sp_br,
|
| 828 |
+
minibatches_mp_conf, minibatches_mp_br, repeated_xp_id, repeated_sp_id
|
| 829 |
+
)
|
| 830 |
+
)
|
| 831 |
+
|
| 832 |
+
update_state = (train_state_conf,
|
| 833 |
+
traj_batch_xp, traj_batch_sp_conf, traj_batch_sp_br, traj_batch_mp_conf, traj_batch_mp_br,
|
| 834 |
+
advantages_xp_conf, advantages_sp_conf, advantages_sp_br, advantages_mp_conf, advantages_mp_br,
|
| 835 |
+
targets_xp_conf, targets_sp_conf, targets_sp_br, targets_mp_conf, targets_mp_br,
|
| 836 |
+
rng, xp_id, sp_id
|
| 837 |
+
)
|
| 838 |
+
return update_state, total_loss_conf
|
| 839 |
+
|
| 840 |
+
# 3) PPO update
|
| 841 |
+
rng, sub_rng = jax.random.split(rng, 2)
|
| 842 |
+
update_state = (
|
| 843 |
+
train_state,
|
| 844 |
+
traj_batch_xp, traj_batch_sp_agent0,
|
| 845 |
+
traj_batch_sp_agent1,
|
| 846 |
+
traj_batch_smp0, traj_batch_smp1,
|
| 847 |
+
advantages_xp_conf,
|
| 848 |
+
advantages_sp_conf, advantages_sp_br,
|
| 849 |
+
advantages_mp_conf, advantages_mp_br,
|
| 850 |
+
targets_xp_conf, targets_sp_conf,
|
| 851 |
+
targets_sp_br, targets_mp_conf,
|
| 852 |
+
targets_mp_br, sub_rng,
|
| 853 |
+
max_means_id, num_prev_trained_conf
|
| 854 |
+
)
|
| 855 |
+
update_state, conf_losses = jax.lax.scan(
|
| 856 |
+
_update_epoch, update_state, None, config["UPDATE_EPOCHS"])
|
| 857 |
+
train_state = update_state[0]
|
| 858 |
+
|
| 859 |
+
(
|
| 860 |
+
conf_value_loss_xp, conf_value_loss_sp, conf_value_loss_mp,
|
| 861 |
+
conf_pg_loss_xp, conf_pg_loss_sp, conf_pg_loss_mp,
|
| 862 |
+
conf_entropy_xp, conf_entropy_sp, conf_entropy_mp
|
| 863 |
+
) = conf_losses[1]
|
| 864 |
+
|
| 865 |
+
new_update_runner_state = (
|
| 866 |
+
train_state, pop_buffer,
|
| 867 |
+
env_state_sp, last_obs_sp,
|
| 868 |
+
env_state_xp, last_obs_xp,
|
| 869 |
+
env_state_mp, last_obs_mp,
|
| 870 |
+
env_state_mp2, last_obs_mp2,
|
| 871 |
+
last_dones_xp, last_dones_sp,
|
| 872 |
+
last_dones_mp, last_dones_mp2,
|
| 873 |
+
rng, update_steps+1, num_prev_trained_conf
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
# Metrics
|
| 877 |
+
def mask_and_mean(x, mask):
|
| 878 |
+
return jnp.where(mask, x, 0).sum() / jnp.maximum(1, mask.sum())
|
| 879 |
+
|
| 880 |
+
mask = traj_batch_xp.info.get("returned_episode", jnp.ones_like(traj_batch_xp.reward))
|
| 881 |
+
metric = jax.tree.map(lambda x: mask_and_mean(x, mask), traj_batch_xp.info)
|
| 882 |
+
metric["update_steps"] = update_steps
|
| 883 |
+
metric["value_loss_conf_xp"] = conf_value_loss_xp.mean()
|
| 884 |
+
metric["value_loss_conf_sp"] = conf_value_loss_sp.mean()
|
| 885 |
+
metric["value_loss_conf_mp"] = conf_value_loss_mp.mean()
|
| 886 |
+
|
| 887 |
+
metric["pg_loss_conf_xp"] = conf_pg_loss_xp.mean()
|
| 888 |
+
metric["pg_loss_conf_sp"] = conf_pg_loss_sp.mean()
|
| 889 |
+
metric["pg_loss_conf_mp"] = conf_pg_loss_mp.mean()
|
| 890 |
+
|
| 891 |
+
metric["entropy_conf_xp"] = conf_entropy_xp.mean()
|
| 892 |
+
metric["entropy_conf_sp"] = conf_entropy_sp.mean()
|
| 893 |
+
metric["entropy_conf_mp"] = conf_entropy_mp.mean()
|
| 894 |
+
|
| 895 |
+
metric["average_rewards_ego"] = jnp.mean(traj_batch_xp.reward)
|
| 896 |
+
metric["average_rewards_br_sp"] = jnp.mean(traj_batch_sp_agent1.reward)
|
| 897 |
+
metric["average_rewards_br_mp2"] = jnp.mean(traj_batch_smp1.reward)
|
| 898 |
+
|
| 899 |
+
return (new_update_runner_state, checkpoint_array, ckpt_idx+1), metric
|
| 900 |
+
|
| 901 |
+
# XP eval against all policies in the buffer
|
| 902 |
+
xp_eval_returns = jax.tree.map(lambda x: x.mean(axis=(-2, -1)),
|
| 903 |
+
jax.vmap(per_id_run_episode_fixed_rng, in_axes=(None, 0))(
|
| 904 |
+
train_state.params,jnp.arange(config["POP_SIZE"])))
|
| 905 |
+
|
| 906 |
+
# SP performance against itself
|
| 907 |
+
sp_eval_returns = jax.tree.map(lambda x: x.mean(), run_episodes(
|
| 908 |
+
eval_rng, env,
|
| 909 |
+
agent_0_param=train_state.params, agent_0_policy=policy,
|
| 910 |
+
agent_1_param=train_state.params, agent_1_policy=policy,
|
| 911 |
+
max_episode_steps=config["ROLLOUT_LENGTH"],
|
| 912 |
+
num_eps=config["NUM_EVAL_EPISODES"]
|
| 913 |
+
))
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
update_steps = 0
|
| 917 |
+
init_done_xp = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]}
|
| 918 |
+
init_done_sp = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]}
|
| 919 |
+
init_done_mp = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]}
|
| 920 |
+
init_done_mp2 = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]}
|
| 921 |
+
|
| 922 |
+
update_runner_state = (
|
| 923 |
+
train_state, pop_buffer,
|
| 924 |
+
env_state_sp, obsv_sp,
|
| 925 |
+
env_state_xp, obsv_xp,
|
| 926 |
+
env_state_mp, obsv_mp,
|
| 927 |
+
env_state_mp2, obsv_mp2,
|
| 928 |
+
init_done_xp, init_done_sp,
|
| 929 |
+
init_done_mp, init_done_mp2,
|
| 930 |
+
rng, update_steps,
|
| 931 |
+
num_existing_agents
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
checkpoint_array = init_ckpt_array(train_state.params)
|
| 935 |
+
ckpt_idx = 0
|
| 936 |
+
update_with_ckpt_runner_state = (update_runner_state, checkpoint_array, ckpt_idx, xp_eval_returns, sp_eval_returns)
|
| 937 |
+
|
| 938 |
+
def _update_step_with_ckpt(state_with_ckpt, unused):
|
| 939 |
+
|
| 940 |
+
(update_runner_state, checkpoint_array, ckpt_idx, xp_eval_returns, sp_eval_returns) = state_with_ckpt
|
| 941 |
+
train_state = update_runner_state[0]
|
| 942 |
+
|
| 943 |
+
# Single PPO update
|
| 944 |
+
new_state_with_ckpt, metric = _update_step(
|
| 945 |
+
(update_runner_state, checkpoint_array, ckpt_idx),
|
| 946 |
+
None
|
| 947 |
+
)
|
| 948 |
+
new_update_runner_state = new_state_with_ckpt[0]
|
| 949 |
+
rng, update_steps = new_update_runner_state[-3], new_update_runner_state[-2]
|
| 950 |
+
|
| 951 |
+
# Decide if we store a checkpoint
|
| 952 |
+
# update steps is 1-indexed because it was incremented at the end of the update step
|
| 953 |
+
to_store = jnp.logical_or(jnp.equal(jnp.mod(update_steps-1, ckpt_and_eval_interval), 0),
|
| 954 |
+
jnp.equal(update_steps, config["NUM_UPDATES"]))
|
| 955 |
+
|
| 956 |
+
def store_and_eval_ckpt(args):
|
| 957 |
+
ckpt_arr_conf, rng, cidx, _, _ = args
|
| 958 |
+
new_ckpt_arr_conf = jax.tree.map(
|
| 959 |
+
lambda c_arr, p: c_arr.at[cidx].set(p),
|
| 960 |
+
ckpt_arr_conf, train_state.params
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
# Eval trained agent against all params in the pool
|
| 964 |
+
xp_eval_returns = jax.tree.map(lambda x: x.mean(axis=(-2, -1)),
|
| 965 |
+
jax.vmap(per_id_run_episode_fixed_rng, in_axes=(None, 0))(
|
| 966 |
+
train_state.params, jnp.arange(config["POP_SIZE"])))
|
| 967 |
+
# Eval trained agent against itself
|
| 968 |
+
sp_eval_returns = jax.tree.map(lambda x: x.mean(), run_episodes(
|
| 969 |
+
eval_rng, env,
|
| 970 |
+
agent_0_param=train_state.params, agent_0_policy=policy,
|
| 971 |
+
agent_1_param=train_state.params, agent_1_policy=policy,
|
| 972 |
+
max_episode_steps=config["ROLLOUT_LENGTH"],
|
| 973 |
+
num_eps=config["NUM_EVAL_EPISODES"]
|
| 974 |
+
))
|
| 975 |
+
|
| 976 |
+
return (new_ckpt_arr_conf, rng, cidx + 1, xp_eval_returns, sp_eval_returns)
|
| 977 |
+
|
| 978 |
+
def skip_ckpt(args):
|
| 979 |
+
return args
|
| 980 |
+
|
| 981 |
+
rng, store_and_eval_rng = jax.random.split(rng, 2)
|
| 982 |
+
(checkpoint_array, store_and_eval_rng, ckpt_idx, xp_eval_returns, sp_eval_returns) = jax.lax.cond(
|
| 983 |
+
to_store,
|
| 984 |
+
store_and_eval_ckpt,
|
| 985 |
+
skip_ckpt,
|
| 986 |
+
(checkpoint_array, store_and_eval_rng, ckpt_idx, xp_eval_returns, sp_eval_returns)
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
return (new_update_runner_state, checkpoint_array,
|
| 990 |
+
ckpt_idx, xp_eval_returns, sp_eval_returns), (metric, xp_eval_returns, sp_eval_returns)
|
| 991 |
+
|
| 992 |
+
new_update_with_ckpt_runner_state, (metric, xp_eval_returns, sp_eval_returns) = jax.lax.scan(
|
| 993 |
+
_update_step_with_ckpt,
|
| 994 |
+
update_with_ckpt_runner_state,
|
| 995 |
+
xs=None, # No per-step input data
|
| 996 |
+
length=config["NUM_UPDATES"],
|
| 997 |
+
)
|
| 998 |
+
new_update_runner_state, new_checkpoint_array, _, _ ,_ = new_update_with_ckpt_runner_state
|
| 999 |
+
final_train_state = new_update_runner_state[0]
|
| 1000 |
+
|
| 1001 |
+
updated_pop_buffer = partner_population.add_agent(pop_buffer, final_train_state.params)
|
| 1002 |
+
conf_checkpoints = new_checkpoint_array
|
| 1003 |
+
return updated_pop_buffer, (conf_checkpoints, metric, xp_eval_returns, sp_eval_returns)
|
| 1004 |
+
|
| 1005 |
+
rngs = jax.random.split(rng, config["PARTNER_POP_SIZE"])
|
| 1006 |
+
rng, add_conf_iter_rngs = rngs[0], rngs[1:]
|
| 1007 |
+
|
| 1008 |
+
iter_ids = jnp.arange(1, config["PARTNER_POP_SIZE"])
|
| 1009 |
+
final_population_buffer, (conf_checkpoints, metric, xp_eval_returns, sp_eval_returns) = jax.lax.scan(
|
| 1010 |
+
add_conf_policy, population_buffer, (iter_ids, add_conf_iter_rngs)
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
out = {
|
| 1014 |
+
"final_params_conf": final_population_buffer.params,
|
| 1015 |
+
"checkpoints_conf": conf_checkpoints,
|
| 1016 |
+
"metrics": metric,
|
| 1017 |
+
"last_ep_infos_xp": xp_eval_returns,
|
| 1018 |
+
"last_ep_infos_sp": sp_eval_returns
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
return out
|
| 1022 |
+
return train
|
| 1023 |
+
|
| 1024 |
+
train_fn = make_comedi_agents(config)
|
| 1025 |
+
out = train_fn(train_rng)
|
| 1026 |
+
return out
|
| 1027 |
+
|
| 1028 |
+
def get_comedi_population(config, out, env):
|
| 1029 |
+
'''
|
| 1030 |
+
Get the partner params and partner population for ego training.
|
| 1031 |
+
'''
|
| 1032 |
+
comedi_pop_size = config["algorithm"]["PARTNER_POP_SIZE"]
|
| 1033 |
+
|
| 1034 |
+
# partner_params has shape (num_seeds, comedi_pop_size, ...)
|
| 1035 |
+
partner_params = out['final_params_conf']
|
| 1036 |
+
|
| 1037 |
+
partner_policy = ActorWithConditionalCriticPolicy(
|
| 1038 |
+
action_dim=env.action_space(env.agents[1]).n,
|
| 1039 |
+
obs_dim=env.observation_space(env.agents[1]).shape[0],
|
| 1040 |
+
pop_size=comedi_pop_size, # used to create onehot agent id
|
| 1041 |
+
activation=config["algorithm"].get("ACTIVATION", "tanh")
|
| 1042 |
+
)
|
| 1043 |
+
|
| 1044 |
+
# Create partner population
|
| 1045 |
+
partner_population = AgentPopulation(
|
| 1046 |
+
pop_size=comedi_pop_size,
|
| 1047 |
+
policy_cls=partner_policy
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
return partner_params, partner_population
|
| 1051 |
+
|
| 1052 |
+
def run_comedi(config, wandb_logger):
|
| 1053 |
+
algorithm_config = dict(config["algorithm"])
|
| 1054 |
+
|
| 1055 |
+
env = make_env(algorithm_config["ENV_NAME"], algorithm_config["ENV_KWARGS"])
|
| 1056 |
+
env = LogWrapper(env)
|
| 1057 |
+
|
| 1058 |
+
log.info("Starting CoMeDi training...")
|
| 1059 |
+
start = time.time()
|
| 1060 |
+
|
| 1061 |
+
# Generate multiple random seeds from the base seed
|
| 1062 |
+
rng = jax.random.PRNGKey(algorithm_config["TRAIN_SEED"])
|
| 1063 |
+
rngs = jax.random.split(rng, algorithm_config["NUM_SEEDS"])
|
| 1064 |
+
|
| 1065 |
+
# Create a vmapped version of train_comedi_partners
|
| 1066 |
+
with jax.disable_jit(False):
|
| 1067 |
+
vmapped_train_fn = jax.jit(
|
| 1068 |
+
jax.vmap(
|
| 1069 |
+
partial(train_comedi_partners,
|
| 1070 |
+
wandb_logger=wandb_logger,
|
| 1071 |
+
env=env,
|
| 1072 |
+
config=algorithm_config)
|
| 1073 |
+
)
|
| 1074 |
+
)
|
| 1075 |
+
out = vmapped_train_fn(rngs)
|
| 1076 |
+
|
| 1077 |
+
end = time.time()
|
| 1078 |
+
log.info(f"CoMeDi training complete in {end - start} seconds")
|
| 1079 |
+
|
| 1080 |
+
metric_names = get_metric_names(algorithm_config["ENV_NAME"])
|
| 1081 |
+
|
| 1082 |
+
# Save FIRST so the checkpoint survives even if metric logging OOMs.
|
| 1083 |
+
savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
|
| 1084 |
+
out_savepath = save_train_run(out, savedir, savename="saved_train_run")
|
| 1085 |
+
log_metrics(config, out, wandb_logger, metric_names, out_savepath)
|
| 1086 |
+
partner_params, partner_population = get_comedi_population(config, out, env)
|
| 1087 |
+
return partner_params, partner_population
|
| 1088 |
+
|
| 1089 |
+
def compute_sp_mask_and_ids(pop_size):
|
| 1090 |
+
cross_product = np.meshgrid(
|
| 1091 |
+
np.arange(pop_size),
|
| 1092 |
+
np.arange(pop_size)
|
| 1093 |
+
)
|
| 1094 |
+
agent_id_cartesian_product = np.stack([g.ravel() for g in cross_product], axis=-1)
|
| 1095 |
+
conf_ids = agent_id_cartesian_product[:, 0]
|
| 1096 |
+
ego_ids = agent_id_cartesian_product[:, 1]
|
| 1097 |
+
sp_mask = (conf_ids == ego_ids)
|
| 1098 |
+
return sp_mask, agent_id_cartesian_product
|
| 1099 |
+
|
| 1100 |
+
def log_metrics(config, outs, logger, metric_names: tuple, out_savepath):
|
| 1101 |
+
metrics = outs["metrics"]
|
| 1102 |
+
# trained_pop_size excludes the initial policy
|
| 1103 |
+
num_seeds, pop_size, num_updates = metrics["pg_loss_conf_sp"].shape
|
| 1104 |
+
# TODO: add the eval_ep_last_info metrics
|
| 1105 |
+
|
| 1106 |
+
### Log evaluation metrics
|
| 1107 |
+
# xp_eval_returns and sp_eval_returns logged at each evaluation only.
|
| 1108 |
+
algorithm_config = config["algorithm"]
|
| 1109 |
+
ckpt_and_eval_interval = max(1, num_updates // max(1, algorithm_config["NUM_CHECKPOINTS"] - 1))
|
| 1110 |
+
# Steps at which store_and_eval_ckpt fires (0-indexed, matching the update_step logged below)
|
| 1111 |
+
eval_steps = list(range(0, num_updates, ckpt_and_eval_interval))
|
| 1112 |
+
if (num_updates - 1) not in eval_steps:
|
| 1113 |
+
eval_steps.append(num_updates - 1)
|
| 1114 |
+
|
| 1115 |
+
# shape (num_seeds, pop_size - 1, num_updates) [pre-scalarized: mean over eval eps and agents taken inside scan]
|
| 1116 |
+
all_returns_sp = np.asarray(outs["last_ep_infos_sp"]["returned_episode_returns"])
|
| 1117 |
+
# shape (num_seeds, pop_size - 1, num_updates, pop_size) [pre-scalarized: mean over eval eps and agents taken inside scan]
|
| 1118 |
+
all_returns_xp = np.asarray(outs["last_ep_infos_xp"]["returned_episode_returns"])
|
| 1119 |
+
|
| 1120 |
+
# Average over seeds only (eval episodes and agents already averaged inside scan)
|
| 1121 |
+
sp_return_curve = all_returns_sp.mean(axis=0) # shape (pop_size - 1, num_updates)
|
| 1122 |
+
xp_return_curve = all_returns_xp.mean(axis=0) # shape (pop_size - 1, num_updates, pop_size)
|
| 1123 |
+
|
| 1124 |
+
for num_add_policies in range(pop_size):
|
| 1125 |
+
for update_step in eval_steps:
|
| 1126 |
+
logger.log_item("Eval/AvgSPReturnCurve", sp_return_curve[num_add_policies, update_step], train_step=update_step)
|
| 1127 |
+
mean_xp_returns = xp_return_curve[num_add_policies, :, :(num_add_policies+1)].mean(axis=-1)
|
| 1128 |
+
logger.log_item("Eval/AvgXPReturnCurve", mean_xp_returns[update_step], train_step=update_step)
|
| 1129 |
+
logger.commit()
|
| 1130 |
+
|
| 1131 |
+
### Log population loss as multi-line plots, where each line is a different population member
|
| 1132 |
+
# both xp and xp metrics has shape (num_seeds, pop_size - 1, num_updates, update_epochs, num_minibatches)
|
| 1133 |
+
# Average over seeds
|
| 1134 |
+
processed_losses = {
|
| 1135 |
+
"ConfPGLossSP": np.asarray(metrics["pg_loss_conf_sp"]).mean(axis=0), # desired shape (pop_size - 1, num_updates)
|
| 1136 |
+
"ConfPGLossXP": np.asarray(metrics["pg_loss_conf_xp"]).mean(axis=0),
|
| 1137 |
+
"ConfPGLossMP": np.asarray(metrics["pg_loss_conf_mp"]).mean(axis=0),
|
| 1138 |
+
"ConfValLossSP": np.asarray(metrics["value_loss_conf_sp"]).mean(axis=0),
|
| 1139 |
+
"ConfValLossXP": np.asarray(metrics["value_loss_conf_xp"]).mean(axis=0),
|
| 1140 |
+
"ConfValLossMP": np.asarray(metrics["value_loss_conf_mp"]).mean(axis=0),
|
| 1141 |
+
"EntropySP": np.asarray(metrics["entropy_conf_sp"]).mean(axis=0),
|
| 1142 |
+
"EntropyXP": np.asarray(metrics["entropy_conf_xp"]).mean(axis=0),
|
| 1143 |
+
"EntropyMP": np.asarray(metrics["entropy_conf_mp"]).mean(axis=0),
|
| 1144 |
+
}
|
| 1145 |
+
|
| 1146 |
+
xs = list(range(num_updates))
|
| 1147 |
+
keys = [f"pair {i}" for i in range(pop_size)]
|
| 1148 |
+
|
| 1149 |
+
for loss_name, loss_data in processed_losses.items():
|
| 1150 |
+
logger.log_item(f"Losses/{loss_name}",
|
| 1151 |
+
wandb.plot.line_series(xs=xs, ys=loss_data, keys=keys,
|
| 1152 |
+
title=loss_name, xname="train_step")
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
### Log artifacts (already saved by caller; just publish to wandb)
|
| 1156 |
+
if config["logger"]["log_train_out"]:
|
| 1157 |
+
logger.log_artifact(name="saved_train_run", path=out_savepath, type_name="train_run")
|
| 1158 |
+
|
| 1159 |
+
# Cleanup locally logged out files
|
| 1160 |
+
if not config["local_logger"]["save_train_out"]:
|
| 1161 |
+
shutil.rmtree(out_savepath)
|
teammate_generation/LBRDiv.py
ADDED
|
@@ -0,0 +1,1098 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''Implementation of the LBRDiv teammate generation algorithm (Rahman et al., AAAI 2024)
|
| 2 |
+
https://ojs.aaai.org/index.php/AAAI/article/view/29702
|
| 3 |
+
|
| 4 |
+
Command to run LBRDiv only on LBF:
|
| 5 |
+
python teammate_generation/run.py algorithm=lbrdiv/lbf/lbf_7x7_nolevels task=lbf/lbf_7x7_nolevels label=test_lbrdiv run_heldout_eval=false train_ego=false
|
| 6 |
+
|
| 7 |
+
Suggested Debug command:
|
| 8 |
+
python teammate_generation/run.py algorithm=lbrdiv/lbf/lbf_7x7_nolevels task=lbf/lbf_7x7_nolevels logger.mode=disabled label=debug algorithm.TOTAL_TIMESTEPS=1e5 algorithm.PARTNER_POP_SIZE=2 train_ego=false run_heldout_eval=false
|
| 9 |
+
|
| 10 |
+
Limitations: does not support recurrent actors.
|
| 11 |
+
'''
|
| 12 |
+
import shutil
|
| 13 |
+
import time
|
| 14 |
+
import logging
|
| 15 |
+
from functools import partial
|
| 16 |
+
|
| 17 |
+
import hydra
|
| 18 |
+
import jax
|
| 19 |
+
import jax.numpy as jnp
|
| 20 |
+
import numpy as np
|
| 21 |
+
import optax
|
| 22 |
+
from flax.training.train_state import TrainState
|
| 23 |
+
import wandb
|
| 24 |
+
|
| 25 |
+
from agents.mlp_actor_critic_agent import ActorWithConditionalCriticPolicy
|
| 26 |
+
from agents.population_interface import AgentPopulation
|
| 27 |
+
from common.plot_utils import get_metric_names
|
| 28 |
+
from common.run_episodes import run_episodes
|
| 29 |
+
from common.save_load_utils import save_train_run
|
| 30 |
+
from envs import make_env
|
| 31 |
+
from envs.log_wrapper import LogWrapper
|
| 32 |
+
from marl.ppo_utils import unbatchify, _create_minibatches
|
| 33 |
+
from teammate_generation.BRDiv import _get_all_ids, XPTransition, gather_params
|
| 34 |
+
|
| 35 |
+
log = logging.getLogger(__name__)
|
| 36 |
+
logging.basicConfig(level=logging.INFO)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def train_lbrdiv_partners(train_rng, env, config, conf_policy, br_policy):
|
| 40 |
+
num_agents = env.num_agents
|
| 41 |
+
assert num_agents == 2, "This code assumes the environment has exactly 2 agents."
|
| 42 |
+
|
| 43 |
+
# Define different minibatch sizes for interactions with ego agent and one with BR agent
|
| 44 |
+
config["NUM_GAME_AGENTS"] = num_agents
|
| 45 |
+
config["NUM_CONF_ACTORS"] = config["NUM_ENVS"]
|
| 46 |
+
config["NUM_BR_ACTORS"] = config["NUM_ENVS"]
|
| 47 |
+
config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // (config["ROLLOUT_LENGTH"] * config["NUM_ENVS"])
|
| 48 |
+
|
| 49 |
+
def make_lbrdiv_agents(config):
|
| 50 |
+
def linear_schedule(count):
|
| 51 |
+
frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
|
| 52 |
+
return config["LR"] * frac
|
| 53 |
+
|
| 54 |
+
def train(rng):
|
| 55 |
+
rng, init_conf_rng, init_br_rng = jax.random.split(rng, 3)
|
| 56 |
+
all_conf_init_rngs = jax.random.split(init_conf_rng, config["PARTNER_POP_SIZE"])
|
| 57 |
+
all_br_init_rngs = jax.random.split(init_br_rng, config["PARTNER_POP_SIZE"])
|
| 58 |
+
identity_matrix = jnp.eye(config["PARTNER_POP_SIZE"])
|
| 59 |
+
|
| 60 |
+
init_conf_hstate = conf_policy.init_hstate(config["NUM_CONF_ACTORS"])
|
| 61 |
+
init_br_hstate = br_policy.init_hstate(config["NUM_BR_ACTORS"])
|
| 62 |
+
|
| 63 |
+
def init_train_states(rng_agents, rng_brs):
|
| 64 |
+
def init_single_pair_optimizers(rng_agent, rng_br):
|
| 65 |
+
init_params_conf = conf_policy.init_params(rng_agent)
|
| 66 |
+
init_params_br = br_policy.init_params(rng_br)
|
| 67 |
+
return init_params_conf, init_params_br
|
| 68 |
+
|
| 69 |
+
init_all_networks_and_optimizers = jax.vmap(init_single_pair_optimizers)
|
| 70 |
+
all_conf_params, all_br_params = init_all_networks_and_optimizers(rng_agents, rng_brs)
|
| 71 |
+
|
| 72 |
+
# Define optimizers for both confederate and BR policy
|
| 73 |
+
tx = optax.chain(
|
| 74 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 75 |
+
optax.adam(learning_rate=linear_schedule if config["ANNEAL_LR"] else config["LR"],
|
| 76 |
+
eps=1e-5),
|
| 77 |
+
)
|
| 78 |
+
tx_br = optax.chain(
|
| 79 |
+
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
|
| 80 |
+
optax.adam(learning_rate=linear_schedule if config["ANNEAL_LR"] else config["LR"],
|
| 81 |
+
eps=1e-5),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
train_state_conf = TrainState.create(
|
| 85 |
+
apply_fn=conf_policy.network.apply,
|
| 86 |
+
params=all_conf_params,
|
| 87 |
+
tx=tx,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
train_state_br = TrainState.create(
|
| 91 |
+
apply_fn=br_policy.network.apply,
|
| 92 |
+
params=all_br_params,
|
| 93 |
+
tx=tx_br,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
return train_state_conf, train_state_br
|
| 97 |
+
|
| 98 |
+
all_conf_optims, all_br_optims = init_train_states(
|
| 99 |
+
all_conf_init_rngs, all_br_init_rngs
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward_pass_conf(params, obs, id, done, avail_actions, hstate, rng):
|
| 103 |
+
act, val, pi, new_hstate = conf_policy.get_action_value_policy(
|
| 104 |
+
params=params,
|
| 105 |
+
obs=obs[jnp.newaxis, ...],
|
| 106 |
+
done=done[jnp.newaxis, ...],
|
| 107 |
+
avail_actions=avail_actions,
|
| 108 |
+
hstate=hstate,
|
| 109 |
+
rng=rng,
|
| 110 |
+
aux_obs=id[jnp.newaxis, ...]
|
| 111 |
+
)
|
| 112 |
+
return act, val, pi, new_hstate
|
| 113 |
+
|
| 114 |
+
def forward_pass_br(params, obs, id, done, avail_actions, hstate, rng):
|
| 115 |
+
act, val, pi, new_hstate = br_policy.get_action_value_policy(
|
| 116 |
+
params=params,
|
| 117 |
+
obs=obs[jnp.newaxis, ...],
|
| 118 |
+
done=done[jnp.newaxis, ...],
|
| 119 |
+
avail_actions=avail_actions,
|
| 120 |
+
hstate=hstate,
|
| 121 |
+
rng=rng,
|
| 122 |
+
aux_obs=id[jnp.newaxis, ...]
|
| 123 |
+
)
|
| 124 |
+
return act, val, pi, new_hstate
|
| 125 |
+
|
| 126 |
+
def _env_step(runner_state, unused):
|
| 127 |
+
"""
|
| 128 |
+
agent_0 = confederate, agent_1 = br
|
| 129 |
+
Returns updated runner_state, and Transitions for agent_0 and agent_1
|
| 130 |
+
"""
|
| 131 |
+
(
|
| 132 |
+
all_train_state_conf, all_train_state_br, last_conf_ids, last_br_ids,
|
| 133 |
+
env_state, last_obs, last_done, last_conf_h, last_br_h, rng
|
| 134 |
+
) = runner_state
|
| 135 |
+
rng, act0_rng, act1_rng, step_rng, conf_sampling_rng, br_sampling_rng = jax.random.split(rng, 6)
|
| 136 |
+
|
| 137 |
+
# For done envs, resample both conf and brs
|
| 138 |
+
needs_resample = last_done["__all__"]
|
| 139 |
+
resampled_conf_ids = jax.random.randint(conf_sampling_rng, (config["NUM_CONF_ACTORS"],), 0, config["PARTNER_POP_SIZE"])
|
| 140 |
+
resampled_br_ids = jax.random.randint(br_sampling_rng, (config["NUM_BR_ACTORS"],), 0, config["PARTNER_POP_SIZE"])
|
| 141 |
+
|
| 142 |
+
# Determine final indices based on whether resampling was needed for each env
|
| 143 |
+
updated_conf_ids = jnp.where(
|
| 144 |
+
needs_resample,
|
| 145 |
+
resampled_conf_ids, # Use newly sampled index if True
|
| 146 |
+
last_conf_ids # Else, keep index from previous step
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
updated_br_ids = jnp.where(
|
| 150 |
+
needs_resample,
|
| 151 |
+
resampled_br_ids, # Use newly sampled index if True
|
| 152 |
+
last_br_ids # Else, keep index from previous step
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Reset the hidden states for resampled conf and br if they are not None
|
| 156 |
+
# WARNING: (L)BRDiv was not tested with recurrent actors, so the code for if the hstate is not None may not work
|
| 157 |
+
if last_conf_h is not None:
|
| 158 |
+
updated_conf_h = jnp.where(
|
| 159 |
+
needs_resample,
|
| 160 |
+
init_conf_hstate,
|
| 161 |
+
last_conf_h
|
| 162 |
+
)
|
| 163 |
+
else:
|
| 164 |
+
updated_conf_h = last_conf_h
|
| 165 |
+
|
| 166 |
+
if last_br_h is not None:
|
| 167 |
+
updated_br_h = jnp.where(
|
| 168 |
+
needs_resample,
|
| 169 |
+
init_br_hstate,
|
| 170 |
+
last_br_h
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
updated_br_h = last_br_h
|
| 174 |
+
|
| 175 |
+
# Get the corresponding conf and br params
|
| 176 |
+
updated_conf_params = gather_params(all_train_state_conf.params, updated_conf_ids)
|
| 177 |
+
updated_br_params = gather_params(all_train_state_br.params, updated_br_ids)
|
| 178 |
+
|
| 179 |
+
updated_conf_onehot_ids = identity_matrix[updated_conf_ids]
|
| 180 |
+
updated_br_onehot_ids = identity_matrix[updated_br_ids]
|
| 181 |
+
|
| 182 |
+
# Get available actions for agent 0 from environment state
|
| 183 |
+
avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)
|
| 184 |
+
avail_actions = jax.lax.stop_gradient(avail_actions)
|
| 185 |
+
avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32)
|
| 186 |
+
avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32)
|
| 187 |
+
|
| 188 |
+
# Agent_0 action
|
| 189 |
+
act0_rng = jax.random.split(act0_rng, config["NUM_ENVS"])
|
| 190 |
+
act_0, val_0, pi_0, new_conf_h = jax.vmap(forward_pass_conf)(updated_conf_params,
|
| 191 |
+
last_obs["agent_0"], updated_br_onehot_ids, last_done["agent_0"], avail_actions_0,
|
| 192 |
+
updated_conf_h, act0_rng)
|
| 193 |
+
logp_0 = pi_0.log_prob(act_0)
|
| 194 |
+
act_0, val_0, logp_0 = act_0.squeeze(), val_0.squeeze(), logp_0.squeeze()
|
| 195 |
+
|
| 196 |
+
# Agent_1 action
|
| 197 |
+
act1_rng = jax.random.split(act1_rng, config["NUM_ENVS"])
|
| 198 |
+
act_1, val_1, pi_1, new_br_h = jax.vmap(forward_pass_br)(updated_br_params,
|
| 199 |
+
last_obs["agent_1"], updated_conf_onehot_ids, last_done["agent_1"], avail_actions_1,
|
| 200 |
+
updated_br_h, act1_rng)
|
| 201 |
+
logp_1 = pi_1.log_prob(act_1)
|
| 202 |
+
act_1, val_1, logp_1 = act_1.squeeze(), val_1.squeeze(), logp_1.squeeze()
|
| 203 |
+
|
| 204 |
+
# Combine actions into the env format
|
| 205 |
+
combined_actions = jnp.concatenate([act_0, act_1], axis=0)
|
| 206 |
+
env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents)
|
| 207 |
+
env_act = {k: v.flatten() for k, v in env_act.items()}
|
| 208 |
+
|
| 209 |
+
# Step env
|
| 210 |
+
step_rngs = jax.random.split(step_rng, config["NUM_ENVS"])
|
| 211 |
+
obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))(
|
| 212 |
+
step_rngs, env_state, env_act
|
| 213 |
+
)
|
| 214 |
+
# note that num_actors = num_envs * num_agents
|
| 215 |
+
info_0 = jax.tree.map(lambda x: x[:, 0], info)
|
| 216 |
+
info_1 = jax.tree.map(lambda x: x[:, 1], info)
|
| 217 |
+
|
| 218 |
+
# Store agent_0 data in transition
|
| 219 |
+
transition_0 = XPTransition(
|
| 220 |
+
done=done["agent_0"],
|
| 221 |
+
action=act_0,
|
| 222 |
+
value=val_0,
|
| 223 |
+
self_onehot_id=updated_conf_onehot_ids,
|
| 224 |
+
oppo_onehot_id=updated_br_onehot_ids,
|
| 225 |
+
reward=reward["agent_1"],
|
| 226 |
+
log_prob=logp_0,
|
| 227 |
+
obs=last_obs["agent_0"],
|
| 228 |
+
info=info_0,
|
| 229 |
+
avail_actions=avail_actions_0
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
transition_1 = XPTransition(
|
| 233 |
+
done=done["agent_1"],
|
| 234 |
+
action=act_1,
|
| 235 |
+
value=val_1,
|
| 236 |
+
self_onehot_id=updated_br_onehot_ids,
|
| 237 |
+
oppo_onehot_id=updated_conf_onehot_ids,
|
| 238 |
+
reward=reward["agent_1"],
|
| 239 |
+
log_prob=logp_1,
|
| 240 |
+
obs=last_obs["agent_1"],
|
| 241 |
+
info=info_1,
|
| 242 |
+
avail_actions=avail_actions_1
|
| 243 |
+
)
|
| 244 |
+
new_runner_state = (all_train_state_conf, all_train_state_br, updated_conf_ids, updated_br_ids,
|
| 245 |
+
env_state_next, obs_next, done, new_conf_h, new_br_h, rng)
|
| 246 |
+
return new_runner_state, (transition_0, transition_1)
|
| 247 |
+
|
| 248 |
+
def _calculate_gae(traj_batch, last_val):
|
| 249 |
+
def _get_advantages(gae_and_next_value, transition):
|
| 250 |
+
gae, next_value = gae_and_next_value
|
| 251 |
+
done, value, reward = (
|
| 252 |
+
transition.done,
|
| 253 |
+
transition.value,
|
| 254 |
+
transition.reward,
|
| 255 |
+
)
|
| 256 |
+
delta = reward + config["GAMMA"] * next_value * (1 - done) - value
|
| 257 |
+
gae = (
|
| 258 |
+
delta
|
| 259 |
+
+ config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
|
| 260 |
+
)
|
| 261 |
+
return (gae, value), gae
|
| 262 |
+
|
| 263 |
+
_, advantages = jax.lax.scan(
|
| 264 |
+
_get_advantages,
|
| 265 |
+
(jnp.zeros_like(last_val), last_val),
|
| 266 |
+
traj_batch,
|
| 267 |
+
reverse=True,
|
| 268 |
+
unroll=16,
|
| 269 |
+
)
|
| 270 |
+
return advantages, advantages + traj_batch.value
|
| 271 |
+
|
| 272 |
+
def run_all_episodes(rng, train_state_conf, train_state_br):
|
| 273 |
+
conf_ids, br_ids = _get_all_ids(config["PARTNER_POP_SIZE"])
|
| 274 |
+
gathered_conf_model_params = gather_params(train_state_conf.params, conf_ids)
|
| 275 |
+
gathered_br_model_params = gather_params(train_state_br.params, br_ids)
|
| 276 |
+
|
| 277 |
+
rng, eval_rng = jax.random.split(rng)
|
| 278 |
+
def run_episodes_fixed_rng(conf_param, br_param):
|
| 279 |
+
return run_episodes(
|
| 280 |
+
eval_rng, env,
|
| 281 |
+
conf_param, conf_policy,
|
| 282 |
+
br_param, br_policy,
|
| 283 |
+
config["ROLLOUT_LENGTH"], config["NUM_EVAL_EPISODES"],
|
| 284 |
+
)
|
| 285 |
+
ep_infos = jax.vmap(run_episodes_fixed_rng)(
|
| 286 |
+
gathered_conf_model_params, gathered_br_model_params, # leaves where shape is (pop_size*pop_size, ...)
|
| 287 |
+
)
|
| 288 |
+
return ep_infos
|
| 289 |
+
|
| 290 |
+
def _update_epoch(update_state, unused):
|
| 291 |
+
def _update_minbatch(all_train_states, all_data):
|
| 292 |
+
train_state_conf, train_state_br = all_train_states
|
| 293 |
+
minbatch_conf, minbatch_br, lms_vertical, lms_horizontal = all_data
|
| 294 |
+
|
| 295 |
+
def _loss_fn(param, agent_policy, minbatch, agent_id, lms_vertical, lms_horizontal):
|
| 296 |
+
'''Compute loss for agent corresponding to agent_id.
|
| 297 |
+
'''
|
| 298 |
+
init_hstate, traj_batch, gae, target_v = minbatch
|
| 299 |
+
# get policy and value of confederate versus ego and best response agents respectively
|
| 300 |
+
squeezed_param = jax.tree.map(lambda x: jnp.squeeze(x, 0), param)
|
| 301 |
+
_, value, pi, _ = agent_policy.get_action_value_policy(
|
| 302 |
+
params=squeezed_param,
|
| 303 |
+
obs=traj_batch.obs,
|
| 304 |
+
done=traj_batch.done,
|
| 305 |
+
avail_actions=traj_batch.avail_actions,
|
| 306 |
+
hstate=init_hstate,
|
| 307 |
+
rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
|
| 308 |
+
aux_obs=traj_batch.oppo_onehot_id
|
| 309 |
+
)
|
| 310 |
+
log_prob = pi.log_prob(traj_batch.action)
|
| 311 |
+
|
| 312 |
+
is_relevant = jnp.equal(
|
| 313 |
+
jnp.argmax(traj_batch.self_onehot_id, axis=-1),
|
| 314 |
+
agent_id
|
| 315 |
+
)
|
| 316 |
+
loss_weights = jnp.where(is_relevant, 1, 0).astype(jnp.float32)
|
| 317 |
+
int_self_id = jnp.argmax(traj_batch.self_onehot_id, axis=-1)
|
| 318 |
+
int_oppo_id = jnp.argmax(traj_batch.oppo_onehot_id, axis=-1)
|
| 319 |
+
|
| 320 |
+
# Given a pair of policies that generate SP trajectories,
|
| 321 |
+
# compute the pair's total Lagrange multiplier in the Lagrange dual.
|
| 322 |
+
# Assuming the SP data is generated by population i, the total LMs
|
| 323 |
+
# amounts to \sum_{j}*lms_vertical[i][j] + \sum_{j}*lms_horizontal[i][j]
|
| 324 |
+
|
| 325 |
+
def _gather_sp_weights(ids):
|
| 326 |
+
s_id, _ = ids
|
| 327 |
+
return jnp.sum(lms_vertical, axis=-1)[s_id], jnp.sum(lms_horizontal, axis=-1)[s_id]
|
| 328 |
+
|
| 329 |
+
# Given a pair of policies that generate XP trajectories,
|
| 330 |
+
# compute the pair's total Lagrange multiplier in the Lagrange dual.
|
| 331 |
+
# Assuming the XP data is generated by the i^th conf policy and the j^th BR policy,
|
| 332 |
+
# the total LMs amounts to
|
| 333 |
+
# -lms_vertical[j][i] -lms_horizontal[i][j]
|
| 334 |
+
|
| 335 |
+
def _gather_xp_weights(ids):
|
| 336 |
+
s_id, o_id = ids
|
| 337 |
+
return -lms_vertical[s_id][o_id], -lms_horizontal[o_id][s_id]
|
| 338 |
+
|
| 339 |
+
def _get_weights(s_id, o_id):
|
| 340 |
+
return jax.lax.cond(
|
| 341 |
+
jnp.equal(s_id, o_id),
|
| 342 |
+
_gather_sp_weights,
|
| 343 |
+
_gather_xp_weights,
|
| 344 |
+
(s_id, o_id)
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# Value loss
|
| 348 |
+
value_pred_clipped = traj_batch.value + (
|
| 349 |
+
value - traj_batch.value
|
| 350 |
+
).clip(
|
| 351 |
+
-config["CLIP_EPS"], config["CLIP_EPS"])
|
| 352 |
+
value_losses = jnp.square(value - target_v)
|
| 353 |
+
value_losses_clipped = jnp.square(value_pred_clipped - target_v)
|
| 354 |
+
value_loss = jax.lax.cond(
|
| 355 |
+
loss_weights.sum() == 0,
|
| 356 |
+
lambda x: jnp.zeros_like(x).astype(jnp.float32),
|
| 357 |
+
lambda x: x,
|
| 358 |
+
(loss_weights * jnp.maximum(value_losses, value_losses_clipped)).sum() / (loss_weights.sum() + 1e-8)
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# # Apply different loss weights for SP and XP data
|
| 362 |
+
# # Loss weights consist of two parts: the first term is the weighting from the (L)BRDiv loss fucntion
|
| 363 |
+
# # which is based on the sum of Lagrange multipliers for a given confederate-ego pair expected returns
|
| 364 |
+
# # in the Lagrange dual formulation. This is indicated by weights1 + weights2 in the code below.
|
| 365 |
+
|
| 366 |
+
# # The second term is a reweighting term to compensate for the data collection process, which uniformly and independently
|
| 367 |
+
# # samples the conf and br ids from 1, ..., n, resulting in P(SP) = 1/n and P(XP) = (n-1)/n.
|
| 368 |
+
# # To prevent the XP loss term from dominating the SP loss term, we would like P(SP) = P(XP) = 1/2.
|
| 369 |
+
# # Thus, we set the 2nd term of the SP weight to n/2, and the 2nd term of the XP weight to n/(2 * (n-1)).
|
| 370 |
+
|
| 371 |
+
n = config["PARTNER_POP_SIZE"]
|
| 372 |
+
is_sp = jnp.equal(jnp.argmax(traj_batch.self_onehot_id, axis=-1), jnp.argmax(traj_batch.oppo_onehot_id, axis=-1))
|
| 373 |
+
weights1, weights2 = jax.vmap(jax.vmap(_get_weights))(int_self_id, int_oppo_id)
|
| 374 |
+
actor_weights_sp = (weights1 + weights2) * (n/2)
|
| 375 |
+
actor_weights_xp = (weights1 + weights2) * (n / (2 * (n-1)))
|
| 376 |
+
actor_weights = jnp.where(is_sp, actor_weights_sp, actor_weights_xp)
|
| 377 |
+
|
| 378 |
+
# Policy gradient loss
|
| 379 |
+
ratio = jnp.exp(log_prob - traj_batch.log_prob)
|
| 380 |
+
gae_norm = (gae - gae.mean()) / (gae.std() + 1e-8)
|
| 381 |
+
pg_loss_1 = ratio * actor_weights * gae_norm
|
| 382 |
+
pg_loss_2 = jnp.clip(
|
| 383 |
+
ratio,
|
| 384 |
+
1.0 - config["CLIP_EPS"],
|
| 385 |
+
1.0 + config["CLIP_EPS"]) * actor_weights * gae_norm
|
| 386 |
+
pg_loss = jax.lax.cond(
|
| 387 |
+
loss_weights.sum() == 0,
|
| 388 |
+
lambda x: jnp.zeros_like(x).astype(jnp.float32),
|
| 389 |
+
lambda x: x,
|
| 390 |
+
-(
|
| 391 |
+
loss_weights * jnp.minimum(pg_loss_1, pg_loss_2)
|
| 392 |
+
).sum()/(loss_weights.sum() + 1e-8)
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Weight entropy based on actor weights
|
| 396 |
+
all_sp_weights1, all_sp_weights2 = jax.vmap(_gather_sp_weights)((int_self_id, int_self_id))
|
| 397 |
+
entropy_scaler = jnp.maximum(all_sp_weights1, all_sp_weights2)
|
| 398 |
+
|
| 399 |
+
# Compute entropy loss
|
| 400 |
+
entropy = jax.lax.cond(
|
| 401 |
+
loss_weights.sum() == 0,
|
| 402 |
+
lambda x: jnp.zeros_like(x).astype(jnp.float32),
|
| 403 |
+
lambda x: x,
|
| 404 |
+
(loss_weights * entropy_scaler * pi.entropy()).sum()/(loss_weights.sum() + 1e-8)
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
total_loss = pg_loss + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy
|
| 408 |
+
return total_loss, (value_loss, pg_loss, entropy)
|
| 409 |
+
|
| 410 |
+
possible_agent_ids = jnp.expand_dims(jnp.arange(config["PARTNER_POP_SIZE"]), 1)
|
| 411 |
+
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
| 412 |
+
|
| 413 |
+
def gather_conf_params_and_return_grads(agent_id):
|
| 414 |
+
# transposing the lm matrices only on the confederate agent side
|
| 415 |
+
# ensures that both the confederate and br policy that interact
|
| 416 |
+
# to generate a trajectory have the same weights when computing
|
| 417 |
+
# the policy gradient loss.
|
| 418 |
+
param_vector = gather_params(train_state_conf.params, agent_id)
|
| 419 |
+
(loss_val_conf, aux_vals_conf), grads_conf = grad_fn(
|
| 420 |
+
param_vector, conf_policy, minbatch_conf, agent_id,
|
| 421 |
+
jnp.transpose(lms_vertical), jnp.transpose(lms_horizontal)
|
| 422 |
+
)
|
| 423 |
+
return (loss_val_conf, aux_vals_conf), grads_conf
|
| 424 |
+
|
| 425 |
+
def gather_br_params_and_return_grads(agent_id):
|
| 426 |
+
param_vector = gather_params(train_state_br.params, agent_id)
|
| 427 |
+
(loss_val_br, aux_vals_br), grads_br = grad_fn(
|
| 428 |
+
param_vector, br_policy, minbatch_br, agent_id,
|
| 429 |
+
lms_vertical, lms_horizontal
|
| 430 |
+
)
|
| 431 |
+
return (loss_val_br, aux_vals_br), grads_br
|
| 432 |
+
|
| 433 |
+
(loss_val_conf, aux_vals_conf), grads_conf = jax.vmap(gather_conf_params_and_return_grads)(possible_agent_ids)
|
| 434 |
+
(loss_val_br, aux_vals_br), grads_br = jax.vmap(gather_br_params_and_return_grads)(possible_agent_ids)
|
| 435 |
+
|
| 436 |
+
grads_conf_new = jax.tree.map(lambda x: jnp.squeeze(x, 1), grads_conf)
|
| 437 |
+
grads_br_new = jax.tree.map(lambda x: jnp.squeeze(x, 1), grads_br)
|
| 438 |
+
train_state_conf = train_state_conf.apply_gradients(grads=grads_conf_new)
|
| 439 |
+
train_state_br = train_state_br.apply_gradients(grads=grads_br_new)
|
| 440 |
+
return (train_state_conf, train_state_br), ((loss_val_conf, aux_vals_conf), (loss_val_br, aux_vals_br))
|
| 441 |
+
|
| 442 |
+
(
|
| 443 |
+
train_state_conf, train_state_br,
|
| 444 |
+
traj_batch_conf, traj_batch_br,
|
| 445 |
+
advantages_conf, advantages_br,
|
| 446 |
+
targets_conf, targets_br,
|
| 447 |
+
rng, lms_vertical, lms_horizontal
|
| 448 |
+
) = update_state
|
| 449 |
+
rng, perm_rng_conf, perm_rng_br = jax.random.split(rng, 3)
|
| 450 |
+
|
| 451 |
+
minibatches_conf = _create_minibatches(traj_batch_conf, advantages_conf, targets_conf, init_conf_hstate,
|
| 452 |
+
config["NUM_CONF_ACTORS"], config["NUM_MINIBATCHES"], perm_rng_conf)
|
| 453 |
+
minibatches_br = _create_minibatches(traj_batch_br, advantages_br, targets_br, init_br_hstate,
|
| 454 |
+
config["NUM_BR_ACTORS"], config["NUM_MINIBATCHES"], perm_rng_br)
|
| 455 |
+
|
| 456 |
+
# Update both policies
|
| 457 |
+
num_minibatches = minibatches_br[1].obs.shape[0]
|
| 458 |
+
|
| 459 |
+
repeated_lms_vertical = lms_vertical[jnp.newaxis, ...].repeat(num_minibatches, axis=0)
|
| 460 |
+
repeated_lms_horizontal = lms_horizontal[jnp.newaxis, ...].repeat(num_minibatches, axis=0)
|
| 461 |
+
|
| 462 |
+
(train_state_conf, train_state_br), all_losses = jax.lax.scan(
|
| 463 |
+
_update_minbatch, (train_state_conf, train_state_br),
|
| 464 |
+
(minibatches_conf, minibatches_br, repeated_lms_vertical, repeated_lms_horizontal)
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
update_state = (train_state_conf, train_state_br,
|
| 468 |
+
traj_batch_conf, traj_batch_br,
|
| 469 |
+
advantages_conf, advantages_br,
|
| 470 |
+
targets_conf, targets_br,
|
| 471 |
+
rng, lms_vertical, lms_horizontal
|
| 472 |
+
)
|
| 473 |
+
return update_state, all_losses
|
| 474 |
+
|
| 475 |
+
def _update_step(update_runner_state, unused):
|
| 476 |
+
"""
|
| 477 |
+
1. Collect rollouts
|
| 478 |
+
2. Compute advantage
|
| 479 |
+
3. PPO updates (UPDATE_EPOCHS epochs)
|
| 480 |
+
4. Lagrange multiplier update (once, after all PPO epochs)
|
| 481 |
+
"""
|
| 482 |
+
(
|
| 483 |
+
all_train_state_conf, all_train_state_br,
|
| 484 |
+
last_env_state, last_obs, last_done, last_conf_h, last_br_h,
|
| 485 |
+
rng, update_steps, lms_vertical, lms_horizontal
|
| 486 |
+
) = update_runner_state
|
| 487 |
+
|
| 488 |
+
rng, conf_sampling_rng, br_sampling_rng = jax.random.split(rng, 3)
|
| 489 |
+
|
| 490 |
+
conf_ids = jax.random.randint(conf_sampling_rng, (config["NUM_ENVS"],), 0, config["PARTNER_POP_SIZE"])
|
| 491 |
+
br_ids = jax.random.randint(br_sampling_rng, (config["NUM_ENVS"],), 0, config["PARTNER_POP_SIZE"])
|
| 492 |
+
|
| 493 |
+
runner_state = (
|
| 494 |
+
all_train_state_conf, all_train_state_br, conf_ids, br_ids,
|
| 495 |
+
last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng
|
| 496 |
+
)
|
| 497 |
+
runner_state, traj_batch = jax.lax.scan(
|
| 498 |
+
_env_step, runner_state, None, config["ROLLOUT_LENGTH"])
|
| 499 |
+
(all_train_state_conf, all_train_state_br, last_conf_ids, last_br_ids,
|
| 500 |
+
last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng) = runner_state
|
| 501 |
+
|
| 502 |
+
# Get the last conf and br params and ids
|
| 503 |
+
last_conf_params = gather_params(all_train_state_conf.params, last_conf_ids)
|
| 504 |
+
last_br_params = gather_params(all_train_state_br.params, last_br_ids)
|
| 505 |
+
|
| 506 |
+
last_conf_one_hots = identity_matrix[last_conf_ids]
|
| 507 |
+
last_br_one_hots = identity_matrix[last_br_ids]
|
| 508 |
+
|
| 509 |
+
# Get agent 0 and agent 1 trajectories from interaction between conf policy and its BR policy.
|
| 510 |
+
traj_batch_conf, traj_batch_br = traj_batch
|
| 511 |
+
|
| 512 |
+
# Compute advantage for confederate agent from interaction with br policy
|
| 513 |
+
avail_actions_0 = jax.vmap(env.get_avail_actions)(last_env_state.env_state)["agent_0"].astype(jnp.float32)
|
| 514 |
+
_, last_val_conf, _, _ = jax.vmap(forward_pass_conf)(
|
| 515 |
+
params=last_conf_params,
|
| 516 |
+
obs=last_obs["agent_0"],
|
| 517 |
+
id=last_br_one_hots,
|
| 518 |
+
done=last_done["agent_0"],
|
| 519 |
+
avail_actions=avail_actions_0,
|
| 520 |
+
hstate=last_conf_h,
|
| 521 |
+
rng=jax.random.split(jax.random.PRNGKey(0), config["NUM_ENVS"]) # Dummy key since we're just extracting the value
|
| 522 |
+
)
|
| 523 |
+
last_val_conf = last_val_conf.squeeze()
|
| 524 |
+
advantages_conf, targets_conf = _calculate_gae(traj_batch_conf, last_val_conf)
|
| 525 |
+
|
| 526 |
+
# Compute advantage for br policy from interaction with confederate agent
|
| 527 |
+
avail_actions_1 = jax.vmap(env.get_avail_actions)(last_env_state.env_state)["agent_1"].astype(jnp.float32)
|
| 528 |
+
_, last_val_br, _, _ = jax.vmap(forward_pass_br)(
|
| 529 |
+
params=last_br_params,
|
| 530 |
+
obs=last_obs["agent_1"],
|
| 531 |
+
id=last_conf_one_hots,
|
| 532 |
+
done=last_done["agent_1"],
|
| 533 |
+
avail_actions=avail_actions_1,
|
| 534 |
+
hstate=last_br_h,
|
| 535 |
+
rng=jax.random.split(jax.random.PRNGKey(0), config["NUM_ENVS"]) # Dummy key since we're just extracting the value
|
| 536 |
+
)
|
| 537 |
+
last_val_br = last_val_br.squeeze()
|
| 538 |
+
advantages_br, targets_br = _calculate_gae(traj_batch_br, last_val_br)
|
| 539 |
+
|
| 540 |
+
# 3) PPO update
|
| 541 |
+
rng, update_rng = jax.random.split(rng, 2)
|
| 542 |
+
update_state = (
|
| 543 |
+
all_train_state_conf, all_train_state_br,
|
| 544 |
+
traj_batch_conf, traj_batch_br,
|
| 545 |
+
advantages_conf, advantages_br,
|
| 546 |
+
targets_conf, targets_br,
|
| 547 |
+
update_rng, lms_vertical, lms_horizontal
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
update_state, all_losses = jax.lax.scan(
|
| 551 |
+
_update_epoch, update_state, None, config["UPDATE_EPOCHS"])
|
| 552 |
+
all_train_state_conf, all_train_state_br = update_state[:2]
|
| 553 |
+
lms_vertical, lms_horizontal = update_state[-2:]
|
| 554 |
+
|
| 555 |
+
# Compute Lagrange gradient updates once per update step (after all PPO epochs).
|
| 556 |
+
# Diagonal and off-diagonal pairs use separate vmaps to avoid evaluating both
|
| 557 |
+
# branches of lax.cond for all pop_size^2 elements under vmap.
|
| 558 |
+
def compute_lagrange_grads_same(params_br, batch, target_value, ids):
|
| 559 |
+
conf_id, br_id = ids
|
| 560 |
+
|
| 561 |
+
all_target_value = jnp.reshape(target_value, (-1, 1))
|
| 562 |
+
repeated_value_sp = jnp.repeat(
|
| 563 |
+
jnp.reshape(all_target_value, (1, -1)),
|
| 564 |
+
config["PARTNER_POP_SIZE"],
|
| 565 |
+
axis=0
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
relevant_conf_params = gather_params(params_br, jnp.reshape(conf_id, (1,)))
|
| 569 |
+
relevant_conf_params = jax.tree.map(lambda x: jnp.squeeze(x, 0), relevant_conf_params)
|
| 570 |
+
def _get_value_xp_vary_conf(param, agent_onehot_id):
|
| 571 |
+
ts, bs = batch.obs.shape[:2]
|
| 572 |
+
agent_onehot_id = agent_onehot_id[jnp.newaxis, jnp.newaxis, ...].repeat(ts, axis=0).repeat(bs, axis=1)
|
| 573 |
+
_, value_xp_vary_conf, _, _ = br_policy.get_action_value_policy(
|
| 574 |
+
params=param,
|
| 575 |
+
obs=batch.obs,
|
| 576 |
+
done=batch.done,
|
| 577 |
+
avail_actions=batch.avail_actions,
|
| 578 |
+
hstate=init_br_hstate,
|
| 579 |
+
rng=jax.random.PRNGKey(0),
|
| 580 |
+
aux_obs=agent_onehot_id
|
| 581 |
+
)
|
| 582 |
+
return value_xp_vary_conf.reshape(ts*bs)
|
| 583 |
+
|
| 584 |
+
all_possible_value_xp_vary_conf = jax.vmap(
|
| 585 |
+
lambda agent_id: _get_value_xp_vary_conf(relevant_conf_params, agent_id)
|
| 586 |
+
)(jnp.eye(config["PARTNER_POP_SIZE"]))
|
| 587 |
+
all_possible_value_xp_vary_conf = all_possible_value_xp_vary_conf.at[conf_id].set(
|
| 588 |
+
repeated_value_sp[conf_id]
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
offsetting_thresholds = jnp.zeros_like(repeated_value_sp)
|
| 592 |
+
offsetting_thresholds = offsetting_thresholds.at[conf_id].set(
|
| 593 |
+
config["TOLERANCE_FACTOR"] * jnp.ones_like(offsetting_thresholds[conf_id])
|
| 594 |
+
)
|
| 595 |
+
grad_sp_vary_conf = repeated_value_sp + offsetting_thresholds - (
|
| 596 |
+
all_possible_value_xp_vary_conf + config["TOLERANCE_FACTOR"] * jnp.ones_like(offsetting_thresholds)
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
##### Compute grad_sp_vary_br
|
| 600 |
+
# This code tries to measure the expected returns of the ego agent had the BR policy been
|
| 601 |
+
# substituted by another BR policy
|
| 602 |
+
|
| 603 |
+
# Lets say that R_{i,-j} is the ego agent's returns when following the BR policy of the i^th pair
|
| 604 |
+
# againts the confederate policy of the j^th pair.
|
| 605 |
+
|
| 606 |
+
# Then grad_sp_vary_conf computes R_{i,-i} - R_{i,-j} - tolerance factor
|
| 607 |
+
# for all possible j (note for j=i, we sub in <repeated_value_sp + offsetting_thresholds above>
|
| 608 |
+
# R_{i,-i} with the target returns + tolerance factor so that R_{i,-i} - R_{i,-j} = 0)
|
| 609 |
+
|
| 610 |
+
# Meanwhile grad_sp_vary_br below computes R_{i,-i} - R_{j,-i} - tolerance factor
|
| 611 |
+
# for all possible j.
|
| 612 |
+
|
| 613 |
+
# Vary the BR policy parameters (j) used in value computation
|
| 614 |
+
# Use the experience generating pop id (batch.self_onehot_id) <i> as the conf ID.
|
| 615 |
+
|
| 616 |
+
relevant_params = gather_params(params_br, jnp.arange(config["PARTNER_POP_SIZE"]))
|
| 617 |
+
def _get_value_xp_vary_br(param):
|
| 618 |
+
ts, bs = batch.obs.shape[:2]
|
| 619 |
+
conf_one_hot = jnp.eye(config["PARTNER_POP_SIZE"])[conf_id]
|
| 620 |
+
conf_one_hot = conf_one_hot[jnp.newaxis, jnp.newaxis, ...].repeat(ts, axis=0).repeat(bs, axis=1)
|
| 621 |
+
_, value_xp_vary_br, _, _ = br_policy.get_action_value_policy(
|
| 622 |
+
params=param,
|
| 623 |
+
obs=batch.obs,
|
| 624 |
+
done=batch.done,
|
| 625 |
+
avail_actions=batch.avail_actions,
|
| 626 |
+
hstate=init_br_hstate,
|
| 627 |
+
rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here
|
| 628 |
+
aux_obs=conf_one_hot
|
| 629 |
+
)
|
| 630 |
+
return value_xp_vary_br.reshape(ts*bs)
|
| 631 |
+
|
| 632 |
+
all_possible_value_xp_vary_br = jax.vmap(
|
| 633 |
+
lambda param: _get_value_xp_vary_br(param)
|
| 634 |
+
)(relevant_params)
|
| 635 |
+
all_possible_value_xp_vary_br = jnp.reshape(
|
| 636 |
+
all_possible_value_xp_vary_br, (config["PARTNER_POP_SIZE"], -1)
|
| 637 |
+
)
|
| 638 |
+
all_possible_value_xp_vary_br = all_possible_value_xp_vary_br.at[conf_id].set(
|
| 639 |
+
repeated_value_sp[conf_id]
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
grad_sp_vary_br = repeated_value_sp + offsetting_thresholds - (
|
| 643 |
+
all_possible_value_xp_vary_br + config["TOLERANCE_FACTOR"] * jnp.ones_like(offsetting_thresholds)
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
all_self_id_int = jnp.reshape(
|
| 647 |
+
batch.self_onehot_id, (-1, jnp.shape(batch.self_onehot_id)[-1])
|
| 648 |
+
).argmax(axis=-1)
|
| 649 |
+
all_oppo_id_int = jnp.reshape(
|
| 650 |
+
batch.oppo_onehot_id, (-1, jnp.shape(batch.oppo_onehot_id)[-1])
|
| 651 |
+
).argmax(axis=-1)
|
| 652 |
+
|
| 653 |
+
self_is_conf = jnp.equal(all_self_id_int, conf_id).astype(jnp.float32)
|
| 654 |
+
oppo_is_conf = jnp.equal(all_oppo_id_int, conf_id).astype(jnp.float32)
|
| 655 |
+
loss_weights = self_is_conf * oppo_is_conf
|
| 656 |
+
repeated_loss_weights = jnp.repeat(
|
| 657 |
+
jnp.expand_dims(loss_weights, axis=0),
|
| 658 |
+
config["PARTNER_POP_SIZE"],
|
| 659 |
+
axis=0
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
# Compute vertical and horizontal gradient
|
| 663 |
+
vertical_grads = jnp.sum(grad_sp_vary_conf * repeated_loss_weights, axis=-1) / (jnp.sum(loss_weights) + 1e-8)
|
| 664 |
+
horizontal_grads = jnp.sum(grad_sp_vary_br * repeated_loss_weights, axis=-1) / (jnp.sum(loss_weights) + 1e-8)
|
| 665 |
+
|
| 666 |
+
output_grad_matrix_vertical = jnp.zeros((config["PARTNER_POP_SIZE"], config["PARTNER_POP_SIZE"]))
|
| 667 |
+
output_grad_matrix_horizontal = jnp.zeros((config["PARTNER_POP_SIZE"], config["PARTNER_POP_SIZE"]))
|
| 668 |
+
output_grad_matrix_vertical = output_grad_matrix_vertical.at[conf_id].set(vertical_grads)
|
| 669 |
+
output_grad_matrix_horizontal = output_grad_matrix_horizontal.at[conf_id].set(horizontal_grads)
|
| 670 |
+
return output_grad_matrix_vertical, output_grad_matrix_horizontal
|
| 671 |
+
|
| 672 |
+
def compute_lagrange_grads_diff(params_br, batch, target_returns, ids):
|
| 673 |
+
conf_id, br_id = ids
|
| 674 |
+
param_conf_id = gather_params(params_br, jnp.reshape(conf_id, (1,)))
|
| 675 |
+
param_br_id = gather_params(params_br, jnp.reshape(br_id, (1,)))
|
| 676 |
+
param_br_id = jax.tree.map(lambda x: jnp.squeeze(x, 0), param_br_id)
|
| 677 |
+
param_conf_id = jax.tree.map(lambda x: jnp.squeeze(x, 0), param_conf_id)
|
| 678 |
+
|
| 679 |
+
all_self_id_int = jnp.reshape(
|
| 680 |
+
batch.self_onehot_id, (-1, jnp.shape(batch.self_onehot_id)[-1])
|
| 681 |
+
).argmax(axis=-1)
|
| 682 |
+
all_oppo_id_int = jnp.reshape(
|
| 683 |
+
batch.oppo_onehot_id, (-1, jnp.shape(batch.oppo_onehot_id)[-1])
|
| 684 |
+
).argmax(axis=-1)
|
| 685 |
+
all_target_returns = jnp.reshape(target_returns, (-1))
|
| 686 |
+
|
| 687 |
+
# Compute data weights based on whether selected ID
|
| 688 |
+
# is relevant for the gradient computation process
|
| 689 |
+
oppo_is_conf = jnp.equal(all_oppo_id_int, conf_id).astype(jnp.float32)
|
| 690 |
+
self_is_br = jnp.equal(all_self_id_int, br_id).astype(jnp.float32)
|
| 691 |
+
loss_weights = oppo_is_conf * self_is_br
|
| 692 |
+
|
| 693 |
+
ts, bs = batch.obs.shape[:2]
|
| 694 |
+
conf_one_hot = jnp.eye(config["PARTNER_POP_SIZE"])[conf_id]
|
| 695 |
+
conf_one_hot = conf_one_hot[jnp.newaxis, jnp.newaxis, ...].repeat(ts, axis=0).repeat(bs, axis=1)
|
| 696 |
+
br_one_hot = jnp.eye(config["PARTNER_POP_SIZE"])[br_id]
|
| 697 |
+
br_one_hot = br_one_hot[jnp.newaxis, jnp.newaxis, ...].repeat(ts, axis=0).repeat(bs, axis=1)
|
| 698 |
+
|
| 699 |
+
_, value_sp_pop_is_br, _, _ = br_policy.get_action_value_policy(
|
| 700 |
+
params=param_br_id,
|
| 701 |
+
obs=batch.obs,
|
| 702 |
+
done=batch.done,
|
| 703 |
+
avail_actions=batch.avail_actions,
|
| 704 |
+
hstate=init_br_hstate,
|
| 705 |
+
rng=jax.random.PRNGKey(0),
|
| 706 |
+
aux_obs=br_one_hot
|
| 707 |
+
)
|
| 708 |
+
value_sp_pop_is_br = value_sp_pop_is_br.reshape(bs*ts)
|
| 709 |
+
|
| 710 |
+
_, value_sp_pop_is_not_br, _, _ = br_policy.get_action_value_policy(
|
| 711 |
+
params=param_conf_id,
|
| 712 |
+
obs=batch.obs,
|
| 713 |
+
done=batch.done,
|
| 714 |
+
avail_actions=batch.avail_actions,
|
| 715 |
+
hstate=init_br_hstate,
|
| 716 |
+
rng=jax.random.PRNGKey(0),
|
| 717 |
+
aux_obs=conf_one_hot
|
| 718 |
+
)
|
| 719 |
+
value_sp_pop_is_not_br = value_sp_pop_is_not_br.reshape(bs*ts)
|
| 720 |
+
|
| 721 |
+
vertical_diff = value_sp_pop_is_br - all_target_returns - config["TOLERANCE_FACTOR"]
|
| 722 |
+
horizontal_diff = value_sp_pop_is_not_br - all_target_returns - config["TOLERANCE_FACTOR"]
|
| 723 |
+
|
| 724 |
+
total_grad_vertical = (loss_weights * vertical_diff).sum() / (loss_weights.sum() + 1e-8)
|
| 725 |
+
total_grad_horizontal = (loss_weights * horizontal_diff).sum() / (loss_weights.sum() + 1e-8)
|
| 726 |
+
|
| 727 |
+
output_grad_matrix_vertical = jnp.zeros((config["PARTNER_POP_SIZE"], config["PARTNER_POP_SIZE"]))
|
| 728 |
+
output_grad_matrix_horizontal = jnp.zeros((config["PARTNER_POP_SIZE"], config["PARTNER_POP_SIZE"]))
|
| 729 |
+
output_grad_matrix_vertical = output_grad_matrix_vertical.at[br_id, conf_id].set(total_grad_vertical)
|
| 730 |
+
output_grad_matrix_horizontal = output_grad_matrix_horizontal.at[conf_id, br_id].set(total_grad_horizontal)
|
| 731 |
+
return output_grad_matrix_vertical, output_grad_matrix_horizontal
|
| 732 |
+
|
| 733 |
+
# Diagonal pairs (conf_id == br_id): vmap over pop_size elements only
|
| 734 |
+
diag_ids = np.arange(config["PARTNER_POP_SIZE"])
|
| 735 |
+
diag_lagrange_grads = jax.vmap(
|
| 736 |
+
lambda conf_id, br_id: compute_lagrange_grads_same(
|
| 737 |
+
all_train_state_br.params, traj_batch_br, targets_br, (conf_id, br_id)
|
| 738 |
+
)
|
| 739 |
+
)(diag_ids, diag_ids)
|
| 740 |
+
|
| 741 |
+
# Off-diagonal pairs (conf_id != br_id): vmap over pop_size*(pop_size-1) elements only
|
| 742 |
+
all_conf_ids_np, all_br_ids_np = _get_all_ids(config["PARTNER_POP_SIZE"])
|
| 743 |
+
off_diag_mask = all_conf_ids_np != all_br_ids_np
|
| 744 |
+
off_diag_conf_ids = all_conf_ids_np[off_diag_mask]
|
| 745 |
+
off_diag_br_ids = all_br_ids_np[off_diag_mask]
|
| 746 |
+
off_diag_lagrange_grads = jax.vmap(
|
| 747 |
+
lambda conf_id, br_id: compute_lagrange_grads_diff(
|
| 748 |
+
all_train_state_br.params, traj_batch_br, targets_br, (conf_id, br_id)
|
| 749 |
+
)
|
| 750 |
+
)(off_diag_conf_ids, off_diag_br_ids)
|
| 751 |
+
|
| 752 |
+
averaged_grad_vertical = (
|
| 753 |
+
jnp.sum(diag_lagrange_grads[0], axis=0) +
|
| 754 |
+
jnp.sum(off_diag_lagrange_grads[0], axis=0)
|
| 755 |
+
)
|
| 756 |
+
averaged_grad_horizontal = (
|
| 757 |
+
jnp.sum(diag_lagrange_grads[1], axis=0) +
|
| 758 |
+
jnp.sum(off_diag_lagrange_grads[1], axis=0)
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
lms_vertical = jnp.maximum(
|
| 762 |
+
lms_vertical - config["LAGRANGE_LR"] * averaged_grad_vertical,
|
| 763 |
+
0.5 * jnp.eye(config["PARTNER_POP_SIZE"])
|
| 764 |
+
)
|
| 765 |
+
lms_vertical = jnp.fill_diagonal(
|
| 766 |
+
lms_vertical, 0.5 * jnp.ones((config["PARTNER_POP_SIZE"]), dtype=jnp.float32),
|
| 767 |
+
inplace=False
|
| 768 |
+
)
|
| 769 |
+
lms_horizontal = jnp.maximum(
|
| 770 |
+
lms_horizontal - config["LAGRANGE_LR"] * averaged_grad_horizontal,
|
| 771 |
+
0.5 * jnp.eye(config["PARTNER_POP_SIZE"]),
|
| 772 |
+
)
|
| 773 |
+
lms_horizontal = jnp.fill_diagonal(
|
| 774 |
+
lms_horizontal, 0.5 * jnp.ones((config["PARTNER_POP_SIZE"]), dtype=jnp.float32),
|
| 775 |
+
inplace=False
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
(_, (value_loss_conf, pg_loss_conf, entropy_conf)), (_, (value_loss_br, pg_loss_br, entropy_br)) = all_losses
|
| 779 |
+
|
| 780 |
+
# Metrics
|
| 781 |
+
def mask_and_mean(x, mask):
|
| 782 |
+
return jnp.where(mask, x, 0).sum() / jnp.maximum(1, mask.sum())
|
| 783 |
+
|
| 784 |
+
mask = traj_batch_conf.info.get("returned_episode", jnp.ones_like(traj_batch_conf.reward))
|
| 785 |
+
metric = jax.tree.map(lambda x: mask_and_mean(x, mask), traj_batch_conf.info)
|
| 786 |
+
metric["lms_vertical"] = lms_vertical
|
| 787 |
+
metric["lms_horizontal"] = lms_horizontal
|
| 788 |
+
metric["update_steps"] = update_steps
|
| 789 |
+
metric["value_loss_conf_agent"] = value_loss_conf.mean(axis=(0, 1))
|
| 790 |
+
metric["value_loss_br_agent"] = value_loss_br.mean(axis=(0, 1))
|
| 791 |
+
|
| 792 |
+
metric["pg_loss_conf_agent"] = pg_loss_conf.mean(axis=(0, 1))
|
| 793 |
+
metric["pg_loss_br_agent"] = pg_loss_br.mean(axis=(0, 1))
|
| 794 |
+
|
| 795 |
+
metric["entropy_conf"] = entropy_conf.mean(axis=(0, 1))
|
| 796 |
+
metric["entropy_br"] = entropy_br.mean(axis=(0, 1))
|
| 797 |
+
|
| 798 |
+
new_runner_state = (
|
| 799 |
+
all_train_state_conf, all_train_state_br,
|
| 800 |
+
last_env_state, last_obs, last_done, last_conf_h, last_br_h,
|
| 801 |
+
rng, update_steps + 1,
|
| 802 |
+
lms_vertical, lms_horizontal
|
| 803 |
+
)
|
| 804 |
+
return (new_runner_state, metric)
|
| 805 |
+
|
| 806 |
+
# --------------------------
|
| 807 |
+
# PPO Update and Checkpoint saving
|
| 808 |
+
# --------------------------
|
| 809 |
+
ckpt_and_eval_interval = config["NUM_UPDATES"] // max(1, config["NUM_CHECKPOINTS"] - 1) # -1 because we store a ckpt at the last update
|
| 810 |
+
num_ckpts = config["NUM_CHECKPOINTS"]
|
| 811 |
+
|
| 812 |
+
# Build a PyTree that holds parameters for all conf agent checkpoints
|
| 813 |
+
def init_ckpt_array(params_pytree):
|
| 814 |
+
return jax.tree.map(
|
| 815 |
+
lambda x: jnp.zeros((num_ckpts,) + x.shape, x.dtype),
|
| 816 |
+
params_pytree)
|
| 817 |
+
|
| 818 |
+
def _update_step_with_ckpt(state_with_ckpt, unused):
|
| 819 |
+
(update_runner_state, checkpoint_array_conf, checkpoint_array_br, ckpt_idx,
|
| 820 |
+
eval_info) = state_with_ckpt
|
| 821 |
+
|
| 822 |
+
# Single PPO update
|
| 823 |
+
new_runner_state, metric = _update_step(update_runner_state, None)
|
| 824 |
+
|
| 825 |
+
(
|
| 826 |
+
train_state_conf, train_state_br,
|
| 827 |
+
last_env_state, last_obs, last_done, last_conf_h, last_br_h,
|
| 828 |
+
rng, update_steps, lms_vertical, lms_horizontal
|
| 829 |
+
) = new_runner_state
|
| 830 |
+
|
| 831 |
+
# Decide if we store a checkpoint
|
| 832 |
+
# update steps is 1-indexed because it was incremented at the end of the update step
|
| 833 |
+
to_store = jnp.logical_or(jnp.equal(jnp.mod(update_steps-1, ckpt_and_eval_interval), 0),
|
| 834 |
+
jnp.equal(update_steps, config["NUM_UPDATES"]))
|
| 835 |
+
|
| 836 |
+
def store_and_eval_ckpt(args):
|
| 837 |
+
ckpt_arr_and_ep_infos, rng, cidx = args
|
| 838 |
+
ckpt_arr_conf, ckpt_arr_br, _ = ckpt_arr_and_ep_infos
|
| 839 |
+
new_ckpt_arr_conf = jax.tree.map(
|
| 840 |
+
lambda c_arr, p: c_arr.at[cidx].set(p),
|
| 841 |
+
ckpt_arr_conf, train_state_conf.params
|
| 842 |
+
)
|
| 843 |
+
new_ckpt_arr_br = jax.tree.map(
|
| 844 |
+
lambda c_arr, p: c_arr.at[cidx].set(p),
|
| 845 |
+
ckpt_arr_br, train_state_br.params
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
rng, eval_rng = jax.random.split(rng)
|
| 849 |
+
ep_last_info = jax.tree.map(lambda x: x.mean(axis=(-2, -1)),
|
| 850 |
+
run_all_episodes(eval_rng, train_state_conf, train_state_br))
|
| 851 |
+
|
| 852 |
+
return ((new_ckpt_arr_conf, new_ckpt_arr_br, ep_last_info), rng, cidx + 1)
|
| 853 |
+
|
| 854 |
+
def skip_ckpt(args):
|
| 855 |
+
return args
|
| 856 |
+
|
| 857 |
+
(checkpoint_array_and_infos, rng, ckpt_idx) = jax.lax.cond(
|
| 858 |
+
to_store,
|
| 859 |
+
store_and_eval_ckpt,
|
| 860 |
+
skip_ckpt,
|
| 861 |
+
((checkpoint_array_conf, checkpoint_array_br, eval_info), rng, ckpt_idx)
|
| 862 |
+
)
|
| 863 |
+
checkpoint_array_conf, checkpoint_array_br, eval_ep_last_info = checkpoint_array_and_infos
|
| 864 |
+
|
| 865 |
+
metric["eval_ep_last_info"] = eval_ep_last_info # return of confederate
|
| 866 |
+
|
| 867 |
+
return ((train_state_conf, train_state_br,
|
| 868 |
+
last_env_state, last_obs, last_done, last_conf_h, last_br_h,
|
| 869 |
+
rng, update_steps, lms_vertical, lms_horizontal),
|
| 870 |
+
checkpoint_array_conf, checkpoint_array_br, ckpt_idx,
|
| 871 |
+
eval_ep_last_info), metric
|
| 872 |
+
|
| 873 |
+
# Initialize checkpoint array
|
| 874 |
+
checkpoint_array_conf = init_ckpt_array(all_conf_optims.params)
|
| 875 |
+
checkpoint_array_br = init_ckpt_array(all_br_optims.params)
|
| 876 |
+
ckpt_idx = 0
|
| 877 |
+
|
| 878 |
+
# Initialize state for scan over _update_step_with_ckpt
|
| 879 |
+
update_steps = 0
|
| 880 |
+
|
| 881 |
+
rng, rng_eval = jax.random.split(rng, 2)
|
| 882 |
+
eval_ep_last_info = jax.tree.map(lambda x: x.mean(axis=(-2, -1)),
|
| 883 |
+
run_all_episodes(rng_eval, all_conf_optims, all_br_optims))
|
| 884 |
+
|
| 885 |
+
# Initialize environment
|
| 886 |
+
rng, reset_rng = jax.random.split(rng)
|
| 887 |
+
reset_rngs = jax.random.split(reset_rng, config["NUM_ENVS"])
|
| 888 |
+
init_obs, init_env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rngs)
|
| 889 |
+
init_done = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]}
|
| 890 |
+
|
| 891 |
+
# Initialize conf and br hstates
|
| 892 |
+
init_conf_h = conf_policy.init_hstate(config["NUM_CONF_ACTORS"])
|
| 893 |
+
init_br_h = br_policy.init_hstate(config["NUM_BR_ACTORS"])
|
| 894 |
+
|
| 895 |
+
# Initialize LMs
|
| 896 |
+
# lm_vertical[i, j] stores the lagrange multiplier for upholding
|
| 897 |
+
# R_{conf(i), BR(i)} >= R_{conf(j), BR(i)} + tolerance_factor
|
| 898 |
+
|
| 899 |
+
# lm_horizontal[i, j] stores the lagrange multiplier for upholding
|
| 900 |
+
# R_{conf(i), BR(i)} >= R_{conf(i), BR(j)} + tolerance_factor
|
| 901 |
+
|
| 902 |
+
# Diagonal elements of both matrices sum up to 1.
|
| 903 |
+
# Providing a weight of 1 to maximize the SP return from any population
|
| 904 |
+
lagrange_multipliers_vertical = 0.5 * jnp.eye(config["PARTNER_POP_SIZE"])
|
| 905 |
+
lagrange_multipliers_horizontal = 0.5 * jnp.eye(config["PARTNER_POP_SIZE"])
|
| 906 |
+
|
| 907 |
+
update_runner_state = (
|
| 908 |
+
all_conf_optims, all_br_optims,
|
| 909 |
+
init_env_state, init_obs, init_done, init_conf_h, init_br_h,
|
| 910 |
+
rng, update_steps,
|
| 911 |
+
lagrange_multipliers_vertical, lagrange_multipliers_horizontal
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
state_with_ckpt = (
|
| 915 |
+
update_runner_state, checkpoint_array_conf,
|
| 916 |
+
checkpoint_array_br, ckpt_idx, eval_ep_last_info
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
# run training
|
| 920 |
+
state_with_ckpt, metrics = jax.lax.scan(
|
| 921 |
+
_update_step_with_ckpt,
|
| 922 |
+
state_with_ckpt,
|
| 923 |
+
xs=None,
|
| 924 |
+
length=config["NUM_UPDATES"]
|
| 925 |
+
)
|
| 926 |
+
|
| 927 |
+
(
|
| 928 |
+
final_runner_state, checkpoint_array_conf, checkpoint_array_br,
|
| 929 |
+
final_ckpt_idx, all_ep_infos
|
| 930 |
+
) = state_with_ckpt
|
| 931 |
+
|
| 932 |
+
out = {
|
| 933 |
+
"final_params_conf": final_runner_state[0].params,
|
| 934 |
+
"final_params_br": final_runner_state[1].params,
|
| 935 |
+
"checkpoints_conf": checkpoint_array_conf,
|
| 936 |
+
"checkpoints_br": checkpoint_array_br,
|
| 937 |
+
"metrics": metrics, # metrics is from the perspective of the confederate agent (averaged over population)
|
| 938 |
+
"all_pair_returns": all_ep_infos
|
| 939 |
+
}
|
| 940 |
+
return out
|
| 941 |
+
|
| 942 |
+
return train
|
| 943 |
+
# ------------------------------
|
| 944 |
+
# Actually run the adversarial teammate training
|
| 945 |
+
# ------------------------------
|
| 946 |
+
train_fn = make_lbrdiv_agents(config)
|
| 947 |
+
out = train_fn(train_rng)
|
| 948 |
+
return out
|
| 949 |
+
|
| 950 |
+
def get_lbrdiv_population(config, out, env):
|
| 951 |
+
'''
|
| 952 |
+
Get the partner params and partner population for ego training.
|
| 953 |
+
'''
|
| 954 |
+
pop_size = config["algorithm"]["PARTNER_POP_SIZE"]
|
| 955 |
+
|
| 956 |
+
# partner_params has shape (num_seeds, pop_size, ...)
|
| 957 |
+
partner_params = out['final_params_conf']
|
| 958 |
+
|
| 959 |
+
partner_policy = ActorWithConditionalCriticPolicy(
|
| 960 |
+
action_dim=env.action_space(env.agents[1]).n,
|
| 961 |
+
obs_dim=env.observation_space(env.agents[1]).shape[0],
|
| 962 |
+
pop_size=pop_size, # used to create onehot agent id
|
| 963 |
+
activation=config["algorithm"].get("ACTIVATION", "tanh")
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
# Create partner population
|
| 967 |
+
partner_population = AgentPopulation(
|
| 968 |
+
pop_size=pop_size,
|
| 969 |
+
policy_cls=partner_policy
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
return partner_params, partner_population
|
| 973 |
+
|
| 974 |
+
def run_lbrdiv(config, wandb_logger):
|
| 975 |
+
algorithm_config = dict(config["algorithm"])
|
| 976 |
+
|
| 977 |
+
env = make_env(algorithm_config["ENV_NAME"], algorithm_config["ENV_KWARGS"])
|
| 978 |
+
env = LogWrapper(env)
|
| 979 |
+
|
| 980 |
+
log.info("Starting LBRDiv training...")
|
| 981 |
+
start = time.time()
|
| 982 |
+
|
| 983 |
+
# Generate multiple random seeds from the base seed
|
| 984 |
+
rng = jax.random.PRNGKey(algorithm_config["TRAIN_SEED"])
|
| 985 |
+
rngs = jax.random.split(rng, algorithm_config["NUM_SEEDS"])
|
| 986 |
+
|
| 987 |
+
# Initialize br and conf policies
|
| 988 |
+
conf_policy = ActorWithConditionalCriticPolicy(
|
| 989 |
+
action_dim=env.action_space(env.agents[0]).n,
|
| 990 |
+
obs_dim=env.observation_space(env.agents[0]).shape[0],
|
| 991 |
+
pop_size=algorithm_config["PARTNER_POP_SIZE"],
|
| 992 |
+
)
|
| 993 |
+
br_policy = ActorWithConditionalCriticPolicy(
|
| 994 |
+
action_dim=env.action_space(env.agents[0]).n,
|
| 995 |
+
obs_dim=env.observation_space(env.agents[0]).shape[0],
|
| 996 |
+
pop_size=algorithm_config["PARTNER_POP_SIZE"],
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
# Create a vmapped version of train_lbrdiv_partners
|
| 1000 |
+
with jax.disable_jit(False):
|
| 1001 |
+
vmapped_train_fn = jax.jit(
|
| 1002 |
+
jax.vmap(
|
| 1003 |
+
partial(train_lbrdiv_partners, env=env, config=algorithm_config, conf_policy=conf_policy, br_policy=br_policy)
|
| 1004 |
+
)
|
| 1005 |
+
)
|
| 1006 |
+
out = vmapped_train_fn(rngs)
|
| 1007 |
+
|
| 1008 |
+
end = time.time()
|
| 1009 |
+
log.info(f"LBRDiv training complete in {end - start} seconds")
|
| 1010 |
+
|
| 1011 |
+
metric_names = get_metric_names(algorithm_config["ENV_NAME"])
|
| 1012 |
+
log_metrics(config, out, wandb_logger, metric_names)
|
| 1013 |
+
|
| 1014 |
+
partner_params, partner_population = get_lbrdiv_population(config, out, env)
|
| 1015 |
+
|
| 1016 |
+
return partner_params, partner_population
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
def log_metrics(config, outs, logger, metric_names: tuple):
|
| 1020 |
+
metrics = outs["metrics"]
|
| 1021 |
+
# metrics now has shape (num_seeds, num_updates, pop_size)
|
| 1022 |
+
num_seeds, num_updates, pop_size = metrics["pg_loss_conf_agent"].shape # number of trained pairs
|
| 1023 |
+
|
| 1024 |
+
### Log evaluation metrics
|
| 1025 |
+
# shape (num_seeds, num_updates, (pop_size)^2) [pre-scalarized: mean over eval eps and agents taken inside scan]
|
| 1026 |
+
all_returns = np.asarray(metrics["eval_ep_last_info"]["returned_episode_returns"])
|
| 1027 |
+
xs = list(range(num_updates))
|
| 1028 |
+
|
| 1029 |
+
all_conf_ids, all_br_ids = _get_all_ids(pop_size)
|
| 1030 |
+
sp_mask = (all_conf_ids == all_br_ids)
|
| 1031 |
+
sp_returns = all_returns[:, :, sp_mask]
|
| 1032 |
+
xp_returns = all_returns[:, :, ~sp_mask]
|
| 1033 |
+
|
| 1034 |
+
# Average over seeds and agent pairs (eval episodes and agents already averaged inside scan)
|
| 1035 |
+
sp_return_curve = sp_returns.mean(axis=(0, 2))
|
| 1036 |
+
xp_return_curve = xp_returns.mean(axis=(0, 2))
|
| 1037 |
+
|
| 1038 |
+
for step in range(num_updates):
|
| 1039 |
+
logger.log_item("Eval/AvgSPReturnCurve", sp_return_curve[step], train_step=step)
|
| 1040 |
+
logger.log_item("Eval/AvgXPReturnCurve", xp_return_curve[step], train_step=step)
|
| 1041 |
+
logger.commit()
|
| 1042 |
+
|
| 1043 |
+
# log final XP matrix to wandb - average over seeds
|
| 1044 |
+
last_returns_array = all_returns[:, -1].mean(axis=0)
|
| 1045 |
+
last_returns_array = np.reshape(last_returns_array, (pop_size, pop_size))
|
| 1046 |
+
logger.log_xp_matrix("Eval/LastXPMatrix", last_returns_array)
|
| 1047 |
+
|
| 1048 |
+
### Log population loss as multi-line plots, where each line is a different population member
|
| 1049 |
+
# shape (num_seeds, num_updates, update_epochs, num_minibatches, pop_size)
|
| 1050 |
+
# Average over seeds
|
| 1051 |
+
processed_losses = {
|
| 1052 |
+
"ConfPGLoss": np.asarray(metrics["pg_loss_conf_agent"]).mean(axis=0).transpose(),
|
| 1053 |
+
"BRPGLoss": np.asarray(metrics["pg_loss_br_agent"]).mean(axis=0).transpose(),
|
| 1054 |
+
"ConfValLoss": np.asarray(metrics["value_loss_conf_agent"]).mean(axis=0).transpose(),
|
| 1055 |
+
"BRValLoss": np.asarray(metrics["value_loss_br_agent"]).mean(axis=0).transpose(),
|
| 1056 |
+
"ConfEntropy": np.asarray(metrics["entropy_conf"]).mean(axis=0).transpose(),
|
| 1057 |
+
"BREntropy": np.asarray(metrics["entropy_br"]).mean(axis=0).transpose(),
|
| 1058 |
+
}
|
| 1059 |
+
|
| 1060 |
+
xs = list(range(num_updates))
|
| 1061 |
+
keys = [f"pair {i}" for i in range(pop_size)]
|
| 1062 |
+
for loss_name, loss_data in processed_losses.items():
|
| 1063 |
+
if np.isnan(loss_data).any():
|
| 1064 |
+
raise ValueError(f"Found nan in loss {loss_name}")
|
| 1065 |
+
logger.log_item(f"Losses/{loss_name}",
|
| 1066 |
+
wandb.plot.line_series(xs=xs, ys=loss_data, keys=keys,
|
| 1067 |
+
title=loss_name, xname="train_step")
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
# Average over seeds for Lagrange multipliers
|
| 1071 |
+
lm_keys = [f"pair {i}, {j}" for i in range(pop_size) for j in range(pop_size)]
|
| 1072 |
+
lm_horizontal = np.asarray(metrics["lms_horizontal"]).mean(axis=0)
|
| 1073 |
+
lm_vertical = np.asarray(metrics["lms_vertical"]).mean(axis=0)
|
| 1074 |
+
lagrange_multipliers = {
|
| 1075 |
+
"LMs_Horizontal": np.reshape(lm_horizontal, (lm_horizontal.shape[0], -1)).transpose(),
|
| 1076 |
+
"LMs_Vertical": np.reshape(lm_vertical, (lm_vertical.shape[0], -1)).transpose()
|
| 1077 |
+
}
|
| 1078 |
+
|
| 1079 |
+
for array_name, array_data in lagrange_multipliers.items():
|
| 1080 |
+
if np.isnan(array_data).any():
|
| 1081 |
+
raise ValueError(f"Found nan in loss {array_name}")
|
| 1082 |
+
logger.log_item(
|
| 1083 |
+
f"Losses/{array_name}",
|
| 1084 |
+
wandb.plot.line_series(xs=xs, ys=array_data, keys=lm_keys,
|
| 1085 |
+
title=array_name, xname="train_step")
|
| 1086 |
+
)
|
| 1087 |
+
logger.commit()
|
| 1088 |
+
|
| 1089 |
+
### Log artifacts
|
| 1090 |
+
savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
|
| 1091 |
+
# Save train run output and log to wandb as artifact
|
| 1092 |
+
out_savepath = save_train_run(outs, savedir, savename="saved_train_run")
|
| 1093 |
+
if config["logger"]["log_train_out"]:
|
| 1094 |
+
logger.log_artifact(name="saved_train_run", path=out_savepath, type_name="train_run")
|
| 1095 |
+
|
| 1096 |
+
# Cleanup locally logged out files
|
| 1097 |
+
if not config["local_logger"]["save_train_out"]:
|
| 1098 |
+
shutil.rmtree(out_savepath)
|
teammate_generation/__init__.py
ADDED
|
File without changes
|
teammate_generation/configs/algorithm/brdiv/_base_.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package algorithm
|
| 2 |
+
# ^ tells hydra to place these value directly under algorithm key
|
| 3 |
+
ALG: brdiv
|
| 4 |
+
TOTAL_TIMESTEPS: 4.5e7 # divided among each pair of BR and Conf agents
|
| 5 |
+
NUM_CHECKPOINTS: 5
|
| 6 |
+
PARTNER_POP_SIZE: 4
|
| 7 |
+
NUM_ENVS: 64
|
| 8 |
+
# SP weight = 1 + 2*XP weight.
|
| 9 |
+
# Thus, as XP weight -> 0, SP/(SP+XP) -> 1.
|
| 10 |
+
# If XP weight -> infinity, XP/(SP+XP) -> 1/3, and SP/(SP+XP) -> 2/3.
|
| 11 |
+
XP_LOSS_WEIGHTS: 1
|
| 12 |
+
LR: 1e-4
|
| 13 |
+
UPDATE_EPOCHS: 15
|
| 14 |
+
NUM_MINIBATCHES: 4
|
| 15 |
+
GAMMA: 0.99
|
| 16 |
+
GAE_LAMBDA: 0.95
|
| 17 |
+
CLIP_EPS: 0.05
|
| 18 |
+
ENT_COEF: 0.01
|
| 19 |
+
VF_COEF: 0.5
|
| 20 |
+
MAX_GRAD_NORM: 1.0
|
| 21 |
+
ANNEAL_LR: false
|
| 22 |
+
ego_train_algorithm:
|
| 23 |
+
EGO_ACTOR_TYPE: s5
|
| 24 |
+
S5_D_MODEL: 16
|
| 25 |
+
S5_SSM_SIZE: 16
|
| 26 |
+
S5_ACTOR_CRITIC_HIDDEN_DIM: 64
|
| 27 |
+
FC_N_LAYERS: 2
|
| 28 |
+
TOTAL_TIMESTEPS: 1e7
|
| 29 |
+
NUM_CHECKPOINTS: 5
|
| 30 |
+
NUM_ENVS: 8
|
| 31 |
+
LR: 1e-4
|
| 32 |
+
UPDATE_EPOCHS: 15
|
| 33 |
+
NUM_MINIBATCHES: 4
|
| 34 |
+
GAMMA: 0.99
|
| 35 |
+
GAE_LAMBDA: 0.95
|
| 36 |
+
CLIP_EPS: 0.05
|
| 37 |
+
ENT_COEF: 0.01
|
| 38 |
+
VF_COEF: 0.5
|
| 39 |
+
MAX_GRAD_NORM: 1.0
|
| 40 |
+
ANNEAL_LR: true
|
teammate_generation/configs/algorithm/brdiv/hanabi.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- brdiv/_base_
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 5e8
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 128
|
| 8 |
+
XP_LOSS_WEIGHTS: 0.05
|
| 9 |
+
LR: 5e-4
|
| 10 |
+
UPDATE_EPOCHS: 4
|
| 11 |
+
NUM_MINIBATCHES: 4
|
| 12 |
+
CLIP_EPS: 0.2
|
| 13 |
+
ENT_COEF: 0.01
|
| 14 |
+
ANNEAL_LR: true
|
| 15 |
+
GAMMA: 0.999
|
| 16 |
+
GAE_LAMBDA: 0.95
|
| 17 |
+
MAX_GRAD_NORM: 0.5
|
| 18 |
+
ego_train_algorithm:
|
| 19 |
+
TOTAL_TIMESTEPS: 1e8
|
| 20 |
+
LR: 5e-4
|
| 21 |
+
ENT_COEF: 0.01
|
| 22 |
+
CLIP_EPS: 0.2
|
| 23 |
+
ANNEAL_LR: true
|
| 24 |
+
UPDATE_EPOCHS: 4
|
| 25 |
+
GAMMA: 0.999
|
| 26 |
+
GAE_LAMBDA: 0.95
|
| 27 |
+
MAX_GRAD_NORM: 0.5
|
teammate_generation/configs/algorithm/brdiv/lbf/lbf_12x12.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- brdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 4.5e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 64
|
| 8 |
+
XP_LOSS_WEIGHTS: 0.05 # 0.1
|
| 9 |
+
LR: 5e-4
|
| 10 |
+
UPDATE_EPOCHS: 15
|
| 11 |
+
NUM_MINIBATCHES: 2 # 4
|
| 12 |
+
CLIP_EPS: 0.05
|
| 13 |
+
ENT_COEF: 0.01
|
| 14 |
+
ego_train_algorithm:
|
| 15 |
+
TOTAL_TIMESTEPS: 3e7
|
| 16 |
+
LR: 1e-4
|
| 17 |
+
ENT_COEF: 0.01
|
| 18 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/brdiv/lbf/lbf_7x7_nolevels.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- brdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 4.5e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 64
|
| 8 |
+
XP_LOSS_WEIGHTS: 0.05 # 0.1
|
| 9 |
+
LR: 5e-4
|
| 10 |
+
UPDATE_EPOCHS: 15
|
| 11 |
+
NUM_MINIBATCHES: 2 # 4
|
| 12 |
+
CLIP_EPS: 0.05
|
| 13 |
+
ENT_COEF: 0.01
|
| 14 |
+
ego_train_algorithm:
|
| 15 |
+
TOTAL_TIMESTEPS: 3e7
|
| 16 |
+
LR: 1e-4
|
| 17 |
+
ENT_COEF: 0.01
|
| 18 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/brdiv/mini-hanabi.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- brdiv/_base_
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
# Mini-Hanabi (3c/3r/hand3) BRDiv config.
|
| 6 |
+
TOTAL_TIMESTEPS: 1e8
|
| 7 |
+
PARTNER_POP_SIZE: 3
|
| 8 |
+
NUM_ENVS: 128
|
| 9 |
+
XP_LOSS_WEIGHTS: 0.05
|
| 10 |
+
LR: 5e-4
|
| 11 |
+
UPDATE_EPOCHS: 4
|
| 12 |
+
NUM_MINIBATCHES: 4
|
| 13 |
+
CLIP_EPS: 0.2
|
| 14 |
+
ENT_COEF: 0.01
|
| 15 |
+
ANNEAL_LR: true
|
| 16 |
+
GAMMA: 0.999
|
| 17 |
+
GAE_LAMBDA: 0.95
|
| 18 |
+
MAX_GRAD_NORM: 0.5
|
| 19 |
+
ego_train_algorithm:
|
| 20 |
+
TOTAL_TIMESTEPS: 1e8
|
| 21 |
+
LR: 5e-4
|
| 22 |
+
ENT_COEF: 0.01
|
| 23 |
+
CLIP_EPS: 0.2
|
| 24 |
+
ANNEAL_LR: true
|
| 25 |
+
UPDATE_EPOCHS: 4
|
| 26 |
+
GAMMA: 0.999
|
| 27 |
+
GAE_LAMBDA: 0.95
|
| 28 |
+
MAX_GRAD_NORM: 0.5
|
teammate_generation/configs/algorithm/brdiv/overcooked-v1/asymm_advantages.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- brdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 4.5e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 64
|
| 8 |
+
XP_LOSS_WEIGHTS: 1
|
| 9 |
+
LR: .0001
|
| 10 |
+
UPDATE_EPOCHS: 15
|
| 11 |
+
NUM_MINIBATCHES: 16
|
| 12 |
+
CLIP_EPS: 0.3
|
| 13 |
+
ENT_COEF: 0.01
|
| 14 |
+
ego_train_algorithm:
|
| 15 |
+
TOTAL_TIMESTEPS: 3e7
|
| 16 |
+
LR: 1e-4
|
| 17 |
+
ENT_COEF: 0.01
|
| 18 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/brdiv/overcooked-v1/coord_ring.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- brdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 9e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 128
|
| 8 |
+
XP_LOSS_WEIGHTS: 0.007
|
| 9 |
+
LR: 5e-4
|
| 10 |
+
UPDATE_EPOCHS: 15
|
| 11 |
+
NUM_MINIBATCHES: 4
|
| 12 |
+
CLIP_EPS: 0.1
|
| 13 |
+
ENT_COEF: 0.05
|
| 14 |
+
ego_train_algorithm:
|
| 15 |
+
TOTAL_TIMESTEPS: 6e7
|
| 16 |
+
LR: 1e-3
|
| 17 |
+
ENT_COEF: 0.01
|
| 18 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/brdiv/overcooked-v1/counter_circuit.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- brdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 9e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 128
|
| 8 |
+
XP_LOSS_WEIGHTS: 0.005
|
| 9 |
+
LR: 1e-3
|
| 10 |
+
UPDATE_EPOCHS: 15
|
| 11 |
+
NUM_MINIBATCHES: 8
|
| 12 |
+
CLIP_EPS: 0.01
|
| 13 |
+
ENT_COEF: 0.05
|
| 14 |
+
ego_train_algorithm:
|
| 15 |
+
TOTAL_TIMESTEPS: 6e7
|
| 16 |
+
LR: 1e-3
|
| 17 |
+
ENT_COEF: 0.01
|
| 18 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/brdiv/overcooked-v1/cramped_room.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- brdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 4.5e7
|
| 6 |
+
PARTNER_POP_SIZE: 4
|
| 7 |
+
NUM_ENVS: 64
|
| 8 |
+
# SP weight = 1 + 2*XP weight.
|
| 9 |
+
# Thus, as XP weight -> 0, SP/(SP+XP) -> 1.
|
| 10 |
+
# If XP weight -> infinity, XP/(SP+XP) -> 1/3, and SP/(SP+XP) -> 2/3.
|
| 11 |
+
XP_LOSS_WEIGHTS: 0.5 # 10
|
| 12 |
+
LR: 1e-4
|
| 13 |
+
UPDATE_EPOCHS: 15
|
| 14 |
+
NUM_MINIBATCHES: 16
|
| 15 |
+
CLIP_EPS: 0.05
|
| 16 |
+
ENT_COEF: 0.01
|
| 17 |
+
ego_train_algorithm:
|
| 18 |
+
TOTAL_TIMESTEPS: 3e7
|
| 19 |
+
LR: 1e-4
|
| 20 |
+
ENT_COEF: 0.01
|
| 21 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/brdiv/overcooked-v1/forced_coord.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- brdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 9e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 128
|
| 8 |
+
XP_LOSS_WEIGHTS: 0.01
|
| 9 |
+
LR: 5e-4
|
| 10 |
+
UPDATE_EPOCHS: 15
|
| 11 |
+
NUM_MINIBATCHES: 16
|
| 12 |
+
CLIP_EPS: 0.05
|
| 13 |
+
ENT_COEF: 0.01
|
| 14 |
+
ego_train_algorithm:
|
| 15 |
+
TOTAL_TIMESTEPS: 6e7
|
| 16 |
+
LR: 1e-3
|
| 17 |
+
ENT_COEF: 0.01
|
| 18 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/comedi/_base_.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package algorithm
|
| 2 |
+
# ^ tells hydra to place these value directly under algorithm key
|
| 3 |
+
ALG: comedi
|
| 4 |
+
TOTAL_TIMESTEPS_PER_ITERATION: 1.2e7 # number of steps used to train each comedi agent at each iteration
|
| 5 |
+
NUM_CHECKPOINTS: 5
|
| 6 |
+
PARTNER_POP_SIZE: 4
|
| 7 |
+
NUM_ENVS: 48
|
| 8 |
+
LR: 1e-4
|
| 9 |
+
UPDATE_EPOCHS: 15
|
| 10 |
+
NUM_MINIBATCHES: 8
|
| 11 |
+
GAMMA: 0.99
|
| 12 |
+
GAE_LAMBDA: 0.95
|
| 13 |
+
CLIP_EPS: 0.05
|
| 14 |
+
ENT_COEF: 0.01
|
| 15 |
+
VF_COEF: 0.5
|
| 16 |
+
MAX_GRAD_NORM: 1.0
|
| 17 |
+
ANNEAL_LR: false
|
| 18 |
+
ACTOR_TYPE: actor_with_conditional_critic
|
| 19 |
+
NUM_ARGMAX_ROLLOUT_EPS: 20
|
| 20 |
+
COMEDI_ALPHA: 1.0
|
| 21 |
+
COMEDI_BETA: 0.5
|
| 22 |
+
ego_train_algorithm:
|
| 23 |
+
EGO_ACTOR_TYPE: s5
|
| 24 |
+
TOTAL_TIMESTEPS: 1e7
|
| 25 |
+
NUM_CHECKPOINTS: 5
|
| 26 |
+
NUM_ENVS: 8
|
| 27 |
+
LR: 1e-4
|
| 28 |
+
UPDATE_EPOCHS: 15
|
| 29 |
+
NUM_MINIBATCHES: 4
|
| 30 |
+
GAMMA: 0.99
|
| 31 |
+
GAE_LAMBDA: 0.95
|
| 32 |
+
CLIP_EPS: 0.05
|
| 33 |
+
ENT_COEF: 0.01
|
| 34 |
+
VF_COEF: 0.5
|
| 35 |
+
MAX_GRAD_NORM: 1.0
|
| 36 |
+
ANNEAL_LR: true
|
teammate_generation/configs/algorithm/comedi/hanabi.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- comedi/_base_
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS_PER_ITERATION: 2e7
|
| 6 |
+
PARTNER_POP_SIZE: 5
|
| 7 |
+
LR: 5e-4
|
| 8 |
+
UPDATE_EPOCHS: 4
|
| 9 |
+
CLIP_EPS: 0.2
|
| 10 |
+
ENT_COEF: 0.01
|
| 11 |
+
ANNEAL_LR: true
|
| 12 |
+
GAMMA: 0.999
|
| 13 |
+
GAE_LAMBDA: 0.95
|
| 14 |
+
MAX_GRAD_NORM: 0.5
|
| 15 |
+
COMEDI_ALPHA: 0.2
|
| 16 |
+
COMEDI_BETA: 0.4
|
| 17 |
+
ego_train_algorithm:
|
| 18 |
+
TOTAL_TIMESTEPS: 1e8
|
| 19 |
+
LR: 5e-4
|
| 20 |
+
ENT_COEF: 0.01
|
| 21 |
+
CLIP_EPS: 0.2
|
| 22 |
+
ANNEAL_LR: true
|
| 23 |
+
UPDATE_EPOCHS: 4
|
| 24 |
+
GAMMA: 0.999
|
| 25 |
+
GAE_LAMBDA: 0.95
|
| 26 |
+
MAX_GRAD_NORM: 0.5
|
teammate_generation/configs/algorithm/comedi/lbf/lbf_12x12.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- comedi/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS_PER_ITERATION: 6e6
|
| 6 |
+
PARTNER_POP_SIZE: 10
|
| 7 |
+
LR: 5e-4
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
CLIP_EPS: 0.05
|
| 10 |
+
ENT_COEF: 0.001
|
| 11 |
+
COMEDI_ALPHA: 0.2 # weight on XP return
|
| 12 |
+
COMEDI_BETA: 0.4 # weight on SXP return
|
| 13 |
+
ego_train_algorithm:
|
| 14 |
+
TOTAL_TIMESTEPS: 3e7
|
| 15 |
+
LR: 5e-5
|
| 16 |
+
ENT_COEF: 1e-4
|
| 17 |
+
CLIP_EPS: 0.1
|
| 18 |
+
ANNEAL_LR: false
|
teammate_generation/configs/algorithm/comedi/lbf/lbf_7x7_nolevels.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- comedi/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS_PER_ITERATION: 6e6
|
| 6 |
+
PARTNER_POP_SIZE: 10
|
| 7 |
+
LR: 5e-4
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
CLIP_EPS: 0.05
|
| 10 |
+
ENT_COEF: 0.001
|
| 11 |
+
COMEDI_ALPHA: 0.2 # weight on XP return
|
| 12 |
+
COMEDI_BETA: 0.4 # weight on SXP return
|
| 13 |
+
ego_train_algorithm:
|
| 14 |
+
TOTAL_TIMESTEPS: 3e7
|
| 15 |
+
LR: 5e-5
|
| 16 |
+
ENT_COEF: 1e-4
|
| 17 |
+
CLIP_EPS: 0.1
|
| 18 |
+
ANNEAL_LR: false
|
teammate_generation/configs/algorithm/comedi/mini-hanabi.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- comedi/_base_
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
# Mini-Hanabi (3c/3r/hand3) CoMeDi config.
|
| 6 |
+
TOTAL_TIMESTEPS_PER_ITERATION: 2e6
|
| 7 |
+
PARTNER_POP_SIZE: 5
|
| 8 |
+
LR: 5e-4
|
| 9 |
+
UPDATE_EPOCHS: 4
|
| 10 |
+
CLIP_EPS: 0.2
|
| 11 |
+
ENT_COEF: 0.01
|
| 12 |
+
ANNEAL_LR: true
|
| 13 |
+
GAMMA: 0.999
|
| 14 |
+
GAE_LAMBDA: 0.95
|
| 15 |
+
MAX_GRAD_NORM: 0.5
|
| 16 |
+
COMEDI_ALPHA: 0.2
|
| 17 |
+
COMEDI_BETA: 0.4
|
| 18 |
+
ego_train_algorithm:
|
| 19 |
+
TOTAL_TIMESTEPS: 1e8
|
| 20 |
+
LR: 5e-4
|
| 21 |
+
ENT_COEF: 0.01
|
| 22 |
+
CLIP_EPS: 0.2
|
| 23 |
+
ANNEAL_LR: true
|
| 24 |
+
UPDATE_EPOCHS: 4
|
| 25 |
+
GAMMA: 0.999
|
| 26 |
+
GAE_LAMBDA: 0.95
|
| 27 |
+
MAX_GRAD_NORM: 0.5
|
teammate_generation/configs/algorithm/comedi/overcooked-v1/asymm_advantages.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- comedi/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 6e6
|
| 6 |
+
PARTNER_POP_SIZE: 10
|
| 7 |
+
LR: .0001
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
CLIP_EPS: 0.3
|
| 10 |
+
ENT_COEF: 0.01
|
| 11 |
+
ego_train_algorithm:
|
| 12 |
+
TOTAL_TIMESTEPS: 3e7
|
| 13 |
+
LR: 5e-5
|
| 14 |
+
ENT_COEF: .001
|
| 15 |
+
CLIP_EPS: 0.1
|
| 16 |
+
UPDATE_EPOCHS: 10
|
teammate_generation/configs/algorithm/comedi/overcooked-v1/coord_ring.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- comedi/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 1e7
|
| 6 |
+
PARTNER_POP_SIZE: 10
|
| 7 |
+
LR: 5e-4
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
CLIP_EPS: 0.1
|
| 10 |
+
ENT_COEF: 0.05
|
| 11 |
+
ego_train_algorithm:
|
| 12 |
+
TOTAL_TIMESTEPS: 6e7
|
| 13 |
+
LR: 3e-5
|
| 14 |
+
ENT_COEF: .001
|
| 15 |
+
CLIP_EPS: 0.1
|
| 16 |
+
UPDATE_EPOCHS: 10
|
teammate_generation/configs/algorithm/comedi/overcooked-v1/counter_circuit.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- comedi/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 1e7
|
| 6 |
+
PARTNER_POP_SIZE: 10
|
| 7 |
+
LR: 1e-3
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
CLIP_EPS: 0.01 # 0.1
|
| 10 |
+
ENT_COEF: 0.05
|
| 11 |
+
ego_train_algorithm:
|
| 12 |
+
TOTAL_TIMESTEPS: 6e7
|
| 13 |
+
LR: 5e-5
|
| 14 |
+
ENT_COEF: .001
|
| 15 |
+
CLIP_EPS: 0.1
|
| 16 |
+
UPDATE_EPOCHS: 10
|
teammate_generation/configs/algorithm/comedi/overcooked-v1/cramped_room.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- comedi/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 6e6
|
| 6 |
+
PARTNER_POP_SIZE: 10
|
| 7 |
+
LR: 1e-4
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
CLIP_EPS: 0.05
|
| 10 |
+
ENT_COEF: 0.01
|
| 11 |
+
ego_train_algorithm:
|
| 12 |
+
TOTAL_TIMESTEPS: 3e7
|
| 13 |
+
LR: 5e-5
|
| 14 |
+
ANNEAL_LR: false
|
| 15 |
+
ENT_COEF: .001
|
| 16 |
+
CLIP_EPS: 0.1
|
| 17 |
+
UPDATE_EPOCHS: 10
|
teammate_generation/configs/algorithm/comedi/overcooked-v1/forced_coord.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- comedi/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 1e7
|
| 6 |
+
PARTNER_POP_SIZE: 10
|
| 7 |
+
LR: 5e-4
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
CLIP_EPS: 0.05
|
| 10 |
+
ENT_COEF: 0.01
|
| 11 |
+
ego_train_algorithm:
|
| 12 |
+
TOTAL_TIMESTEPS: 6e7
|
| 13 |
+
LR: 1e-5
|
| 14 |
+
ENT_COEF: 1e-4
|
| 15 |
+
CLIP_EPS: 0.1
|
| 16 |
+
UPDATE_EPOCHS: 5
|
teammate_generation/configs/algorithm/fcp/_base_.yaml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package algorithm
|
| 2 |
+
# ^ tells hydra to place these value directly under algorithm key
|
| 3 |
+
ALG: fcp
|
| 4 |
+
ACTOR_TYPE: mlp
|
| 5 |
+
TOTAL_TIMESTEPS: 1e6 # per PARTNER_POP_SIZE trained
|
| 6 |
+
NUM_CHECKPOINTS: 5
|
| 7 |
+
PARTNER_POP_SIZE: 20 # true partner pop size is PARTNER_POP_SIZE * NUM_CHECKPOINTS
|
| 8 |
+
NUM_ENVS: 8
|
| 9 |
+
LR: 1e-4
|
| 10 |
+
UPDATE_EPOCHS: 15
|
| 11 |
+
NUM_MINIBATCHES: 4
|
| 12 |
+
GAMMA: 0.99
|
| 13 |
+
GAE_LAMBDA: 0.95
|
| 14 |
+
CLIP_EPS: 0.05
|
| 15 |
+
ENT_COEF: 0.01
|
| 16 |
+
VF_COEF: 0.5
|
| 17 |
+
MAX_GRAD_NORM: 1.0
|
| 18 |
+
ANNEAL_LR: true
|
| 19 |
+
ego_train_algorithm:
|
| 20 |
+
EGO_ACTOR_TYPE: s5
|
| 21 |
+
S5_D_MODEL: 16
|
| 22 |
+
S5_SSM_SIZE: 16
|
| 23 |
+
S5_ACTOR_CRITIC_HIDDEN_DIM: 64
|
| 24 |
+
FC_N_LAYERS: 2
|
| 25 |
+
TOTAL_TIMESTEPS: 1e7
|
| 26 |
+
NUM_CHECKPOINTS: 5
|
| 27 |
+
NUM_ENVS: 8
|
| 28 |
+
LR: 1e-4
|
| 29 |
+
UPDATE_EPOCHS: 15
|
| 30 |
+
NUM_MINIBATCHES: 4
|
| 31 |
+
GAMMA: 0.99
|
| 32 |
+
GAE_LAMBDA: 0.95
|
| 33 |
+
CLIP_EPS: 0.05
|
| 34 |
+
ENT_COEF: 0.01
|
| 35 |
+
VF_COEF: 0.5
|
| 36 |
+
MAX_GRAD_NORM: 1.0
|
| 37 |
+
ANNEAL_LR: true
|
teammate_generation/configs/algorithm/fcp/hanabi.yaml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- fcp/_base_
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
# Full 2-player Hanabi FCP config. Trains IPPO partners then ego.
|
| 6 |
+
# Hyperparameters aligned with JaxMARL Hanabi consensus.
|
| 7 |
+
#
|
| 8 |
+
# PARTNER_POP_SIZE=3 (not 10): FCP vmaps across pop size, so 10
|
| 9 |
+
# parallel IPPO instances with 658-dim obs OOMs on 48GB. 3 partners
|
| 10 |
+
# x 5 checkpoints = 15 total partners, enough for diversity.
|
| 11 |
+
TOTAL_TIMESTEPS: 1e9
|
| 12 |
+
PARTNER_POP_SIZE: 3
|
| 13 |
+
LR: 5e-4
|
| 14 |
+
NUM_ENVS: 32
|
| 15 |
+
UPDATE_EPOCHS: 4
|
| 16 |
+
NUM_MINIBATCHES: 4
|
| 17 |
+
CLIP_EPS: 0.2
|
| 18 |
+
ENT_COEF: 0.01
|
| 19 |
+
ANNEAL_LR: true
|
| 20 |
+
GAMMA: 0.999
|
| 21 |
+
GAE_LAMBDA: 0.95
|
| 22 |
+
MAX_GRAD_NORM: 0.5
|
| 23 |
+
ego_train_algorithm:
|
| 24 |
+
TOTAL_TIMESTEPS: 1e9
|
| 25 |
+
LR: 5e-4
|
| 26 |
+
ENT_COEF: 0.01
|
| 27 |
+
CLIP_EPS: 0.2
|
| 28 |
+
ANNEAL_LR: true
|
| 29 |
+
UPDATE_EPOCHS: 4
|
| 30 |
+
GAMMA: 0.999
|
| 31 |
+
GAE_LAMBDA: 0.95
|
| 32 |
+
MAX_GRAD_NORM: 0.5
|
teammate_generation/configs/algorithm/fcp/lbf/lbf_12x12.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- fcp/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 1e6
|
| 6 |
+
LR: .0001
|
| 7 |
+
NUM_ENVS: 8
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
NUM_MINIBATCHES: 4
|
| 10 |
+
CLIP_EPS: 0.03
|
| 11 |
+
ENT_COEF: 0.01
|
| 12 |
+
ego_train_algorithm:
|
| 13 |
+
TOTAL_TIMESTEPS: 3e7
|
| 14 |
+
LR: 1e-4
|
| 15 |
+
ENT_COEF: 0.01
|
| 16 |
+
CLIP_EPS: 0.05
|
| 17 |
+
|
teammate_generation/configs/algorithm/fcp/lbf/lbf_7x7_nolevels.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- fcp/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 1e6
|
| 6 |
+
LR: .0001
|
| 7 |
+
NUM_ENVS: 8
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
NUM_MINIBATCHES: 4
|
| 10 |
+
CLIP_EPS: 0.03
|
| 11 |
+
ENT_COEF: 0.01
|
| 12 |
+
ego_train_algorithm:
|
| 13 |
+
TOTAL_TIMESTEPS: 3e7
|
| 14 |
+
LR: 1e-4
|
| 15 |
+
ENT_COEF: 0.01
|
| 16 |
+
CLIP_EPS: 0.05
|
| 17 |
+
|
teammate_generation/configs/algorithm/fcp/mini-hanabi.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- fcp/_base_
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
# Mini-Hanabi (3c/3r/hand3) FCP config.
|
| 6 |
+
TOTAL_TIMESTEPS: 1e8
|
| 7 |
+
LR: 5e-4
|
| 8 |
+
NUM_ENVS: 128
|
| 9 |
+
UPDATE_EPOCHS: 4
|
| 10 |
+
NUM_MINIBATCHES: 4
|
| 11 |
+
CLIP_EPS: 0.2
|
| 12 |
+
ENT_COEF: 0.01
|
| 13 |
+
ANNEAL_LR: true
|
| 14 |
+
GAMMA: 0.999
|
| 15 |
+
GAE_LAMBDA: 0.95
|
| 16 |
+
MAX_GRAD_NORM: 0.5
|
| 17 |
+
ego_train_algorithm:
|
| 18 |
+
TOTAL_TIMESTEPS: 1e8
|
| 19 |
+
LR: 5e-4
|
| 20 |
+
ENT_COEF: 0.01
|
| 21 |
+
CLIP_EPS: 0.2
|
| 22 |
+
ANNEAL_LR: true
|
| 23 |
+
UPDATE_EPOCHS: 4
|
| 24 |
+
GAMMA: 0.999
|
| 25 |
+
GAE_LAMBDA: 0.95
|
| 26 |
+
MAX_GRAD_NORM: 0.5
|
teammate_generation/configs/algorithm/fcp/overcooked-v1/asymm_advantages.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- fcp/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 2e6
|
| 6 |
+
LR: .0001
|
| 7 |
+
NUM_ENVS: 8
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
NUM_MINIBATCHES: 16
|
| 10 |
+
CLIP_EPS: 0.3
|
| 11 |
+
ENT_COEF: 0.01
|
| 12 |
+
ego_train_algorithm:
|
| 13 |
+
TOTAL_TIMESTEPS: 3e7
|
| 14 |
+
LR: 1e-4
|
| 15 |
+
ENT_COEF: 0.01
|
| 16 |
+
CLIP_EPS: 0.05
|
| 17 |
+
|
teammate_generation/configs/algorithm/fcp/overcooked-v1/coord_ring.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- fcp/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 4e6
|
| 6 |
+
LR: 1e-3
|
| 7 |
+
NUM_ENVS: 8
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
NUM_MINIBATCHES: 16
|
| 10 |
+
CLIP_EPS: 0.1
|
| 11 |
+
ENT_COEF: 0.05
|
| 12 |
+
ego_train_algorithm:
|
| 13 |
+
TOTAL_TIMESTEPS: 6e7
|
| 14 |
+
LR: 1e-3
|
| 15 |
+
ENT_COEF: 0.01
|
| 16 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/fcp/overcooked-v1/counter_circuit.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- fcp/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 4e6
|
| 6 |
+
LR: 1e-3
|
| 7 |
+
NUM_ENVS: 8
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
NUM_MINIBATCHES: 16
|
| 10 |
+
CLIP_EPS: 0.1
|
| 11 |
+
ENT_COEF: 0.05
|
| 12 |
+
ego_train_algorithm:
|
| 13 |
+
TOTAL_TIMESTEPS: 6e7
|
| 14 |
+
LR: 1e-3
|
| 15 |
+
ENT_COEF: 0.01
|
| 16 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/fcp/overcooked-v1/cramped_room.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- fcp/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 2e6
|
| 6 |
+
LR: .0001
|
| 7 |
+
NUM_ENVS: 8
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
NUM_MINIBATCHES: 16
|
| 10 |
+
CLIP_EPS: 0.2
|
| 11 |
+
ENT_COEF: 0.01
|
| 12 |
+
ego_train_algorithm:
|
| 13 |
+
TOTAL_TIMESTEPS: 3e7
|
| 14 |
+
LR: 1e-4
|
| 15 |
+
ENT_COEF: 0.01
|
| 16 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/fcp/overcooked-v1/forced_coord.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- fcp/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 4e6
|
| 6 |
+
LR: 1e-3
|
| 7 |
+
NUM_ENVS: 8
|
| 8 |
+
UPDATE_EPOCHS: 15
|
| 9 |
+
NUM_MINIBATCHES: 16
|
| 10 |
+
CLIP_EPS: 0.1
|
| 11 |
+
ENT_COEF: 0.05
|
| 12 |
+
ego_train_algorithm:
|
| 13 |
+
TOTAL_TIMESTEPS: 6e7
|
| 14 |
+
LR: 1e-3
|
| 15 |
+
ENT_COEF: 0.01
|
| 16 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/lbrdiv/_base_.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package algorithm
|
| 2 |
+
# ^ tells hydra to place these value directly under algorithm key
|
| 3 |
+
ALG: lbrdiv
|
| 4 |
+
TOTAL_TIMESTEPS: 4.5e7 # divided among each pair of BR and Conf agents
|
| 5 |
+
NUM_CHECKPOINTS: 5
|
| 6 |
+
PARTNER_POP_SIZE: 4
|
| 7 |
+
NUM_ENVS: 64
|
| 8 |
+
TOLERANCE_FACTOR: 0.1 # require that SP - XP > TOLERANCE_FACTOR
|
| 9 |
+
LAGRANGE_LR: 0.01 # specific to L-BRDiv
|
| 10 |
+
LR: 1e-4
|
| 11 |
+
UPDATE_EPOCHS: 15
|
| 12 |
+
NUM_MINIBATCHES: 4
|
| 13 |
+
GAMMA: 0.99
|
| 14 |
+
GAE_LAMBDA: 0.95
|
| 15 |
+
CLIP_EPS: 0.05
|
| 16 |
+
ENT_COEF: 0.01
|
| 17 |
+
VF_COEF: 0.5
|
| 18 |
+
MAX_GRAD_NORM: 1.0
|
| 19 |
+
ANNEAL_LR: false
|
| 20 |
+
ego_train_algorithm:
|
| 21 |
+
EGO_ACTOR_TYPE: s5
|
| 22 |
+
S5_D_MODEL: 16
|
| 23 |
+
S5_SSM_SIZE: 16
|
| 24 |
+
S5_ACTOR_CRITIC_HIDDEN_DIM: 64
|
| 25 |
+
FC_N_LAYERS: 2
|
| 26 |
+
TOTAL_TIMESTEPS: 1e7
|
| 27 |
+
NUM_CHECKPOINTS: 5
|
| 28 |
+
NUM_ENVS: 8
|
| 29 |
+
LR: 1e-4
|
| 30 |
+
UPDATE_EPOCHS: 15
|
| 31 |
+
NUM_MINIBATCHES: 4
|
| 32 |
+
GAMMA: 0.99
|
| 33 |
+
GAE_LAMBDA: 0.95
|
| 34 |
+
CLIP_EPS: 0.05
|
| 35 |
+
ENT_COEF: 0.01
|
| 36 |
+
VF_COEF: 0.5
|
| 37 |
+
MAX_GRAD_NORM: 1.0
|
| 38 |
+
ANNEAL_LR: true
|
teammate_generation/configs/algorithm/lbrdiv/hanabi.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- lbrdiv/_base_
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 5e8
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 128
|
| 8 |
+
LR: 5e-4
|
| 9 |
+
UPDATE_EPOCHS: 4
|
| 10 |
+
NUM_MINIBATCHES: 4
|
| 11 |
+
CLIP_EPS: 0.2
|
| 12 |
+
ENT_COEF: 0.01
|
| 13 |
+
ANNEAL_LR: true
|
| 14 |
+
GAMMA: 0.999
|
| 15 |
+
GAE_LAMBDA: 0.95
|
| 16 |
+
MAX_GRAD_NORM: 0.5
|
| 17 |
+
ego_train_algorithm:
|
| 18 |
+
TOTAL_TIMESTEPS: 1e8
|
| 19 |
+
LR: 5e-4
|
| 20 |
+
ENT_COEF: 0.01
|
| 21 |
+
CLIP_EPS: 0.2
|
| 22 |
+
ANNEAL_LR: true
|
| 23 |
+
UPDATE_EPOCHS: 4
|
| 24 |
+
GAMMA: 0.999
|
| 25 |
+
GAE_LAMBDA: 0.95
|
| 26 |
+
MAX_GRAD_NORM: 0.5
|
teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_12x12.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- lbrdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 4.5e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 64
|
| 8 |
+
LR: 5e-4
|
| 9 |
+
UPDATE_EPOCHS: 15
|
| 10 |
+
NUM_MINIBATCHES: 4
|
| 11 |
+
CLIP_EPS: 0.05
|
| 12 |
+
ENT_COEF: 0.01
|
| 13 |
+
ego_train_algorithm:
|
| 14 |
+
TOTAL_TIMESTEPS: 3e7
|
| 15 |
+
LR: 1e-4
|
| 16 |
+
ENT_COEF: 0.01
|
| 17 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_7x7_nolevels.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- lbrdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 4.5e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 64
|
| 8 |
+
LR: 5e-4
|
| 9 |
+
UPDATE_EPOCHS: 15
|
| 10 |
+
NUM_MINIBATCHES: 4
|
| 11 |
+
CLIP_EPS: 0.05
|
| 12 |
+
ENT_COEF: 0.01
|
| 13 |
+
ego_train_algorithm:
|
| 14 |
+
TOTAL_TIMESTEPS: 3e7
|
| 15 |
+
LR: 1e-4
|
| 16 |
+
ENT_COEF: 0.01
|
| 17 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/lbrdiv/mini-hanabi.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- lbrdiv/_base_
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
# Mini-Hanabi (3c/3r/hand3) LBRDiv config.
|
| 6 |
+
TOTAL_TIMESTEPS: 1e8
|
| 7 |
+
PARTNER_POP_SIZE: 3
|
| 8 |
+
NUM_ENVS: 128
|
| 9 |
+
LR: 5e-4
|
| 10 |
+
UPDATE_EPOCHS: 4
|
| 11 |
+
NUM_MINIBATCHES: 4
|
| 12 |
+
CLIP_EPS: 0.2
|
| 13 |
+
ENT_COEF: 0.01
|
| 14 |
+
ANNEAL_LR: true
|
| 15 |
+
GAMMA: 0.999
|
| 16 |
+
GAE_LAMBDA: 0.95
|
| 17 |
+
MAX_GRAD_NORM: 0.5
|
| 18 |
+
ego_train_algorithm:
|
| 19 |
+
TOTAL_TIMESTEPS: 1e8
|
| 20 |
+
LR: 5e-4
|
| 21 |
+
ENT_COEF: 0.01
|
| 22 |
+
CLIP_EPS: 0.2
|
| 23 |
+
ANNEAL_LR: true
|
| 24 |
+
UPDATE_EPOCHS: 4
|
| 25 |
+
GAMMA: 0.999
|
| 26 |
+
GAE_LAMBDA: 0.95
|
| 27 |
+
MAX_GRAD_NORM: 0.5
|
teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/asymm_advantages.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- lbrdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 4.5e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 64
|
| 8 |
+
TOLERANCE_FACTOR: 10.0 # require that SP - XP > TOLERANCE_FACTOR
|
| 9 |
+
LR: .0001
|
| 10 |
+
UPDATE_EPOCHS: 15
|
| 11 |
+
NUM_MINIBATCHES: 16
|
| 12 |
+
CLIP_EPS: 0.3
|
| 13 |
+
ENT_COEF: 0.01
|
| 14 |
+
ego_train_algorithm:
|
| 15 |
+
TOTAL_TIMESTEPS: 3e7
|
| 16 |
+
LR: 1e-4
|
| 17 |
+
ENT_COEF: 0.01
|
| 18 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/coord_ring.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- lbrdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 9e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 128
|
| 8 |
+
TOLERANCE_FACTOR: 10.0 # require that SP - XP > TOLERANCE_FACTOR
|
| 9 |
+
LR: 5e-4
|
| 10 |
+
UPDATE_EPOCHS: 15
|
| 11 |
+
NUM_MINIBATCHES: 4
|
| 12 |
+
CLIP_EPS: 0.1
|
| 13 |
+
ENT_COEF: 0.05
|
| 14 |
+
ego_train_algorithm:
|
| 15 |
+
TOTAL_TIMESTEPS: 6e7
|
| 16 |
+
LR: 1e-3
|
| 17 |
+
ENT_COEF: 0.01
|
| 18 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/counter_circuit.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- lbrdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 9e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 128
|
| 8 |
+
TOLERANCE_FACTOR: 10.0 # require that SP - XP > TOLERANCE_FACTOR
|
| 9 |
+
LR: 1e-3
|
| 10 |
+
UPDATE_EPOCHS: 15
|
| 11 |
+
NUM_MINIBATCHES: 8
|
| 12 |
+
CLIP_EPS: 0.01
|
| 13 |
+
ENT_COEF: 0.05
|
| 14 |
+
ego_train_algorithm:
|
| 15 |
+
TOTAL_TIMESTEPS: 6e7
|
| 16 |
+
LR: 1e-3
|
| 17 |
+
ENT_COEF: 0.01
|
| 18 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/cramped_room.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- lbrdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 4.5e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 64
|
| 8 |
+
TOLERANCE_FACTOR: 10.0 # require that SP - XP > TOLERANCE_FACTOR
|
| 9 |
+
LR: 1e-4
|
| 10 |
+
UPDATE_EPOCHS: 15
|
| 11 |
+
NUM_MINIBATCHES: 16
|
| 12 |
+
CLIP_EPS: 0.05
|
| 13 |
+
ENT_COEF: 0.01
|
| 14 |
+
ego_train_algorithm:
|
| 15 |
+
TOTAL_TIMESTEPS: 3e7
|
| 16 |
+
LR: 1e-4
|
| 17 |
+
ENT_COEF: 0.01
|
| 18 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/forced_coord.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- lbrdiv/_base_
|
| 3 |
+
- _self_ # values from this file override the values from the base file
|
| 4 |
+
|
| 5 |
+
TOTAL_TIMESTEPS: 9e7
|
| 6 |
+
PARTNER_POP_SIZE: 3
|
| 7 |
+
NUM_ENVS: 128
|
| 8 |
+
TOLERANCE_FACTOR: 5.0 # require that SP - XP > TOLERANCE_FACTOR
|
| 9 |
+
LR: 5e-4
|
| 10 |
+
UPDATE_EPOCHS: 15
|
| 11 |
+
NUM_MINIBATCHES: 16
|
| 12 |
+
CLIP_EPS: 0.05
|
| 13 |
+
ENT_COEF: 0.01
|
| 14 |
+
ego_train_algorithm:
|
| 15 |
+
TOTAL_TIMESTEPS: 6e7
|
| 16 |
+
LR: 1e-3
|
| 17 |
+
ENT_COEF: 0.01
|
| 18 |
+
CLIP_EPS: 0.05
|
teammate_generation/configs/base_config_teammate.yaml
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- task: lbf/lbf_7x7_nolevels # task configs
|
| 3 |
+
- algorithm@algorithm: fcp/${task} # task-specific algorithm configs
|
| 4 |
+
- hydra: hydra_simple
|
| 5 |
+
- ../../evaluation/configs/global_heldout_settings
|
| 6 |
+
- _self_
|
| 7 |
+
|
| 8 |
+
ENV_NAME: ${task.ENV_NAME}
|
| 9 |
+
ENV_KWARGS: ${task.ENV_KWARGS}
|
| 10 |
+
ROLLOUT_LENGTH: ${task.ROLLOUT_LENGTH}
|
| 11 |
+
TASK_NAME: ${task.TASK_NAME}
|
| 12 |
+
|
| 13 |
+
# training settings
|
| 14 |
+
train_ego: true # whether to train the ego agent
|
| 15 |
+
run_heldout_eval: true # whether to run a heldout evaluation of the ego agent
|
| 16 |
+
|
| 17 |
+
# teammate generation settings
|
| 18 |
+
algorithm:
|
| 19 |
+
NUM_EVAL_EPISODES: 20 # used during training
|
| 20 |
+
TRAIN_SEED: 20374 # 112358 # 20374
|
| 21 |
+
NUM_SEEDS: 1
|
| 22 |
+
ENV_NAME: ${ENV_NAME}
|
| 23 |
+
ENV_KWARGS: ${ENV_KWARGS}
|
| 24 |
+
ROLLOUT_LENGTH: ${ROLLOUT_LENGTH}
|
| 25 |
+
# ego training settings
|
| 26 |
+
ego_train_algorithm:
|
| 27 |
+
NUM_EGO_TRAIN_SEEDS: 1 # per seed of teammate generation
|
| 28 |
+
NUM_EVAL_EPISODES: 20
|
| 29 |
+
TRAIN_SEED: 204829
|
| 30 |
+
ENV_NAME: ${ENV_NAME}
|
| 31 |
+
ENV_KWARGS: ${ENV_KWARGS}
|
| 32 |
+
ROLLOUT_LENGTH: ${ROLLOUT_LENGTH}
|
| 33 |
+
|
| 34 |
+
label: "default_label"
|
| 35 |
+
name: ${TASK_NAME}/${algorithm.ALG}/${label}
|
| 36 |
+
|
| 37 |
+
# wandb settings
|
| 38 |
+
logger:
|
| 39 |
+
project: aht-benchmark
|
| 40 |
+
entity: aht-project
|
| 41 |
+
tags:
|
| 42 |
+
- ${algorithm.ALG}
|
| 43 |
+
- ${TASK_NAME}
|
| 44 |
+
- seed=${algorithm.TRAIN_SEED}
|
| 45 |
+
- ${label}
|
| 46 |
+
mode: offline # options: online, offline, disabled
|
| 47 |
+
verbose: false
|
| 48 |
+
log_train_out: true # whether to log the out dictionary
|
| 49 |
+
log_eval_out: true # whether to log the eval metrics
|
| 50 |
+
|
| 51 |
+
# Local logger
|
| 52 |
+
local_logger:
|
| 53 |
+
save_train_out: true
|
| 54 |
+
save_eval_out: true
|
teammate_generation/configs/hydra/hydra_simple.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
job:
|
| 2 |
+
chdir: true
|
| 3 |
+
run:
|
| 4 |
+
dir: results/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
| 5 |
+
sweep:
|
| 6 |
+
dir: results_sweep/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
| 7 |
+
subdir: ${run.seed}
|
teammate_generation/configs/task/hanabi.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hanabi: teammate generation task config.
|
| 2 |
+
# Mirrors ego_agent_training/configs/task/hanabi.yaml because
|
| 3 |
+
# teammate_generation methods (FCP, BRDiv, LBRDiv, CoMeDi) call into
|
| 4 |
+
# ego_agent_training as a subroutine, which asserts num_agents == 2.
|
| 5 |
+
# Hanabi is natively 2-player so this is satisfied by default.
|
| 6 |
+
ENV_NAME: hanabi
|
| 7 |
+
ROLLOUT_LENGTH: 128
|
| 8 |
+
ENV_KWARGS:
|
| 9 |
+
num_agents: 2
|
| 10 |
+
num_colors: 5
|
| 11 |
+
num_ranks: 5
|
| 12 |
+
hand_size: 5
|
| 13 |
+
max_info_tokens: 8
|
| 14 |
+
max_life_tokens: 3
|
| 15 |
+
num_cards_of_rank: [3, 2, 2, 2, 1]
|
| 16 |
+
TASK_NAME: hanabi
|
teammate_generation/configs/task/lbf/lbf_12x12.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ENV_NAME: lbf
|
| 2 |
+
ROLLOUT_LENGTH: 128
|
| 3 |
+
ENV_KWARGS:
|
| 4 |
+
grid_size: 12
|
| 5 |
+
num_food: 6
|
| 6 |
+
different_levels: true
|
| 7 |
+
TASK_NAME: lbf/lbf_12x12
|
teammate_generation/configs/task/lbf/lbf_7x7_nolevels.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ENV_NAME: lbf
|
| 2 |
+
ROLLOUT_LENGTH: 128
|
| 3 |
+
ENV_KWARGS: {}
|
| 4 |
+
TASK_NAME: lbf/lbf_7x7_nolevels
|
teammate_generation/configs/task/mini-hanabi.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mini-Hanabi: teammate generation task config.
|
| 2 |
+
# Mirrors ego_agent_training/configs/task/mini-hanabi.yaml.
|
| 3 |
+
ENV_NAME: hanabi
|
| 4 |
+
ROLLOUT_LENGTH: 128
|
| 5 |
+
ENV_KWARGS:
|
| 6 |
+
num_agents: 2
|
| 7 |
+
num_colors: 3
|
| 8 |
+
num_ranks: 3
|
| 9 |
+
hand_size: 3
|
| 10 |
+
max_info_tokens: 5
|
| 11 |
+
max_life_tokens: 3
|
| 12 |
+
num_cards_of_rank: [2, 2, 1]
|
| 13 |
+
TASK_NAME: mini-hanabi
|