File size: 3,017 Bytes
5146e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import abc
from typing import Tuple, Dict
import chex
from functools import partial
import jax
import jax.numpy as jnp


class AgentPolicy(abc.ABC):
    '''Abstract base class for a policy.'''

    def __init__(self, action_dim, obs_dim):
        '''
        Args:
            action_dim: int, dimension of the action space
            obs_dim: int, dimension of the observation space
        '''
        self.action_dim = action_dim
        self.obs_dim = obs_dim

    @abc.abstractmethod
    @partial(jax.jit, static_argnums=(0,))
    def get_action(self, params, obs, done, avail_actions, hstate, rng,
                   aux_obs=None, env_state=None, test_mode=False) -> Tuple[int, chex.Array]:
        """
        Only computes an action given an observation, done flag, available actions, hidden state, and random key.

        Args:
            params (dict): The parameters of the policy.
            obs (chex.Array): The observation.
            done (chex.Array): The done flag.
            avail_actions (chex.Array): The available actions.
            hstate (chex.Array): The hidden state.
            key (jax.random.PRNGKey): The random key.
            env_state (chex.Array): The environment state.
            aux_obs (chex.Array): an optional auxiliary vector to append to the observation
        Returns:
            Tuple[int, chex.Array]: A tuple containing the action and the new hidden state.
        """
        pass

    @partial(jax.jit, static_argnums=(0,))
    def get_action_value_policy(self, params, obs, done, avail_actions, hstate, rng,
                                aux_obs=None, env_state=None) -> Tuple[int, chex.Array, chex.Array, chex.Array]:
        """
        Computes the action, value, and policy given an observation,
        done flag, available actions, hidden state, and random key.

        Args:
            params (dict): The parameters of the policy.
            obs (chex.Array): The observation.
            done (chex.Array): The done flag.
            avail_actions (chex.Array): The available actions.
            hstate (chex.Array): The hidden state.
            key (jax.random.PRNGKey): The random key.
            aux_obs (chex.Array): an optional auxiliary vector to append to the observation
        Returns:
            Tuple[int, chex.Array, chex.Array, chex.Array]:
                A tuple containing the action, value, policy, and new hidden state.
        """
        pass

    def init_hstate(self, batch_size, aux_info: dict=None) -> chex.Array:
        """Initialize the hidden state for the policy.
        Args:
            batch_size: int, the batch size of the hidden state
            aux_info: any auxiliary information needed to initialize the hidden state at the
            start of an episode (e.g. the agent id).
        Returns:
            chex.Array: the initialized hidden state
        """
        return None

    def init_params(self, rng) -> Dict:
        """Initialize the parameters for the policy."""
        return None