from functools import partial import jax class AgentPopulation: '''Base class for a population of homogeneous agents TODO: develop more complex population classes that can handle heterogeneous agents ''' def __init__(self, pop_size, policy_cls): ''' Args: pop_size: int, number of agents in the population policy_cls: an instance of the AgentPolicy class. The policy class for the population of agents ''' self.pop_size = pop_size self.policy_cls = policy_cls # AgentPolicy class def sample_agent_indices(self, n, rng): '''Sample n indices from the population, with replacement.''' return jax.random.randint(rng, (n,), 0, self.pop_size) def gather_agent_params(self, pop_params, agent_indices): '''Gather the parameters of the agents specified by agent_indices. Args: pop_params: pytree of parameters for the population of agents of shape (pop_size, ...). agent_indices: indices with shape (num_envs,), each in [0, pop_size) ''' def gather_leaf(leaf): # leaf shape: (num_envs, ...) return jax.vmap(lambda idx: leaf[idx])(agent_indices) return jax.tree.map(gather_leaf, pop_params) def get_actions(self, pop_params, agent_indices, obs, done, avail_actions, hstate, rng, env_state=None, aux_obs=None, test_mode=False): ''' Get the actions of the agents specified by agent_indices. Args: pop_params: pytree of parameters for the population of agents of shape (pop_size, ...). agent_indices: indices with shape (num_envs,), each in [0, pop_size) obs: observations with shape (num_envs, ...) done: done flags with shape (num_envs,) avail_actions: available actions with shape (num_envs, num_actions) hstate: hidden state with shape (num_envs, ...) or None if policy doesn't use hidden state rng: random key env_state: environment state with shape (num_envs, ...) or None if policy doesn't use env state aux_obs: an optional auxiliary vector to append to the observation Returns: actions: actions with shape (num_envs,) new_hstate: new hidden state with shape (num_envs, ...) or None ''' gathered_params = self.gather_agent_params(pop_params, agent_indices) num_envs = agent_indices.squeeze().shape[0] rngs_batched = jax.random.split(rng, num_envs) vmapped_get_action = jax.vmap(partial(self.policy_cls.get_action, aux_obs=aux_obs, env_state=env_state, test_mode=test_mode)) actions, new_hstate = vmapped_get_action( gathered_params, obs, done, avail_actions, hstate, rngs_batched) return actions, new_hstate def init_hstate(self, n: int, aux_info: dict=None): '''Initialize the hidden state for n members of the population.''' return self.policy_cls.init_hstate(n, aux_info)