Spaces:
Running
Running
File size: 3,017 Bytes
5146e76 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | import abc
from typing import Tuple, Dict
import chex
from functools import partial
import jax
import jax.numpy as jnp
class AgentPolicy(abc.ABC):
'''Abstract base class for a policy.'''
def __init__(self, action_dim, obs_dim):
'''
Args:
action_dim: int, dimension of the action space
obs_dim: int, dimension of the observation space
'''
self.action_dim = action_dim
self.obs_dim = obs_dim
@abc.abstractmethod
@partial(jax.jit, static_argnums=(0,))
def get_action(self, params, obs, done, avail_actions, hstate, rng,
aux_obs=None, env_state=None, test_mode=False) -> Tuple[int, chex.Array]:
"""
Only computes an action given an observation, done flag, available actions, hidden state, and random key.
Args:
params (dict): The parameters of the policy.
obs (chex.Array): The observation.
done (chex.Array): The done flag.
avail_actions (chex.Array): The available actions.
hstate (chex.Array): The hidden state.
key (jax.random.PRNGKey): The random key.
env_state (chex.Array): The environment state.
aux_obs (chex.Array): an optional auxiliary vector to append to the observation
Returns:
Tuple[int, chex.Array]: A tuple containing the action and the new hidden state.
"""
pass
@partial(jax.jit, static_argnums=(0,))
def get_action_value_policy(self, params, obs, done, avail_actions, hstate, rng,
aux_obs=None, env_state=None) -> Tuple[int, chex.Array, chex.Array, chex.Array]:
"""
Computes the action, value, and policy given an observation,
done flag, available actions, hidden state, and random key.
Args:
params (dict): The parameters of the policy.
obs (chex.Array): The observation.
done (chex.Array): The done flag.
avail_actions (chex.Array): The available actions.
hstate (chex.Array): The hidden state.
key (jax.random.PRNGKey): The random key.
aux_obs (chex.Array): an optional auxiliary vector to append to the observation
Returns:
Tuple[int, chex.Array, chex.Array, chex.Array]:
A tuple containing the action, value, policy, and new hidden state.
"""
pass
def init_hstate(self, batch_size, aux_info: dict=None) -> chex.Array:
"""Initialize the hidden state for the policy.
Args:
batch_size: int, the batch size of the hidden state
aux_info: any auxiliary information needed to initialize the hidden state at the
start of an episode (e.g. the agent id).
Returns:
chex.Array: the initialized hidden state
"""
return None
def init_params(self, rng) -> Dict:
"""Initialize the parameters for the policy."""
return None
|