diff --git a/teammate_generation/BRDiv.py b/teammate_generation/BRDiv.py new file mode 100644 index 0000000000000000000000000000000000000000..92c4a71c0697895c2f3d60fbc65338a3ac3de087 --- /dev/null +++ b/teammate_generation/BRDiv.py @@ -0,0 +1,832 @@ +'''Implementation of the BRDiv teammate generation algorithm (Rahman et al., TMLR 2023) +https://arxiv.org/abs/2207.14138 + +Command to run BRDiv only on LBF: +python teammate_generation/run.py algorithm=brdiv/lbf/lbf_7x7_nolevels task=lbf/lbf_7x7_nolevels label=test_brdiv run_heldout_eval=false train_ego=false + +Limitations: does not support recurrent actors. +''' +import shutil +import time +import logging +from typing import NamedTuple +from functools import partial + +import hydra +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax.training.train_state import TrainState +import wandb + +from agents.mlp_actor_critic_agent import ActorWithConditionalCriticPolicy +from agents.population_interface import AgentPopulation +from common.plot_utils import get_metric_names +from common.run_episodes import run_episodes +from common.save_load_utils import save_train_run +from envs import make_env +from envs.log_wrapper import LogWrapper +from marl.ppo_utils import unbatchify, _create_minibatches + +log = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +class XPTransition(NamedTuple): + done: jnp.ndarray + action: jnp.ndarray + value: jnp.ndarray + self_onehot_id: jnp.ndarray + oppo_onehot_id: jnp.ndarray + reward: jnp.ndarray + log_prob: jnp.ndarray + obs: jnp.ndarray + info: jnp.ndarray + avail_actions: jnp.ndarray + +def _get_all_ids(pop_size): + cross_product = np.meshgrid( + np.arange(pop_size), + np.arange(pop_size) + ) + agent_id_cartesian_product = np.stack([g.ravel() for g in cross_product], axis=-1) + all_conf_ids = agent_id_cartesian_product[:, 1] + all_br_ids = agent_id_cartesian_product[:, 0] + return all_conf_ids, all_br_ids + +def gather_params(partner_params_pytree, idx_vec): + """ + partner_params_pytree: pytree with all partner params. Each leaf has shape (n_seeds, m_ckpts, ...). + idx_vec: a vector of indices with shape (num_envs,) each in [0, n_seeds*m_ckpts). + + Return a new pytree where each leaf has shape (num_envs, ...). Each leaf has a sampled + partner's parameters for each environment. + """ + # We'll define a function that gathers from each leaf + # where leaf has shape (n_seeds, m_ckpts, ...), we want [idx_vec[i]] for each i. + # We'll vmap a slicing function. + def gather_leaf(leaf): + def slice_one(idx): + return leaf[idx] # shape (...) + return jax.vmap(slice_one)(idx_vec) + + return jax.tree.map(gather_leaf, partner_params_pytree) + +def train_brdiv_partners(train_rng, env, config, conf_policy, br_policy): + num_agents = env.num_agents + assert num_agents == 2, "This code assumes the environment has exactly 2 agents." + + # Define different minibatch sizes for interactions with ego agent and one with BR agent + config["NUM_GAME_AGENTS"] = num_agents + config["NUM_CONF_ACTORS"] = config["NUM_ENVS"] + config["NUM_BR_ACTORS"] = config["NUM_ENVS"] + config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // (config["ROLLOUT_LENGTH"] * config["NUM_ENVS"]) + + def make_brdiv_agents(config): + def linear_schedule(count): + frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"] + return config["LR"] * frac + + def train(rng): + rng, init_conf_rng, init_br_rng = jax.random.split(rng, 3) + all_conf_init_rngs = jax.random.split(init_conf_rng, config["PARTNER_POP_SIZE"]) + all_br_init_rngs = jax.random.split(init_br_rng, config["PARTNER_POP_SIZE"]) + identity_matrix = jnp.eye(config["PARTNER_POP_SIZE"]) + + init_conf_hstate = conf_policy.init_hstate(config["NUM_CONF_ACTORS"]) + init_br_hstate = br_policy.init_hstate(config["NUM_BR_ACTORS"]) + + def init_train_states(rng_agents, rng_brs): + def init_single_pair_optimizers(rng_agent, rng_br): + init_params_conf = conf_policy.init_params(rng_agent) + init_params_br = br_policy.init_params(rng_br) + return init_params_conf, init_params_br + + init_all_networks_and_optimizers = jax.vmap(init_single_pair_optimizers) + all_conf_params, all_br_params = init_all_networks_and_optimizers(rng_agents, rng_brs) + + # Define optimizers for both confederate and BR policy + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(learning_rate=linear_schedule if config["ANNEAL_LR"] else config["LR"], + eps=1e-5), + ) + tx_br = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(learning_rate=linear_schedule if config["ANNEAL_LR"] else config["LR"], + eps=1e-5), + ) + + train_state_conf = TrainState.create( + apply_fn=conf_policy.network.apply, + params=all_conf_params, + tx=tx, + ) + + train_state_br = TrainState.create( + apply_fn=br_policy.network.apply, + params=all_br_params, + tx=tx_br, + ) + + return train_state_conf, train_state_br + + all_conf_optims, all_br_optims = init_train_states( + all_conf_init_rngs, all_br_init_rngs + ) + + def forward_pass_conf(params, obs, id, done, avail_actions, hstate, rng): + act, val, pi, new_hstate = conf_policy.get_action_value_policy( + params=params, + obs=obs[jnp.newaxis, ...], + done=done[jnp.newaxis, ...], + avail_actions=avail_actions, + hstate=hstate, + rng=rng, + aux_obs=id[jnp.newaxis, ...] + ) + return act, val, pi, new_hstate + + def forward_pass_br(params, obs, id, done, avail_actions, hstate, rng): + act, val, pi, new_hstate = br_policy.get_action_value_policy( + params=params, + obs=obs[jnp.newaxis, ...], + done=done[jnp.newaxis, ...], + avail_actions=avail_actions, + hstate=hstate, + rng=rng, + aux_obs=id[jnp.newaxis, ...] + ) + return act, val, pi, new_hstate + + def _env_step(runner_state, unused): + """ + agent_0 = confederate, agent_1 = br + Returns updated runner_state, and Transitions for agent_0 and agent_1 + """ + ( + all_train_state_conf, all_train_state_br, last_conf_ids, last_br_ids, + env_state, last_obs, last_done, last_conf_h, last_br_h, rng + ) = runner_state + rng, act0_rng, act1_rng, step_rng, conf_sampling_rng, br_sampling_rng = jax.random.split(rng, 6) + + # For done envs, resample both conf and brs + needs_resample = last_done["__all__"] + resampled_conf_ids = jax.random.randint(conf_sampling_rng, (config["NUM_CONF_ACTORS"],), 0, config["PARTNER_POP_SIZE"]) + resampled_br_ids = jax.random.randint(br_sampling_rng, (config["NUM_BR_ACTORS"],), 0, config["PARTNER_POP_SIZE"]) + + # Determine final indices based on whether resampling was needed for each env + updated_conf_ids = jnp.where( + needs_resample, + resampled_conf_ids, # Use newly sampled index if True + last_conf_ids # Else, keep index from previous step + ) + + updated_br_ids = jnp.where( + needs_resample, + resampled_br_ids, # Use newly sampled index if True + last_br_ids # Else, keep index from previous step + ) + + # Reset the hidden states for resampled conf and br if they are not None + # WARNING: BRDiv was not tested with recurrent actors, so the code for if the hstate is not None may not work + if last_conf_h is not None: + updated_conf_h = jnp.where( + needs_resample, + init_conf_hstate, + last_conf_h + ) + else: + updated_conf_h = last_conf_h + + if last_br_h is not None: + updated_br_h = jnp.where( + needs_resample, + init_br_hstate, + last_br_h + ) + else: + updated_br_h = last_br_h + + # Get the corresponding conf and br params + updated_conf_params = gather_params(all_train_state_conf.params, updated_conf_ids) + updated_br_params = gather_params(all_train_state_br.params, updated_br_ids) + + updated_conf_onehot_ids = identity_matrix[updated_conf_ids] + updated_br_onehot_ids = identity_matrix[updated_br_ids] + + # Get available actions for agent 0 from environment state + avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state) + avail_actions = jax.lax.stop_gradient(avail_actions) + avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32) + avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32) + + # Agent_0 action + act0_rng = jax.random.split(act0_rng, config["NUM_ENVS"]) + act_0, val_0, pi_0, new_conf_h = jax.vmap(forward_pass_conf)(updated_conf_params, + last_obs["agent_0"], updated_br_onehot_ids, last_done["agent_0"], avail_actions_0, + updated_conf_h, act0_rng) + logp_0 = pi_0.log_prob(act_0) + act_0, val_0, logp_0 = act_0.squeeze(), val_0.squeeze(), logp_0.squeeze() + + # Agent_1 action + act1_rng = jax.random.split(act1_rng, config["NUM_ENVS"]) + act_1, val_1, pi_1, new_br_h = jax.vmap(forward_pass_br)(updated_br_params, + last_obs["agent_1"], updated_conf_onehot_ids, last_done["agent_1"], avail_actions_1, + updated_br_h, act1_rng) + logp_1 = pi_1.log_prob(act_1) + act_1, val_1, logp_1 = act_1.squeeze(), val_1.squeeze(), logp_1.squeeze() + + # Combine actions into the env format + combined_actions = jnp.concatenate([act_0, act_1], axis=0) + env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents) + env_act = {k: v.flatten() for k, v in env_act.items()} + + # Step env + step_rngs = jax.random.split(step_rng, config["NUM_ENVS"]) + obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))( + step_rngs, env_state, env_act + ) + # note that num_actors = num_envs * num_agents + info_0 = jax.tree.map(lambda x: x[:, 0], info) + info_1 = jax.tree.map(lambda x: x[:, 1], info) + + def _compute_rewards(conf_id, br_id, agent_rew): + return jax.lax.cond(jnp.equal( + jnp.argmax(conf_id, axis=-1), jnp.argmax(br_id, axis=-1) + ), + lambda x: x, + lambda x: -x, + agent_rew + ) + + agent_0_rews = jax.vmap(_compute_rewards)(updated_conf_onehot_ids, updated_br_onehot_ids, reward["agent_1"]) + agent_1_rews = jax.vmap(_compute_rewards)(updated_conf_onehot_ids, updated_br_onehot_ids, reward["agent_0"]) + + # Store agent_0 data in transition + transition_0 = XPTransition( + done=done["agent_0"], + action=act_0, + value=val_0, + self_onehot_id=updated_conf_onehot_ids, + oppo_onehot_id=updated_br_onehot_ids, + reward=agent_0_rews, + log_prob=logp_0, + obs=last_obs["agent_0"], + info=info_0, + avail_actions=avail_actions_0 + ) + + transition_1 = XPTransition( + done=done["agent_1"], + action=act_1, + value=val_1, + self_onehot_id=updated_br_onehot_ids, + oppo_onehot_id=updated_conf_onehot_ids, + reward=agent_1_rews, + log_prob=logp_1, + obs=last_obs["agent_1"], + info=info_1, + avail_actions=avail_actions_1 + ) + new_runner_state = (all_train_state_conf, all_train_state_br, updated_conf_ids, updated_br_ids, + env_state_next, obs_next, done, new_conf_h, new_br_h, rng) + return new_runner_state, (transition_0, transition_1) + + def _calculate_gae(traj_batch, last_val): + def _get_advantages(gae_and_next_value, transition): + gae, next_value = gae_and_next_value + done, value, reward = ( + transition.done, + transition.value, + transition.reward, + ) + delta = reward + config["GAMMA"] * next_value * (1 - done) - value + gae = ( + delta + + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae + ) + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + def run_all_episodes(rng, train_state_conf, train_state_br): + conf_ids, br_ids = _get_all_ids(config["PARTNER_POP_SIZE"]) + gathered_conf_model_params = gather_params(train_state_conf.params, conf_ids) + gathered_br_model_params = gather_params(train_state_br.params, br_ids) + + rng, eval_rng = jax.random.split(rng) + def run_episodes_fixed_rng(conf_param, br_param): + return run_episodes( + eval_rng, env, + conf_param, conf_policy, + br_param, br_policy, + config["ROLLOUT_LENGTH"], config["NUM_EVAL_EPISODES"], + ) + ep_infos = jax.vmap(run_episodes_fixed_rng)( + gathered_conf_model_params, gathered_br_model_params, # leaves where shape is (pop_size*pop_size, ...) + ) + return ep_infos + + def _update_epoch(update_state, unused): + def _update_minbatch(all_train_states, all_data): + train_state_conf, train_state_br = all_train_states + minbatch_conf, minbatch_br = all_data + + def _loss_fn(param, agent_policy, minbatch, agent_id): + '''Compute loss for agent corresponding to agent_id. + ''' + init_hstate, traj_batch, gae, target_v = minbatch + # get policy and value of confederate versus ego and best response agents respectively + squeezed_param = jax.tree.map(lambda x: jnp.squeeze(x, 0), param) + _, value, pi, _ = agent_policy.get_action_value_policy( + params=squeezed_param, + obs=traj_batch.obs, + done=traj_batch.done, + avail_actions=traj_batch.avail_actions, + hstate=init_hstate, + rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here + aux_obs=traj_batch.oppo_onehot_id + ) + log_prob = pi.log_prob(traj_batch.action) + + is_relevant = jnp.equal( + jnp.argmax(traj_batch.self_onehot_id, axis=-1), + agent_id + ) + loss_weights = jnp.where(is_relevant, 1, 0).astype(jnp.float32) + + # Value loss + value_pred_clipped = traj_batch.value + ( + value - traj_batch.value + ).clip( + -config["CLIP_EPS"], config["CLIP_EPS"]) + value_losses = jnp.square(value - target_v) + value_losses_clipped = jnp.square(value_pred_clipped - target_v) + value_loss = jax.lax.cond( + loss_weights.sum() == 0, + lambda x: jnp.zeros_like(x).astype(jnp.float32), + lambda x: x, + (loss_weights * jnp.maximum(value_losses, value_losses_clipped)).sum() / (loss_weights.sum() + 1e-8) + ) + + n = config["PARTNER_POP_SIZE"] + # Apply different loss weights for SP and XP data + # Loss weights consist of two parts: the first term is the weighting from the BRDiv loss fucntion + # The second term is a reweighting term to compensate for the data collection process, which uniformly and independently + # samples the conf and br ids from 1, ..., n, resulting in P(SP) = 1/n and P(XP) = (n-1)/n. + # To prevent the XP loss term from dominating the SP loss term, we would like P(SP) = P(XP) = 1/2. + # Thus, we set the 2nd term of the SP weight to n/2, and the 2nd term of the XP weight to n/(2 * (n-1)). + + is_sp = jnp.equal(jnp.argmax(traj_batch.self_onehot_id, axis=-1), jnp.argmax(traj_batch.oppo_onehot_id, axis=-1)) + sp_weight = (1 + 2*config["XP_LOSS_WEIGHTS"]) * (n/2) + xp_weight = config["XP_LOSS_WEIGHTS"] * (n / (2 * (n-1))) + actor_weights = jnp.where(is_sp, sp_weight, xp_weight) + + # Policy gradient loss + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae_norm = (gae - gae.mean()) / (gae.std() + 1e-8) + pg_loss_1 = ratio * gae_norm * actor_weights + pg_loss_2 = jnp.clip( + ratio, + 1.0 - config["CLIP_EPS"], + 1.0 + config["CLIP_EPS"]) * gae_norm * actor_weights + pg_loss = jax.lax.cond( + loss_weights.sum() == 0, + lambda x: jnp.zeros_like(x).astype(jnp.float32), + lambda x: x, + -( + loss_weights*jnp.minimum(pg_loss_1, pg_loss_2) + ).sum()/(loss_weights.sum() + 1e-8) + ) + + # Entropy + entropy = jax.lax.cond( + loss_weights.sum() == 0, + lambda x: jnp.zeros_like(x).astype(jnp.float32), + lambda x: x, + (loss_weights * pi.entropy()).sum()/(loss_weights.sum() + 1e-8) + ) + + total_loss = pg_loss + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy + return total_loss, (value_loss, pg_loss, entropy) + + possible_agent_ids = jnp.expand_dims(jnp.arange(config["PARTNER_POP_SIZE"]), 1) + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + + def gather_conf_params_and_return_grads(agent_id): + param_vector = gather_params(train_state_conf.params, agent_id) + (loss_val_conf, aux_vals_conf), grads_conf = grad_fn( + param_vector, conf_policy, minbatch_conf, agent_id + ) + return (loss_val_conf, aux_vals_conf), grads_conf + + def gather_br_params_and_return_grads(agent_id): + param_vector = gather_params(train_state_br.params, agent_id) + (loss_val_br, aux_vals_br), grads_br = grad_fn( + param_vector, br_policy, minbatch_br, agent_id + ) + return (loss_val_br, aux_vals_br), grads_br + + (loss_val_conf, aux_vals_conf), grads_conf = jax.vmap(gather_conf_params_and_return_grads)(possible_agent_ids) + (loss_val_br, aux_vals_br), grads_br = jax.vmap(gather_br_params_and_return_grads)(possible_agent_ids) + + grads_conf_new = jax.tree.map(lambda x: jnp.squeeze(x, 1), grads_conf) + grads_br_new = jax.tree.map(lambda x: jnp.squeeze(x, 1), grads_br) + train_state_conf = train_state_conf.apply_gradients(grads=grads_conf_new) + train_state_br = train_state_br.apply_gradients(grads=grads_br_new) + return (train_state_conf, train_state_br), ((loss_val_conf, aux_vals_conf), (loss_val_br, aux_vals_br)) + + ( + train_state_conf, train_state_br, + traj_batch_conf, traj_batch_br, + advantages_conf, advantages_br, + targets_conf, targets_br, + rng + ) = update_state + rng, perm_rng_conf, perm_rng_br = jax.random.split(rng, 3) + + minibatches_conf = _create_minibatches(traj_batch_conf, advantages_conf, targets_conf, init_conf_hstate, + config["NUM_CONF_ACTORS"], config["NUM_MINIBATCHES"], perm_rng_conf) + minibatches_br = _create_minibatches(traj_batch_br, advantages_br, targets_br, init_br_hstate, + config["NUM_BR_ACTORS"], config["NUM_MINIBATCHES"], perm_rng_br) + + # Update both policies + (train_state_conf, train_state_br), all_losses = jax.lax.scan( + _update_minbatch, (train_state_conf, train_state_br), (minibatches_conf, minibatches_br) + ) + + update_state = (train_state_conf, train_state_br, + traj_batch_conf, traj_batch_br, + advantages_conf, advantages_br, + targets_conf, targets_br, + rng + ) + return update_state, all_losses + + def _update_step(update_runner_state, unused): + """ + 1. Collect rollouts + 2. Compute advantage + 3. PPO updates + """ + ( + all_train_state_conf, all_train_state_br, + last_env_state, last_obs, last_done, last_conf_h, last_br_h, + rng, update_steps + ) = update_runner_state + + rng, conf_sampling_rng, br_sampling_rng = jax.random.split(rng, 3) + + conf_ids = jax.random.randint(conf_sampling_rng, (config["NUM_ENVS"],), 0, config["PARTNER_POP_SIZE"]) + br_ids = jax.random.randint(br_sampling_rng, (config["NUM_ENVS"],), 0, config["PARTNER_POP_SIZE"]) + + runner_state = ( + all_train_state_conf, all_train_state_br, conf_ids, br_ids, + last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng + ) + runner_state, traj_batch = jax.lax.scan( + _env_step, runner_state, None, config["ROLLOUT_LENGTH"]) + (all_train_state_conf, all_train_state_br, last_conf_ids, last_br_ids, + last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng) = runner_state + + # Get the last conf and br params and ids + last_conf_params = gather_params(all_train_state_conf.params, last_conf_ids) + last_br_params = gather_params(all_train_state_br.params, last_br_ids) + + last_conf_one_hots = identity_matrix[last_conf_ids] + last_br_one_hots = identity_matrix[last_br_ids] + + # Get agent 0 and agent 1 trajectories from interaction between conf policy and its BR policy. + traj_batch_conf, traj_batch_br = traj_batch + + # Compute advantage for confederate agent from interaction with br policy + avail_actions_0 = jax.vmap(env.get_avail_actions)(last_env_state.env_state)["agent_0"].astype(jnp.float32) + _, last_val_conf, _, _ = jax.vmap(forward_pass_conf)( + params=last_conf_params, + obs=last_obs["agent_0"], + id=last_br_one_hots, + done=last_done["agent_0"], + avail_actions=avail_actions_0, + hstate=last_conf_h, + rng=jax.random.split(jax.random.PRNGKey(0), config["NUM_ENVS"]) # Dummy key since we're just extracting the value + ) + last_val_conf = last_val_conf.squeeze() + advantages_conf, targets_conf = _calculate_gae(traj_batch_conf, last_val_conf) + + # Compute advantage for br policy from interaction with confederate agent + avail_actions_1 = jax.vmap(env.get_avail_actions)(last_env_state.env_state)["agent_1"].astype(jnp.float32) + _, last_val_br, _, _ = jax.vmap(forward_pass_br)( + params=last_br_params, + obs=last_obs["agent_1"], + id=last_conf_one_hots, + done=last_done["agent_1"], + avail_actions=avail_actions_1, + hstate=last_br_h, + rng=jax.random.split(jax.random.PRNGKey(0), config["NUM_ENVS"]) # Dummy key since we're just extracting the value + ) + last_val_br = last_val_br.squeeze() + advantages_br, targets_br = _calculate_gae(traj_batch_br, last_val_br) + + # 3) PPO update + rng, update_rng = jax.random.split(rng, 2) + update_state = ( + all_train_state_conf, all_train_state_br, + traj_batch_conf, traj_batch_br, + advantages_conf, advantages_br, + targets_conf, targets_br, + update_rng + ) + + update_state, all_losses = jax.lax.scan( + _update_epoch, update_state, None, config["UPDATE_EPOCHS"]) + all_train_state_conf, all_train_state_br = update_state[:2] + (_, (value_loss_conf, pg_loss_conf, entropy_conf)), (_, (value_loss_br, pg_loss_br, entropy_br)) = all_losses + + # Metrics + def mask_and_mean(x, mask): + return jnp.where(mask, x, 0).sum() / jnp.maximum(1, mask.sum()) + + mask = traj_batch_conf.info.get("returned_episode", jnp.ones_like(traj_batch_conf.reward)) + metric = jax.tree.map(lambda x: mask_and_mean(x, mask), traj_batch_conf.info) + metric["update_steps"] = update_steps + metric["value_loss_conf_agent"] = value_loss_conf.mean(axis=(0, 1)) + metric["value_loss_br_agent"] = value_loss_br.mean(axis=(0, 1)) + + metric["pg_loss_conf_agent"] = pg_loss_conf.mean(axis=(0, 1)) + metric["pg_loss_br_agent"] = pg_loss_br.mean(axis=(0, 1)) + + metric["entropy_conf"] = entropy_conf.mean(axis=(0, 1)) + metric["entropy_br"] = entropy_br.mean(axis=(0, 1)) + + new_runner_state = ( + all_train_state_conf, all_train_state_br, + last_env_state, last_obs, last_done, last_conf_h, last_br_h, + rng, update_steps + 1 + ) + return (new_runner_state, metric) + + # -------------------------- + # PPO Update and Checkpoint saving + # -------------------------- + ckpt_and_eval_interval = config["NUM_UPDATES"] // max(1, config["NUM_CHECKPOINTS"] - 1) # -1 because we store a ckpt at the last update + num_ckpts = config["NUM_CHECKPOINTS"] + + # Build a PyTree that holds parameters for all conf agent checkpoints + def init_ckpt_array(params_pytree): + return jax.tree.map( + lambda x: jnp.zeros((num_ckpts,) + x.shape, x.dtype), + params_pytree) + + def _update_step_with_ckpt(state_with_ckpt, unused): + (update_runner_state, checkpoint_array_conf, checkpoint_array_br, ckpt_idx, + eval_info) = state_with_ckpt + + # Single PPO update + new_runner_state, metric = _update_step(update_runner_state, None) + + train_state_conf, train_state_br, last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng, update_steps = new_runner_state + + # Decide if we store a checkpoint + # update steps is 1-indexed because it was incremented at the end of the update step + to_store = jnp.logical_or(jnp.equal(jnp.mod(update_steps-1, ckpt_and_eval_interval), 0), + jnp.equal(update_steps, config["NUM_UPDATES"])) + + def store_and_eval_ckpt(args): + ckpt_arr_and_ep_infos, rng, cidx = args + ckpt_arr_conf, ckpt_arr_br, _ = ckpt_arr_and_ep_infos + new_ckpt_arr_conf = jax.tree.map( + lambda c_arr, p: c_arr.at[cidx].set(p), + ckpt_arr_conf, train_state_conf.params + ) + new_ckpt_arr_br = jax.tree.map( + lambda c_arr, p: c_arr.at[cidx].set(p), + ckpt_arr_br, train_state_br.params + ) + + rng, eval_rng = jax.random.split(rng) + ep_last_info = jax.tree.map(lambda x: x.mean(axis=(-2, -1)), + run_all_episodes(eval_rng, train_state_conf, train_state_br)) + + return ((new_ckpt_arr_conf, new_ckpt_arr_br, ep_last_info), rng, cidx + 1) + + def skip_ckpt(args): + return args + + (checkpoint_array_and_infos, rng, ckpt_idx) = jax.lax.cond( + to_store, + store_and_eval_ckpt, + skip_ckpt, + ((checkpoint_array_conf, checkpoint_array_br, eval_info), rng, ckpt_idx) + ) + checkpoint_array_conf, checkpoint_array_br, eval_ep_last_info = checkpoint_array_and_infos + + metric["eval_ep_last_info"] = eval_ep_last_info # return of confederate + + return ((train_state_conf, train_state_br, + last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng, update_steps), + checkpoint_array_conf, checkpoint_array_br, ckpt_idx, + eval_ep_last_info), metric + + # Initialize checkpoint array + checkpoint_array_conf = init_ckpt_array(all_conf_optims.params) + checkpoint_array_br = init_ckpt_array(all_br_optims.params) + ckpt_idx = 0 + + # Initialize state for scan over _update_step_with_ckpt + update_steps = 0 + + rng, rng_eval = jax.random.split(rng, 2) + eval_ep_last_info = jax.tree.map(lambda x: x.mean(axis=(-2, -1)), + run_all_episodes(rng_eval, all_conf_optims, all_br_optims)) + + # Initialize environment + rng, reset_rng = jax.random.split(rng) + reset_rngs = jax.random.split(reset_rng, config["NUM_ENVS"]) + init_obs, init_env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rngs) + init_done = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]} + + # Initialize conf and br hstates + init_conf_h = conf_policy.init_hstate(config["NUM_CONF_ACTORS"]) + init_br_h = br_policy.init_hstate(config["NUM_BR_ACTORS"]) + + update_runner_state = ( + all_conf_optims, all_br_optims, + init_env_state, init_obs, init_done, init_conf_h, init_br_h, + rng, update_steps + ) + + state_with_ckpt = ( + update_runner_state, checkpoint_array_conf, + checkpoint_array_br, ckpt_idx, eval_ep_last_info + ) + + # run training + state_with_ckpt, metrics = jax.lax.scan( + _update_step_with_ckpt, + state_with_ckpt, + xs=None, + length=config["NUM_UPDATES"] + ) + + ( + final_runner_state, checkpoint_array_conf, checkpoint_array_br, + final_ckpt_idx, all_ep_infos + ) = state_with_ckpt + + out = { + "final_params_conf": final_runner_state[0].params, + "final_params_br": final_runner_state[1].params, + "checkpoints_conf": checkpoint_array_conf, + "checkpoints_br": checkpoint_array_br, + "metrics": metrics, # metrics is from the perspective of the confederate agent (averaged over population) + "all_pair_returns": all_ep_infos + } + return out + + return train + # ------------------------------ + # Actually run the adversarial teammate training + # ------------------------------ + train_fn = make_brdiv_agents(config) + out = train_fn(train_rng) + return out + +def get_brdiv_population(config, out, env): + ''' + Get the partner params and partner population for ego training. + ''' + brdiv_pop_size = config["algorithm"]["PARTNER_POP_SIZE"] + + # partner_params has shape (num_seeds, brdiv_pop_size, ...) + partner_params = out['final_params_conf'] + + partner_policy = ActorWithConditionalCriticPolicy( + action_dim=env.action_space(env.agents[1]).n, + obs_dim=env.observation_space(env.agents[1]).shape[0], + pop_size=brdiv_pop_size, # used to create onehot agent id + activation=config["algorithm"].get("ACTIVATION", "tanh") + ) + + # Create partner population + partner_population = AgentPopulation( + pop_size=brdiv_pop_size, + policy_cls=partner_policy + ) + + return partner_params, partner_population + +def run_brdiv(config, wandb_logger): + algorithm_config = dict(config["algorithm"]) + + env = make_env(algorithm_config["ENV_NAME"], algorithm_config["ENV_KWARGS"]) + env = LogWrapper(env) + + log.info("Starting BRDiv training...") + start = time.time() + + # Generate multiple random seeds from the base seed + rng = jax.random.PRNGKey(algorithm_config["TRAIN_SEED"]) + rngs = jax.random.split(rng, algorithm_config["NUM_SEEDS"]) + + # Initialize br and conf policies + conf_policy = ActorWithConditionalCriticPolicy( + action_dim=env.action_space(env.agents[0]).n, + obs_dim=env.observation_space(env.agents[0]).shape[0], + pop_size=algorithm_config["PARTNER_POP_SIZE"], + ) + br_policy = ActorWithConditionalCriticPolicy( + action_dim=env.action_space(env.agents[0]).n, + obs_dim=env.observation_space(env.agents[0]).shape[0], + pop_size=algorithm_config["PARTNER_POP_SIZE"], + ) + + # Create a vmapped version of train_brdiv_partners + with jax.disable_jit(False): + vmapped_train_fn = jax.jit( + jax.vmap( + partial(train_brdiv_partners, env=env, config=algorithm_config, conf_policy=conf_policy, br_policy=br_policy) + ) + ) + out = vmapped_train_fn(rngs) + + end = time.time() + log.info(f"BRDiv training complete in {end - start} seconds") + + metric_names = get_metric_names(algorithm_config["ENV_NAME"]) + log_metrics(config, out, wandb_logger, metric_names) + + partner_params, partner_population = get_brdiv_population(config, out, env) + + return partner_params, partner_population + + +def log_metrics(config, outs, logger, metric_names: tuple): + metrics = outs["metrics"] + # metrics now has shape (num_seeds, num_updates, pop_size) + num_seeds, num_updates, pop_size = metrics["pg_loss_conf_agent"].shape # number of trained pairs + + ### Log evaluation metrics + # we plot XP return curves separately from SP return curves + # shape (num_seeds, num_updates, (pop_size)^2) [pre-scalarized: mean over eval eps and agents taken inside scan] + all_returns = np.asarray(metrics["eval_ep_last_info"]["returned_episode_returns"]) + xs = list(range(num_updates)) + + all_conf_ids, all_br_ids = _get_all_ids(pop_size) + sp_mask = (all_conf_ids == all_br_ids) + sp_returns = all_returns[:, :, sp_mask] + xp_returns = all_returns[:, :, ~sp_mask] + + # Average over seeds and agent pairs (eval episodes and agents already averaged inside scan) + sp_return_curve = sp_returns.mean(axis=(0, 2)) + xp_return_curve = xp_returns.mean(axis=(0, 2)) + + for step in range(num_updates): + logger.log_item("Eval/AvgSPReturnCurve", sp_return_curve[step], train_step=step) + logger.log_item("Eval/AvgXPReturnCurve", xp_return_curve[step], train_step=step) + logger.commit() + + # log final XP matrix to wandb - average over seeds + last_returns_array = all_returns[:, -1].mean(axis=0) + last_returns_array = np.reshape(last_returns_array, (pop_size, pop_size)) + logger.log_xp_matrix("Eval/LastXPMatrix", last_returns_array) + + ### Log population loss as multi-line plots, where each line is a different population member + # shape (num_seeds, num_updates, update_epochs, num_minibatches, pop_size) + # Average over seeds + processed_losses = { + "ConfPGLoss": np.asarray(metrics["pg_loss_conf_agent"]).mean(axis=0).transpose(), + "BRPGLoss": np.asarray(metrics["pg_loss_br_agent"]).mean(axis=0).transpose(), + "ConfValLoss": np.asarray(metrics["value_loss_conf_agent"]).mean(axis=0).transpose(), + "BRValLoss": np.asarray(metrics["value_loss_br_agent"]).mean(axis=0).transpose(), + "ConfEntropy": np.asarray(metrics["entropy_conf"]).mean(axis=0).transpose(), + "BREntropy": np.asarray(metrics["entropy_br"]).mean(axis=0).transpose(), + } + + xs = list(range(num_updates)) + keys = [f"pair {i}" for i in range(pop_size)] + for loss_name, loss_data in processed_losses.items(): + if np.isnan(loss_data).any(): + raise ValueError(f"Found nan in loss {loss_name}") + logger.log_item(f"Losses/{loss_name}", + wandb.plot.line_series(xs=xs, ys=loss_data, keys=keys, + title=loss_name, xname="train_step") + ) + + ### Log artifacts + savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + # Save train run output and log to wandb as artifact + out_savepath = save_train_run(outs, savedir, savename="saved_train_run") + if config["logger"]["log_train_out"]: + logger.log_artifact(name="saved_train_run", path=out_savepath, type_name="train_run") + + # Cleanup locally logged out files + if not config["local_logger"]["save_train_out"]: + shutil.rmtree(out_savepath) diff --git a/teammate_generation/CoMeDi.py b/teammate_generation/CoMeDi.py new file mode 100644 index 0000000000000000000000000000000000000000..95fe1c1230d66c5e0f5624c65168c27c119266e5 --- /dev/null +++ b/teammate_generation/CoMeDi.py @@ -0,0 +1,1161 @@ +'''Implementation of the CoMeDi teammate generation algorithm (Sarkar et al. NeurIPS 2023) +https://openreview.net/forum?id=MljeRycu9s + +Command to run CoMeDi only on LBF: +python teammate_generation/run.py algorithm=comedi/lbf/lbf_7x7_nolevels task=lbf/lbf_7x7_nolevels label=test_comedi run_heldout_eval=false train_ego=false + +Limitations: does not support recurrent actors. +''' +from functools import partial +import logging +import shutil +import time +from typing import NamedTuple + +from flax.training.train_state import TrainState +import hydra +import jax +import jax.numpy as jnp +import numpy as np +import optax +import wandb + +from agents.mlp_actor_critic_agent import ActorWithConditionalCriticPolicy +from agents.initialize_agents import initialize_actor_with_conditional_critic +from agents.population_interface import AgentPopulation +from agents.population_buffer import BufferedPopulation +from common.save_load_utils import save_train_run +from common.plot_utils import get_metric_names +from common.run_episodes import run_episodes +from envs import make_env +from envs.log_wrapper import LogWrapper, LogEnvState +from marl.ippo import make_train as make_ppo_train +from marl.ppo_utils import Transition, unbatchify, _create_minibatches + +log = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +class ResetTransition(NamedTuple): + '''Stores extra information for resetting agents to a point in some trajectory.''' + env_state: LogEnvState + conf_obs: jnp.ndarray + partner_obs: jnp.ndarray + conf_done: jnp.ndarray + partner_done: jnp.ndarray + conf_hstate: jnp.ndarray + partner_hstate: jnp.ndarray + +def train_comedi_partners(train_rng, wandb_logger, env, config): + num_agents = env.num_agents + assert num_agents == 2, "This code assumes the environment has exactly 2 agents." + + # Define 4 types of rollouts: SP, XP, MP, MP2 + config["NUM_GAME_AGENTS"] = num_agents + + config["NUM_ACTORS"] = num_agents * config["NUM_ENVS"] + # Right now assume control of both agent and its BR + config["NUM_CONTROLLED_ACTORS"] = config["NUM_ACTORS"] + + # Compute numbber of updates PER outermost iteration + # Calculate timesteps per update + # 1. Overhead from population selection rollouts + # We divide by 2 because for ease in Jax, this implementation uses a vmap over PARTNER_POP_SIZE to + # evaluate the agent generated at each outermost iteration against all previously + # generated agents, but a non-Jax implementation would only need to evaluate against + # *previously* generated agents. + selection_steps = config["PARTNER_POP_SIZE"] * config["NUM_ARGMAX_ROLLOUT_EPS"] * config["ROLLOUT_LENGTH"] // 2 + # 2. Training rollouts: 4 distinct rollout phases (SP, XP, MP, MP2) each using NUM_ENVS + training_steps = 4 * config["ROLLOUT_LENGTH"] * config["NUM_ENVS"] + + steps_per_update = selection_steps + training_steps + config["NUM_UPDATES"] = int(config["TOTAL_TIMESTEPS_PER_ITERATION"] // steps_per_update) + + def make_comedi_agents(config): + def linear_schedule(count): + frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"] + return config["LR"] * frac + + def train_init_ippo_partners(config, partner_rng, env): + ''' + Train a pool IPPO agents w/parameter sharing. + Returns out, a dictionary of the model checkpoints, final parameters, and metrics. + ''' + # POP_SIZE is referenced throughout the CoMeDi training loops + config["POP_SIZE"] = config["PARTNER_POP_SIZE"] + # Use a local copy for warmup-specific overrides to avoid + # mutating the shared config (ACTOR_TYPE, TOTAL_TIMESTEPS) + warmup_config = dict(config) + warmup_config["TOTAL_TIMESTEPS"] = config["TOTAL_TIMESTEPS_PER_ITERATION"] + warmup_config["ACTOR_TYPE"] = "pseudo_actor_with_conditional_critic" + out = make_ppo_train(warmup_config, env, wandb_logger)(partner_rng) + return out + + def train(rng): + # Start by training a single PPO agent via self-play + rng, init_ppo_rng, init_conf_rng = jax.random.split(rng, 3) + + init_ppo_partner = train_init_ippo_partners(config, init_ppo_rng, env) + + # Initialize a population buffer + dummy_policy, dummy_init_params = initialize_actor_with_conditional_critic(config, env, init_conf_rng) + partner_population = BufferedPopulation( + max_pop_size=config["PARTNER_POP_SIZE"], + policy_cls=dummy_policy, + ) + + population_buffer = partner_population.reset_buffer(dummy_init_params) + population_buffer = partner_population.add_agent(population_buffer, init_ppo_partner["final_params"]) + + def add_conf_policy(pop_buffer, func_input): + num_existing_agents, rng = func_input + rng, init_conf_rng = jax.random.split(rng) + + # Create new confederate agent policy and critic + policy, init_params = initialize_actor_with_conditional_critic( + config, env, init_conf_rng + ) + + # Create a train_state and optimizer for the newly initialzied model + if config["ANNEAL_LR"]: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(learning_rate=linear_schedule, eps=1e-5), + ) + else: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(config["LR"], eps=1e-5)) + + train_state = TrainState.create( + apply_fn=policy.network.apply, + params=init_params, + tx=tx, + ) + + # Reset envs for SP, XP, and MP + rng, reset_rng_eval, reset_rng_sp, reset_rng_xp, reset_rng_mp, reset_rng_mp2 = jax.random.split(rng, 6) + + reset_rngs_sps = jax.random.split(reset_rng_sp, config["NUM_ENVS"]) + reset_rngs_xps = jax.random.split(reset_rng_xp, config["NUM_ENVS"]) + reset_rngs_mps = jax.random.split(reset_rng_mp, config["NUM_ENVS"]) + reset_rngs_mps2 = jax.random.split(reset_rng_mp2, config["NUM_ENVS"]) + + obsv_xp, env_state_xp = jax.vmap(env.reset, in_axes=(0,))(reset_rngs_sps) + obsv_sp, env_state_sp = jax.vmap(env.reset, in_axes=(0,))(reset_rngs_xps) + obsv_mp, env_state_mp = jax.vmap(env.reset, in_axes=(0,))(reset_rngs_mps) + obsv_mp2, env_state_mp2 = jax.vmap(env.reset, in_axes=(0,))(reset_rngs_mps2) + + # build a pytree that can hold the parameters for all checkpoints. + ckpt_and_eval_interval = config["NUM_UPDATES"] // max(1, config["NUM_CHECKPOINTS"] - 1) + num_ckpts = config["NUM_CHECKPOINTS"] + def init_ckpt_array(params_pytree): + return jax.tree.map( + lambda x: jnp.zeros((num_ckpts,) + x.shape, x.dtype), + params_pytree + ) + + # define evaluation function + rng, eval_rng = jax.random.split(rng, 2) + def per_id_run_episode_fixed_rng(agent0_param, agent1_id): + agent1_param = partner_population.gather_agent_params(pop_buffer, + agent_indices=agent1_id * jnp.ones((1,), dtype=np.int32)) + agent1_param = jax.tree_map(lambda y: jnp.squeeze(y, 0), agent1_param) + all_outs = run_episodes( + rng=eval_rng, env=env, + agent_0_param=agent0_param, agent_0_policy=policy, + agent_1_param=agent1_param, agent_1_policy=policy, + max_episode_steps=config["ROLLOUT_LENGTH"], + num_eps=config["NUM_ARGMAX_ROLLOUT_EPS"] + ) + return all_outs + + def _update_step(update_with_ckpt_runner_state, unused): + update_runner_state, checkpoint_array, ckpt_idx = update_with_ckpt_runner_state + ( + train_state, pop_buffer, + env_state_sp, obsv_sp, + env_state_xp, obsv_xp, + env_state_mp, obsv_mp, + env_state_mp2, obsv_mp2, + last_dones_xp, + last_dones_sp, + last_dones_mp, + last_dones_mp2, + rng, update_steps, + num_prev_trained_conf + ) = update_runner_state + + # Identify the expected returns from the newly trained policy + # when interacting with the previously generated confederate + # policies + valid_sampling_indices = jnp.arange(config["POP_SIZE"]) + run_all_rollouts = jax.vmap(per_id_run_episode_fixed_rng, in_axes=(None, 0))( + train_state.params,valid_sampling_indices) + + # Mask out the XP returns against invalid policies + # resulting from IDs that are yet set to a specific + # confederate params + all_mean_returns = run_all_rollouts["returned_episode_returns"][:, :, 0].mean(axis=-1) + masked_mean_returns = jnp.where( + valid_sampling_indices >= num_prev_trained_conf, -jnp.inf, all_mean_returns + ) + + # Pick the right confederate params to act as the XP agent + max_means_id = masked_mean_returns.argmax() + xp_param = jax.tree_map( + lambda x: jnp.squeeze(x, 0), + partner_population.gather_agent_params(pop_buffer, + agent_indices=max_means_id * jnp.ones((1,), dtype=np.int32)) + ) + + rng, rng_xp, rng_sp, rng_mp, rng_mp2 = jax.random.split(rng, 5) + + def _env_step_conf_ego(runner_state, unused): + """ + agent_0 = confederate, agent_1 = ego + Returns updated runner_state and a Transition for the confederate. + """ + train_state, xp_param, xp_id, env_state, last_obs, last_dones, rng = runner_state + rng, act_rng, partner_rng, step_rng = jax.random.split(rng, 4) + + obs_0 = last_obs["agent_0"] + obs_1 = last_obs["agent_1"] + + # Get available actions for agent 0 from environment state + avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state) + avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32) + avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32) + + # Add one-hot ID of XP teammate + xp_one_hot_id = jnp.eye(config["POP_SIZE"])[xp_id] + xp_one_hot_id = jnp.expand_dims( + jnp.expand_dims( + xp_one_hot_id, 0 + ), 0 + ) + + # Agent_0 (confederate) action using policy interface + aux_obs = jnp.repeat(xp_one_hot_id, config["NUM_ENVS"], axis=1) + act_0, val_0, pi_0, _ = policy.get_action_value_policy( + params=train_state.params, + obs=obs_0.reshape(1, config["NUM_ENVS"], -1), + done=last_dones["agent_0"].reshape(1, config["NUM_ENVS"]), + avail_actions=jax.lax.stop_gradient(avail_actions_0), + hstate=None, + rng=act_rng, + aux_obs=aux_obs + ) + logp_0 = pi_0.log_prob(act_0) + + act_0 = act_0.squeeze() + logp_0 = logp_0.squeeze() + val_0 = val_0.squeeze() + + # Agent_1 (ego) action using policy interface + act_1, _, _, _ = policy.get_action_value_policy( + params=xp_param, + obs=obs_1.reshape(1, config["NUM_ENVS"], -1), + done=last_dones["agent_1"].reshape(1, config["NUM_ENVS"]), + avail_actions=jax.lax.stop_gradient(avail_actions_1), + hstate=None, + rng=partner_rng, + aux_obs=aux_obs + ) + act_1 = act_1.squeeze() + + # Combine actions into the env format + combined_actions = jnp.concatenate([act_0, act_1], axis=0) # shape (2*num_envs,) + env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents) + env_act = {k: v.flatten() for k, v in env_act.items()} + + # Step env + step_rngs = jax.random.split(step_rng, config["NUM_ENVS"]) + obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))( + step_rngs, env_state, env_act + ) + # note that num_actors = num_envs * num_agents + info_0 = jax.tree.map(lambda x: x[:, 0], info) + + # Store agent_0 data in transition + transition = Transition( + done=done["agent_0"], + action=act_0, + value=val_0, + reward=reward["agent_1"], + log_prob=logp_0, + obs=obs_0, + info=info_0, + avail_actions=avail_actions_0 + ) + new_runner_state = (train_state, xp_param, xp_id, env_state_next, obs_next, done, rng) + return new_runner_state, transition + + def _env_step_conf_br(runner_state, unused): + """ + agent_0 = confederate, agent_1 = best response + Returns updated runner_state, and Transitions for the confederate and best response. + """ + train_state, env_state, last_obs, last_dones, rng, current_trained_pop_id, reset_traj_batch = runner_state + rng, conf_rng, br_rng, step_rng = jax.random.split(rng, 4) + + def gather_sampled(data_pytree, flat_indices, first_nonbatch_dim: int): + '''Will treat all dimensions up to the first_nonbatch_dim as batch dimensions. ''' + batch_size = config["ROLLOUT_LENGTH"] * config["NUM_ENVS"] + flat_data = jax.tree.map(lambda x: x.reshape(batch_size, *x.shape[first_nonbatch_dim:]), data_pytree) + sampled_data = jax.tree.map(lambda x: x[flat_indices], flat_data) # Shape (N, ...) + return sampled_data + + if reset_traj_batch is not None: + rng, sample_rng = jax.random.split(rng) + needs_resample = last_dones["__all__"] # shape (N,) bool + + total_reset_states = config["ROLLOUT_LENGTH"] * config["NUM_ENVS"] + sampled_indices = jax.random.randint(sample_rng, shape=(config["NUM_ENVS"],), minval=0, + maxval=total_reset_states) + + # Gather sampled leaves from each data pytree + sampled_env_state = gather_sampled(reset_traj_batch.env_state, sampled_indices, first_nonbatch_dim=2) + sampled_conf_obs = gather_sampled(reset_traj_batch.conf_obs, sampled_indices, first_nonbatch_dim=2) + sampled_br_obs = gather_sampled(reset_traj_batch.partner_obs, sampled_indices, first_nonbatch_dim=2) + sampled_conf_done = gather_sampled(reset_traj_batch.conf_done, sampled_indices, first_nonbatch_dim=2) + sampled_br_done = gather_sampled(reset_traj_batch.partner_done, sampled_indices, first_nonbatch_dim=2) + + # for done environments, select data corresponding to the reset_traj_batch states + env_state = jax.tree.map( + lambda sampled, original: jnp.where( + needs_resample.reshape((-1,) + (1,) * (original.ndim - 1)), + sampled, original + ), + sampled_env_state, + env_state + ) + obs_0 = jnp.where(needs_resample[:, jnp.newaxis], sampled_conf_obs, last_obs["agent_0"]) + obs_1 = jnp.where(needs_resample[:, jnp.newaxis], sampled_br_obs, last_obs["agent_1"]) + + dones_0 = jnp.where(needs_resample, sampled_conf_done, last_dones["agent_0"]) + dones_1 = jnp.where(needs_resample, sampled_br_done, last_dones["agent_1"]) + + else: + + # Reset conf-br data collection from conf-ego states + obs_0, obs_1 = last_obs["agent_0"], last_obs["agent_1"] + dones_0, dones_1 = last_dones["agent_0"], last_dones["agent_1"] + + # Get available actions for agent 0 from environment state + avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state) + avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32) + avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32) + + # Agent_0 (confederate) action + # Add one-hot ID of XP teammate + sp_one_hot_id = jnp.eye(config["POP_SIZE"])[current_trained_pop_id] + sp_one_hot_id = jnp.expand_dims( + jnp.expand_dims( + sp_one_hot_id, 0 + ), 0 + ) + + aux_obs = jnp.repeat(sp_one_hot_id, config["NUM_ENVS"], 1) + act_0, val_0, pi_0, _ = policy.get_action_value_policy( + params=train_state.params, + obs=obs_0.reshape(1, config["NUM_ENVS"], -1), + done=dones_0.reshape(1, config["NUM_ENVS"]), + avail_actions=jax.lax.stop_gradient(avail_actions_0), + hstate=None, + rng=conf_rng, + aux_obs=aux_obs + ) + logp_0 = pi_0.log_prob(act_0) + + act_0 = act_0.squeeze() + logp_0 = logp_0.squeeze() + val_0 = val_0.squeeze() + + # Agent 1 (best response) action + act_1, val_1, pi_1, _ = policy.get_action_value_policy( + params=train_state.params, + obs=obs_1.reshape(1, config["NUM_ENVS"], -1), + done=dones_1.reshape(1, config["NUM_ENVS"]), + avail_actions=jax.lax.stop_gradient(avail_actions_1), + hstate=None, + rng=br_rng, + aux_obs=aux_obs + ) + logp_1 = pi_1.log_prob(act_1) + + act_1 = act_1.squeeze() + logp_1 = logp_1.squeeze() + val_1 = val_1.squeeze() + + # Combine actions into the env format + combined_actions = jnp.concatenate([act_0, act_1], axis=0) # shape (2*num_envs,) + env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents) + env_act = {k: v.flatten() for k, v in env_act.items()} + + # Step env + step_rngs = jax.random.split(step_rng, config["NUM_ENVS"]) + obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))( + step_rngs, env_state, env_act + ) + info_0 = jax.tree.map(lambda x: x[:, 0], info) + info_1 = jax.tree.map(lambda x: x[:, 1], info) + + # Store agent_0 (confederate) data in transition + transition_0 = Transition( + done=done["agent_0"], + action=act_0, + value=val_0, + reward=reward["agent_0"], + log_prob=logp_0, + obs=obs_0, + info=info_0, + avail_actions=avail_actions_0 + ) + # Store agent_1 (best response) data in transition + transition_1 = Transition( + done=done["agent_1"], + action=act_1, + value=val_1, + reward=reward["agent_1"], + log_prob=logp_1, + obs=obs_1, + info=info_1, + avail_actions=avail_actions_1 + ) + # Pass reset_traj_batch and init_br_hstate through unchanged in the state tuple + new_runner_state = (train_state, env_state_next, obs_next, done, rng, current_trained_pop_id, reset_traj_batch) + return new_runner_state, (transition_0, transition_1) + + def _env_step_mixed(runner_state, unused): + """ + agent_0 = confederate, agent_1 = ego OR best response + Returns a ResetTransition for resetting to env states encountered here. + """ + train_state_conf, ego_param, env_state, last_obs, last_dones, rng, current_trained_pop_id = runner_state + rng, act_rng, ego_act_rng, br_act_rng, partner_choice_rng, step_rng = jax.random.split(rng, 6) + + obs_0 = last_obs["agent_0"] + obs_1 = last_obs["agent_1"] + + # Get available actions for agent 0 from environment state + avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state) + avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32) + avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32) + + xp_one_hot_id = jnp.eye(config["POP_SIZE"])[current_trained_pop_id] + xp_one_hot_id = jnp.expand_dims( + jnp.expand_dims( + xp_one_hot_id, 0 + ), 0 + ) + + # Agent_0 (confederate) action using policy interface + aux_obs = jnp.repeat(xp_one_hot_id, config["NUM_ENVS"], axis=1) + + # Agent_0 (confederate) action using policy interface + act_0, val_0, pi_0, _ = policy.get_action_value_policy( + params=train_state_conf.params, + obs=obs_0.reshape(1, config["NUM_ENVS"], -1), + done=last_dones["agent_0"].reshape(1, config["NUM_ENVS"]), + avail_actions=jax.lax.stop_gradient(avail_actions_0), + hstate=None, + rng=act_rng, + aux_obs=aux_obs + ) + logp_0 = pi_0.log_prob(act_0) + + act_0 = act_0.squeeze() + logp_0 = logp_0.squeeze() + val_0 = val_0.squeeze() + + ### Compute both the ego action and the best response action + act_ego, _, _, _ = policy.get_action_value_policy( + params=ego_param, + obs=obs_1.reshape(1, config["NUM_ENVS"], -1), + done=last_dones["agent_1"].reshape(1, config["NUM_ENVS"]), + avail_actions=jax.lax.stop_gradient(avail_actions_1), + hstate=None, + rng=ego_act_rng, + aux_obs=aux_obs + ) + act_br, _, _, _ = policy.get_action_value_policy( + params=train_state.params, + obs=obs_1.reshape(1, config["NUM_ENVS"], -1), + done=last_dones["agent_1"].reshape(1, config["NUM_ENVS"]), + avail_actions=jax.lax.stop_gradient(avail_actions_1), + hstate=None, + rng=br_act_rng, + aux_obs=aux_obs + ) + + act_ego = act_ego.squeeze() + act_br = act_br.squeeze() + # Agent 1 (ego or best response) action - choose between ego and best response + partner_choice = jax.random.randint(partner_choice_rng, shape=(config["NUM_ENVS"],), minval=0, maxval=2) + act_1 = jnp.where(partner_choice == 0, act_ego, act_br) + + # Combine actions into the env format + combined_actions = jnp.concatenate([act_0, act_1], axis=0) + env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents) + env_act = {k: v.flatten() for k, v in env_act.items()} + + # Step env + step_rngs = jax.random.split(step_rng, config["NUM_ENVS"]) + obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))( + step_rngs, env_state, env_act + ) + + reset_transition = ResetTransition( + # all of these are from before env step + env_state=env_state, + conf_obs=obs_0, + partner_obs=obs_1, + conf_done=last_dones["agent_0"], + partner_done=last_dones["agent_1"], + conf_hstate=None, + # we record the best response hstate because we use it to reset the best response + partner_hstate=None + ) + new_runner_state = (train_state_conf, ego_param, env_state_next, obs_next, done, rng, current_trained_pop_id) + return new_runner_state, reset_transition + + # Do XP rollout (based on train_state params and the param in pop_buffer identified in Step 1) + runner_state_xp = (train_state, xp_param, max_means_id, env_state_xp, obsv_xp, last_dones_xp, rng_xp) + runner_state_xp, traj_batch_xp = jax.lax.scan( + _env_step_conf_ego, runner_state_xp, None, config["ROLLOUT_LENGTH"]) + (train_state, xp_param, max_means_id, env_state_xp, last_obs_xp, last_dones_xp, rng_xp) = runner_state_xp + + # Do self-play (based on train_state params) rollout like in the IPPO code + runner_state_sp = (train_state, env_state_sp, obsv_sp, last_dones_sp, rng_sp, num_prev_trained_conf, None) + runner_state_sp, (traj_batch_sp_agent0, traj_batch_sp_agent1) = jax.lax.scan( + _env_step_conf_br, runner_state_sp, None, config["ROLLOUT_LENGTH"]) + (train_state, env_state_sp, last_obs_sp, last_dones_sp, rng_sp, num_prev_trained_conf, mp_traj_batch) = runner_state_sp + + # Step 4 + # Do MP rollout (based on train_state params and the param in pop_buffer identified in Step 1) + runner_state_mp = (train_state, xp_param, env_state_mp, obsv_mp, last_dones_mp, rng_mp, num_prev_trained_conf) + runner_state_mp, traj_batch_mp = jax.lax.scan( + _env_step_mixed, runner_state_mp, None, config["ROLLOUT_LENGTH"]) + (train_state, xp_param, env_state_mp, last_obs_mp, last_dones_mp, rng_mp, num_prev_trained_conf) = runner_state_mp + + runner_state_smp = (train_state, env_state_mp2, obsv_mp2, last_dones_mp2, rng_mp2, num_prev_trained_conf, traj_batch_mp) + runner_state_smp, (traj_batch_smp0, traj_batch_smp1) = jax.lax.scan( + _env_step_conf_br, runner_state_smp, None, config["ROLLOUT_LENGTH"]) + (train_state, env_state_mp2, last_obs_mp2, last_dones_mp2, rng_mp2, num_prev_trained_conf, mp2_traj_batch) = runner_state_smp + + def _calculate_gae(traj_batch, last_val): + def _get_advantages(gae_and_next_value, transition): + gae, next_value = gae_and_next_value + done, value, reward = ( + transition.done, + transition.value, + transition.reward, + ) + delta = reward + config["GAMMA"] * next_value * (1 - done) - value + gae = ( + delta + + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae + ) + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + def _compute_advantages_and_targets(env_state, policy, policy_params, policy_hstate, + last_obs, last_dones, traj_batch, agent_name, value_idx=None): + '''Value_idx argument is to support the ActorWithDoubleCritic (confederate) policy, which + has two value heads. Value head 0 models the ego agent while value head 1 models the best response.''' + avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state)[agent_name].astype(jnp.float32) + + # Add one-hot ID of interaction teammate + xp_one_hot_id = jnp.eye(config["POP_SIZE"])[value_idx] + xp_one_hot_id = jnp.expand_dims( + jnp.expand_dims( + xp_one_hot_id, 0 + ), 0 + ) + + # Agent_0 (confederate) action using policy interface + aux_obs = jnp.repeat(xp_one_hot_id, last_obs[agent_name].shape[0], axis=1) + + _, vals, _, _ = policy.get_action_value_policy( + params=policy_params, + obs=last_obs[agent_name].reshape(1, last_obs[agent_name].shape[0], -1), + done=last_dones[agent_name].reshape(1, last_obs[agent_name].shape[0]), + avail_actions=jax.lax.stop_gradient(avail_actions), + hstate=policy_hstate, + rng=jax.random.PRNGKey(0), # dummy key as we don't sample actions + aux_obs=aux_obs + ) + last_val = vals.squeeze() + advantages, targets = _calculate_gae(traj_batch, last_val) + return advantages, targets + + # 5a) Compute conf advantages for XP (conf-ego) interaction + advantages_xp_conf, targets_xp_conf = _compute_advantages_and_targets( + env_state_xp, policy, train_state.params, None, + last_obs_xp, last_dones_xp, traj_batch_xp, "agent_0", value_idx=max_means_id) + + # 5b) Compute conf and br advantages for SP (conf-br) interaction + advantages_sp_conf, targets_sp_conf = _compute_advantages_and_targets( + env_state_sp, policy, train_state.params, None, + last_obs_sp, last_dones_sp, traj_batch_sp_agent0, "agent_0", value_idx=num_prev_trained_conf) + + advantages_sp_br, targets_sp_br = _compute_advantages_and_targets( + env_state_sp, policy, train_state.params, None, + last_obs_sp, last_dones_sp, traj_batch_sp_agent1, "agent_1", value_idx=num_prev_trained_conf) + + # 5c) Compute advantages from MP interactions + advantages_mp_conf, targets_mp_conf = _compute_advantages_and_targets( + env_state_mp2, policy, train_state.params, None, + last_obs_mp2, last_dones_mp2, traj_batch_smp0, "agent_0", value_idx=num_prev_trained_conf) + + advantages_mp_br, targets_mp_br = _compute_advantages_and_targets( + env_state_mp2, policy, train_state.params, None, + last_obs_mp2, last_dones_mp2, traj_batch_smp1, "agent_1", value_idx=num_prev_trained_conf) + + def _update_epoch(update_state, unused): + def _compute_ppo_value_loss(pred_value, traj_batch, target_v): + '''Value loss function for PPO''' + value_pred_clipped = traj_batch.value + ( + pred_value - traj_batch.value + ).clip( + -config["CLIP_EPS"], config["CLIP_EPS"]) + value_losses = jnp.square(pred_value - target_v) + value_losses_clipped = jnp.square(value_pred_clipped - target_v) + value_loss = ( + jnp.maximum(value_losses, value_losses_clipped).mean() + ) + return value_loss + + def _compute_ppo_pg_loss(log_prob, traj_batch, gae): + '''Policy gradient loss function for PPO''' + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae_norm = (gae - gae.mean()) / (gae.std() + 1e-8) + pg_loss_1 = ratio * gae_norm + pg_loss_2 = jnp.clip( + ratio, + 1.0 - config["CLIP_EPS"], + 1.0 + config["CLIP_EPS"]) * gae_norm + pg_loss = -jnp.mean(jnp.minimum(pg_loss_1, pg_loss_2)) + return pg_loss + + def _update_minbatch_conf(train_state_conf, batch_infos): + minbatch_xp, minbatch_sp1, minbatch_sp2, minbatch_mp1, minbatch_mp2, xp_id, sp_id = batch_infos + _, traj_batch_xp, advantages_xp, returns_xp = minbatch_xp + _, traj_batch_sp1, advantages_sp1, returns_sp1 = minbatch_sp1 + _, traj_batch_sp2, advantages_sp2, returns_sp2 = minbatch_sp2 + _, traj_batch_mp1, advantages_mp1, returns_mp1 = minbatch_mp1 + _, traj_batch_mp2, advantages_mp2, returns_mp2 = minbatch_mp2 + + def _loss_fn_conf(params, traj_batch_xp, gae_xp, target_v_xp, + traj_batch_sp, gae_sp, target_v_sp, + traj_batch_sp2, gae_sp2, target_v_sp2, + traj_batch_mp, gae_mp, target_v_mp, + traj_batch_mp2, gae_mp2, target_v_mp2): + # get policy and value of confederate versus ego and best response agents respectively + xp_one_hot_id = jnp.eye(config["POP_SIZE"])[xp_id] + xp_one_hot_id = jnp.expand_dims( + jnp.expand_dims( + xp_one_hot_id, 0 + ), 0 + ) + + sp_one_hot_id = jnp.eye(config["POP_SIZE"])[sp_id] + sp_one_hot_id = jnp.expand_dims( + jnp.expand_dims( + sp_one_hot_id, 0 + ), 0 + ) + + # Agent_0 (confederate) action using policy interface + aux_obs_xp = jnp.repeat(xp_one_hot_id, traj_batch_xp.obs.shape[1], axis=1) + aux_obs_xp = jnp.repeat(aux_obs_xp, traj_batch_xp.obs.shape[0], axis=0) + + _, value_xp, pi_xp, _ = policy.get_action_value_policy( + params=params, + obs=traj_batch_xp.obs, + done=traj_batch_xp.done, + avail_actions=traj_batch_xp.avail_actions, + hstate=None, + rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here + aux_obs=aux_obs_xp + ) + + aux_obs_sp = jnp.repeat(xp_one_hot_id, traj_batch_sp.obs.shape[1], axis=1) + aux_obs_sp = jnp.repeat(aux_obs_sp, traj_batch_sp.obs.shape[0], axis=0) + _, value_sp, pi_sp, _ = policy.get_action_value_policy( + params=params, + obs=traj_batch_sp.obs, + done=traj_batch_sp.done, + avail_actions=traj_batch_sp.avail_actions, + hstate=None, + rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here + aux_obs=aux_obs_sp + ) + + _, value_sp2, pi_sp2, _ = policy.get_action_value_policy( + params=params, + obs=traj_batch_sp2.obs, + done=traj_batch_sp2.done, + avail_actions=traj_batch_sp2.avail_actions, + hstate=None, + rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here + aux_obs=aux_obs_sp + ) + + _, value_mp, pi_mp, _ = policy.get_action_value_policy( + params=params, + obs=traj_batch_mp.obs, + done=traj_batch_mp.done, + avail_actions=traj_batch_mp.avail_actions, + hstate=None, + rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here + aux_obs=aux_obs_sp + ) + + _, value_mp2, pi_mp2, _ = policy.get_action_value_policy( + params=params, + obs=traj_batch_mp2.obs, + done=traj_batch_mp2.done, + avail_actions=traj_batch_mp2.avail_actions, + hstate=None, + rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here + aux_obs=aux_obs_sp + ) + + log_prob_xp = pi_xp.log_prob(traj_batch_xp.action) + log_prob_sp = pi_sp.log_prob(traj_batch_sp.action) + log_prob_sp2 = pi_sp2.log_prob(traj_batch_sp2.action) + log_prob_mp = pi_mp.log_prob(traj_batch_mp.action) + log_prob_mp2 = pi_mp2.log_prob(traj_batch_mp2.action) + + + value_loss_xp = _compute_ppo_value_loss(value_xp, traj_batch_xp, target_v_xp) + value_loss_sp = _compute_ppo_value_loss(value_sp, traj_batch_sp, target_v_sp) + value_loss_sp2 = _compute_ppo_value_loss(value_sp2, traj_batch_sp2, target_v_sp2) + value_loss_mp = _compute_ppo_value_loss(value_mp, traj_batch_mp, target_v_mp) + value_loss_mp2 = _compute_ppo_value_loss(value_mp2, traj_batch_mp2, target_v_mp2) + + pg_loss_xp = _compute_ppo_pg_loss(log_prob_xp, traj_batch_xp, gae_xp) + pg_loss_sp = _compute_ppo_pg_loss(log_prob_sp, traj_batch_sp, gae_sp) + pg_loss_sp2 = _compute_ppo_pg_loss(log_prob_sp2, traj_batch_sp2, gae_sp2) + pg_loss_mp = _compute_ppo_pg_loss(log_prob_mp, traj_batch_mp, gae_mp) + pg_loss_mp2 = _compute_ppo_pg_loss(log_prob_mp2, traj_batch_mp2, gae_mp2) + + + # Entropy for interaction with ego agent + entropy_xp = jnp.mean(pi_xp.entropy()) + entropy_sp = jnp.mean(pi_sp.entropy()) + entropy_sp2 = jnp.mean(pi_sp2.entropy()) + entropy_mp = jnp.mean(pi_mp.entropy()) + entropy_mp2 = jnp.mean(pi_mp2.entropy()) + + xp_pg_weight = -config["COMEDI_ALPHA"] # negate to minimize the ego agent's PG objective + sp_pg_weight = 1.0 + mp2_pg_weight = config["COMEDI_BETA"] + + xp_loss = xp_pg_weight * pg_loss_xp + config["VF_COEF"] * value_loss_xp - config["ENT_COEF"] * entropy_xp + sp_loss = sp_pg_weight * pg_loss_sp + config["VF_COEF"] * value_loss_sp - config["ENT_COEF"] * entropy_sp + sp2_loss = sp_pg_weight * pg_loss_sp2 + config["VF_COEF"] * value_loss_sp2 - config["ENT_COEF"] * entropy_sp2 + mp_loss = mp2_pg_weight * pg_loss_mp + config["VF_COEF"] * value_loss_mp - config["ENT_COEF"] * entropy_mp + mp2_loss = mp2_pg_weight * pg_loss_mp2 + config["VF_COEF"] * value_loss_mp2 - config["ENT_COEF"] * entropy_mp2 + + total_loss = sp_loss + sp2_loss + xp_loss + mp2_loss + mp_loss + return total_loss, (value_loss_xp, value_loss_sp + value_loss_sp2, value_loss_mp + value_loss_mp2, + pg_loss_xp, pg_loss_sp + pg_loss_sp2, pg_loss_mp + pg_loss_mp2, + entropy_xp, entropy_sp + entropy_sp2, entropy_mp + entropy_mp2) + + grad_fn = jax.value_and_grad(_loss_fn_conf, has_aux=True) + (loss_val, aux_vals), grads = grad_fn( + train_state_conf.params, + traj_batch_xp, advantages_xp, returns_xp, + traj_batch_sp1, advantages_sp1, returns_sp1, + traj_batch_sp2, advantages_sp2, returns_sp2, + traj_batch_mp1, advantages_mp1, returns_mp1, + traj_batch_mp2, advantages_mp2, returns_mp2) + train_state_conf = train_state_conf.apply_gradients(grads=grads) + return train_state_conf, (loss_val, aux_vals) + + ( + train_state_conf, traj_batch_xp, + traj_batch_sp_conf, traj_batch_sp_br, + traj_batch_mp_conf, traj_batch_mp_br, + advantages_xp_conf, advantages_sp_conf, + advantages_sp_br, advantages_mp_conf, + advantages_mp_br, targets_xp_conf, + targets_sp_conf, targets_sp_br, + targets_mp_conf, targets_mp_br, + rng, xp_id, sp_id + ) = update_state + + rng, perm_rng_xp, perm_rng_sp_conf, perm_rng_sp_br, perm_rng_mp2_conf, perm_rng_mp2_br = jax.random.split(rng, 6) + + # Create minibatches for each agent and interaction type + minibatches_xp = _create_minibatches( + traj_batch_xp, advantages_xp_conf, targets_xp_conf, None, + config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_xp + ) + minibatches_sp_conf = _create_minibatches( + traj_batch_sp_conf, advantages_sp_conf, targets_sp_conf, None, + config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_sp_conf + ) + minibatches_sp_br = _create_minibatches( + traj_batch_sp_br, advantages_sp_br, targets_sp_br, None, + config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_sp_br + ) + minibatches_mp_conf = _create_minibatches( + traj_batch_mp_conf, advantages_mp_conf, targets_mp_conf, None, + config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_mp2_conf + ) + minibatches_mp_br = _create_minibatches( + traj_batch_mp_br, advantages_mp_br, targets_mp_br, None, + config["NUM_ENVS"], config["NUM_MINIBATCHES"], perm_rng_mp2_br + ) + + # Update confederate + repeated_xp_id = jnp.repeat(xp_id, minibatches_xp[1].obs.shape[0], axis=0) + repeated_sp_id = jnp.repeat(sp_id, minibatches_sp_br[1].obs.shape[0], axis=0) + train_state_conf, total_loss_conf = jax.lax.scan( + _update_minbatch_conf, train_state_conf, ( + minibatches_xp, minibatches_sp_conf, minibatches_sp_br, + minibatches_mp_conf, minibatches_mp_br, repeated_xp_id, repeated_sp_id + ) + ) + + update_state = (train_state_conf, + traj_batch_xp, traj_batch_sp_conf, traj_batch_sp_br, traj_batch_mp_conf, traj_batch_mp_br, + advantages_xp_conf, advantages_sp_conf, advantages_sp_br, advantages_mp_conf, advantages_mp_br, + targets_xp_conf, targets_sp_conf, targets_sp_br, targets_mp_conf, targets_mp_br, + rng, xp_id, sp_id + ) + return update_state, total_loss_conf + + # 3) PPO update + rng, sub_rng = jax.random.split(rng, 2) + update_state = ( + train_state, + traj_batch_xp, traj_batch_sp_agent0, + traj_batch_sp_agent1, + traj_batch_smp0, traj_batch_smp1, + advantages_xp_conf, + advantages_sp_conf, advantages_sp_br, + advantages_mp_conf, advantages_mp_br, + targets_xp_conf, targets_sp_conf, + targets_sp_br, targets_mp_conf, + targets_mp_br, sub_rng, + max_means_id, num_prev_trained_conf + ) + update_state, conf_losses = jax.lax.scan( + _update_epoch, update_state, None, config["UPDATE_EPOCHS"]) + train_state = update_state[0] + + ( + conf_value_loss_xp, conf_value_loss_sp, conf_value_loss_mp, + conf_pg_loss_xp, conf_pg_loss_sp, conf_pg_loss_mp, + conf_entropy_xp, conf_entropy_sp, conf_entropy_mp + ) = conf_losses[1] + + new_update_runner_state = ( + train_state, pop_buffer, + env_state_sp, last_obs_sp, + env_state_xp, last_obs_xp, + env_state_mp, last_obs_mp, + env_state_mp2, last_obs_mp2, + last_dones_xp, last_dones_sp, + last_dones_mp, last_dones_mp2, + rng, update_steps+1, num_prev_trained_conf + ) + + # Metrics + def mask_and_mean(x, mask): + return jnp.where(mask, x, 0).sum() / jnp.maximum(1, mask.sum()) + + mask = traj_batch_xp.info.get("returned_episode", jnp.ones_like(traj_batch_xp.reward)) + metric = jax.tree.map(lambda x: mask_and_mean(x, mask), traj_batch_xp.info) + metric["update_steps"] = update_steps + metric["value_loss_conf_xp"] = conf_value_loss_xp.mean() + metric["value_loss_conf_sp"] = conf_value_loss_sp.mean() + metric["value_loss_conf_mp"] = conf_value_loss_mp.mean() + + metric["pg_loss_conf_xp"] = conf_pg_loss_xp.mean() + metric["pg_loss_conf_sp"] = conf_pg_loss_sp.mean() + metric["pg_loss_conf_mp"] = conf_pg_loss_mp.mean() + + metric["entropy_conf_xp"] = conf_entropy_xp.mean() + metric["entropy_conf_sp"] = conf_entropy_sp.mean() + metric["entropy_conf_mp"] = conf_entropy_mp.mean() + + metric["average_rewards_ego"] = jnp.mean(traj_batch_xp.reward) + metric["average_rewards_br_sp"] = jnp.mean(traj_batch_sp_agent1.reward) + metric["average_rewards_br_mp2"] = jnp.mean(traj_batch_smp1.reward) + + return (new_update_runner_state, checkpoint_array, ckpt_idx+1), metric + + # XP eval against all policies in the buffer + xp_eval_returns = jax.tree.map(lambda x: x.mean(axis=(-2, -1)), + jax.vmap(per_id_run_episode_fixed_rng, in_axes=(None, 0))( + train_state.params,jnp.arange(config["POP_SIZE"]))) + + # SP performance against itself + sp_eval_returns = jax.tree.map(lambda x: x.mean(), run_episodes( + eval_rng, env, + agent_0_param=train_state.params, agent_0_policy=policy, + agent_1_param=train_state.params, agent_1_policy=policy, + max_episode_steps=config["ROLLOUT_LENGTH"], + num_eps=config["NUM_EVAL_EPISODES"] + )) + + + update_steps = 0 + init_done_xp = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]} + init_done_sp = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]} + init_done_mp = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]} + init_done_mp2 = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]} + + update_runner_state = ( + train_state, pop_buffer, + env_state_sp, obsv_sp, + env_state_xp, obsv_xp, + env_state_mp, obsv_mp, + env_state_mp2, obsv_mp2, + init_done_xp, init_done_sp, + init_done_mp, init_done_mp2, + rng, update_steps, + num_existing_agents + ) + + checkpoint_array = init_ckpt_array(train_state.params) + ckpt_idx = 0 + update_with_ckpt_runner_state = (update_runner_state, checkpoint_array, ckpt_idx, xp_eval_returns, sp_eval_returns) + + def _update_step_with_ckpt(state_with_ckpt, unused): + + (update_runner_state, checkpoint_array, ckpt_idx, xp_eval_returns, sp_eval_returns) = state_with_ckpt + train_state = update_runner_state[0] + + # Single PPO update + new_state_with_ckpt, metric = _update_step( + (update_runner_state, checkpoint_array, ckpt_idx), + None + ) + new_update_runner_state = new_state_with_ckpt[0] + rng, update_steps = new_update_runner_state[-3], new_update_runner_state[-2] + + # Decide if we store a checkpoint + # update steps is 1-indexed because it was incremented at the end of the update step + to_store = jnp.logical_or(jnp.equal(jnp.mod(update_steps-1, ckpt_and_eval_interval), 0), + jnp.equal(update_steps, config["NUM_UPDATES"])) + + def store_and_eval_ckpt(args): + ckpt_arr_conf, rng, cidx, _, _ = args + new_ckpt_arr_conf = jax.tree.map( + lambda c_arr, p: c_arr.at[cidx].set(p), + ckpt_arr_conf, train_state.params + ) + + # Eval trained agent against all params in the pool + xp_eval_returns = jax.tree.map(lambda x: x.mean(axis=(-2, -1)), + jax.vmap(per_id_run_episode_fixed_rng, in_axes=(None, 0))( + train_state.params, jnp.arange(config["POP_SIZE"]))) + # Eval trained agent against itself + sp_eval_returns = jax.tree.map(lambda x: x.mean(), run_episodes( + eval_rng, env, + agent_0_param=train_state.params, agent_0_policy=policy, + agent_1_param=train_state.params, agent_1_policy=policy, + max_episode_steps=config["ROLLOUT_LENGTH"], + num_eps=config["NUM_EVAL_EPISODES"] + )) + + return (new_ckpt_arr_conf, rng, cidx + 1, xp_eval_returns, sp_eval_returns) + + def skip_ckpt(args): + return args + + rng, store_and_eval_rng = jax.random.split(rng, 2) + (checkpoint_array, store_and_eval_rng, ckpt_idx, xp_eval_returns, sp_eval_returns) = jax.lax.cond( + to_store, + store_and_eval_ckpt, + skip_ckpt, + (checkpoint_array, store_and_eval_rng, ckpt_idx, xp_eval_returns, sp_eval_returns) + ) + + return (new_update_runner_state, checkpoint_array, + ckpt_idx, xp_eval_returns, sp_eval_returns), (metric, xp_eval_returns, sp_eval_returns) + + new_update_with_ckpt_runner_state, (metric, xp_eval_returns, sp_eval_returns) = jax.lax.scan( + _update_step_with_ckpt, + update_with_ckpt_runner_state, + xs=None, # No per-step input data + length=config["NUM_UPDATES"], + ) + new_update_runner_state, new_checkpoint_array, _, _ ,_ = new_update_with_ckpt_runner_state + final_train_state = new_update_runner_state[0] + + updated_pop_buffer = partner_population.add_agent(pop_buffer, final_train_state.params) + conf_checkpoints = new_checkpoint_array + return updated_pop_buffer, (conf_checkpoints, metric, xp_eval_returns, sp_eval_returns) + + rngs = jax.random.split(rng, config["PARTNER_POP_SIZE"]) + rng, add_conf_iter_rngs = rngs[0], rngs[1:] + + iter_ids = jnp.arange(1, config["PARTNER_POP_SIZE"]) + final_population_buffer, (conf_checkpoints, metric, xp_eval_returns, sp_eval_returns) = jax.lax.scan( + add_conf_policy, population_buffer, (iter_ids, add_conf_iter_rngs) + ) + + out = { + "final_params_conf": final_population_buffer.params, + "checkpoints_conf": conf_checkpoints, + "metrics": metric, + "last_ep_infos_xp": xp_eval_returns, + "last_ep_infos_sp": sp_eval_returns + } + + return out + return train + + train_fn = make_comedi_agents(config) + out = train_fn(train_rng) + return out + +def get_comedi_population(config, out, env): + ''' + Get the partner params and partner population for ego training. + ''' + comedi_pop_size = config["algorithm"]["PARTNER_POP_SIZE"] + + # partner_params has shape (num_seeds, comedi_pop_size, ...) + partner_params = out['final_params_conf'] + + partner_policy = ActorWithConditionalCriticPolicy( + action_dim=env.action_space(env.agents[1]).n, + obs_dim=env.observation_space(env.agents[1]).shape[0], + pop_size=comedi_pop_size, # used to create onehot agent id + activation=config["algorithm"].get("ACTIVATION", "tanh") + ) + + # Create partner population + partner_population = AgentPopulation( + pop_size=comedi_pop_size, + policy_cls=partner_policy + ) + + return partner_params, partner_population + +def run_comedi(config, wandb_logger): + algorithm_config = dict(config["algorithm"]) + + env = make_env(algorithm_config["ENV_NAME"], algorithm_config["ENV_KWARGS"]) + env = LogWrapper(env) + + log.info("Starting CoMeDi training...") + start = time.time() + + # Generate multiple random seeds from the base seed + rng = jax.random.PRNGKey(algorithm_config["TRAIN_SEED"]) + rngs = jax.random.split(rng, algorithm_config["NUM_SEEDS"]) + + # Create a vmapped version of train_comedi_partners + with jax.disable_jit(False): + vmapped_train_fn = jax.jit( + jax.vmap( + partial(train_comedi_partners, + wandb_logger=wandb_logger, + env=env, + config=algorithm_config) + ) + ) + out = vmapped_train_fn(rngs) + + end = time.time() + log.info(f"CoMeDi training complete in {end - start} seconds") + + metric_names = get_metric_names(algorithm_config["ENV_NAME"]) + + # Save FIRST so the checkpoint survives even if metric logging OOMs. + savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + out_savepath = save_train_run(out, savedir, savename="saved_train_run") + log_metrics(config, out, wandb_logger, metric_names, out_savepath) + partner_params, partner_population = get_comedi_population(config, out, env) + return partner_params, partner_population + +def compute_sp_mask_and_ids(pop_size): + cross_product = np.meshgrid( + np.arange(pop_size), + np.arange(pop_size) + ) + agent_id_cartesian_product = np.stack([g.ravel() for g in cross_product], axis=-1) + conf_ids = agent_id_cartesian_product[:, 0] + ego_ids = agent_id_cartesian_product[:, 1] + sp_mask = (conf_ids == ego_ids) + return sp_mask, agent_id_cartesian_product + +def log_metrics(config, outs, logger, metric_names: tuple, out_savepath): + metrics = outs["metrics"] + # trained_pop_size excludes the initial policy + num_seeds, pop_size, num_updates = metrics["pg_loss_conf_sp"].shape + # TODO: add the eval_ep_last_info metrics + + ### Log evaluation metrics + # xp_eval_returns and sp_eval_returns logged at each evaluation only. + algorithm_config = config["algorithm"] + ckpt_and_eval_interval = max(1, num_updates // max(1, algorithm_config["NUM_CHECKPOINTS"] - 1)) + # Steps at which store_and_eval_ckpt fires (0-indexed, matching the update_step logged below) + eval_steps = list(range(0, num_updates, ckpt_and_eval_interval)) + if (num_updates - 1) not in eval_steps: + eval_steps.append(num_updates - 1) + + # shape (num_seeds, pop_size - 1, num_updates) [pre-scalarized: mean over eval eps and agents taken inside scan] + all_returns_sp = np.asarray(outs["last_ep_infos_sp"]["returned_episode_returns"]) + # shape (num_seeds, pop_size - 1, num_updates, pop_size) [pre-scalarized: mean over eval eps and agents taken inside scan] + all_returns_xp = np.asarray(outs["last_ep_infos_xp"]["returned_episode_returns"]) + + # Average over seeds only (eval episodes and agents already averaged inside scan) + sp_return_curve = all_returns_sp.mean(axis=0) # shape (pop_size - 1, num_updates) + xp_return_curve = all_returns_xp.mean(axis=0) # shape (pop_size - 1, num_updates, pop_size) + + for num_add_policies in range(pop_size): + for update_step in eval_steps: + logger.log_item("Eval/AvgSPReturnCurve", sp_return_curve[num_add_policies, update_step], train_step=update_step) + mean_xp_returns = xp_return_curve[num_add_policies, :, :(num_add_policies+1)].mean(axis=-1) + logger.log_item("Eval/AvgXPReturnCurve", mean_xp_returns[update_step], train_step=update_step) + logger.commit() + + ### Log population loss as multi-line plots, where each line is a different population member + # both xp and xp metrics has shape (num_seeds, pop_size - 1, num_updates, update_epochs, num_minibatches) + # Average over seeds + processed_losses = { + "ConfPGLossSP": np.asarray(metrics["pg_loss_conf_sp"]).mean(axis=0), # desired shape (pop_size - 1, num_updates) + "ConfPGLossXP": np.asarray(metrics["pg_loss_conf_xp"]).mean(axis=0), + "ConfPGLossMP": np.asarray(metrics["pg_loss_conf_mp"]).mean(axis=0), + "ConfValLossSP": np.asarray(metrics["value_loss_conf_sp"]).mean(axis=0), + "ConfValLossXP": np.asarray(metrics["value_loss_conf_xp"]).mean(axis=0), + "ConfValLossMP": np.asarray(metrics["value_loss_conf_mp"]).mean(axis=0), + "EntropySP": np.asarray(metrics["entropy_conf_sp"]).mean(axis=0), + "EntropyXP": np.asarray(metrics["entropy_conf_xp"]).mean(axis=0), + "EntropyMP": np.asarray(metrics["entropy_conf_mp"]).mean(axis=0), + } + + xs = list(range(num_updates)) + keys = [f"pair {i}" for i in range(pop_size)] + + for loss_name, loss_data in processed_losses.items(): + logger.log_item(f"Losses/{loss_name}", + wandb.plot.line_series(xs=xs, ys=loss_data, keys=keys, + title=loss_name, xname="train_step") + ) + + ### Log artifacts (already saved by caller; just publish to wandb) + if config["logger"]["log_train_out"]: + logger.log_artifact(name="saved_train_run", path=out_savepath, type_name="train_run") + + # Cleanup locally logged out files + if not config["local_logger"]["save_train_out"]: + shutil.rmtree(out_savepath) diff --git a/teammate_generation/LBRDiv.py b/teammate_generation/LBRDiv.py new file mode 100644 index 0000000000000000000000000000000000000000..14b7df0d9395c6d0ca020f7047b7dff17869c55f --- /dev/null +++ b/teammate_generation/LBRDiv.py @@ -0,0 +1,1098 @@ +'''Implementation of the LBRDiv teammate generation algorithm (Rahman et al., AAAI 2024) +https://ojs.aaai.org/index.php/AAAI/article/view/29702 + +Command to run LBRDiv only on LBF: +python teammate_generation/run.py algorithm=lbrdiv/lbf/lbf_7x7_nolevels task=lbf/lbf_7x7_nolevels label=test_lbrdiv run_heldout_eval=false train_ego=false + +Suggested Debug command: +python teammate_generation/run.py algorithm=lbrdiv/lbf/lbf_7x7_nolevels task=lbf/lbf_7x7_nolevels logger.mode=disabled label=debug algorithm.TOTAL_TIMESTEPS=1e5 algorithm.PARTNER_POP_SIZE=2 train_ego=false run_heldout_eval=false + +Limitations: does not support recurrent actors. +''' +import shutil +import time +import logging +from functools import partial + +import hydra +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax.training.train_state import TrainState +import wandb + +from agents.mlp_actor_critic_agent import ActorWithConditionalCriticPolicy +from agents.population_interface import AgentPopulation +from common.plot_utils import get_metric_names +from common.run_episodes import run_episodes +from common.save_load_utils import save_train_run +from envs import make_env +from envs.log_wrapper import LogWrapper +from marl.ppo_utils import unbatchify, _create_minibatches +from teammate_generation.BRDiv import _get_all_ids, XPTransition, gather_params + +log = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def train_lbrdiv_partners(train_rng, env, config, conf_policy, br_policy): + num_agents = env.num_agents + assert num_agents == 2, "This code assumes the environment has exactly 2 agents." + + # Define different minibatch sizes for interactions with ego agent and one with BR agent + config["NUM_GAME_AGENTS"] = num_agents + config["NUM_CONF_ACTORS"] = config["NUM_ENVS"] + config["NUM_BR_ACTORS"] = config["NUM_ENVS"] + config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // (config["ROLLOUT_LENGTH"] * config["NUM_ENVS"]) + + def make_lbrdiv_agents(config): + def linear_schedule(count): + frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"] + return config["LR"] * frac + + def train(rng): + rng, init_conf_rng, init_br_rng = jax.random.split(rng, 3) + all_conf_init_rngs = jax.random.split(init_conf_rng, config["PARTNER_POP_SIZE"]) + all_br_init_rngs = jax.random.split(init_br_rng, config["PARTNER_POP_SIZE"]) + identity_matrix = jnp.eye(config["PARTNER_POP_SIZE"]) + + init_conf_hstate = conf_policy.init_hstate(config["NUM_CONF_ACTORS"]) + init_br_hstate = br_policy.init_hstate(config["NUM_BR_ACTORS"]) + + def init_train_states(rng_agents, rng_brs): + def init_single_pair_optimizers(rng_agent, rng_br): + init_params_conf = conf_policy.init_params(rng_agent) + init_params_br = br_policy.init_params(rng_br) + return init_params_conf, init_params_br + + init_all_networks_and_optimizers = jax.vmap(init_single_pair_optimizers) + all_conf_params, all_br_params = init_all_networks_and_optimizers(rng_agents, rng_brs) + + # Define optimizers for both confederate and BR policy + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(learning_rate=linear_schedule if config["ANNEAL_LR"] else config["LR"], + eps=1e-5), + ) + tx_br = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(learning_rate=linear_schedule if config["ANNEAL_LR"] else config["LR"], + eps=1e-5), + ) + + train_state_conf = TrainState.create( + apply_fn=conf_policy.network.apply, + params=all_conf_params, + tx=tx, + ) + + train_state_br = TrainState.create( + apply_fn=br_policy.network.apply, + params=all_br_params, + tx=tx_br, + ) + + return train_state_conf, train_state_br + + all_conf_optims, all_br_optims = init_train_states( + all_conf_init_rngs, all_br_init_rngs + ) + + def forward_pass_conf(params, obs, id, done, avail_actions, hstate, rng): + act, val, pi, new_hstate = conf_policy.get_action_value_policy( + params=params, + obs=obs[jnp.newaxis, ...], + done=done[jnp.newaxis, ...], + avail_actions=avail_actions, + hstate=hstate, + rng=rng, + aux_obs=id[jnp.newaxis, ...] + ) + return act, val, pi, new_hstate + + def forward_pass_br(params, obs, id, done, avail_actions, hstate, rng): + act, val, pi, new_hstate = br_policy.get_action_value_policy( + params=params, + obs=obs[jnp.newaxis, ...], + done=done[jnp.newaxis, ...], + avail_actions=avail_actions, + hstate=hstate, + rng=rng, + aux_obs=id[jnp.newaxis, ...] + ) + return act, val, pi, new_hstate + + def _env_step(runner_state, unused): + """ + agent_0 = confederate, agent_1 = br + Returns updated runner_state, and Transitions for agent_0 and agent_1 + """ + ( + all_train_state_conf, all_train_state_br, last_conf_ids, last_br_ids, + env_state, last_obs, last_done, last_conf_h, last_br_h, rng + ) = runner_state + rng, act0_rng, act1_rng, step_rng, conf_sampling_rng, br_sampling_rng = jax.random.split(rng, 6) + + # For done envs, resample both conf and brs + needs_resample = last_done["__all__"] + resampled_conf_ids = jax.random.randint(conf_sampling_rng, (config["NUM_CONF_ACTORS"],), 0, config["PARTNER_POP_SIZE"]) + resampled_br_ids = jax.random.randint(br_sampling_rng, (config["NUM_BR_ACTORS"],), 0, config["PARTNER_POP_SIZE"]) + + # Determine final indices based on whether resampling was needed for each env + updated_conf_ids = jnp.where( + needs_resample, + resampled_conf_ids, # Use newly sampled index if True + last_conf_ids # Else, keep index from previous step + ) + + updated_br_ids = jnp.where( + needs_resample, + resampled_br_ids, # Use newly sampled index if True + last_br_ids # Else, keep index from previous step + ) + + # Reset the hidden states for resampled conf and br if they are not None + # WARNING: (L)BRDiv was not tested with recurrent actors, so the code for if the hstate is not None may not work + if last_conf_h is not None: + updated_conf_h = jnp.where( + needs_resample, + init_conf_hstate, + last_conf_h + ) + else: + updated_conf_h = last_conf_h + + if last_br_h is not None: + updated_br_h = jnp.where( + needs_resample, + init_br_hstate, + last_br_h + ) + else: + updated_br_h = last_br_h + + # Get the corresponding conf and br params + updated_conf_params = gather_params(all_train_state_conf.params, updated_conf_ids) + updated_br_params = gather_params(all_train_state_br.params, updated_br_ids) + + updated_conf_onehot_ids = identity_matrix[updated_conf_ids] + updated_br_onehot_ids = identity_matrix[updated_br_ids] + + # Get available actions for agent 0 from environment state + avail_actions = jax.vmap(env.get_avail_actions)(env_state.env_state) + avail_actions = jax.lax.stop_gradient(avail_actions) + avail_actions_0 = avail_actions["agent_0"].astype(jnp.float32) + avail_actions_1 = avail_actions["agent_1"].astype(jnp.float32) + + # Agent_0 action + act0_rng = jax.random.split(act0_rng, config["NUM_ENVS"]) + act_0, val_0, pi_0, new_conf_h = jax.vmap(forward_pass_conf)(updated_conf_params, + last_obs["agent_0"], updated_br_onehot_ids, last_done["agent_0"], avail_actions_0, + updated_conf_h, act0_rng) + logp_0 = pi_0.log_prob(act_0) + act_0, val_0, logp_0 = act_0.squeeze(), val_0.squeeze(), logp_0.squeeze() + + # Agent_1 action + act1_rng = jax.random.split(act1_rng, config["NUM_ENVS"]) + act_1, val_1, pi_1, new_br_h = jax.vmap(forward_pass_br)(updated_br_params, + last_obs["agent_1"], updated_conf_onehot_ids, last_done["agent_1"], avail_actions_1, + updated_br_h, act1_rng) + logp_1 = pi_1.log_prob(act_1) + act_1, val_1, logp_1 = act_1.squeeze(), val_1.squeeze(), logp_1.squeeze() + + # Combine actions into the env format + combined_actions = jnp.concatenate([act_0, act_1], axis=0) + env_act = unbatchify(combined_actions, env.agents, config["NUM_ENVS"], num_agents) + env_act = {k: v.flatten() for k, v in env_act.items()} + + # Step env + step_rngs = jax.random.split(step_rng, config["NUM_ENVS"]) + obs_next, env_state_next, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))( + step_rngs, env_state, env_act + ) + # note that num_actors = num_envs * num_agents + info_0 = jax.tree.map(lambda x: x[:, 0], info) + info_1 = jax.tree.map(lambda x: x[:, 1], info) + + # Store agent_0 data in transition + transition_0 = XPTransition( + done=done["agent_0"], + action=act_0, + value=val_0, + self_onehot_id=updated_conf_onehot_ids, + oppo_onehot_id=updated_br_onehot_ids, + reward=reward["agent_1"], + log_prob=logp_0, + obs=last_obs["agent_0"], + info=info_0, + avail_actions=avail_actions_0 + ) + + transition_1 = XPTransition( + done=done["agent_1"], + action=act_1, + value=val_1, + self_onehot_id=updated_br_onehot_ids, + oppo_onehot_id=updated_conf_onehot_ids, + reward=reward["agent_1"], + log_prob=logp_1, + obs=last_obs["agent_1"], + info=info_1, + avail_actions=avail_actions_1 + ) + new_runner_state = (all_train_state_conf, all_train_state_br, updated_conf_ids, updated_br_ids, + env_state_next, obs_next, done, new_conf_h, new_br_h, rng) + return new_runner_state, (transition_0, transition_1) + + def _calculate_gae(traj_batch, last_val): + def _get_advantages(gae_and_next_value, transition): + gae, next_value = gae_and_next_value + done, value, reward = ( + transition.done, + transition.value, + transition.reward, + ) + delta = reward + config["GAMMA"] * next_value * (1 - done) - value + gae = ( + delta + + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae + ) + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + def run_all_episodes(rng, train_state_conf, train_state_br): + conf_ids, br_ids = _get_all_ids(config["PARTNER_POP_SIZE"]) + gathered_conf_model_params = gather_params(train_state_conf.params, conf_ids) + gathered_br_model_params = gather_params(train_state_br.params, br_ids) + + rng, eval_rng = jax.random.split(rng) + def run_episodes_fixed_rng(conf_param, br_param): + return run_episodes( + eval_rng, env, + conf_param, conf_policy, + br_param, br_policy, + config["ROLLOUT_LENGTH"], config["NUM_EVAL_EPISODES"], + ) + ep_infos = jax.vmap(run_episodes_fixed_rng)( + gathered_conf_model_params, gathered_br_model_params, # leaves where shape is (pop_size*pop_size, ...) + ) + return ep_infos + + def _update_epoch(update_state, unused): + def _update_minbatch(all_train_states, all_data): + train_state_conf, train_state_br = all_train_states + minbatch_conf, minbatch_br, lms_vertical, lms_horizontal = all_data + + def _loss_fn(param, agent_policy, minbatch, agent_id, lms_vertical, lms_horizontal): + '''Compute loss for agent corresponding to agent_id. + ''' + init_hstate, traj_batch, gae, target_v = minbatch + # get policy and value of confederate versus ego and best response agents respectively + squeezed_param = jax.tree.map(lambda x: jnp.squeeze(x, 0), param) + _, value, pi, _ = agent_policy.get_action_value_policy( + params=squeezed_param, + obs=traj_batch.obs, + done=traj_batch.done, + avail_actions=traj_batch.avail_actions, + hstate=init_hstate, + rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here + aux_obs=traj_batch.oppo_onehot_id + ) + log_prob = pi.log_prob(traj_batch.action) + + is_relevant = jnp.equal( + jnp.argmax(traj_batch.self_onehot_id, axis=-1), + agent_id + ) + loss_weights = jnp.where(is_relevant, 1, 0).astype(jnp.float32) + int_self_id = jnp.argmax(traj_batch.self_onehot_id, axis=-1) + int_oppo_id = jnp.argmax(traj_batch.oppo_onehot_id, axis=-1) + + # Given a pair of policies that generate SP trajectories, + # compute the pair's total Lagrange multiplier in the Lagrange dual. + # Assuming the SP data is generated by population i, the total LMs + # amounts to \sum_{j}*lms_vertical[i][j] + \sum_{j}*lms_horizontal[i][j] + + def _gather_sp_weights(ids): + s_id, _ = ids + return jnp.sum(lms_vertical, axis=-1)[s_id], jnp.sum(lms_horizontal, axis=-1)[s_id] + + # Given a pair of policies that generate XP trajectories, + # compute the pair's total Lagrange multiplier in the Lagrange dual. + # Assuming the XP data is generated by the i^th conf policy and the j^th BR policy, + # the total LMs amounts to + # -lms_vertical[j][i] -lms_horizontal[i][j] + + def _gather_xp_weights(ids): + s_id, o_id = ids + return -lms_vertical[s_id][o_id], -lms_horizontal[o_id][s_id] + + def _get_weights(s_id, o_id): + return jax.lax.cond( + jnp.equal(s_id, o_id), + _gather_sp_weights, + _gather_xp_weights, + (s_id, o_id) + ) + + # Value loss + value_pred_clipped = traj_batch.value + ( + value - traj_batch.value + ).clip( + -config["CLIP_EPS"], config["CLIP_EPS"]) + value_losses = jnp.square(value - target_v) + value_losses_clipped = jnp.square(value_pred_clipped - target_v) + value_loss = jax.lax.cond( + loss_weights.sum() == 0, + lambda x: jnp.zeros_like(x).astype(jnp.float32), + lambda x: x, + (loss_weights * jnp.maximum(value_losses, value_losses_clipped)).sum() / (loss_weights.sum() + 1e-8) + ) + + # # Apply different loss weights for SP and XP data + # # Loss weights consist of two parts: the first term is the weighting from the (L)BRDiv loss fucntion + # # which is based on the sum of Lagrange multipliers for a given confederate-ego pair expected returns + # # in the Lagrange dual formulation. This is indicated by weights1 + weights2 in the code below. + + # # The second term is a reweighting term to compensate for the data collection process, which uniformly and independently + # # samples the conf and br ids from 1, ..., n, resulting in P(SP) = 1/n and P(XP) = (n-1)/n. + # # To prevent the XP loss term from dominating the SP loss term, we would like P(SP) = P(XP) = 1/2. + # # Thus, we set the 2nd term of the SP weight to n/2, and the 2nd term of the XP weight to n/(2 * (n-1)). + + n = config["PARTNER_POP_SIZE"] + is_sp = jnp.equal(jnp.argmax(traj_batch.self_onehot_id, axis=-1), jnp.argmax(traj_batch.oppo_onehot_id, axis=-1)) + weights1, weights2 = jax.vmap(jax.vmap(_get_weights))(int_self_id, int_oppo_id) + actor_weights_sp = (weights1 + weights2) * (n/2) + actor_weights_xp = (weights1 + weights2) * (n / (2 * (n-1))) + actor_weights = jnp.where(is_sp, actor_weights_sp, actor_weights_xp) + + # Policy gradient loss + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae_norm = (gae - gae.mean()) / (gae.std() + 1e-8) + pg_loss_1 = ratio * actor_weights * gae_norm + pg_loss_2 = jnp.clip( + ratio, + 1.0 - config["CLIP_EPS"], + 1.0 + config["CLIP_EPS"]) * actor_weights * gae_norm + pg_loss = jax.lax.cond( + loss_weights.sum() == 0, + lambda x: jnp.zeros_like(x).astype(jnp.float32), + lambda x: x, + -( + loss_weights * jnp.minimum(pg_loss_1, pg_loss_2) + ).sum()/(loss_weights.sum() + 1e-8) + ) + + # Weight entropy based on actor weights + all_sp_weights1, all_sp_weights2 = jax.vmap(_gather_sp_weights)((int_self_id, int_self_id)) + entropy_scaler = jnp.maximum(all_sp_weights1, all_sp_weights2) + + # Compute entropy loss + entropy = jax.lax.cond( + loss_weights.sum() == 0, + lambda x: jnp.zeros_like(x).astype(jnp.float32), + lambda x: x, + (loss_weights * entropy_scaler * pi.entropy()).sum()/(loss_weights.sum() + 1e-8) + ) + + total_loss = pg_loss + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy + return total_loss, (value_loss, pg_loss, entropy) + + possible_agent_ids = jnp.expand_dims(jnp.arange(config["PARTNER_POP_SIZE"]), 1) + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + + def gather_conf_params_and_return_grads(agent_id): + # transposing the lm matrices only on the confederate agent side + # ensures that both the confederate and br policy that interact + # to generate a trajectory have the same weights when computing + # the policy gradient loss. + param_vector = gather_params(train_state_conf.params, agent_id) + (loss_val_conf, aux_vals_conf), grads_conf = grad_fn( + param_vector, conf_policy, minbatch_conf, agent_id, + jnp.transpose(lms_vertical), jnp.transpose(lms_horizontal) + ) + return (loss_val_conf, aux_vals_conf), grads_conf + + def gather_br_params_and_return_grads(agent_id): + param_vector = gather_params(train_state_br.params, agent_id) + (loss_val_br, aux_vals_br), grads_br = grad_fn( + param_vector, br_policy, minbatch_br, agent_id, + lms_vertical, lms_horizontal + ) + return (loss_val_br, aux_vals_br), grads_br + + (loss_val_conf, aux_vals_conf), grads_conf = jax.vmap(gather_conf_params_and_return_grads)(possible_agent_ids) + (loss_val_br, aux_vals_br), grads_br = jax.vmap(gather_br_params_and_return_grads)(possible_agent_ids) + + grads_conf_new = jax.tree.map(lambda x: jnp.squeeze(x, 1), grads_conf) + grads_br_new = jax.tree.map(lambda x: jnp.squeeze(x, 1), grads_br) + train_state_conf = train_state_conf.apply_gradients(grads=grads_conf_new) + train_state_br = train_state_br.apply_gradients(grads=grads_br_new) + return (train_state_conf, train_state_br), ((loss_val_conf, aux_vals_conf), (loss_val_br, aux_vals_br)) + + ( + train_state_conf, train_state_br, + traj_batch_conf, traj_batch_br, + advantages_conf, advantages_br, + targets_conf, targets_br, + rng, lms_vertical, lms_horizontal + ) = update_state + rng, perm_rng_conf, perm_rng_br = jax.random.split(rng, 3) + + minibatches_conf = _create_minibatches(traj_batch_conf, advantages_conf, targets_conf, init_conf_hstate, + config["NUM_CONF_ACTORS"], config["NUM_MINIBATCHES"], perm_rng_conf) + minibatches_br = _create_minibatches(traj_batch_br, advantages_br, targets_br, init_br_hstate, + config["NUM_BR_ACTORS"], config["NUM_MINIBATCHES"], perm_rng_br) + + # Update both policies + num_minibatches = minibatches_br[1].obs.shape[0] + + repeated_lms_vertical = lms_vertical[jnp.newaxis, ...].repeat(num_minibatches, axis=0) + repeated_lms_horizontal = lms_horizontal[jnp.newaxis, ...].repeat(num_minibatches, axis=0) + + (train_state_conf, train_state_br), all_losses = jax.lax.scan( + _update_minbatch, (train_state_conf, train_state_br), + (minibatches_conf, minibatches_br, repeated_lms_vertical, repeated_lms_horizontal) + ) + + update_state = (train_state_conf, train_state_br, + traj_batch_conf, traj_batch_br, + advantages_conf, advantages_br, + targets_conf, targets_br, + rng, lms_vertical, lms_horizontal + ) + return update_state, all_losses + + def _update_step(update_runner_state, unused): + """ + 1. Collect rollouts + 2. Compute advantage + 3. PPO updates (UPDATE_EPOCHS epochs) + 4. Lagrange multiplier update (once, after all PPO epochs) + """ + ( + all_train_state_conf, all_train_state_br, + last_env_state, last_obs, last_done, last_conf_h, last_br_h, + rng, update_steps, lms_vertical, lms_horizontal + ) = update_runner_state + + rng, conf_sampling_rng, br_sampling_rng = jax.random.split(rng, 3) + + conf_ids = jax.random.randint(conf_sampling_rng, (config["NUM_ENVS"],), 0, config["PARTNER_POP_SIZE"]) + br_ids = jax.random.randint(br_sampling_rng, (config["NUM_ENVS"],), 0, config["PARTNER_POP_SIZE"]) + + runner_state = ( + all_train_state_conf, all_train_state_br, conf_ids, br_ids, + last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng + ) + runner_state, traj_batch = jax.lax.scan( + _env_step, runner_state, None, config["ROLLOUT_LENGTH"]) + (all_train_state_conf, all_train_state_br, last_conf_ids, last_br_ids, + last_env_state, last_obs, last_done, last_conf_h, last_br_h, rng) = runner_state + + # Get the last conf and br params and ids + last_conf_params = gather_params(all_train_state_conf.params, last_conf_ids) + last_br_params = gather_params(all_train_state_br.params, last_br_ids) + + last_conf_one_hots = identity_matrix[last_conf_ids] + last_br_one_hots = identity_matrix[last_br_ids] + + # Get agent 0 and agent 1 trajectories from interaction between conf policy and its BR policy. + traj_batch_conf, traj_batch_br = traj_batch + + # Compute advantage for confederate agent from interaction with br policy + avail_actions_0 = jax.vmap(env.get_avail_actions)(last_env_state.env_state)["agent_0"].astype(jnp.float32) + _, last_val_conf, _, _ = jax.vmap(forward_pass_conf)( + params=last_conf_params, + obs=last_obs["agent_0"], + id=last_br_one_hots, + done=last_done["agent_0"], + avail_actions=avail_actions_0, + hstate=last_conf_h, + rng=jax.random.split(jax.random.PRNGKey(0), config["NUM_ENVS"]) # Dummy key since we're just extracting the value + ) + last_val_conf = last_val_conf.squeeze() + advantages_conf, targets_conf = _calculate_gae(traj_batch_conf, last_val_conf) + + # Compute advantage for br policy from interaction with confederate agent + avail_actions_1 = jax.vmap(env.get_avail_actions)(last_env_state.env_state)["agent_1"].astype(jnp.float32) + _, last_val_br, _, _ = jax.vmap(forward_pass_br)( + params=last_br_params, + obs=last_obs["agent_1"], + id=last_conf_one_hots, + done=last_done["agent_1"], + avail_actions=avail_actions_1, + hstate=last_br_h, + rng=jax.random.split(jax.random.PRNGKey(0), config["NUM_ENVS"]) # Dummy key since we're just extracting the value + ) + last_val_br = last_val_br.squeeze() + advantages_br, targets_br = _calculate_gae(traj_batch_br, last_val_br) + + # 3) PPO update + rng, update_rng = jax.random.split(rng, 2) + update_state = ( + all_train_state_conf, all_train_state_br, + traj_batch_conf, traj_batch_br, + advantages_conf, advantages_br, + targets_conf, targets_br, + update_rng, lms_vertical, lms_horizontal + ) + + update_state, all_losses = jax.lax.scan( + _update_epoch, update_state, None, config["UPDATE_EPOCHS"]) + all_train_state_conf, all_train_state_br = update_state[:2] + lms_vertical, lms_horizontal = update_state[-2:] + + # Compute Lagrange gradient updates once per update step (after all PPO epochs). + # Diagonal and off-diagonal pairs use separate vmaps to avoid evaluating both + # branches of lax.cond for all pop_size^2 elements under vmap. + def compute_lagrange_grads_same(params_br, batch, target_value, ids): + conf_id, br_id = ids + + all_target_value = jnp.reshape(target_value, (-1, 1)) + repeated_value_sp = jnp.repeat( + jnp.reshape(all_target_value, (1, -1)), + config["PARTNER_POP_SIZE"], + axis=0 + ) + + relevant_conf_params = gather_params(params_br, jnp.reshape(conf_id, (1,))) + relevant_conf_params = jax.tree.map(lambda x: jnp.squeeze(x, 0), relevant_conf_params) + def _get_value_xp_vary_conf(param, agent_onehot_id): + ts, bs = batch.obs.shape[:2] + agent_onehot_id = agent_onehot_id[jnp.newaxis, jnp.newaxis, ...].repeat(ts, axis=0).repeat(bs, axis=1) + _, value_xp_vary_conf, _, _ = br_policy.get_action_value_policy( + params=param, + obs=batch.obs, + done=batch.done, + avail_actions=batch.avail_actions, + hstate=init_br_hstate, + rng=jax.random.PRNGKey(0), + aux_obs=agent_onehot_id + ) + return value_xp_vary_conf.reshape(ts*bs) + + all_possible_value_xp_vary_conf = jax.vmap( + lambda agent_id: _get_value_xp_vary_conf(relevant_conf_params, agent_id) + )(jnp.eye(config["PARTNER_POP_SIZE"])) + all_possible_value_xp_vary_conf = all_possible_value_xp_vary_conf.at[conf_id].set( + repeated_value_sp[conf_id] + ) + + offsetting_thresholds = jnp.zeros_like(repeated_value_sp) + offsetting_thresholds = offsetting_thresholds.at[conf_id].set( + config["TOLERANCE_FACTOR"] * jnp.ones_like(offsetting_thresholds[conf_id]) + ) + grad_sp_vary_conf = repeated_value_sp + offsetting_thresholds - ( + all_possible_value_xp_vary_conf + config["TOLERANCE_FACTOR"] * jnp.ones_like(offsetting_thresholds) + ) + + ##### Compute grad_sp_vary_br + # This code tries to measure the expected returns of the ego agent had the BR policy been + # substituted by another BR policy + + # Lets say that R_{i,-j} is the ego agent's returns when following the BR policy of the i^th pair + # againts the confederate policy of the j^th pair. + + # Then grad_sp_vary_conf computes R_{i,-i} - R_{i,-j} - tolerance factor + # for all possible j (note for j=i, we sub in + # R_{i,-i} with the target returns + tolerance factor so that R_{i,-i} - R_{i,-j} = 0) + + # Meanwhile grad_sp_vary_br below computes R_{i,-i} - R_{j,-i} - tolerance factor + # for all possible j. + + # Vary the BR policy parameters (j) used in value computation + # Use the experience generating pop id (batch.self_onehot_id) as the conf ID. + + relevant_params = gather_params(params_br, jnp.arange(config["PARTNER_POP_SIZE"])) + def _get_value_xp_vary_br(param): + ts, bs = batch.obs.shape[:2] + conf_one_hot = jnp.eye(config["PARTNER_POP_SIZE"])[conf_id] + conf_one_hot = conf_one_hot[jnp.newaxis, jnp.newaxis, ...].repeat(ts, axis=0).repeat(bs, axis=1) + _, value_xp_vary_br, _, _ = br_policy.get_action_value_policy( + params=param, + obs=batch.obs, + done=batch.done, + avail_actions=batch.avail_actions, + hstate=init_br_hstate, + rng=jax.random.PRNGKey(0), # only used for action sampling, which is not used here + aux_obs=conf_one_hot + ) + return value_xp_vary_br.reshape(ts*bs) + + all_possible_value_xp_vary_br = jax.vmap( + lambda param: _get_value_xp_vary_br(param) + )(relevant_params) + all_possible_value_xp_vary_br = jnp.reshape( + all_possible_value_xp_vary_br, (config["PARTNER_POP_SIZE"], -1) + ) + all_possible_value_xp_vary_br = all_possible_value_xp_vary_br.at[conf_id].set( + repeated_value_sp[conf_id] + ) + + grad_sp_vary_br = repeated_value_sp + offsetting_thresholds - ( + all_possible_value_xp_vary_br + config["TOLERANCE_FACTOR"] * jnp.ones_like(offsetting_thresholds) + ) + + all_self_id_int = jnp.reshape( + batch.self_onehot_id, (-1, jnp.shape(batch.self_onehot_id)[-1]) + ).argmax(axis=-1) + all_oppo_id_int = jnp.reshape( + batch.oppo_onehot_id, (-1, jnp.shape(batch.oppo_onehot_id)[-1]) + ).argmax(axis=-1) + + self_is_conf = jnp.equal(all_self_id_int, conf_id).astype(jnp.float32) + oppo_is_conf = jnp.equal(all_oppo_id_int, conf_id).astype(jnp.float32) + loss_weights = self_is_conf * oppo_is_conf + repeated_loss_weights = jnp.repeat( + jnp.expand_dims(loss_weights, axis=0), + config["PARTNER_POP_SIZE"], + axis=0 + ) + + # Compute vertical and horizontal gradient + vertical_grads = jnp.sum(grad_sp_vary_conf * repeated_loss_weights, axis=-1) / (jnp.sum(loss_weights) + 1e-8) + horizontal_grads = jnp.sum(grad_sp_vary_br * repeated_loss_weights, axis=-1) / (jnp.sum(loss_weights) + 1e-8) + + output_grad_matrix_vertical = jnp.zeros((config["PARTNER_POP_SIZE"], config["PARTNER_POP_SIZE"])) + output_grad_matrix_horizontal = jnp.zeros((config["PARTNER_POP_SIZE"], config["PARTNER_POP_SIZE"])) + output_grad_matrix_vertical = output_grad_matrix_vertical.at[conf_id].set(vertical_grads) + output_grad_matrix_horizontal = output_grad_matrix_horizontal.at[conf_id].set(horizontal_grads) + return output_grad_matrix_vertical, output_grad_matrix_horizontal + + def compute_lagrange_grads_diff(params_br, batch, target_returns, ids): + conf_id, br_id = ids + param_conf_id = gather_params(params_br, jnp.reshape(conf_id, (1,))) + param_br_id = gather_params(params_br, jnp.reshape(br_id, (1,))) + param_br_id = jax.tree.map(lambda x: jnp.squeeze(x, 0), param_br_id) + param_conf_id = jax.tree.map(lambda x: jnp.squeeze(x, 0), param_conf_id) + + all_self_id_int = jnp.reshape( + batch.self_onehot_id, (-1, jnp.shape(batch.self_onehot_id)[-1]) + ).argmax(axis=-1) + all_oppo_id_int = jnp.reshape( + batch.oppo_onehot_id, (-1, jnp.shape(batch.oppo_onehot_id)[-1]) + ).argmax(axis=-1) + all_target_returns = jnp.reshape(target_returns, (-1)) + + # Compute data weights based on whether selected ID + # is relevant for the gradient computation process + oppo_is_conf = jnp.equal(all_oppo_id_int, conf_id).astype(jnp.float32) + self_is_br = jnp.equal(all_self_id_int, br_id).astype(jnp.float32) + loss_weights = oppo_is_conf * self_is_br + + ts, bs = batch.obs.shape[:2] + conf_one_hot = jnp.eye(config["PARTNER_POP_SIZE"])[conf_id] + conf_one_hot = conf_one_hot[jnp.newaxis, jnp.newaxis, ...].repeat(ts, axis=0).repeat(bs, axis=1) + br_one_hot = jnp.eye(config["PARTNER_POP_SIZE"])[br_id] + br_one_hot = br_one_hot[jnp.newaxis, jnp.newaxis, ...].repeat(ts, axis=0).repeat(bs, axis=1) + + _, value_sp_pop_is_br, _, _ = br_policy.get_action_value_policy( + params=param_br_id, + obs=batch.obs, + done=batch.done, + avail_actions=batch.avail_actions, + hstate=init_br_hstate, + rng=jax.random.PRNGKey(0), + aux_obs=br_one_hot + ) + value_sp_pop_is_br = value_sp_pop_is_br.reshape(bs*ts) + + _, value_sp_pop_is_not_br, _, _ = br_policy.get_action_value_policy( + params=param_conf_id, + obs=batch.obs, + done=batch.done, + avail_actions=batch.avail_actions, + hstate=init_br_hstate, + rng=jax.random.PRNGKey(0), + aux_obs=conf_one_hot + ) + value_sp_pop_is_not_br = value_sp_pop_is_not_br.reshape(bs*ts) + + vertical_diff = value_sp_pop_is_br - all_target_returns - config["TOLERANCE_FACTOR"] + horizontal_diff = value_sp_pop_is_not_br - all_target_returns - config["TOLERANCE_FACTOR"] + + total_grad_vertical = (loss_weights * vertical_diff).sum() / (loss_weights.sum() + 1e-8) + total_grad_horizontal = (loss_weights * horizontal_diff).sum() / (loss_weights.sum() + 1e-8) + + output_grad_matrix_vertical = jnp.zeros((config["PARTNER_POP_SIZE"], config["PARTNER_POP_SIZE"])) + output_grad_matrix_horizontal = jnp.zeros((config["PARTNER_POP_SIZE"], config["PARTNER_POP_SIZE"])) + output_grad_matrix_vertical = output_grad_matrix_vertical.at[br_id, conf_id].set(total_grad_vertical) + output_grad_matrix_horizontal = output_grad_matrix_horizontal.at[conf_id, br_id].set(total_grad_horizontal) + return output_grad_matrix_vertical, output_grad_matrix_horizontal + + # Diagonal pairs (conf_id == br_id): vmap over pop_size elements only + diag_ids = np.arange(config["PARTNER_POP_SIZE"]) + diag_lagrange_grads = jax.vmap( + lambda conf_id, br_id: compute_lagrange_grads_same( + all_train_state_br.params, traj_batch_br, targets_br, (conf_id, br_id) + ) + )(diag_ids, diag_ids) + + # Off-diagonal pairs (conf_id != br_id): vmap over pop_size*(pop_size-1) elements only + all_conf_ids_np, all_br_ids_np = _get_all_ids(config["PARTNER_POP_SIZE"]) + off_diag_mask = all_conf_ids_np != all_br_ids_np + off_diag_conf_ids = all_conf_ids_np[off_diag_mask] + off_diag_br_ids = all_br_ids_np[off_diag_mask] + off_diag_lagrange_grads = jax.vmap( + lambda conf_id, br_id: compute_lagrange_grads_diff( + all_train_state_br.params, traj_batch_br, targets_br, (conf_id, br_id) + ) + )(off_diag_conf_ids, off_diag_br_ids) + + averaged_grad_vertical = ( + jnp.sum(diag_lagrange_grads[0], axis=0) + + jnp.sum(off_diag_lagrange_grads[0], axis=0) + ) + averaged_grad_horizontal = ( + jnp.sum(diag_lagrange_grads[1], axis=0) + + jnp.sum(off_diag_lagrange_grads[1], axis=0) + ) + + lms_vertical = jnp.maximum( + lms_vertical - config["LAGRANGE_LR"] * averaged_grad_vertical, + 0.5 * jnp.eye(config["PARTNER_POP_SIZE"]) + ) + lms_vertical = jnp.fill_diagonal( + lms_vertical, 0.5 * jnp.ones((config["PARTNER_POP_SIZE"]), dtype=jnp.float32), + inplace=False + ) + lms_horizontal = jnp.maximum( + lms_horizontal - config["LAGRANGE_LR"] * averaged_grad_horizontal, + 0.5 * jnp.eye(config["PARTNER_POP_SIZE"]), + ) + lms_horizontal = jnp.fill_diagonal( + lms_horizontal, 0.5 * jnp.ones((config["PARTNER_POP_SIZE"]), dtype=jnp.float32), + inplace=False + ) + + (_, (value_loss_conf, pg_loss_conf, entropy_conf)), (_, (value_loss_br, pg_loss_br, entropy_br)) = all_losses + + # Metrics + def mask_and_mean(x, mask): + return jnp.where(mask, x, 0).sum() / jnp.maximum(1, mask.sum()) + + mask = traj_batch_conf.info.get("returned_episode", jnp.ones_like(traj_batch_conf.reward)) + metric = jax.tree.map(lambda x: mask_and_mean(x, mask), traj_batch_conf.info) + metric["lms_vertical"] = lms_vertical + metric["lms_horizontal"] = lms_horizontal + metric["update_steps"] = update_steps + metric["value_loss_conf_agent"] = value_loss_conf.mean(axis=(0, 1)) + metric["value_loss_br_agent"] = value_loss_br.mean(axis=(0, 1)) + + metric["pg_loss_conf_agent"] = pg_loss_conf.mean(axis=(0, 1)) + metric["pg_loss_br_agent"] = pg_loss_br.mean(axis=(0, 1)) + + metric["entropy_conf"] = entropy_conf.mean(axis=(0, 1)) + metric["entropy_br"] = entropy_br.mean(axis=(0, 1)) + + new_runner_state = ( + all_train_state_conf, all_train_state_br, + last_env_state, last_obs, last_done, last_conf_h, last_br_h, + rng, update_steps + 1, + lms_vertical, lms_horizontal + ) + return (new_runner_state, metric) + + # -------------------------- + # PPO Update and Checkpoint saving + # -------------------------- + ckpt_and_eval_interval = config["NUM_UPDATES"] // max(1, config["NUM_CHECKPOINTS"] - 1) # -1 because we store a ckpt at the last update + num_ckpts = config["NUM_CHECKPOINTS"] + + # Build a PyTree that holds parameters for all conf agent checkpoints + def init_ckpt_array(params_pytree): + return jax.tree.map( + lambda x: jnp.zeros((num_ckpts,) + x.shape, x.dtype), + params_pytree) + + def _update_step_with_ckpt(state_with_ckpt, unused): + (update_runner_state, checkpoint_array_conf, checkpoint_array_br, ckpt_idx, + eval_info) = state_with_ckpt + + # Single PPO update + new_runner_state, metric = _update_step(update_runner_state, None) + + ( + train_state_conf, train_state_br, + last_env_state, last_obs, last_done, last_conf_h, last_br_h, + rng, update_steps, lms_vertical, lms_horizontal + ) = new_runner_state + + # Decide if we store a checkpoint + # update steps is 1-indexed because it was incremented at the end of the update step + to_store = jnp.logical_or(jnp.equal(jnp.mod(update_steps-1, ckpt_and_eval_interval), 0), + jnp.equal(update_steps, config["NUM_UPDATES"])) + + def store_and_eval_ckpt(args): + ckpt_arr_and_ep_infos, rng, cidx = args + ckpt_arr_conf, ckpt_arr_br, _ = ckpt_arr_and_ep_infos + new_ckpt_arr_conf = jax.tree.map( + lambda c_arr, p: c_arr.at[cidx].set(p), + ckpt_arr_conf, train_state_conf.params + ) + new_ckpt_arr_br = jax.tree.map( + lambda c_arr, p: c_arr.at[cidx].set(p), + ckpt_arr_br, train_state_br.params + ) + + rng, eval_rng = jax.random.split(rng) + ep_last_info = jax.tree.map(lambda x: x.mean(axis=(-2, -1)), + run_all_episodes(eval_rng, train_state_conf, train_state_br)) + + return ((new_ckpt_arr_conf, new_ckpt_arr_br, ep_last_info), rng, cidx + 1) + + def skip_ckpt(args): + return args + + (checkpoint_array_and_infos, rng, ckpt_idx) = jax.lax.cond( + to_store, + store_and_eval_ckpt, + skip_ckpt, + ((checkpoint_array_conf, checkpoint_array_br, eval_info), rng, ckpt_idx) + ) + checkpoint_array_conf, checkpoint_array_br, eval_ep_last_info = checkpoint_array_and_infos + + metric["eval_ep_last_info"] = eval_ep_last_info # return of confederate + + return ((train_state_conf, train_state_br, + last_env_state, last_obs, last_done, last_conf_h, last_br_h, + rng, update_steps, lms_vertical, lms_horizontal), + checkpoint_array_conf, checkpoint_array_br, ckpt_idx, + eval_ep_last_info), metric + + # Initialize checkpoint array + checkpoint_array_conf = init_ckpt_array(all_conf_optims.params) + checkpoint_array_br = init_ckpt_array(all_br_optims.params) + ckpt_idx = 0 + + # Initialize state for scan over _update_step_with_ckpt + update_steps = 0 + + rng, rng_eval = jax.random.split(rng, 2) + eval_ep_last_info = jax.tree.map(lambda x: x.mean(axis=(-2, -1)), + run_all_episodes(rng_eval, all_conf_optims, all_br_optims)) + + # Initialize environment + rng, reset_rng = jax.random.split(rng) + reset_rngs = jax.random.split(reset_rng, config["NUM_ENVS"]) + init_obs, init_env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rngs) + init_done = {k: jnp.zeros((config["NUM_ENVS"]), dtype=bool) for k in env.agents + ["__all__"]} + + # Initialize conf and br hstates + init_conf_h = conf_policy.init_hstate(config["NUM_CONF_ACTORS"]) + init_br_h = br_policy.init_hstate(config["NUM_BR_ACTORS"]) + + # Initialize LMs + # lm_vertical[i, j] stores the lagrange multiplier for upholding + # R_{conf(i), BR(i)} >= R_{conf(j), BR(i)} + tolerance_factor + + # lm_horizontal[i, j] stores the lagrange multiplier for upholding + # R_{conf(i), BR(i)} >= R_{conf(i), BR(j)} + tolerance_factor + + # Diagonal elements of both matrices sum up to 1. + # Providing a weight of 1 to maximize the SP return from any population + lagrange_multipliers_vertical = 0.5 * jnp.eye(config["PARTNER_POP_SIZE"]) + lagrange_multipliers_horizontal = 0.5 * jnp.eye(config["PARTNER_POP_SIZE"]) + + update_runner_state = ( + all_conf_optims, all_br_optims, + init_env_state, init_obs, init_done, init_conf_h, init_br_h, + rng, update_steps, + lagrange_multipliers_vertical, lagrange_multipliers_horizontal + ) + + state_with_ckpt = ( + update_runner_state, checkpoint_array_conf, + checkpoint_array_br, ckpt_idx, eval_ep_last_info + ) + + # run training + state_with_ckpt, metrics = jax.lax.scan( + _update_step_with_ckpt, + state_with_ckpt, + xs=None, + length=config["NUM_UPDATES"] + ) + + ( + final_runner_state, checkpoint_array_conf, checkpoint_array_br, + final_ckpt_idx, all_ep_infos + ) = state_with_ckpt + + out = { + "final_params_conf": final_runner_state[0].params, + "final_params_br": final_runner_state[1].params, + "checkpoints_conf": checkpoint_array_conf, + "checkpoints_br": checkpoint_array_br, + "metrics": metrics, # metrics is from the perspective of the confederate agent (averaged over population) + "all_pair_returns": all_ep_infos + } + return out + + return train + # ------------------------------ + # Actually run the adversarial teammate training + # ------------------------------ + train_fn = make_lbrdiv_agents(config) + out = train_fn(train_rng) + return out + +def get_lbrdiv_population(config, out, env): + ''' + Get the partner params and partner population for ego training. + ''' + pop_size = config["algorithm"]["PARTNER_POP_SIZE"] + + # partner_params has shape (num_seeds, pop_size, ...) + partner_params = out['final_params_conf'] + + partner_policy = ActorWithConditionalCriticPolicy( + action_dim=env.action_space(env.agents[1]).n, + obs_dim=env.observation_space(env.agents[1]).shape[0], + pop_size=pop_size, # used to create onehot agent id + activation=config["algorithm"].get("ACTIVATION", "tanh") + ) + + # Create partner population + partner_population = AgentPopulation( + pop_size=pop_size, + policy_cls=partner_policy + ) + + return partner_params, partner_population + +def run_lbrdiv(config, wandb_logger): + algorithm_config = dict(config["algorithm"]) + + env = make_env(algorithm_config["ENV_NAME"], algorithm_config["ENV_KWARGS"]) + env = LogWrapper(env) + + log.info("Starting LBRDiv training...") + start = time.time() + + # Generate multiple random seeds from the base seed + rng = jax.random.PRNGKey(algorithm_config["TRAIN_SEED"]) + rngs = jax.random.split(rng, algorithm_config["NUM_SEEDS"]) + + # Initialize br and conf policies + conf_policy = ActorWithConditionalCriticPolicy( + action_dim=env.action_space(env.agents[0]).n, + obs_dim=env.observation_space(env.agents[0]).shape[0], + pop_size=algorithm_config["PARTNER_POP_SIZE"], + ) + br_policy = ActorWithConditionalCriticPolicy( + action_dim=env.action_space(env.agents[0]).n, + obs_dim=env.observation_space(env.agents[0]).shape[0], + pop_size=algorithm_config["PARTNER_POP_SIZE"], + ) + + # Create a vmapped version of train_lbrdiv_partners + with jax.disable_jit(False): + vmapped_train_fn = jax.jit( + jax.vmap( + partial(train_lbrdiv_partners, env=env, config=algorithm_config, conf_policy=conf_policy, br_policy=br_policy) + ) + ) + out = vmapped_train_fn(rngs) + + end = time.time() + log.info(f"LBRDiv training complete in {end - start} seconds") + + metric_names = get_metric_names(algorithm_config["ENV_NAME"]) + log_metrics(config, out, wandb_logger, metric_names) + + partner_params, partner_population = get_lbrdiv_population(config, out, env) + + return partner_params, partner_population + + +def log_metrics(config, outs, logger, metric_names: tuple): + metrics = outs["metrics"] + # metrics now has shape (num_seeds, num_updates, pop_size) + num_seeds, num_updates, pop_size = metrics["pg_loss_conf_agent"].shape # number of trained pairs + + ### Log evaluation metrics + # shape (num_seeds, num_updates, (pop_size)^2) [pre-scalarized: mean over eval eps and agents taken inside scan] + all_returns = np.asarray(metrics["eval_ep_last_info"]["returned_episode_returns"]) + xs = list(range(num_updates)) + + all_conf_ids, all_br_ids = _get_all_ids(pop_size) + sp_mask = (all_conf_ids == all_br_ids) + sp_returns = all_returns[:, :, sp_mask] + xp_returns = all_returns[:, :, ~sp_mask] + + # Average over seeds and agent pairs (eval episodes and agents already averaged inside scan) + sp_return_curve = sp_returns.mean(axis=(0, 2)) + xp_return_curve = xp_returns.mean(axis=(0, 2)) + + for step in range(num_updates): + logger.log_item("Eval/AvgSPReturnCurve", sp_return_curve[step], train_step=step) + logger.log_item("Eval/AvgXPReturnCurve", xp_return_curve[step], train_step=step) + logger.commit() + + # log final XP matrix to wandb - average over seeds + last_returns_array = all_returns[:, -1].mean(axis=0) + last_returns_array = np.reshape(last_returns_array, (pop_size, pop_size)) + logger.log_xp_matrix("Eval/LastXPMatrix", last_returns_array) + + ### Log population loss as multi-line plots, where each line is a different population member + # shape (num_seeds, num_updates, update_epochs, num_minibatches, pop_size) + # Average over seeds + processed_losses = { + "ConfPGLoss": np.asarray(metrics["pg_loss_conf_agent"]).mean(axis=0).transpose(), + "BRPGLoss": np.asarray(metrics["pg_loss_br_agent"]).mean(axis=0).transpose(), + "ConfValLoss": np.asarray(metrics["value_loss_conf_agent"]).mean(axis=0).transpose(), + "BRValLoss": np.asarray(metrics["value_loss_br_agent"]).mean(axis=0).transpose(), + "ConfEntropy": np.asarray(metrics["entropy_conf"]).mean(axis=0).transpose(), + "BREntropy": np.asarray(metrics["entropy_br"]).mean(axis=0).transpose(), + } + + xs = list(range(num_updates)) + keys = [f"pair {i}" for i in range(pop_size)] + for loss_name, loss_data in processed_losses.items(): + if np.isnan(loss_data).any(): + raise ValueError(f"Found nan in loss {loss_name}") + logger.log_item(f"Losses/{loss_name}", + wandb.plot.line_series(xs=xs, ys=loss_data, keys=keys, + title=loss_name, xname="train_step") + ) + + # Average over seeds for Lagrange multipliers + lm_keys = [f"pair {i}, {j}" for i in range(pop_size) for j in range(pop_size)] + lm_horizontal = np.asarray(metrics["lms_horizontal"]).mean(axis=0) + lm_vertical = np.asarray(metrics["lms_vertical"]).mean(axis=0) + lagrange_multipliers = { + "LMs_Horizontal": np.reshape(lm_horizontal, (lm_horizontal.shape[0], -1)).transpose(), + "LMs_Vertical": np.reshape(lm_vertical, (lm_vertical.shape[0], -1)).transpose() + } + + for array_name, array_data in lagrange_multipliers.items(): + if np.isnan(array_data).any(): + raise ValueError(f"Found nan in loss {array_name}") + logger.log_item( + f"Losses/{array_name}", + wandb.plot.line_series(xs=xs, ys=array_data, keys=lm_keys, + title=array_name, xname="train_step") + ) + logger.commit() + + ### Log artifacts + savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + # Save train run output and log to wandb as artifact + out_savepath = save_train_run(outs, savedir, savename="saved_train_run") + if config["logger"]["log_train_out"]: + logger.log_artifact(name="saved_train_run", path=out_savepath, type_name="train_run") + + # Cleanup locally logged out files + if not config["local_logger"]["save_train_out"]: + shutil.rmtree(out_savepath) diff --git a/teammate_generation/__init__.py b/teammate_generation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/teammate_generation/configs/algorithm/brdiv/_base_.yaml b/teammate_generation/configs/algorithm/brdiv/_base_.yaml new file mode 100644 index 0000000000000000000000000000000000000000..65dee33b7ee0751fd08111b011b7f505c81f7507 --- /dev/null +++ b/teammate_generation/configs/algorithm/brdiv/_base_.yaml @@ -0,0 +1,40 @@ +# @package algorithm +# ^ tells hydra to place these value directly under algorithm key +ALG: brdiv +TOTAL_TIMESTEPS: 4.5e7 # divided among each pair of BR and Conf agents +NUM_CHECKPOINTS: 5 +PARTNER_POP_SIZE: 4 +NUM_ENVS: 64 +# SP weight = 1 + 2*XP weight. +# Thus, as XP weight -> 0, SP/(SP+XP) -> 1. +# If XP weight -> infinity, XP/(SP+XP) -> 1/3, and SP/(SP+XP) -> 2/3. +XP_LOSS_WEIGHTS: 1 +LR: 1e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 4 +GAMMA: 0.99 +GAE_LAMBDA: 0.95 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +VF_COEF: 0.5 +MAX_GRAD_NORM: 1.0 +ANNEAL_LR: false +ego_train_algorithm: + EGO_ACTOR_TYPE: s5 + S5_D_MODEL: 16 + S5_SSM_SIZE: 16 + S5_ACTOR_CRITIC_HIDDEN_DIM: 64 + FC_N_LAYERS: 2 + TOTAL_TIMESTEPS: 1e7 + NUM_CHECKPOINTS: 5 + NUM_ENVS: 8 + LR: 1e-4 + UPDATE_EPOCHS: 15 + NUM_MINIBATCHES: 4 + GAMMA: 0.99 + GAE_LAMBDA: 0.95 + CLIP_EPS: 0.05 + ENT_COEF: 0.01 + VF_COEF: 0.5 + MAX_GRAD_NORM: 1.0 + ANNEAL_LR: true \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/brdiv/hanabi.yaml b/teammate_generation/configs/algorithm/brdiv/hanabi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f93ca5b385d4de6ffcab502a862743de5b25281 --- /dev/null +++ b/teammate_generation/configs/algorithm/brdiv/hanabi.yaml @@ -0,0 +1,27 @@ +defaults: + - brdiv/_base_ + - _self_ + +TOTAL_TIMESTEPS: 5e8 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 128 +XP_LOSS_WEIGHTS: 0.05 +LR: 5e-4 +UPDATE_EPOCHS: 4 +NUM_MINIBATCHES: 4 +CLIP_EPS: 0.2 +ENT_COEF: 0.01 +ANNEAL_LR: true +GAMMA: 0.999 +GAE_LAMBDA: 0.95 +MAX_GRAD_NORM: 0.5 +ego_train_algorithm: + TOTAL_TIMESTEPS: 1e8 + LR: 5e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.2 + ANNEAL_LR: true + UPDATE_EPOCHS: 4 + GAMMA: 0.999 + GAE_LAMBDA: 0.95 + MAX_GRAD_NORM: 0.5 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/brdiv/lbf/lbf_12x12.yaml b/teammate_generation/configs/algorithm/brdiv/lbf/lbf_12x12.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25c394bdad98c5cbe21b3d35a2922e4c58eb1586 --- /dev/null +++ b/teammate_generation/configs/algorithm/brdiv/lbf/lbf_12x12.yaml @@ -0,0 +1,18 @@ +defaults: + - brdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 4.5e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 64 +XP_LOSS_WEIGHTS: 0.05 # 0.1 +LR: 5e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 2 # 4 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 1e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/brdiv/lbf/lbf_7x7_nolevels.yaml b/teammate_generation/configs/algorithm/brdiv/lbf/lbf_7x7_nolevels.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25c394bdad98c5cbe21b3d35a2922e4c58eb1586 --- /dev/null +++ b/teammate_generation/configs/algorithm/brdiv/lbf/lbf_7x7_nolevels.yaml @@ -0,0 +1,18 @@ +defaults: + - brdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 4.5e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 64 +XP_LOSS_WEIGHTS: 0.05 # 0.1 +LR: 5e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 2 # 4 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 1e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/brdiv/mini-hanabi.yaml b/teammate_generation/configs/algorithm/brdiv/mini-hanabi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d15766ae40974e372775b92d904ccf94fd274e35 --- /dev/null +++ b/teammate_generation/configs/algorithm/brdiv/mini-hanabi.yaml @@ -0,0 +1,28 @@ +defaults: + - brdiv/_base_ + - _self_ + +# Mini-Hanabi (3c/3r/hand3) BRDiv config. +TOTAL_TIMESTEPS: 1e8 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 128 +XP_LOSS_WEIGHTS: 0.05 +LR: 5e-4 +UPDATE_EPOCHS: 4 +NUM_MINIBATCHES: 4 +CLIP_EPS: 0.2 +ENT_COEF: 0.01 +ANNEAL_LR: true +GAMMA: 0.999 +GAE_LAMBDA: 0.95 +MAX_GRAD_NORM: 0.5 +ego_train_algorithm: + TOTAL_TIMESTEPS: 1e8 + LR: 5e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.2 + ANNEAL_LR: true + UPDATE_EPOCHS: 4 + GAMMA: 0.999 + GAE_LAMBDA: 0.95 + MAX_GRAD_NORM: 0.5 diff --git a/teammate_generation/configs/algorithm/brdiv/overcooked-v1/asymm_advantages.yaml b/teammate_generation/configs/algorithm/brdiv/overcooked-v1/asymm_advantages.yaml new file mode 100644 index 0000000000000000000000000000000000000000..16fbbf4253ecc7fc57d9db24b44810a8fe37980f --- /dev/null +++ b/teammate_generation/configs/algorithm/brdiv/overcooked-v1/asymm_advantages.yaml @@ -0,0 +1,18 @@ +defaults: + - brdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 4.5e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 64 +XP_LOSS_WEIGHTS: 1 +LR: .0001 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 16 +CLIP_EPS: 0.3 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 1e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/brdiv/overcooked-v1/coord_ring.yaml b/teammate_generation/configs/algorithm/brdiv/overcooked-v1/coord_ring.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9565248bbbad09f370a708e9d5d8daa7680c57bf --- /dev/null +++ b/teammate_generation/configs/algorithm/brdiv/overcooked-v1/coord_ring.yaml @@ -0,0 +1,18 @@ +defaults: + - brdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 9e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 128 +XP_LOSS_WEIGHTS: 0.007 +LR: 5e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 4 +CLIP_EPS: 0.1 +ENT_COEF: 0.05 +ego_train_algorithm: + TOTAL_TIMESTEPS: 6e7 + LR: 1e-3 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 diff --git a/teammate_generation/configs/algorithm/brdiv/overcooked-v1/counter_circuit.yaml b/teammate_generation/configs/algorithm/brdiv/overcooked-v1/counter_circuit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af8b5bb6e930fb26ec7937c8762ae24835ff0f08 --- /dev/null +++ b/teammate_generation/configs/algorithm/brdiv/overcooked-v1/counter_circuit.yaml @@ -0,0 +1,18 @@ +defaults: + - brdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 9e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 128 +XP_LOSS_WEIGHTS: 0.005 +LR: 1e-3 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 8 +CLIP_EPS: 0.01 +ENT_COEF: 0.05 +ego_train_algorithm: + TOTAL_TIMESTEPS: 6e7 + LR: 1e-3 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 diff --git a/teammate_generation/configs/algorithm/brdiv/overcooked-v1/cramped_room.yaml b/teammate_generation/configs/algorithm/brdiv/overcooked-v1/cramped_room.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b3062eb6828c97edb2a87dee171c346a1e0d41c --- /dev/null +++ b/teammate_generation/configs/algorithm/brdiv/overcooked-v1/cramped_room.yaml @@ -0,0 +1,21 @@ +defaults: + - brdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 4.5e7 +PARTNER_POP_SIZE: 4 +NUM_ENVS: 64 +# SP weight = 1 + 2*XP weight. +# Thus, as XP weight -> 0, SP/(SP+XP) -> 1. +# If XP weight -> infinity, XP/(SP+XP) -> 1/3, and SP/(SP+XP) -> 2/3. +XP_LOSS_WEIGHTS: 0.5 # 10 +LR: 1e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 16 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 1e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/brdiv/overcooked-v1/forced_coord.yaml b/teammate_generation/configs/algorithm/brdiv/overcooked-v1/forced_coord.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26ea89240c103799cb09dd98adce55edede4d5bc --- /dev/null +++ b/teammate_generation/configs/algorithm/brdiv/overcooked-v1/forced_coord.yaml @@ -0,0 +1,18 @@ +defaults: + - brdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 9e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 128 +XP_LOSS_WEIGHTS: 0.01 +LR: 5e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 16 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 6e7 + LR: 1e-3 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 diff --git a/teammate_generation/configs/algorithm/comedi/_base_.yaml b/teammate_generation/configs/algorithm/comedi/_base_.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e7d4e547b0eafd98dc98731b1cd218cb404ba7c0 --- /dev/null +++ b/teammate_generation/configs/algorithm/comedi/_base_.yaml @@ -0,0 +1,36 @@ +# @package algorithm +# ^ tells hydra to place these value directly under algorithm key +ALG: comedi +TOTAL_TIMESTEPS_PER_ITERATION: 1.2e7 # number of steps used to train each comedi agent at each iteration +NUM_CHECKPOINTS: 5 +PARTNER_POP_SIZE: 4 +NUM_ENVS: 48 +LR: 1e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 8 +GAMMA: 0.99 +GAE_LAMBDA: 0.95 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +VF_COEF: 0.5 +MAX_GRAD_NORM: 1.0 +ANNEAL_LR: false +ACTOR_TYPE: actor_with_conditional_critic +NUM_ARGMAX_ROLLOUT_EPS: 20 +COMEDI_ALPHA: 1.0 +COMEDI_BETA: 0.5 +ego_train_algorithm: + EGO_ACTOR_TYPE: s5 + TOTAL_TIMESTEPS: 1e7 + NUM_CHECKPOINTS: 5 + NUM_ENVS: 8 + LR: 1e-4 + UPDATE_EPOCHS: 15 + NUM_MINIBATCHES: 4 + GAMMA: 0.99 + GAE_LAMBDA: 0.95 + CLIP_EPS: 0.05 + ENT_COEF: 0.01 + VF_COEF: 0.5 + MAX_GRAD_NORM: 1.0 + ANNEAL_LR: true \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/comedi/hanabi.yaml b/teammate_generation/configs/algorithm/comedi/hanabi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a2d1bf9d02da9df4655b4dbf64a961af5d601464 --- /dev/null +++ b/teammate_generation/configs/algorithm/comedi/hanabi.yaml @@ -0,0 +1,26 @@ +defaults: + - comedi/_base_ + - _self_ + +TOTAL_TIMESTEPS_PER_ITERATION: 2e7 +PARTNER_POP_SIZE: 5 +LR: 5e-4 +UPDATE_EPOCHS: 4 +CLIP_EPS: 0.2 +ENT_COEF: 0.01 +ANNEAL_LR: true +GAMMA: 0.999 +GAE_LAMBDA: 0.95 +MAX_GRAD_NORM: 0.5 +COMEDI_ALPHA: 0.2 +COMEDI_BETA: 0.4 +ego_train_algorithm: + TOTAL_TIMESTEPS: 1e8 + LR: 5e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.2 + ANNEAL_LR: true + UPDATE_EPOCHS: 4 + GAMMA: 0.999 + GAE_LAMBDA: 0.95 + MAX_GRAD_NORM: 0.5 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/comedi/lbf/lbf_12x12.yaml b/teammate_generation/configs/algorithm/comedi/lbf/lbf_12x12.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4af767bd8a07b304a2bb472819433aef7f9f92ab --- /dev/null +++ b/teammate_generation/configs/algorithm/comedi/lbf/lbf_12x12.yaml @@ -0,0 +1,18 @@ +defaults: + - comedi/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS_PER_ITERATION: 6e6 +PARTNER_POP_SIZE: 10 +LR: 5e-4 +UPDATE_EPOCHS: 15 +CLIP_EPS: 0.05 +ENT_COEF: 0.001 +COMEDI_ALPHA: 0.2 # weight on XP return +COMEDI_BETA: 0.4 # weight on SXP return +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 5e-5 + ENT_COEF: 1e-4 + CLIP_EPS: 0.1 + ANNEAL_LR: false \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/comedi/lbf/lbf_7x7_nolevels.yaml b/teammate_generation/configs/algorithm/comedi/lbf/lbf_7x7_nolevels.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4af767bd8a07b304a2bb472819433aef7f9f92ab --- /dev/null +++ b/teammate_generation/configs/algorithm/comedi/lbf/lbf_7x7_nolevels.yaml @@ -0,0 +1,18 @@ +defaults: + - comedi/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS_PER_ITERATION: 6e6 +PARTNER_POP_SIZE: 10 +LR: 5e-4 +UPDATE_EPOCHS: 15 +CLIP_EPS: 0.05 +ENT_COEF: 0.001 +COMEDI_ALPHA: 0.2 # weight on XP return +COMEDI_BETA: 0.4 # weight on SXP return +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 5e-5 + ENT_COEF: 1e-4 + CLIP_EPS: 0.1 + ANNEAL_LR: false \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/comedi/mini-hanabi.yaml b/teammate_generation/configs/algorithm/comedi/mini-hanabi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..255c06334f5bf7bf529108c13b71a32f6e0eb4c4 --- /dev/null +++ b/teammate_generation/configs/algorithm/comedi/mini-hanabi.yaml @@ -0,0 +1,27 @@ +defaults: + - comedi/_base_ + - _self_ + +# Mini-Hanabi (3c/3r/hand3) CoMeDi config. +TOTAL_TIMESTEPS_PER_ITERATION: 2e6 +PARTNER_POP_SIZE: 5 +LR: 5e-4 +UPDATE_EPOCHS: 4 +CLIP_EPS: 0.2 +ENT_COEF: 0.01 +ANNEAL_LR: true +GAMMA: 0.999 +GAE_LAMBDA: 0.95 +MAX_GRAD_NORM: 0.5 +COMEDI_ALPHA: 0.2 +COMEDI_BETA: 0.4 +ego_train_algorithm: + TOTAL_TIMESTEPS: 1e8 + LR: 5e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.2 + ANNEAL_LR: true + UPDATE_EPOCHS: 4 + GAMMA: 0.999 + GAE_LAMBDA: 0.95 + MAX_GRAD_NORM: 0.5 diff --git a/teammate_generation/configs/algorithm/comedi/overcooked-v1/asymm_advantages.yaml b/teammate_generation/configs/algorithm/comedi/overcooked-v1/asymm_advantages.yaml new file mode 100644 index 0000000000000000000000000000000000000000..281099384362bdb1fb3963ab0ddc341de14a5183 --- /dev/null +++ b/teammate_generation/configs/algorithm/comedi/overcooked-v1/asymm_advantages.yaml @@ -0,0 +1,16 @@ +defaults: + - comedi/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 6e6 +PARTNER_POP_SIZE: 10 +LR: .0001 +UPDATE_EPOCHS: 15 +CLIP_EPS: 0.3 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 5e-5 + ENT_COEF: .001 + CLIP_EPS: 0.1 + UPDATE_EPOCHS: 10 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/comedi/overcooked-v1/coord_ring.yaml b/teammate_generation/configs/algorithm/comedi/overcooked-v1/coord_ring.yaml new file mode 100644 index 0000000000000000000000000000000000000000..151da46cc7990485fc29a52af16c59b5fdbb1e42 --- /dev/null +++ b/teammate_generation/configs/algorithm/comedi/overcooked-v1/coord_ring.yaml @@ -0,0 +1,16 @@ +defaults: + - comedi/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 1e7 +PARTNER_POP_SIZE: 10 +LR: 5e-4 +UPDATE_EPOCHS: 15 +CLIP_EPS: 0.1 +ENT_COEF: 0.05 +ego_train_algorithm: + TOTAL_TIMESTEPS: 6e7 + LR: 3e-5 + ENT_COEF: .001 + CLIP_EPS: 0.1 + UPDATE_EPOCHS: 10 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/comedi/overcooked-v1/counter_circuit.yaml b/teammate_generation/configs/algorithm/comedi/overcooked-v1/counter_circuit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0281cff716f73e15001f2fbb597a6e22ed46140d --- /dev/null +++ b/teammate_generation/configs/algorithm/comedi/overcooked-v1/counter_circuit.yaml @@ -0,0 +1,16 @@ +defaults: + - comedi/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 1e7 +PARTNER_POP_SIZE: 10 +LR: 1e-3 +UPDATE_EPOCHS: 15 +CLIP_EPS: 0.01 # 0.1 +ENT_COEF: 0.05 +ego_train_algorithm: + TOTAL_TIMESTEPS: 6e7 + LR: 5e-5 + ENT_COEF: .001 + CLIP_EPS: 0.1 + UPDATE_EPOCHS: 10 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/comedi/overcooked-v1/cramped_room.yaml b/teammate_generation/configs/algorithm/comedi/overcooked-v1/cramped_room.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9f8ea3a89e6d91185d29e108ead2822796602686 --- /dev/null +++ b/teammate_generation/configs/algorithm/comedi/overcooked-v1/cramped_room.yaml @@ -0,0 +1,17 @@ +defaults: + - comedi/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 6e6 +PARTNER_POP_SIZE: 10 +LR: 1e-4 +UPDATE_EPOCHS: 15 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 5e-5 + ANNEAL_LR: false + ENT_COEF: .001 + CLIP_EPS: 0.1 + UPDATE_EPOCHS: 10 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/comedi/overcooked-v1/forced_coord.yaml b/teammate_generation/configs/algorithm/comedi/overcooked-v1/forced_coord.yaml new file mode 100644 index 0000000000000000000000000000000000000000..acf4416d40c88ea2974074c29831b2d761d2600e --- /dev/null +++ b/teammate_generation/configs/algorithm/comedi/overcooked-v1/forced_coord.yaml @@ -0,0 +1,16 @@ +defaults: + - comedi/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 1e7 +PARTNER_POP_SIZE: 10 +LR: 5e-4 +UPDATE_EPOCHS: 15 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 6e7 + LR: 1e-5 + ENT_COEF: 1e-4 + CLIP_EPS: 0.1 + UPDATE_EPOCHS: 5 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/fcp/_base_.yaml b/teammate_generation/configs/algorithm/fcp/_base_.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee3826f496cf6214dad6f9c3e94d7eef94e61b08 --- /dev/null +++ b/teammate_generation/configs/algorithm/fcp/_base_.yaml @@ -0,0 +1,37 @@ +# @package algorithm +# ^ tells hydra to place these value directly under algorithm key +ALG: fcp +ACTOR_TYPE: mlp +TOTAL_TIMESTEPS: 1e6 # per PARTNER_POP_SIZE trained +NUM_CHECKPOINTS: 5 +PARTNER_POP_SIZE: 20 # true partner pop size is PARTNER_POP_SIZE * NUM_CHECKPOINTS +NUM_ENVS: 8 +LR: 1e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 4 +GAMMA: 0.99 +GAE_LAMBDA: 0.95 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +VF_COEF: 0.5 +MAX_GRAD_NORM: 1.0 +ANNEAL_LR: true +ego_train_algorithm: + EGO_ACTOR_TYPE: s5 + S5_D_MODEL: 16 + S5_SSM_SIZE: 16 + S5_ACTOR_CRITIC_HIDDEN_DIM: 64 + FC_N_LAYERS: 2 + TOTAL_TIMESTEPS: 1e7 + NUM_CHECKPOINTS: 5 + NUM_ENVS: 8 + LR: 1e-4 + UPDATE_EPOCHS: 15 + NUM_MINIBATCHES: 4 + GAMMA: 0.99 + GAE_LAMBDA: 0.95 + CLIP_EPS: 0.05 + ENT_COEF: 0.01 + VF_COEF: 0.5 + MAX_GRAD_NORM: 1.0 + ANNEAL_LR: true \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/fcp/hanabi.yaml b/teammate_generation/configs/algorithm/fcp/hanabi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..98de934d3d79ccc9f1199d5c426884a546a1fe81 --- /dev/null +++ b/teammate_generation/configs/algorithm/fcp/hanabi.yaml @@ -0,0 +1,32 @@ +defaults: + - fcp/_base_ + - _self_ + +# Full 2-player Hanabi FCP config. Trains IPPO partners then ego. +# Hyperparameters aligned with JaxMARL Hanabi consensus. +# +# PARTNER_POP_SIZE=3 (not 10): FCP vmaps across pop size, so 10 +# parallel IPPO instances with 658-dim obs OOMs on 48GB. 3 partners +# x 5 checkpoints = 15 total partners, enough for diversity. +TOTAL_TIMESTEPS: 1e9 +PARTNER_POP_SIZE: 3 +LR: 5e-4 +NUM_ENVS: 32 +UPDATE_EPOCHS: 4 +NUM_MINIBATCHES: 4 +CLIP_EPS: 0.2 +ENT_COEF: 0.01 +ANNEAL_LR: true +GAMMA: 0.999 +GAE_LAMBDA: 0.95 +MAX_GRAD_NORM: 0.5 +ego_train_algorithm: + TOTAL_TIMESTEPS: 1e9 + LR: 5e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.2 + ANNEAL_LR: true + UPDATE_EPOCHS: 4 + GAMMA: 0.999 + GAE_LAMBDA: 0.95 + MAX_GRAD_NORM: 0.5 diff --git a/teammate_generation/configs/algorithm/fcp/lbf/lbf_12x12.yaml b/teammate_generation/configs/algorithm/fcp/lbf/lbf_12x12.yaml new file mode 100644 index 0000000000000000000000000000000000000000..131a240d1c5de6dfa51d2378420f70624d721db1 --- /dev/null +++ b/teammate_generation/configs/algorithm/fcp/lbf/lbf_12x12.yaml @@ -0,0 +1,17 @@ +defaults: + - fcp/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 1e6 +LR: .0001 +NUM_ENVS: 8 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 4 +CLIP_EPS: 0.03 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 1e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 + diff --git a/teammate_generation/configs/algorithm/fcp/lbf/lbf_7x7_nolevels.yaml b/teammate_generation/configs/algorithm/fcp/lbf/lbf_7x7_nolevels.yaml new file mode 100644 index 0000000000000000000000000000000000000000..131a240d1c5de6dfa51d2378420f70624d721db1 --- /dev/null +++ b/teammate_generation/configs/algorithm/fcp/lbf/lbf_7x7_nolevels.yaml @@ -0,0 +1,17 @@ +defaults: + - fcp/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 1e6 +LR: .0001 +NUM_ENVS: 8 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 4 +CLIP_EPS: 0.03 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 1e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 + diff --git a/teammate_generation/configs/algorithm/fcp/mini-hanabi.yaml b/teammate_generation/configs/algorithm/fcp/mini-hanabi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cc04326a25ea03047bbf8e1f6dfd69d1f8c85720 --- /dev/null +++ b/teammate_generation/configs/algorithm/fcp/mini-hanabi.yaml @@ -0,0 +1,26 @@ +defaults: + - fcp/_base_ + - _self_ + +# Mini-Hanabi (3c/3r/hand3) FCP config. +TOTAL_TIMESTEPS: 1e8 +LR: 5e-4 +NUM_ENVS: 128 +UPDATE_EPOCHS: 4 +NUM_MINIBATCHES: 4 +CLIP_EPS: 0.2 +ENT_COEF: 0.01 +ANNEAL_LR: true +GAMMA: 0.999 +GAE_LAMBDA: 0.95 +MAX_GRAD_NORM: 0.5 +ego_train_algorithm: + TOTAL_TIMESTEPS: 1e8 + LR: 5e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.2 + ANNEAL_LR: true + UPDATE_EPOCHS: 4 + GAMMA: 0.999 + GAE_LAMBDA: 0.95 + MAX_GRAD_NORM: 0.5 diff --git a/teammate_generation/configs/algorithm/fcp/overcooked-v1/asymm_advantages.yaml b/teammate_generation/configs/algorithm/fcp/overcooked-v1/asymm_advantages.yaml new file mode 100644 index 0000000000000000000000000000000000000000..101514ef99cd22b600e42b3ba4bbb9bb291902c0 --- /dev/null +++ b/teammate_generation/configs/algorithm/fcp/overcooked-v1/asymm_advantages.yaml @@ -0,0 +1,17 @@ +defaults: + - fcp/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 2e6 +LR: .0001 +NUM_ENVS: 8 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 16 +CLIP_EPS: 0.3 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 1e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 + diff --git a/teammate_generation/configs/algorithm/fcp/overcooked-v1/coord_ring.yaml b/teammate_generation/configs/algorithm/fcp/overcooked-v1/coord_ring.yaml new file mode 100644 index 0000000000000000000000000000000000000000..718efcc4b3435869bbd82c783d9f8668e730189b --- /dev/null +++ b/teammate_generation/configs/algorithm/fcp/overcooked-v1/coord_ring.yaml @@ -0,0 +1,16 @@ +defaults: + - fcp/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 4e6 +LR: 1e-3 +NUM_ENVS: 8 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 16 +CLIP_EPS: 0.1 +ENT_COEF: 0.05 +ego_train_algorithm: + TOTAL_TIMESTEPS: 6e7 + LR: 1e-3 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 diff --git a/teammate_generation/configs/algorithm/fcp/overcooked-v1/counter_circuit.yaml b/teammate_generation/configs/algorithm/fcp/overcooked-v1/counter_circuit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..718efcc4b3435869bbd82c783d9f8668e730189b --- /dev/null +++ b/teammate_generation/configs/algorithm/fcp/overcooked-v1/counter_circuit.yaml @@ -0,0 +1,16 @@ +defaults: + - fcp/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 4e6 +LR: 1e-3 +NUM_ENVS: 8 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 16 +CLIP_EPS: 0.1 +ENT_COEF: 0.05 +ego_train_algorithm: + TOTAL_TIMESTEPS: 6e7 + LR: 1e-3 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 diff --git a/teammate_generation/configs/algorithm/fcp/overcooked-v1/cramped_room.yaml b/teammate_generation/configs/algorithm/fcp/overcooked-v1/cramped_room.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d684bbd094b2080e873b4a4c3c5b319cc19bc4be --- /dev/null +++ b/teammate_generation/configs/algorithm/fcp/overcooked-v1/cramped_room.yaml @@ -0,0 +1,16 @@ +defaults: + - fcp/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 2e6 +LR: .0001 +NUM_ENVS: 8 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 16 +CLIP_EPS: 0.2 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 1e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 diff --git a/teammate_generation/configs/algorithm/fcp/overcooked-v1/forced_coord.yaml b/teammate_generation/configs/algorithm/fcp/overcooked-v1/forced_coord.yaml new file mode 100644 index 0000000000000000000000000000000000000000..718efcc4b3435869bbd82c783d9f8668e730189b --- /dev/null +++ b/teammate_generation/configs/algorithm/fcp/overcooked-v1/forced_coord.yaml @@ -0,0 +1,16 @@ +defaults: + - fcp/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 4e6 +LR: 1e-3 +NUM_ENVS: 8 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 16 +CLIP_EPS: 0.1 +ENT_COEF: 0.05 +ego_train_algorithm: + TOTAL_TIMESTEPS: 6e7 + LR: 1e-3 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 diff --git a/teammate_generation/configs/algorithm/lbrdiv/_base_.yaml b/teammate_generation/configs/algorithm/lbrdiv/_base_.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d3cbb8da3a8403c14c37cd5baee81eb63c55e390 --- /dev/null +++ b/teammate_generation/configs/algorithm/lbrdiv/_base_.yaml @@ -0,0 +1,38 @@ +# @package algorithm +# ^ tells hydra to place these value directly under algorithm key +ALG: lbrdiv +TOTAL_TIMESTEPS: 4.5e7 # divided among each pair of BR and Conf agents +NUM_CHECKPOINTS: 5 +PARTNER_POP_SIZE: 4 +NUM_ENVS: 64 +TOLERANCE_FACTOR: 0.1 # require that SP - XP > TOLERANCE_FACTOR +LAGRANGE_LR: 0.01 # specific to L-BRDiv +LR: 1e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 4 +GAMMA: 0.99 +GAE_LAMBDA: 0.95 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +VF_COEF: 0.5 +MAX_GRAD_NORM: 1.0 +ANNEAL_LR: false +ego_train_algorithm: + EGO_ACTOR_TYPE: s5 + S5_D_MODEL: 16 + S5_SSM_SIZE: 16 + S5_ACTOR_CRITIC_HIDDEN_DIM: 64 + FC_N_LAYERS: 2 + TOTAL_TIMESTEPS: 1e7 + NUM_CHECKPOINTS: 5 + NUM_ENVS: 8 + LR: 1e-4 + UPDATE_EPOCHS: 15 + NUM_MINIBATCHES: 4 + GAMMA: 0.99 + GAE_LAMBDA: 0.95 + CLIP_EPS: 0.05 + ENT_COEF: 0.01 + VF_COEF: 0.5 + MAX_GRAD_NORM: 1.0 + ANNEAL_LR: true \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/lbrdiv/hanabi.yaml b/teammate_generation/configs/algorithm/lbrdiv/hanabi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..541432754816dc94eb2d308436029ca46df24d84 --- /dev/null +++ b/teammate_generation/configs/algorithm/lbrdiv/hanabi.yaml @@ -0,0 +1,26 @@ +defaults: + - lbrdiv/_base_ + - _self_ + +TOTAL_TIMESTEPS: 5e8 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 128 +LR: 5e-4 +UPDATE_EPOCHS: 4 +NUM_MINIBATCHES: 4 +CLIP_EPS: 0.2 +ENT_COEF: 0.01 +ANNEAL_LR: true +GAMMA: 0.999 +GAE_LAMBDA: 0.95 +MAX_GRAD_NORM: 0.5 +ego_train_algorithm: + TOTAL_TIMESTEPS: 1e8 + LR: 5e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.2 + ANNEAL_LR: true + UPDATE_EPOCHS: 4 + GAMMA: 0.999 + GAE_LAMBDA: 0.95 + MAX_GRAD_NORM: 0.5 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_12x12.yaml b/teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_12x12.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af5b4b534fdaa434e2b80421cf6714eced131985 --- /dev/null +++ b/teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_12x12.yaml @@ -0,0 +1,17 @@ +defaults: + - lbrdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 4.5e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 64 +LR: 5e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 4 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 1e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_7x7_nolevels.yaml b/teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_7x7_nolevels.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af5b4b534fdaa434e2b80421cf6714eced131985 --- /dev/null +++ b/teammate_generation/configs/algorithm/lbrdiv/lbf/lbf_7x7_nolevels.yaml @@ -0,0 +1,17 @@ +defaults: + - lbrdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 4.5e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 64 +LR: 5e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 4 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 1e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 \ No newline at end of file diff --git a/teammate_generation/configs/algorithm/lbrdiv/mini-hanabi.yaml b/teammate_generation/configs/algorithm/lbrdiv/mini-hanabi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2bec20d6ed13ee1e684ad02c469984b937e68cf3 --- /dev/null +++ b/teammate_generation/configs/algorithm/lbrdiv/mini-hanabi.yaml @@ -0,0 +1,27 @@ +defaults: + - lbrdiv/_base_ + - _self_ + +# Mini-Hanabi (3c/3r/hand3) LBRDiv config. +TOTAL_TIMESTEPS: 1e8 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 128 +LR: 5e-4 +UPDATE_EPOCHS: 4 +NUM_MINIBATCHES: 4 +CLIP_EPS: 0.2 +ENT_COEF: 0.01 +ANNEAL_LR: true +GAMMA: 0.999 +GAE_LAMBDA: 0.95 +MAX_GRAD_NORM: 0.5 +ego_train_algorithm: + TOTAL_TIMESTEPS: 1e8 + LR: 5e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.2 + ANNEAL_LR: true + UPDATE_EPOCHS: 4 + GAMMA: 0.999 + GAE_LAMBDA: 0.95 + MAX_GRAD_NORM: 0.5 diff --git a/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/asymm_advantages.yaml b/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/asymm_advantages.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c52f80ee19d5985331a8f38ead2578b1182ad213 --- /dev/null +++ b/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/asymm_advantages.yaml @@ -0,0 +1,18 @@ +defaults: + - lbrdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 4.5e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 64 +TOLERANCE_FACTOR: 10.0 # require that SP - XP > TOLERANCE_FACTOR +LR: .0001 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 16 +CLIP_EPS: 0.3 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 1e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 diff --git a/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/coord_ring.yaml b/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/coord_ring.yaml new file mode 100644 index 0000000000000000000000000000000000000000..caa01181c4e8c95a378805e7d7ff807d8e7d9139 --- /dev/null +++ b/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/coord_ring.yaml @@ -0,0 +1,18 @@ +defaults: + - lbrdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 9e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 128 +TOLERANCE_FACTOR: 10.0 # require that SP - XP > TOLERANCE_FACTOR +LR: 5e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 4 +CLIP_EPS: 0.1 +ENT_COEF: 0.05 +ego_train_algorithm: + TOTAL_TIMESTEPS: 6e7 + LR: 1e-3 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 diff --git a/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/counter_circuit.yaml b/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/counter_circuit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..815e1b897cbc7a7172860b8ad92a772be2d7591d --- /dev/null +++ b/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/counter_circuit.yaml @@ -0,0 +1,18 @@ +defaults: + - lbrdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 9e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 128 +TOLERANCE_FACTOR: 10.0 # require that SP - XP > TOLERANCE_FACTOR +LR: 1e-3 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 8 +CLIP_EPS: 0.01 +ENT_COEF: 0.05 +ego_train_algorithm: + TOTAL_TIMESTEPS: 6e7 + LR: 1e-3 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 diff --git a/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/cramped_room.yaml b/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/cramped_room.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2cf177aeee9f34d19c2143c14cabe146388b0d32 --- /dev/null +++ b/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/cramped_room.yaml @@ -0,0 +1,18 @@ +defaults: + - lbrdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 4.5e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 64 +TOLERANCE_FACTOR: 10.0 # require that SP - XP > TOLERANCE_FACTOR +LR: 1e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 16 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 3e7 + LR: 1e-4 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 diff --git a/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/forced_coord.yaml b/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/forced_coord.yaml new file mode 100644 index 0000000000000000000000000000000000000000..070c587969077b30f1c8705ad7b4a087126e5ea7 --- /dev/null +++ b/teammate_generation/configs/algorithm/lbrdiv/overcooked-v1/forced_coord.yaml @@ -0,0 +1,18 @@ +defaults: + - lbrdiv/_base_ + - _self_ # values from this file override the values from the base file + +TOTAL_TIMESTEPS: 9e7 +PARTNER_POP_SIZE: 3 +NUM_ENVS: 128 +TOLERANCE_FACTOR: 5.0 # require that SP - XP > TOLERANCE_FACTOR +LR: 5e-4 +UPDATE_EPOCHS: 15 +NUM_MINIBATCHES: 16 +CLIP_EPS: 0.05 +ENT_COEF: 0.01 +ego_train_algorithm: + TOTAL_TIMESTEPS: 6e7 + LR: 1e-3 + ENT_COEF: 0.01 + CLIP_EPS: 0.05 diff --git a/teammate_generation/configs/base_config_teammate.yaml b/teammate_generation/configs/base_config_teammate.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd7bc76d492a84616a28bd305dd157bfacde9c7a --- /dev/null +++ b/teammate_generation/configs/base_config_teammate.yaml @@ -0,0 +1,54 @@ +defaults: + - task: lbf/lbf_7x7_nolevels # task configs + - algorithm@algorithm: fcp/${task} # task-specific algorithm configs + - hydra: hydra_simple + - ../../evaluation/configs/global_heldout_settings + - _self_ + +ENV_NAME: ${task.ENV_NAME} +ENV_KWARGS: ${task.ENV_KWARGS} +ROLLOUT_LENGTH: ${task.ROLLOUT_LENGTH} +TASK_NAME: ${task.TASK_NAME} + +# training settings +train_ego: true # whether to train the ego agent +run_heldout_eval: true # whether to run a heldout evaluation of the ego agent + +# teammate generation settings +algorithm: + NUM_EVAL_EPISODES: 20 # used during training + TRAIN_SEED: 20374 # 112358 # 20374 + NUM_SEEDS: 1 + ENV_NAME: ${ENV_NAME} + ENV_KWARGS: ${ENV_KWARGS} + ROLLOUT_LENGTH: ${ROLLOUT_LENGTH} + # ego training settings + ego_train_algorithm: + NUM_EGO_TRAIN_SEEDS: 1 # per seed of teammate generation + NUM_EVAL_EPISODES: 20 + TRAIN_SEED: 204829 + ENV_NAME: ${ENV_NAME} + ENV_KWARGS: ${ENV_KWARGS} + ROLLOUT_LENGTH: ${ROLLOUT_LENGTH} + +label: "default_label" +name: ${TASK_NAME}/${algorithm.ALG}/${label} + +# wandb settings +logger: + project: aht-benchmark + entity: aht-project + tags: + - ${algorithm.ALG} + - ${TASK_NAME} + - seed=${algorithm.TRAIN_SEED} + - ${label} + mode: offline # options: online, offline, disabled + verbose: false + log_train_out: true # whether to log the out dictionary + log_eval_out: true # whether to log the eval metrics + +# Local logger +local_logger: + save_train_out: true + save_eval_out: true diff --git a/teammate_generation/configs/hydra/hydra_simple.yaml b/teammate_generation/configs/hydra/hydra_simple.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cf45f7d837365a2223968db5b447fb2847e8df6c --- /dev/null +++ b/teammate_generation/configs/hydra/hydra_simple.yaml @@ -0,0 +1,7 @@ +job: + chdir: true +run: + dir: results/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} +sweep: + dir: results_sweep/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${run.seed} diff --git a/teammate_generation/configs/task/hanabi.yaml b/teammate_generation/configs/task/hanabi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b3ef90d9f5c6eaecb774c73222263c07f274fc9 --- /dev/null +++ b/teammate_generation/configs/task/hanabi.yaml @@ -0,0 +1,16 @@ +# Hanabi: teammate generation task config. +# Mirrors ego_agent_training/configs/task/hanabi.yaml because +# teammate_generation methods (FCP, BRDiv, LBRDiv, CoMeDi) call into +# ego_agent_training as a subroutine, which asserts num_agents == 2. +# Hanabi is natively 2-player so this is satisfied by default. +ENV_NAME: hanabi +ROLLOUT_LENGTH: 128 +ENV_KWARGS: + num_agents: 2 + num_colors: 5 + num_ranks: 5 + hand_size: 5 + max_info_tokens: 8 + max_life_tokens: 3 + num_cards_of_rank: [3, 2, 2, 2, 1] +TASK_NAME: hanabi diff --git a/teammate_generation/configs/task/lbf/lbf_12x12.yaml b/teammate_generation/configs/task/lbf/lbf_12x12.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dfd28b2ecf0a9326f5c155e9aea7bb44ad8366af --- /dev/null +++ b/teammate_generation/configs/task/lbf/lbf_12x12.yaml @@ -0,0 +1,7 @@ +ENV_NAME: lbf +ROLLOUT_LENGTH: 128 +ENV_KWARGS: + grid_size: 12 + num_food: 6 + different_levels: true +TASK_NAME: lbf/lbf_12x12 diff --git a/teammate_generation/configs/task/lbf/lbf_7x7_nolevels.yaml b/teammate_generation/configs/task/lbf/lbf_7x7_nolevels.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0c4cd73b7a095d09a5ac0f831183dab4b02fd228 --- /dev/null +++ b/teammate_generation/configs/task/lbf/lbf_7x7_nolevels.yaml @@ -0,0 +1,4 @@ +ENV_NAME: lbf +ROLLOUT_LENGTH: 128 +ENV_KWARGS: {} +TASK_NAME: lbf/lbf_7x7_nolevels diff --git a/teammate_generation/configs/task/mini-hanabi.yaml b/teammate_generation/configs/task/mini-hanabi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dab5dac1f4ef37ca051ac1bd7557250d77975b48 --- /dev/null +++ b/teammate_generation/configs/task/mini-hanabi.yaml @@ -0,0 +1,13 @@ +# Mini-Hanabi: teammate generation task config. +# Mirrors ego_agent_training/configs/task/mini-hanabi.yaml. +ENV_NAME: hanabi +ROLLOUT_LENGTH: 128 +ENV_KWARGS: + num_agents: 2 + num_colors: 3 + num_ranks: 3 + hand_size: 3 + max_info_tokens: 5 + max_life_tokens: 3 + num_cards_of_rank: [2, 2, 1] +TASK_NAME: mini-hanabi diff --git a/teammate_generation/configs/task/overcooked-v1/asymm_advantages.yaml b/teammate_generation/configs/task/overcooked-v1/asymm_advantages.yaml new file mode 100644 index 0000000000000000000000000000000000000000..532df4dc33dcd5182a799ce4fe24fbca17970e0c --- /dev/null +++ b/teammate_generation/configs/task/overcooked-v1/asymm_advantages.yaml @@ -0,0 +1,6 @@ +ENV_NAME: overcooked-v1 +ROLLOUT_LENGTH: 400 # rollout length must be greater than episode length +ENV_KWARGS: + layout: asymm_advantages + random_obj_state: true +TASK_NAME: overcooked-v1/asymm_advantages diff --git a/teammate_generation/configs/task/overcooked-v1/coord_ring.yaml b/teammate_generation/configs/task/overcooked-v1/coord_ring.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e8c0b2d63cb30c5c6b8bffc548bd6872f6c2eecf --- /dev/null +++ b/teammate_generation/configs/task/overcooked-v1/coord_ring.yaml @@ -0,0 +1,14 @@ +ENV_NAME: overcooked-v1 +ROLLOUT_LENGTH: 400 # rollout length must be greater than episode length +ENV_KWARGS: + layout: coord_ring + random_obj_state: true + do_reward_shaping: true + reward_shaping_params: + PLACEMENT_IN_POT_REW: .5 # reward for putting ingredients + PLATE_PICKUP_REWARD: .1 # reward for picking up a plate + SOUP_PICKUP_REWARD: 1. # reward for picking up a ready soup + ONION_PICKUP_REWARD: .1 + COUNTER_PICKUP_REWARD: 0 + COUNTER_DROP_REWARD: 0 +TASK_NAME: overcooked-v1/coord_ring diff --git a/teammate_generation/configs/task/overcooked-v1/counter_circuit.yaml b/teammate_generation/configs/task/overcooked-v1/counter_circuit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0948e9f611d41121a93a9d5e65f251b9d749d87a --- /dev/null +++ b/teammate_generation/configs/task/overcooked-v1/counter_circuit.yaml @@ -0,0 +1,14 @@ +ENV_NAME: overcooked-v1 +ROLLOUT_LENGTH: 400 # rollout length must be greater than episode length +ENV_KWARGS: + layout: counter_circuit + random_obj_state: true + do_reward_shaping: true + reward_shaping_params: + PLACEMENT_IN_POT_REW: .5 # reward for putting ingredients + PLATE_PICKUP_REWARD: .1 # reward for picking up a plate + SOUP_PICKUP_REWARD: 1. # reward for picking up a ready soup + ONION_PICKUP_REWARD: .1 + COUNTER_PICKUP_REWARD: 0 + COUNTER_DROP_REWARD: 0 +TASK_NAME: overcooked-v1/counter_circuit diff --git a/teammate_generation/configs/task/overcooked-v1/cramped_room.yaml b/teammate_generation/configs/task/overcooked-v1/cramped_room.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c48190253e6ec321ce22a146f6a0e9fafb437c3a --- /dev/null +++ b/teammate_generation/configs/task/overcooked-v1/cramped_room.yaml @@ -0,0 +1,14 @@ +ENV_NAME: overcooked-v1 +ROLLOUT_LENGTH: 400 +ENV_KWARGS: + layout: cramped_room + random_obj_state: true + do_reward_shaping: true + reward_shaping_params: + PLACEMENT_IN_POT_REW: .5 # reward for putting ingredients + PLATE_PICKUP_REWARD: .1 # reward for picking up a plate + SOUP_PICKUP_REWARD: 1. # reward for picking up a ready soup + ONION_PICKUP_REWARD: .1 + COUNTER_PICKUP_REWARD: 0 + COUNTER_DROP_REWARD: 0 +TASK_NAME: overcooked-v1/cramped_room diff --git a/teammate_generation/configs/task/overcooked-v1/forced_coord.yaml b/teammate_generation/configs/task/overcooked-v1/forced_coord.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a3af83a96adf7f00788cb362a6a24245823ac8c9 --- /dev/null +++ b/teammate_generation/configs/task/overcooked-v1/forced_coord.yaml @@ -0,0 +1,14 @@ +ENV_NAME: overcooked-v1 +ROLLOUT_LENGTH: 400 # rollout length must be greater than episode length +ENV_KWARGS: + layout: forced_coord + random_obj_state: true + do_reward_shaping: true + reward_shaping_params: + PLACEMENT_IN_POT_REW: .5 # reward for putting ingredients + PLATE_PICKUP_REWARD: .1 # reward for picking up a plate + SOUP_PICKUP_REWARD: 1. # reward for picking up a ready soup + ONION_PICKUP_REWARD: .1 + COUNTER_PICKUP_REWARD: 0 + COUNTER_DROP_REWARD: 0 +TASK_NAME: overcooked-v1/forced_coord diff --git a/teammate_generation/experiments.sh b/teammate_generation/experiments.sh new file mode 100644 index 0000000000000000000000000000000000000000..60ecbc0ff7eac70f243b901c6435a488d4f1684b --- /dev/null +++ b/teammate_generation/experiments.sh @@ -0,0 +1,74 @@ +#!/bin/bash + +# Algorithm to run +algo="comedi" +label="heldout_teammates" +num_seeds=1 +save_local_outs=true +save_online_outs=false +wandb_mode=online + +# Create log directory if it doesn't exist +mkdir -p results/teammate_generation_logs/${algo}/${label} + +# Get current timestamp for log file +timestamp=$(date +"%Y%m%d_%H%M%S") +log_file="results/teammate_generation_logs/${algo}/${label}/experiment_${timestamp}.log" + +# Tasks to run +tasks=( + # "overcooked-v1/asymm_advantages" + # "overcooked-v1/coord_ring" + # "overcooked-v1/counter_circuit" + # "overcooked-v1/cramped_room" + # "overcooked-v1/forced_coord" + # "lbf/lbf_7x7_nolevels" + "lbf/lbf_12x12" +) + +# Function to log messages +log() { + local message="$1" + local timestamp=$(date +"%Y-%m-%d %H:%M:%S") + echo "[${timestamp}] ${message}" | tee -a "${log_file}" +} + +# Initialize counters +success_count=0 +failure_count=0 + +# Run experiments for each task +for task in "${tasks[@]}"; do + log "Starting task: ${algo}/${task}" + + if python teammate_generation/run.py algorithm="${algo}/${task}" task="${task}" \ + label="${label}" \ + algorithm.NUM_SEEDS="${num_seeds}" \ + logger.mode="${wandb_mode}" \ + logger.log_train_out="${save_online_outs}" \ + logger.log_eval_out="${save_online_outs}" \ + local_logger.save_train_out="${save_local_outs}" \ + local_logger.save_eval_out="${save_local_outs}" \ + 2>> "${log_file}"; then + log "✅ Successfully completed task: ${algo}/${task}" + ((success_count++)) + else + log "❌ Failed to complete task: ${algo}/${task}" + ((failure_count++)) + fi +done + +# Print final summary +log "Experiment Summary:" +log "Total tasks attempted: ${#tasks[@]}" +log "Successful tasks: ${success_count}" +log "Failed tasks: ${failure_count}" + +if [ ${failure_count} -gt 0 ]; then + log "Warning: Some tasks failed. Check the log file for details: ${log_file}" + exit 1 +else + log "All tasks completed successfully!" + exit 0 +fi + diff --git a/teammate_generation/fcp.py b/teammate_generation/fcp.py new file mode 100644 index 0000000000000000000000000000000000000000..6321174d2dea1bd2fdb16a2261d595d7bcdbe3de --- /dev/null +++ b/teammate_generation/fcp.py @@ -0,0 +1,116 @@ +'''Implementation of the Fictitious Co-Play teammate generation algorithm (Strouse et al. NeurIPS 2021) +https://proceedings.neurips.cc/paper/2021/hash/797134c3e42371bb4979a462eb2f042a-Abstract.html +''' +import shutil +import time +import logging +from functools import partial + +import jax +import hydra +import numpy as np +from agents.mlp_actor_critic_agent import MLPActorCriticPolicy +from agents.population_interface import AgentPopulation +from envs import make_env +from envs.log_wrapper import LogWrapper +from marl.ippo import make_train as make_ppo_train +from common.plot_utils import get_metric_names +from common.save_load_utils import save_train_run + +log = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def get_fcp_population(config, out, env): + ''' + For each seeed, flatten the partner pool for for ego training. + ''' + num_seeds = config["algorithm"]["NUM_SEEDS"] + fcp_pop_size = config["algorithm"]["PARTNER_POP_SIZE"] * config["algorithm"]["NUM_CHECKPOINTS"] + + partner_params = out['checkpoints'] # shape is (num_seeds, partner_pop_size, num_ckpts, ...) + flattened_partner_params = jax.tree.map(lambda x: x.reshape(num_seeds, fcp_pop_size, *x.shape[3:]), partner_params) + + partner_policy = MLPActorCriticPolicy( + action_dim=env.action_space(env.agents[1]).n, + obs_dim=env.observation_space(env.agents[1]).shape[0], + activation=config["algorithm"].get("ACTIVATION", "tanh") + ) + + # Create partner population + partner_population = AgentPopulation( + pop_size=fcp_pop_size, + policy_cls=partner_policy + ) + + return flattened_partner_params, partner_population + +def train_fcp_partners(rng, env, algorithm_config, wandb_logger): + '''Single seed of training an FCP pool.''' + rngs = jax.random.split(rng, algorithm_config["PARTNER_POP_SIZE"]) + train_jit = jax.jit(jax.vmap(make_ppo_train(algorithm_config, env, logger=wandb_logger))) + out = train_jit(rngs) + return out + +def run_fcp(config, wandb_logger): + ''' + Train a pool of partners for FCP. Return checkpoints for all partners. + Returns out, a dictionary of the final train_state, metrics, and checkpoints. + ''' + algorithm_config = config["algorithm"] + rng = jax.random.PRNGKey(algorithm_config["TRAIN_SEED"]) + rngs = jax.random.split(rng, algorithm_config["NUM_SEEDS"]) + + env = make_env(algorithm_config["ENV_NAME"], algorithm_config["ENV_KWARGS"]) + env = LogWrapper(env) + + start_time = time.time() + with jax.disable_jit(False): + vmapped_train_fn = jax.jit( + jax.vmap( + partial(train_fcp_partners, + env=env, + algorithm_config=algorithm_config, + wandb_logger=wandb_logger) + ) + ) + out = vmapped_train_fn(rngs) + end_time = time.time() + log.info(f"Training FCP partners took {end_time - start_time:.2f} seconds.") + + flattened_partner_params, partner_population = get_fcp_population(config, out, env) + + # Save FIRST so the checkpoint survives even if metric logging OOMs + # on long runs. Same pattern as teammate_generation/train_ego.py. + savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + out_savepath = save_train_run(out, savedir, savename="saved_train_run") + log_metrics(config, out, wandb_logger, out_savepath) + + return flattened_partner_params, partner_population + +def log_metrics(config, out, logger, out_savepath): + '''Log statistics and log saved train run to wandb as artifact.''' + metric_names = get_metric_names(config["ENV_NAME"]) + # After mask_and_mean in ippo, metrics have shape + # (num_seeds, partner_pop_size, num_partner_updates) + partner_metrics = out["metrics"] + num_partner_updates = partner_metrics["returned_episode_returns"].shape[2] + + # Average over seeds and pop members → (num_partner_updates,) + partner_stat_means = { + stat_name: np.mean(np.asarray(partner_metrics[stat_name]), axis=(0, 1)) + for stat_name in metric_names + if stat_name in partner_metrics + } + + for step in range(num_partner_updates): + for stat_name, stat_data in partner_stat_means.items(): + logger.log_item(f"Train/Partner_{stat_name}", stat_data[step], train_step=step) + + logger.commit() + + if config["logger"]["log_train_out"]: + logger.log_artifact(name="saved_train_run", path=out_savepath, type_name="train_run") + # Cleanup locally logged out file + if not config["local_logger"]["save_train_out"]: + shutil.rmtree(out_savepath) diff --git a/teammate_generation/run.py b/teammate_generation/run.py new file mode 100644 index 0000000000000000000000000000000000000000..eb0bb8021a5ec4ca4e61f675d2e627b9ffb766a6 --- /dev/null +++ b/teammate_generation/run.py @@ -0,0 +1,43 @@ +'''Main entry point for running teammate generation algorithms.''' +import hydra +from omegaconf import OmegaConf + +from evaluation.heldout_eval import run_heldout_evaluation, log_heldout_metrics +from common.plot_utils import get_metric_names +from common.wandb_visualizations import Logger +from teammate_generation.BRDiv import run_brdiv +from teammate_generation.LBRDiv import run_lbrdiv +from teammate_generation.CoMeDi import run_comedi +from teammate_generation.fcp import run_fcp +from teammate_generation.train_ego import train_ego_agent + + +@hydra.main(version_base=None, config_path="configs", config_name="base_config_teammate") +def run_training(cfg): + print(OmegaConf.to_yaml(cfg, resolve=True)) + wandb_logger = Logger(cfg) + cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + + # train partner population + if cfg["algorithm"]["ALG"] == "brdiv": + partner_params, partner_population = run_brdiv(cfg, wandb_logger) + elif cfg["algorithm"]["ALG"] == "fcp": + partner_params, partner_population = run_fcp(cfg, wandb_logger) + elif cfg["algorithm"]["ALG"] == "lbrdiv": + partner_params, partner_population = run_lbrdiv(cfg, wandb_logger) + elif cfg["algorithm"]["ALG"] == "comedi": + partner_params, partner_population = run_comedi(cfg, wandb_logger) + else: + raise NotImplementedError("Selected method not implemented.") + + metric_names = get_metric_names(cfg["task"]["ENV_NAME"]) + if cfg["train_ego"]: + ego_params, ego_policy, init_ego_params = train_ego_agent(cfg, wandb_logger, partner_params, partner_population) + + if cfg["run_heldout_eval"]: + eval_metrics, ego_names, heldout_names = run_heldout_evaluation(cfg, ego_policy, ego_params, init_ego_params, ego_as_2d=False) + log_heldout_metrics(cfg, wandb_logger, eval_metrics, ego_names, heldout_names, metric_names, ego_as_2d=False) + wandb_logger.close() + +if __name__ == '__main__': + run_training() diff --git a/teammate_generation/train_ego.py b/teammate_generation/train_ego.py new file mode 100644 index 0000000000000000000000000000000000000000..476546979451a1ceb3b56cb5ae9690518e8c8731 --- /dev/null +++ b/teammate_generation/train_ego.py @@ -0,0 +1,136 @@ +import shutil +import time +import logging + +import jax +import numpy as np +import hydra + +from envs import make_env +from envs.log_wrapper import LogWrapper + +from ego_agent_training.ppo_ego import train_ppo_ego_agent +from ego_agent_training.utils import initialize_ego_agent +from common.plot_utils import get_metric_names, get_stats +from common.save_load_utils import save_train_run + +log = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +def train_ego_agent(config, logger, partner_params, partner_population): + ''' + Train PPO ego agent against a population of partner agents. + Args: + config: dict, config for the training + partner_params: partner parameters pytree with shape (num_seeds, pop_size, ...) + partner_population: partner population object + ''' + algorithm_config = config["algorithm"]["ego_train_algorithm"] + env = make_env(algorithm_config["ENV_NAME"], algorithm_config["ENV_KWARGS"]) + env = LogWrapper(env) + + num_seeds = jax.tree.leaves(partner_params)[0].shape[0] + + rng = jax.random.PRNGKey(algorithm_config["TRAIN_SEED"]) + rng, init_rng = jax.random.split(rng, 2) + train_rngs = jax.random.split(rng, num_seeds) + + + log.info("Starting ego agent training...") + start_time = time.time() + + def train_ego_fn(rng, partner_params): + rng, init_rng, train_rng = jax.random.split(rng, 3) + ego_policy, init_ego_params = initialize_ego_agent(algorithm_config, env, init_rng) + + return train_ppo_ego_agent( + config=algorithm_config, + env=env, + train_rng=train_rng, + ego_policy=ego_policy, + init_ego_params=init_ego_params, + n_ego_train_seeds=algorithm_config["NUM_EGO_TRAIN_SEEDS"], # PER provided partner params + partner_population=partner_population, + partner_params=partner_params + ) + + # Run the training + vmapped_train_fn = jax.jit(jax.vmap(train_ego_fn, in_axes=(0, 0))) + out = vmapped_train_fn(train_rngs, partner_params) + log.info(f"Ego agent training completed in {time.time() - start_time:.2f} seconds") + + # Prepare ego params and policy for heldout evaluation + num_seeds, num_ego_train_seeds = jax.tree.leaves(out["final_params"])[0].shape[:2] + ego_params = jax.tree.map(lambda x: x.reshape(num_seeds*num_ego_train_seeds, *x.shape[2:]), + out["final_params"]) + ego_policy, init_ego_params = initialize_ego_agent(algorithm_config, env, init_rng) + + # Save checkpoint BEFORE logging metrics. Metrics logging reshapes + # large tensors that can OOM on long runs (e.g. 1e9 steps produces + # ~7 GB eval_ep_last_info). Saving first ensures the checkpoint + # survives even if metrics logging fails. + # Hit this the hard way: lost ~20h of FCP 1e9 ego training to an OOM + # in log_ego_metrics AFTER training finished, before the save. Order + # matters here: save, then log. And always .device_get() before big + # reshapes. + savedir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + out_savepath = save_train_run(out, savedir, savename="ego_train_run") + log.info(f"Saved ego checkpoint to {out_savepath}") + + # Log metrics (transfer to CPU to avoid GPU OOM on reshape) + metric_names = get_metric_names(algorithm_config["ENV_NAME"]) + log_ego_metrics(config, out, logger, metric_names, out_savepath) + + return ego_params, ego_policy, init_ego_params + +def log_ego_metrics(config, out, logger, metric_names: tuple, out_savepath: str): + '''Log metrics for the ego agent returned by the above train_ego_agent function. + ''' + # Transfer metrics to CPU before reshaping to avoid GPU OOM on long runs. + train_metrics = jax.device_get(out["metrics"]) + + # each leaf of out["metrics"] has shape (num_seeds, num_ego_train_seeds, num_updates, ...) + # we combine the first two dimensions together to get a single seeds dimension, + num_seeds, num_ego_train_seeds = train_metrics["returned_episode_returns"].shape[:2] + train_metrics = jax.tree.map(lambda x: x.reshape(num_seeds * num_ego_train_seeds, *x.shape[2:]), + train_metrics) + + #### Extract train metrics #### + train_stats = get_stats(train_metrics, metric_names) + # each key in train_stats is a metric name, and the value is an array of shape (num_seeds, num_updates, 2) + # where the last dimension contains the mean and std of the metric + train_stats = {k: np.mean(np.array(v), axis=0) for k, v in train_stats.items()} + + all_ego_value_losses = np.asarray(train_metrics["value_loss"]) # shape (num_seeds, num_updates) + all_ego_actor_losses = np.asarray(train_metrics["actor_loss"]) # shape (num_seeds, num_updates) + all_ego_entropy_losses = np.asarray(train_metrics["entropy_loss"]) # shape (num_seeds, num_updates) + + # Process eval return metrics - average across ego seeds, eval episodes, training partners + # and num_agents per game for each checkpoint + all_ego_returns = np.asarray(train_metrics["eval_ep_last_info"]["returned_episode_returns"]) # shape (num_seeds, num_updates) [pre-scalarized: mean over partners, eval eps, and agents taken inside scan] + average_ego_rets_per_iter = np.mean(all_ego_returns, axis=0) + + # Process loss metrics - average across ego seeds + average_ego_value_losses = np.mean(all_ego_value_losses, axis=0) + average_ego_actor_losses = np.mean(all_ego_actor_losses, axis=0) + average_ego_entropy_losses = np.mean(all_ego_entropy_losses, axis=0) + + # Log metrics for each update step + num_updates = len(average_ego_value_losses) + for step in range(num_updates): + for stat_name, stat_data in train_stats.items(): + # second dimension contains the mean and std of the metric + stat_mean = stat_data[step, 0] + logger.log_item(f"Train/Ego_{stat_name}", stat_mean, train_step=step, commit=True) + + logger.log_item("Eval/EgoReturn", average_ego_rets_per_iter[step], train_step=step, commit=True) + logger.log_item("Train/EgoValueLoss", average_ego_value_losses[step], train_step=step, commit=True) + logger.log_item("Train/EgoActorLoss", average_ego_actor_losses[step], train_step=step, commit=True) + logger.log_item("Train/EgoEntropyLoss", average_ego_entropy_losses[step], train_step=step, commit=True) + + logger.commit() + + if config["logger"]["log_train_out"]: + logger.log_artifact(name="ego_train_run", path=out_savepath, type_name="train_run") + if not config["local_logger"]["save_train_out"]: + shutil.rmtree(out_savepath) \ No newline at end of file