Spaces:
Running
Running
| from typing import Optional, Tuple | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| from ding.torch_utils import MLP, ResBlock | |
| from ding.utils import MODEL_REGISTRY, SequenceType | |
| from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork | |
| from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean | |
| # use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. | |
| class StochasticMuZeroModel(nn.Module): | |
| def __init__( | |
| self, | |
| observation_shape: SequenceType = (12, 96, 96), | |
| action_space_size: int = 6, | |
| chance_space_size: int = 2, | |
| num_res_blocks: int = 1, | |
| num_channels: int = 64, | |
| reward_head_channels: int = 16, | |
| value_head_channels: int = 16, | |
| policy_head_channels: int = 16, | |
| fc_reward_layers: SequenceType = [32], | |
| fc_value_layers: SequenceType = [32], | |
| fc_policy_layers: SequenceType = [32], | |
| reward_support_size: int = 601, | |
| value_support_size: int = 601, | |
| proj_hid: int = 1024, | |
| proj_out: int = 1024, | |
| pred_hid: int = 512, | |
| pred_out: int = 1024, | |
| self_supervised_learning_loss: bool = False, | |
| categorical_distribution: bool = True, | |
| activation: nn.Module = nn.ReLU(inplace=True), | |
| last_linear_layer_init_zero: bool = True, | |
| state_norm: bool = False, | |
| downsample: bool = False, | |
| *args, | |
| **kwargs | |
| ): | |
| """ | |
| Overview: | |
| The definition of the neural network model used in Stochastic MuZero, | |
| which is proposed in the paper https://openreview.net/pdf?id=X6D9bAHhBQ1. | |
| Stochastic MuZero model consists of a representation network, a dynamics network and a prediction network. | |
| The networks are built on convolution residual blocks and fully connected layers. | |
| Arguments: | |
| - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[12, 96, 96] for Atari. | |
| - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. | |
| - chance_space_size: (:obj:`int`): Chance space size, the action space for decision node, usually an integer | |
| number for discrete action space. | |
| - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. | |
| - num_channels (:obj:`int`): The channels of hidden states. | |
| - reward_head_channels (:obj:`int`): The channels of reward head. | |
| - value_head_channels (:obj:`int`): The channels of value head. | |
| - policy_head_channels (:obj:`int`): The channels of policy head. | |
| - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). | |
| - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). | |
| - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). | |
| - reward_support_size (:obj:`int`): The size of categorical reward output | |
| - value_support_size (:obj:`int`): The size of categorical value output. | |
| - proj_hid (:obj:`int`): The size of projection hidden layer. | |
| - proj_out (:obj:`int`): The size of projection output layer. | |
| - pred_hid (:obj:`int`): The size of prediction hidden layer. | |
| - pred_out (:obj:`int`): The size of prediction output layer. | |
| - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ | |
| in Stochastic MuZero model, default set it to False. | |
| - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical \ | |
| distribution for value and reward. | |
| - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ | |
| operation to speedup, e.g. ReLU(inplace=True). | |
| - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \ | |
| dynamics/prediction mlp, default set it to True. | |
| - state_norm (:obj:`bool`): Whether to use normalization for hidden states, default set it to False. | |
| - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ | |
| defaults to True. This option is often used in video games like Atari. In board games like go, \ | |
| we don't need this module. | |
| """ | |
| super(StochasticMuZeroModel, self).__init__() | |
| self.categorical_distribution = categorical_distribution | |
| if self.categorical_distribution: | |
| self.reward_support_size = reward_support_size | |
| self.value_support_size = value_support_size | |
| else: | |
| self.reward_support_size = 1 | |
| self.value_support_size = 1 | |
| self.action_space_size = action_space_size | |
| self.chance_space_size = chance_space_size | |
| self.proj_hid = proj_hid | |
| self.proj_out = proj_out | |
| self.pred_hid = pred_hid | |
| self.pred_out = pred_out | |
| self.self_supervised_learning_loss = self_supervised_learning_loss | |
| self.last_linear_layer_init_zero = last_linear_layer_init_zero | |
| self.state_norm = state_norm | |
| self.downsample = downsample | |
| flatten_output_size_for_reward_head = ( | |
| (reward_head_channels * math.ceil(observation_shape[1] / 16) * | |
| math.ceil(observation_shape[2] / 16)) if downsample else | |
| (reward_head_channels * observation_shape[1] * observation_shape[2]) | |
| ) | |
| flatten_output_size_for_value_head = ( | |
| (value_head_channels * math.ceil(observation_shape[1] / 16) * | |
| math.ceil(observation_shape[2] / 16)) if downsample else | |
| (value_head_channels * observation_shape[1] * observation_shape[2]) | |
| ) | |
| flatten_output_size_for_policy_head = ( | |
| (policy_head_channels * math.ceil(observation_shape[1] / 16) * | |
| math.ceil(observation_shape[2] / 16)) if downsample else | |
| (policy_head_channels * observation_shape[1] * observation_shape[2]) | |
| ) | |
| self.representation_network = RepresentationNetwork( | |
| observation_shape, | |
| num_res_blocks, | |
| num_channels, | |
| downsample, | |
| ) | |
| self.chance_encoder = ChanceEncoder( | |
| observation_shape, chance_space_size | |
| ) | |
| self.dynamics_network = DynamicsNetwork( | |
| num_res_blocks, | |
| num_channels + 1, | |
| reward_head_channels, | |
| fc_reward_layers, | |
| self.reward_support_size, | |
| flatten_output_size_for_reward_head, | |
| last_linear_layer_init_zero=self.last_linear_layer_init_zero, | |
| ) | |
| self.prediction_network = PredictionNetwork( | |
| observation_shape, | |
| action_space_size, | |
| num_res_blocks, | |
| num_channels, | |
| value_head_channels, | |
| policy_head_channels, | |
| fc_value_layers, | |
| fc_policy_layers, | |
| self.value_support_size, | |
| flatten_output_size_for_value_head, | |
| flatten_output_size_for_policy_head, | |
| last_linear_layer_init_zero=self.last_linear_layer_init_zero, | |
| ) | |
| self.afterstate_dynamics_network = AfterstateDynamicsNetwork( | |
| num_res_blocks, | |
| num_channels + 1, | |
| reward_head_channels, | |
| fc_reward_layers, | |
| self.reward_support_size, | |
| flatten_output_size_for_reward_head, | |
| last_linear_layer_init_zero=self.last_linear_layer_init_zero, | |
| ) | |
| self.afterstate_prediction_network = AfterstatePredictionNetwork( | |
| chance_space_size, | |
| num_res_blocks, | |
| num_channels, | |
| value_head_channels, | |
| policy_head_channels, | |
| fc_value_layers, | |
| fc_policy_layers, | |
| self.value_support_size, | |
| flatten_output_size_for_value_head, | |
| flatten_output_size_for_policy_head, | |
| last_linear_layer_init_zero=self.last_linear_layer_init_zero, | |
| ) | |
| if self.self_supervised_learning_loss: | |
| # projection used in EfficientZero | |
| if self.downsample: | |
| # In Atari, if the observation_shape is set to (12, 96, 96), which indicates the original shape of | |
| # (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is | |
| # (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus, | |
| # self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304 | |
| ceil_size = math.ceil(observation_shape[1] / 16) * math.ceil(observation_shape[2] / 16) | |
| self.projection_input_dim = num_channels * ceil_size | |
| else: | |
| self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2] | |
| self.projection = nn.Sequential( | |
| nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, | |
| nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, | |
| nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) | |
| ) | |
| self.prediction_head = nn.Sequential( | |
| nn.Linear(self.proj_out, self.pred_hid), | |
| nn.BatchNorm1d(self.pred_hid), | |
| activation, | |
| nn.Linear(self.pred_hid, self.pred_out), | |
| ) | |
| def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: | |
| """ | |
| Overview: | |
| Initial inference of Stochastic MuZero model, which is the first step of the Stochastic MuZero model. | |
| To perform the initial inference, we first use the representation network to obtain the ``latent_state``. | |
| Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): The 2D image observation data. | |
| Returns (MZNetworkOutput): | |
| - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. | |
| - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ | |
| In initial inference, we set it to zero vector. | |
| - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. | |
| - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. | |
| - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. | |
| - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. | |
| - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. | |
| - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ | |
| latent state, W_ is the width of latent state. | |
| """ | |
| batch_size = obs.size(0) | |
| latent_state = self._representation(obs) | |
| policy_logits, value = self._prediction(latent_state) | |
| return MZNetworkOutput( | |
| value, | |
| [0. for _ in range(batch_size)], | |
| policy_logits, | |
| latent_state, | |
| ) | |
| def recurrent_inference(self, state: torch.Tensor, option: torch.Tensor, | |
| afterstate: bool = False) -> MZNetworkOutput: | |
| """ | |
| Overview: | |
| Recurrent inference of Stochastic MuZero model, which is the rollout step of the Stochastic MuZero model. | |
| To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, | |
| ``reward``, by the given current ``latent_state`` and ``action``. | |
| We then use the prediction network to predict the ``value`` and ``policy_logits`` of the current | |
| ``latent_state``. | |
| Arguments: | |
| - state (:obj:`torch.Tensor`): The encoding latent state of input state or the afterstate. | |
| - option (:obj:`torch.Tensor`): The action to rollout or the chance to predict next latent state. | |
| - afterstate (:obj:`bool`): Whether to use afterstate prediction network to predict next latent state. | |
| Returns (MZNetworkOutput): | |
| - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. | |
| - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. | |
| - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. | |
| - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. | |
| - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. | |
| - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. | |
| - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. | |
| - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. | |
| - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. | |
| - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ | |
| latent state, W_ is the width of latent state. | |
| - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ | |
| latent state, W_ is the width of latent state. | |
| """ | |
| if afterstate: | |
| # state is afterstate, option is chance | |
| next_latent_state, reward = self._dynamics(state, option) | |
| policy_logits, value = self._prediction(next_latent_state) | |
| return MZNetworkOutput(value, reward, policy_logits, next_latent_state) | |
| else: | |
| # state is latent_state, option is action | |
| next_afterstate, reward = self._afterstate_dynamics(state, option) | |
| policy_logits, value = self._afterstate_prediction(next_afterstate) | |
| return MZNetworkOutput(value, reward, policy_logits, next_afterstate) | |
| def _representation(self, observation: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Use the representation network to encode the observations into latent state. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): The 2D image observation data. | |
| Returns: | |
| - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. | |
| - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ | |
| latent state, W_ is the width of latent state. | |
| """ | |
| latent_state = self.representation_network(observation) | |
| if self.state_norm: | |
| latent_state = renormalize(latent_state) | |
| return latent_state | |
| def chance_encode(self, observation: torch.Tensor): | |
| output = self.chance_encoder(observation) | |
| return output | |
| def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Overview: | |
| Use the prediction network to predict ``policy_logits`` and ``value``. | |
| Arguments: | |
| - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. | |
| Returns: | |
| - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. | |
| - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. | |
| Shapes: | |
| - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ | |
| latent state, W_ is the width of latent state. | |
| - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. | |
| - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. | |
| """ | |
| return self.prediction_network(latent_state) | |
| def _afterstate_prediction(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Overview: | |
| Use the prediction network to predict ``policy_logits`` and ``value``. | |
| Arguments: | |
| - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. | |
| Returns: | |
| - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. | |
| - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. | |
| Shapes: | |
| - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ | |
| latent state, W_ is the width of latent state. | |
| - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. | |
| - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. | |
| """ | |
| return self.afterstate_prediction_network(afterstate) | |
| def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Overview: | |
| Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` | |
| and ``reward``. | |
| Arguments: | |
| - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. | |
| - action (:obj:`torch.Tensor`): The predicted action to rollout. | |
| Returns: | |
| - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. | |
| - reward (:obj:`torch.Tensor`): The predicted reward of the current latent state and selected action. | |
| Shapes: | |
| - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ | |
| latent state, W_ is the width of latent state. | |
| - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. | |
| - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ | |
| latent state, W_ is the width of latent state. | |
| - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. | |
| """ | |
| # NOTE: the discrete action encoding type is important for some environments | |
| # discrete action space | |
| # the final action_encoding shape is (batch_size, 1, latent_state[2], latent_state[3]), e.g. (8, 1, 4, 1). | |
| action_encoding = ( | |
| torch.ones(( | |
| latent_state.shape[0], | |
| 1, | |
| latent_state.shape[2], | |
| latent_state.shape[3], | |
| )).to(action.device).float() | |
| ) | |
| if len(action.shape) == 2: | |
| # (batch_size, action_dim) -> (batch_size, action_dim, 1) | |
| # e.g., torch.Size([8, 1]) -> torch.Size([8, 1, 1]) | |
| action = action.unsqueeze(-1) | |
| elif len(action.shape) == 1: | |
| # (batch_size,) -> (batch_size, action_dim=1, 1) | |
| # e.g., -> torch.Size([8, 1]) -> torch.Size([8, 1, 1]) | |
| action = action.unsqueeze(-1).unsqueeze(-1) | |
| # action[:, 0, None, None] shape: (batch_size, action_dim, 1, 1) e.g. (8, 1, 1, 1) | |
| # the final action_encoding shape: (batch_size, 1, latent_state[2], latent_state[3]) e.g. (8, 1, 4, 1), | |
| # where each element is normalized as action[i]/action_space_size | |
| action_encoding = (action[:, 0, None, None] * action_encoding / self.chance_space_size) | |
| # state_action_encoding shape: (batch_size, latent_state[1] + 1, latent_state[2], latent_state[3]) | |
| state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) | |
| next_latent_state, reward = self.dynamics_network(state_action_encoding) | |
| if self.state_norm: | |
| next_latent_state = renormalize(next_latent_state) | |
| return next_latent_state, reward | |
| def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[ | |
| torch.Tensor, torch.Tensor]: | |
| """ | |
| Overview: | |
| Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` | |
| and ``reward``. | |
| Arguments: | |
| - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. | |
| - action (:obj:`torch.Tensor`): The predicted action to rollout. | |
| Returns: | |
| - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. | |
| - reward (:obj:`torch.Tensor`): The predicted reward of the current latent state and selected action. | |
| Shapes: | |
| - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ | |
| latent state, W_ is the width of latent state. | |
| - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. | |
| - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ | |
| latent state, W_ is the width of latent state. | |
| - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. | |
| """ | |
| # NOTE: the discrete action encoding type is important for some environments | |
| # discrete action space | |
| # the final action_encoding shape is (batch_size, 1, latent_state[2], latent_state[3]), e.g. (8, 1, 4, 1). | |
| action_encoding = ( | |
| torch.ones(( | |
| latent_state.shape[0], | |
| 1, | |
| latent_state.shape[2], | |
| latent_state.shape[3], | |
| )).to(action.device).float() | |
| ) | |
| if len(action.shape) == 2: | |
| # (batch_size, action_dim) -> (batch_size, action_dim, 1) | |
| # e.g., torch.Size([8, 1]) -> torch.Size([8, 1, 1]) | |
| action = action.unsqueeze(-1) | |
| elif len(action.shape) == 1: | |
| # (batch_size,) -> (batch_size, action_dim=1, 1) | |
| # e.g., -> torch.Size([8, 1]) -> torch.Size([8, 1, 1]) | |
| action = action.unsqueeze(-1).unsqueeze(-1) | |
| # action[:, 0, None, None] shape: (batch_size, action_dim, 1, 1) e.g. (8, 1, 1, 1) | |
| # the final action_encoding shape: (batch_size, 1, latent_state[2], latent_state[3]) e.g. (8, 1, 4, 1), | |
| # where each element is normalized as action[i]/action_space_size | |
| action_encoding = (action[:, 0, None, None] * action_encoding / self.action_space_size) | |
| # state_action_encoding shape: (batch_size, latent_state[1] + 1, latent_state[2], latent_state[3]) | |
| state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) | |
| next_latent_state, reward = self.afterstate_dynamics_network(state_action_encoding) | |
| if self.state_norm: | |
| next_latent_state = renormalize(next_latent_state) | |
| return next_latent_state, reward | |
| def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Project the latent state to a lower dimension to calculate the self-supervised loss, which is involved in | |
| in EfficientZero. | |
| For more details, please refer to paper ``Exploring Simple Siamese Representation Learning``. | |
| Arguments: | |
| - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. | |
| - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. | |
| Returns: | |
| - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. | |
| Shapes: | |
| - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ | |
| latent state, W_ is the width of latent state. | |
| - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. | |
| Examples: | |
| >>> latent_state = torch.randn(256, 64, 6, 6) | |
| >>> output = self.project(latent_state) | |
| >>> output.shape # (256, 1024) | |
| .. note:: | |
| for Atari: | |
| observation_shape = (12, 96, 96), # original shape is (3,96,96), frame_stack_num=4 | |
| if downsample is True, latent_state.shape: (batch_size, num_channel, obs_shape[1] / 16, obs_shape[2] / 16) | |
| i.e., (256, 64, 96 / 16, 96 / 16) = (256, 64, 6, 6) | |
| latent_state reshape: (256, 64, 6, 6) -> (256,64*6*6) = (256, 2304) | |
| # self.projection_input_dim = 64*6*6 = 2304 | |
| # self.projection_output_dim = 1024 | |
| """ | |
| latent_state = latent_state.reshape(latent_state.shape[0], -1) | |
| proj = self.projection(latent_state) | |
| if with_grad: | |
| # with grad, use prediction_head | |
| return self.prediction_head(proj) | |
| else: | |
| return proj.detach() | |
| def get_params_mean(self) -> float: | |
| return get_params_mean(self) | |
| class DynamicsNetwork(nn.Module): | |
| def __init__( | |
| self, | |
| num_res_blocks: int, | |
| num_channels: int, | |
| reward_head_channels: int, | |
| fc_reward_layers: SequenceType, | |
| output_support_size: int, | |
| flatten_output_size_for_reward_head: int, | |
| last_linear_layer_init_zero: bool = True, | |
| activation: Optional[nn.Module] = nn.ReLU(inplace=True), | |
| ): | |
| """ | |
| Overview: | |
| The definition of dynamics network in Stochastic MuZero algorithm, which is used to predict next latent state and | |
| reward given current latent state and action. | |
| Arguments: | |
| - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. | |
| - num_channels (:obj:`int`): The channels of input, including obs and action encoding. | |
| - reward_head_channels (:obj:`int`): The channels of reward head. | |
| - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). | |
| - output_support_size (:obj:`int`): The size of categorical reward output. | |
| - flatten_output_size_for_reward_head (:obj:`int`): The flatten size of output for reward head, i.e., \ | |
| the input size of reward head. | |
| - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \ | |
| reward mlp, default set it to True. | |
| - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ | |
| operation to speedup, e.g. ReLU(inplace=True). | |
| """ | |
| super().__init__() | |
| self.num_channels = num_channels | |
| self.flatten_output_size_for_reward_head = flatten_output_size_for_reward_head | |
| self.conv = nn.Conv2d(num_channels, num_channels - 1, kernel_size=3, stride=1, padding=1, bias=False) | |
| self.bn = nn.BatchNorm2d(num_channels - 1) | |
| self.resblocks = nn.ModuleList( | |
| [ | |
| ResBlock( | |
| in_channels=num_channels - 1, activation=activation, norm_type='BN', res_type='basic', bias=False | |
| ) for _ in range(num_res_blocks) | |
| ] | |
| ) | |
| self.conv1x1_reward = nn.Conv2d(num_channels - 1, reward_head_channels, 1) | |
| self.bn_reward = nn.BatchNorm2d(reward_head_channels) | |
| self.fc_reward_head = MLP( | |
| self.flatten_output_size_for_reward_head, | |
| hidden_channels=fc_reward_layers[0], | |
| layer_num=len(fc_reward_layers) + 1, | |
| out_channels=output_support_size, | |
| activation=activation, | |
| norm_type='BN', | |
| output_activation=False, | |
| output_norm=False, | |
| last_linear_layer_init_zero=last_linear_layer_init_zero | |
| ) | |
| self.activation = activation | |
| def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Overview: | |
| Forward computation of the dynamics network. Predict next latent state given current latent state and action. | |
| Arguments: | |
| - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ | |
| latent state and action encoding, with shape (batch_size, num_channels, height, width). | |
| Returns: | |
| - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, num_channels, \ | |
| height, width). | |
| - reward (:obj:`torch.Tensor`): The predicted reward, with shape (batch_size, output_support_size). | |
| """ | |
| # take the state encoding (latent_state), state_action_encoding[:, -1, :, :] is action encoding | |
| latent_state = state_action_encoding[:, :-1, :, :] | |
| x = self.conv(state_action_encoding) | |
| x = self.bn(x) | |
| # the residual link: add state encoding to the state_action encoding | |
| x += latent_state | |
| x = self.activation(x) | |
| for block in self.resblocks: | |
| x = block(x) | |
| next_latent_state = x | |
| x = self.conv1x1_reward(next_latent_state) | |
| x = self.bn_reward(x) | |
| x = self.activation(x) | |
| x = x.view(-1, self.flatten_output_size_for_reward_head) | |
| # use the fully connected layer to predict reward | |
| reward = self.fc_reward_head(x) | |
| return next_latent_state, reward | |
| def get_dynamic_mean(self) -> float: | |
| return get_dynamic_mean(self) | |
| def get_reward_mean(self) -> float: | |
| return get_reward_mean(self) | |
| # TODO(pu): customize different afterstate dynamics network | |
| AfterstateDynamicsNetwork = DynamicsNetwork | |
| class AfterstatePredictionNetwork(nn.Module): | |
| def __init__( | |
| self, | |
| action_space_size: int, | |
| num_res_blocks: int, | |
| num_channels: int, | |
| value_head_channels: int, | |
| policy_head_channels: int, | |
| fc_value_layers: int, | |
| fc_policy_layers: int, | |
| output_support_size: int, | |
| flatten_output_size_for_value_head: int, | |
| flatten_output_size_for_policy_head: int, | |
| last_linear_layer_init_zero: bool = True, | |
| activation: nn.Module = nn.ReLU(inplace=True), | |
| ) -> None: | |
| """ | |
| Overview: | |
| The definition of afterstate policy and value prediction network, which is used to predict value and policy by the | |
| given afterstate. | |
| Arguments: | |
| - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. | |
| - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. | |
| - num_channels (:obj:`int`): The channels of hidden states. | |
| - value_head_channels (:obj:`int`): The channels of value head. | |
| - policy_head_channels (:obj:`int`): The channels of policy head. | |
| - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). | |
| - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). | |
| - output_support_size (:obj:`int`): The size of categorical value output. | |
| - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ | |
| - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ | |
| of the value head. | |
| - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ | |
| of the policy head. | |
| - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \ | |
| dynamics/prediction mlp, default set it to True. | |
| - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ | |
| operation to speedup, e.g. ReLU(inplace=True). | |
| """ | |
| super(AfterstatePredictionNetwork, self).__init__() | |
| self.resblocks = nn.ModuleList( | |
| [ | |
| ResBlock(in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False) | |
| for _ in range(num_res_blocks) | |
| ] | |
| ) | |
| self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) | |
| self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) | |
| self.bn_value = nn.BatchNorm2d(value_head_channels) | |
| self.bn_policy = nn.BatchNorm2d(policy_head_channels) | |
| self.flatten_output_size_for_value_head = flatten_output_size_for_value_head | |
| self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head | |
| self.activation = activation | |
| self.fc_value = MLP( | |
| in_channels=self.flatten_output_size_for_value_head, | |
| hidden_channels=fc_value_layers[0], | |
| out_channels=output_support_size, | |
| layer_num=len(fc_value_layers) + 1, | |
| activation=self.activation, | |
| norm_type='BN', | |
| output_activation=False, | |
| output_norm=False, | |
| last_linear_layer_init_zero=last_linear_layer_init_zero | |
| ) | |
| self.fc_policy = MLP( | |
| in_channels=self.flatten_output_size_for_policy_head, | |
| hidden_channels=fc_policy_layers[0], | |
| out_channels=action_space_size, | |
| layer_num=len(fc_policy_layers) + 1, | |
| activation=self.activation, | |
| norm_type='BN', | |
| output_activation=False, | |
| output_norm=False, | |
| last_linear_layer_init_zero=last_linear_layer_init_zero | |
| ) | |
| def forward(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Overview: | |
| Forward computation of the afterstate prediction network. | |
| Arguments: | |
| - afterstate (:obj:`torch.Tensor`): input tensor with shape (B, afterstate_dim). | |
| Returns: | |
| - afterstate_policy_logits (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). | |
| - afterstate_value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). | |
| """ | |
| for res_block in self.resblocks: | |
| afterstate = res_block(afterstate) | |
| value = self.conv1x1_value(afterstate) | |
| value = self.bn_value(value) | |
| value = self.activation(value) | |
| policy = self.conv1x1_policy(afterstate) | |
| policy = self.bn_policy(policy) | |
| policy = self.activation(policy) | |
| value = value.reshape(-1, self.flatten_output_size_for_value_head) | |
| policy = policy.reshape(-1, self.flatten_output_size_for_policy_head) | |
| afterstate_value = self.fc_value(value) | |
| afterstate_policy_logits = self.fc_policy(policy) | |
| return afterstate_policy_logits, afterstate_value | |
| class ChanceEncoderBackbone(nn.Module): | |
| """ | |
| Overview: | |
| The definition of chance encoder backbone network, \ | |
| which is used to encode the (image) observation into a latent space. | |
| Arguments: | |
| - input_dimensions (:obj:`tuple`): The dimension of observation space. | |
| - chance_encoding_dim (:obj:`int`): The dimension of chance encoding. | |
| """ | |
| def __init__(self, input_dimensions, chance_encoding_dim=4): | |
| super(ChanceEncoderBackbone, self).__init__() | |
| self.conv1 = nn.Conv2d(input_dimensions[0] * 2, 32, 3, padding=1) | |
| self.conv2 = nn.Conv2d(32, 64, 3, padding=1) | |
| self.fc1 = nn.Linear(64 * input_dimensions[1] * input_dimensions[2], 128) | |
| self.fc2 = nn.Linear(128, 64) | |
| self.fc3 = nn.Linear(64, chance_encoding_dim) | |
| def forward(self, x): | |
| x = torch.relu(self.conv1(x)) | |
| x = torch.relu(self.conv2(x)) | |
| x = x.view(x.shape[0], -1) | |
| x = torch.relu(self.fc1(x)) | |
| x = torch.relu(self.fc2(x)) | |
| x = self.fc3(x) | |
| return x | |
| class ChanceEncoderBackboneMLP(nn.Module): | |
| """ | |
| Overview: | |
| The definition of chance encoder backbone network, \ | |
| which is used to encode the (vector) observation into a latent space. | |
| Arguments: | |
| - input_dimensions (:obj:`tuple`): The dimension of observation space. | |
| - chance_encoding_dim (:obj:`int`): The dimension of chance encoding. | |
| """ | |
| def __init__(self, input_dimensions, chance_encoding_dim=4): | |
| super(ChanceEncoderBackboneMLP, self).__init__() | |
| self.fc1 = nn.Linear(input_dimensions, 128) | |
| self.fc2 = nn.Linear(128, 64) | |
| self.fc3 = nn.Linear(64, chance_encoding_dim) | |
| def forward(self, x): | |
| x = torch.relu(self.fc1(x)) | |
| x = torch.relu(self.fc2(x)) | |
| x = self.fc3(x) | |
| return x | |
| class ChanceEncoder(nn.Module): | |
| def __init__(self, input_dimensions, action_dimension, encoder_backbone_type='conv'): | |
| super().__init__() | |
| # Specify the action space for the model | |
| self.action_space = action_dimension | |
| if encoder_backbone_type == 'conv': | |
| # Define the encoder, which transforms observations into a latent space | |
| self.encoder = ChanceEncoderBackbone(input_dimensions, action_dimension) | |
| elif encoder_backbone_type == 'mlp': | |
| self.encoder = ChanceEncoderBackboneMLP(input_dimensions, action_dimension) | |
| else: | |
| raise ValueError('Encoder backbone type not supported') | |
| # Using the Straight Through Estimator method for backpropagation | |
| self.onehot_argmax = StraightThroughEstimator() | |
| def forward(self, observations): | |
| """ | |
| Overview: | |
| Forward method for the ChanceEncoder. This method takes an observation \ | |
| and applies the encoder to transform it to a latent space. Then applies the \ | |
| StraightThroughEstimator to this encoding. \ | |
| References: Planning in Stochastic Environments with a Learned Model (ICLR 2022), page 5, | |
| Chance Outcomes section. | |
| Arguments: | |
| - observations (:obj:`torch.Tensor`): Observation tensor. | |
| Returns: | |
| - chance (:obj:`torch.Tensor`): Transformed tensor after applying one-hot argmax. | |
| - chance_encoding (:obj:`torch.Tensor`): Encoding of the input observation tensor. | |
| """ | |
| # Apply the encoder to the observation | |
| chance_encoding = self.encoder(observations) | |
| # Apply one-hot argmax to the encoding | |
| chance_onehot = self.onehot_argmax(chance_encoding) | |
| return chance_encoding, chance_onehot | |
| class StraightThroughEstimator(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x): | |
| """ | |
| Overview: | |
| Forward method for the StraightThroughEstimator. This applies the one-hot argmax \ | |
| function to the input tensor. | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): Input tensor. | |
| Returns: | |
| - (:obj:`torch.Tensor`): Transformed tensor after applying one-hot argmax. | |
| """ | |
| # Apply one-hot argmax to the input | |
| x = OnehotArgmax.apply(x) | |
| return x | |
| class OnehotArgmax(torch.autograd.Function): | |
| """ | |
| Overview: | |
| Custom PyTorch function for one-hot argmax. This function transforms the input tensor \ | |
| into a one-hot tensor where the index with the maximum value in the original tensor is \ | |
| set to 1 and all other indices are set to 0. It allows gradients to flow to the encoder \ | |
| during backpropagation. | |
| For more information, refer to: \ | |
| https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html | |
| """ | |
| def forward(ctx, input): | |
| """ | |
| Overview: | |
| Forward method for the one-hot argmax function. This method transforms the input \ | |
| tensor into a one-hot tensor. | |
| Arguments: | |
| - ctx (:obj:`context`): A context object that can be used to stash information for | |
| backward computation. | |
| - input (:obj:`torch.Tensor`): Input tensor. | |
| Returns: | |
| - (:obj:`torch.Tensor`): One-hot tensor. | |
| """ | |
| # Transform the input tensor to a one-hot tensor | |
| return torch.zeros_like(input).scatter_(-1, torch.argmax(input, dim=-1, keepdim=True), 1.) | |
| def backward(ctx, grad_output): | |
| """ | |
| Overview: | |
| Backward method for the one-hot argmax function. This method allows gradients \ | |
| to flow to the encoder during backpropagation. | |
| Arguments: | |
| - ctx (:obj:`context`): A context object that was stashed in the forward pass. | |
| - grad_output (:obj:`torch.Tensor`): The gradient of the output tensor. | |
| Returns: | |
| - (:obj:`torch.Tensor`): The gradient of the input tensor. | |
| """ | |
| return grad_output | |