""" Contains torch Modules for policy networks. These networks take an observation dictionary as input (and possibly additional conditioning, such as subgoal or goal dictionaries) and produce action predictions, samples, or distributions as outputs. Note that actions are assumed to lie in [-1, 1], and most networks will have a final tanh activation to help ensure this range. """ import textwrap import numpy as np from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import torch.distributions as D import robomimic.utils.tensor_utils as TensorUtils from robomimic.models.base_nets import Module from robomimic.models.transformers import GPT_Backbone from robomimic.models.obs_nets import MIMO_MLP, RNN_MIMO_MLP, MIMO_Transformer, ObservationDecoder from robomimic.models.vae_nets import VAE from robomimic.models.distributions import TanhWrappedDistribution class ActorNetwork(MIMO_MLP): """ A basic policy network that predicts actions from observations. Can optionally be goal conditioned on future observations. """ def __init__( self, obs_shapes, ac_dim, mlp_layer_dims, goal_shapes=None, encoder_kwargs=None, ): """ Args: obs_shapes (OrderedDict): a dictionary that maps observation keys to expected shapes for observations. ac_dim (int): dimension of action space. mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes. goal_shapes (OrderedDict): a dictionary that maps observation keys to expected shapes for goal observations. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-observation key information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ assert isinstance(obs_shapes, OrderedDict) self.obs_shapes = obs_shapes self.ac_dim = ac_dim # set up different observation groups for @MIMO_MLP observation_group_shapes = OrderedDict() observation_group_shapes["obs"] = OrderedDict(self.obs_shapes) self._is_goal_conditioned = False if goal_shapes is not None and len(goal_shapes) > 0: assert isinstance(goal_shapes, OrderedDict) self._is_goal_conditioned = True self.goal_shapes = OrderedDict(goal_shapes) observation_group_shapes["goal"] = OrderedDict(self.goal_shapes) else: self.goal_shapes = OrderedDict() output_shapes = self._get_output_shapes() super(ActorNetwork, self).__init__( input_obs_group_shapes=observation_group_shapes, output_shapes=output_shapes, layer_dims=mlp_layer_dims, encoder_kwargs=encoder_kwargs, ) def _get_output_shapes(self): """ Allow subclasses to re-define outputs from @MIMO_MLP, since we won't always directly predict actions, but may instead predict the parameters of a action distribution. """ return OrderedDict(action=(self.ac_dim,)) def output_shape(self, input_shape=None): return [self.ac_dim] def forward(self, obs_dict, goal_dict=None): actions = super(ActorNetwork, self).forward(obs=obs_dict, goal=goal_dict)["action"] # apply tanh squashing to ensure actions are in [-1, 1] return torch.tanh(actions) def _to_string(self): """Info to pretty print.""" return "action_dim={}".format(self.ac_dim) class PerturbationActorNetwork(ActorNetwork): """ An action perturbation network - primarily used in BCQ. It takes states and actions and returns action perturbations. """ def __init__( self, obs_shapes, ac_dim, mlp_layer_dims, perturbation_scale=0.05, goal_shapes=None, encoder_kwargs=None, ): """ Args: obs_shapes (OrderedDict): a dictionary that maps observation keys to expected shapes for observations. ac_dim (int): dimension of action space. mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes. perturbation_scale (float): the perturbation network output is always squashed to lie in +/- @perturbation_scale. The final action output is equal to the original input action added to the output perturbation (and clipped to lie in [-1, 1]). goal_shapes (OrderedDict): a dictionary that maps modality to expected shapes for goal observations. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ self.perturbation_scale = perturbation_scale # add in action as a modality new_obs_shapes = OrderedDict(obs_shapes) new_obs_shapes["action"] = (ac_dim,) # pass to super class to instantiate network super(PerturbationActorNetwork, self).__init__( obs_shapes=new_obs_shapes, ac_dim=ac_dim, mlp_layer_dims=mlp_layer_dims, goal_shapes=goal_shapes, encoder_kwargs=encoder_kwargs, ) def forward(self, obs_dict, acts, goal_dict=None): """Forward pass through perturbation actor.""" # add in actions inputs = dict(obs_dict) inputs["action"] = acts perturbations = super(PerturbationActorNetwork, self).forward(inputs, goal_dict) # add perturbations from network to original actions, and ensure the new actions lie in [-1, 1] output_actions = acts + self.perturbation_scale * perturbations output_actions = output_actions.clamp(-1.0, 1.0) return output_actions def _to_string(self): """Info to pretty print.""" return "action_dim={}, perturbation_scale={}".format(self.ac_dim, self.perturbation_scale) class GaussianActorNetwork(ActorNetwork): """ Variant of actor network that learns a diagonal unimodal Gaussian distribution over actions. """ def __init__( self, obs_shapes, ac_dim, mlp_layer_dims, fixed_std=False, std_activation="softplus", init_last_fc_weight=None, init_std=0.3, mean_limits=(-9.0, 9.0), std_limits=(0.007, 7.5), low_noise_eval=True, use_tanh=False, goal_shapes=None, encoder_kwargs=None, ): """ Args: obs_shapes (OrderedDict): a dictionary that maps modality to expected shapes for observations. ac_dim (int): dimension of action space. mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes. fixed_std (bool): if True, std is not learned, but kept constant at @init_std std_activation (None or str): type of activation to use for std deviation. Options are: None: no activation applied (not recommended unless using fixed std) `'softplus'`: Only applicable if not using fixed std. Softplus activation applied, after which the output is scaled by init_std / softplus(0) `'exp'`: Only applicable if not using fixed std. Exp applied; this corresponds to network output as being interpreted as log_std instead of std NOTE: In all cases, the final result is clipped to be within @std_limits init_last_fc_weight (None or float): if specified, will intialize the final layer network weights to be uniformly sampled from [-init_weight, init_weight] init_std (None or float): approximate initial scaling for standard deviation outputs from network. If None mean_limits (2-array): (min, max) to clamp final mean output by std_limits (2-array): (min, max) to clamp final std output by low_noise_eval (float): if True, model will output means of Gaussian distribution at eval time. use_tanh (bool): if True, use a tanh-Gaussian distribution goal_shapes (OrderedDict): a dictionary that maps modality to expected shapes for goal observations. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ # parameters specific to Gaussian actor self.fixed_std = fixed_std self.init_std = init_std self.mean_limits = np.array(mean_limits) self.std_limits = np.array(std_limits) # Define activations to use def softplus_scaled(x): out = F.softplus(x) out = out * (self.init_std / F.softplus(torch.zeros(1).to(x.device))) return out self.activations = { None: lambda x: x, "softplus": softplus_scaled, "exp": torch.exp, } assert std_activation in self.activations, \ "std_activation must be one of: {}; instead got: {}".format(self.activations.keys(), std_activation) self.std_activation = std_activation if not self.fixed_std else None self.low_noise_eval = low_noise_eval self.use_tanh = use_tanh super(GaussianActorNetwork, self).__init__( obs_shapes=obs_shapes, ac_dim=ac_dim, mlp_layer_dims=mlp_layer_dims, goal_shapes=goal_shapes, encoder_kwargs=encoder_kwargs, ) # If initialization weight was specified, make sure all final layer network weights are specified correctly if init_last_fc_weight is not None: with torch.no_grad(): for name, layer in self.nets["decoder"].nets.items(): torch.nn.init.uniform_(layer.weight, -init_last_fc_weight, init_last_fc_weight) torch.nn.init.uniform_(layer.bias, -init_last_fc_weight, init_last_fc_weight) def _get_output_shapes(self): """ Tells @MIMO_MLP superclass about the output dictionary that should be generated at the last layer. Network outputs parameters of Gaussian distribution. """ return OrderedDict( mean=(self.ac_dim,), scale=(self.ac_dim,), ) def forward_train(self, obs_dict, goal_dict=None): """ Return full Gaussian distribution, which is useful for computing quantities necessary at train-time, like log-likelihood, KL divergence, etc. Args: obs_dict (dict): batch of observations goal_dict (dict): if not None, batch of goal observations Returns: dist (Distribution): Gaussian distribution """ out = MIMO_MLP.forward(self, obs=obs_dict, goal=goal_dict) mean = out["mean"] # Use either constant std or learned std depending on setting scale = out["scale"] if not self.fixed_std else torch.ones_like(mean) * self.init_std # Clamp the mean mean = torch.clamp(mean, min=self.mean_limits[0], max=self.mean_limits[1]) # apply tanh squashing to mean if not using tanh-Gaussian to ensure mean is in [-1, 1] if not self.use_tanh: mean = torch.tanh(mean) # Calculate scale if self.low_noise_eval and (not self.training): # override std value so that you always approximately sample the mean scale = torch.ones_like(mean) * 1e-4 else: # Post-process the scale accordingly scale = self.activations[self.std_activation](scale) # Clamp the scale scale = torch.clamp(scale, min=self.std_limits[0], max=self.std_limits[1]) # the Independent call will make it so that `batch_shape` for dist will be equal to batch size # while `event_shape` will be equal to action dimension - ensuring that log-probability # computations are summed across the action dimension dist = D.Normal(loc=mean, scale=scale) dist = D.Independent(dist, 1) if self.use_tanh: # Wrap distribution with Tanh dist = TanhWrappedDistribution(base_dist=dist, scale=1.) return dist def forward(self, obs_dict, goal_dict=None): """ Samples actions from the policy distribution. Args: obs_dict (dict): batch of observations goal_dict (dict): if not None, batch of goal observations Returns: action (torch.Tensor): batch of actions from policy distribution """ dist = self.forward_train(obs_dict, goal_dict) if self.low_noise_eval and (not self.training): if self.use_tanh: # # scaling factor lets us output actions like [-1. 1.] and is consistent with the distribution transform # return (1. + 1e-6) * torch.tanh(dist.base_dist.mean) return torch.tanh(dist.mean) return dist.mean return dist.sample() def _to_string(self): """Info to pretty print.""" msg = "action_dim={}\nfixed_std={}\nstd_activation={}\ninit_std={}\nmean_limits={}\nstd_limits={}\nlow_noise_eval={}".format( self.ac_dim, self.fixed_std, self.std_activation, self.init_std, self.mean_limits, self.std_limits, self.low_noise_eval) return msg class GMMActorNetwork(ActorNetwork): """ Variant of actor network that learns a multimodal Gaussian mixture distribution over actions. """ def __init__( self, obs_shapes, ac_dim, mlp_layer_dims, num_modes=5, min_std=0.01, std_activation="softplus", low_noise_eval=True, use_tanh=False, goal_shapes=None, encoder_kwargs=None, ): """ Args: obs_shapes (OrderedDict): a dictionary that maps modality to expected shapes for observations. ac_dim (int): dimension of action space. mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes. num_modes (int): number of GMM modes min_std (float): minimum std output from network std_activation (None or str): type of activation to use for std deviation. Options are: `'softplus'`: Softplus activation applied `'exp'`: Exp applied; this corresponds to network output being interpreted as log_std instead of std low_noise_eval (float): if True, model will sample from GMM with low std, so that one of the GMM modes will be sampled (approximately) use_tanh (bool): if True, use a tanh-Gaussian distribution goal_shapes (OrderedDict): a dictionary that maps modality to expected shapes for goal observations. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ # parameters specific to GMM actor self.num_modes = num_modes self.min_std = min_std self.low_noise_eval = low_noise_eval self.use_tanh = use_tanh # Define activations to use self.activations = { "softplus": F.softplus, "exp": torch.exp, } assert std_activation in self.activations, \ "std_activation must be one of: {}; instead got: {}".format(self.activations.keys(), std_activation) self.std_activation = std_activation super(GMMActorNetwork, self).__init__( obs_shapes=obs_shapes, ac_dim=ac_dim, mlp_layer_dims=mlp_layer_dims, goal_shapes=goal_shapes, encoder_kwargs=encoder_kwargs, ) def _get_output_shapes(self): """ Tells @MIMO_MLP superclass about the output dictionary that should be generated at the last layer. Network outputs parameters of GMM distribution. """ return OrderedDict( mean=(self.num_modes, self.ac_dim), scale=(self.num_modes, self.ac_dim), logits=(self.num_modes,), ) def forward_train(self, obs_dict, goal_dict=None): """ Return full GMM distribution, which is useful for computing quantities necessary at train-time, like log-likelihood, KL divergence, etc. Args: obs_dict (dict): batch of observations goal_dict (dict): if not None, batch of goal observations Returns: dist (Distribution): GMM distribution """ out = MIMO_MLP.forward(self, obs=obs_dict, goal=goal_dict) means = out["mean"] scales = out["scale"] logits = out["logits"] # apply tanh squashing to means if not using tanh-GMM to ensure means are in [-1, 1] if not self.use_tanh: means = torch.tanh(means) # Calculate scale if self.low_noise_eval and (not self.training): # low-noise for all Gaussian dists scales = torch.ones_like(means) * 1e-4 else: # post-process the scale accordingly scales = self.activations[self.std_activation](scales) + self.min_std # mixture components - make sure that `batch_shape` for the distribution is equal # to (batch_size, num_modes) since MixtureSameFamily expects this shape component_distribution = D.Normal(loc=means, scale=scales) component_distribution = D.Independent(component_distribution, 1) # unnormalized logits to categorical distribution for mixing the modes mixture_distribution = D.Categorical(logits=logits) dist = D.MixtureSameFamily( mixture_distribution=mixture_distribution, component_distribution=component_distribution, ) if self.use_tanh: # Wrap distribution with Tanh dist = TanhWrappedDistribution(base_dist=dist, scale=1.) return dist def forward(self, obs_dict, goal_dict=None): """ Samples actions from the policy distribution. Args: obs_dict (dict): batch of observations goal_dict (dict): if not None, batch of goal observations Returns: action (torch.Tensor): batch of actions from policy distribution """ dist = self.forward_train(obs_dict, goal_dict) return dist.sample() def _to_string(self): """Info to pretty print.""" return "action_dim={}\nnum_modes={}\nmin_std={}\nstd_activation={}\nlow_noise_eval={}".format( self.ac_dim, self.num_modes, self.min_std, self.std_activation, self.low_noise_eval) class RNNActorNetwork(RNN_MIMO_MLP): """ An RNN policy network that predicts actions from observations. """ def __init__( self, obs_shapes, ac_dim, mlp_layer_dims, rnn_hidden_dim, rnn_num_layers, rnn_type="LSTM", # [LSTM, GRU] rnn_kwargs=None, goal_shapes=None, encoder_kwargs=None, ): """ Args: obs_shapes (OrderedDict): a dictionary that maps modality to expected shapes for observations. ac_dim (int): dimension of action space. mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes. rnn_hidden_dim (int): RNN hidden dimension rnn_num_layers (int): number of RNN layers rnn_type (str): [LSTM, GRU] rnn_kwargs (dict): kwargs for the torch.nn.LSTM / GRU goal_shapes (OrderedDict): a dictionary that maps modality to expected shapes for goal observations. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ self.ac_dim = ac_dim assert isinstance(obs_shapes, OrderedDict) self.obs_shapes = obs_shapes # set up different observation groups for @RNN_MIMO_MLP observation_group_shapes = OrderedDict() observation_group_shapes["obs"] = OrderedDict(self.obs_shapes) self._is_goal_conditioned = False if goal_shapes is not None and len(goal_shapes) > 0: assert isinstance(goal_shapes, OrderedDict) self._is_goal_conditioned = True self.goal_shapes = OrderedDict(goal_shapes) observation_group_shapes["goal"] = OrderedDict(self.goal_shapes) else: self.goal_shapes = OrderedDict() output_shapes = self._get_output_shapes() super(RNNActorNetwork, self).__init__( input_obs_group_shapes=observation_group_shapes, output_shapes=output_shapes, mlp_layer_dims=mlp_layer_dims, mlp_activation=nn.ReLU, mlp_layer_func=nn.Linear, rnn_hidden_dim=rnn_hidden_dim, rnn_num_layers=rnn_num_layers, rnn_type=rnn_type, rnn_kwargs=rnn_kwargs, per_step=True, encoder_kwargs=encoder_kwargs, ) def _get_output_shapes(self): """ Allow subclasses to re-define outputs from @RNN_MIMO_MLP, since we won't always directly predict actions, but may instead predict the parameters of a action distribution. """ return OrderedDict(action=(self.ac_dim,)) def output_shape(self, input_shape): # note: @input_shape should be dictionary (key: mod) # infers temporal dimension from input shape mod = list(self.obs_shapes.keys())[0] T = input_shape[mod][0] TensorUtils.assert_size_at_dim(input_shape, size=T, dim=0, msg="RNNActorNetwork: input_shape inconsistent in temporal dimension") return [T, self.ac_dim] def forward(self, obs_dict, goal_dict=None, rnn_init_state=None, return_state=False): """ Forward a sequence of inputs through the RNN and the per-step network. Args: obs_dict (dict): batch of observations - each tensor in the dictionary should have leading dimensions batch and time [B, T, ...] goal_dict (dict): if not None, batch of goal observations rnn_init_state: rnn hidden state, initialize to zero state if set to None return_state (bool): whether to return hidden state Returns: actions (torch.Tensor): predicted action sequence rnn_state: return rnn state at the end if return_state is set to True """ if self._is_goal_conditioned: assert goal_dict is not None # repeat the goal observation in time to match dimension with obs_dict mod = list(obs_dict.keys())[0] goal_dict = TensorUtils.unsqueeze_expand_at(goal_dict, size=obs_dict[mod].shape[1], dim=1) outputs = super(RNNActorNetwork, self).forward( obs=obs_dict, goal=goal_dict, rnn_init_state=rnn_init_state, return_state=return_state) if return_state: actions, state = outputs else: actions = outputs state = None # apply tanh squashing to ensure actions are in [-1, 1] actions = torch.tanh(actions["action"]) if return_state: return actions, state else: return actions def forward_step(self, obs_dict, goal_dict=None, rnn_state=None): """ Unroll RNN over single timestep to get actions. Args: obs_dict (dict): batch of observations. Should not contain time dimension. goal_dict (dict): if not None, batch of goal observations rnn_state: rnn hidden state, initialize to zero state if set to None Returns: actions (torch.Tensor): batch of actions - does not contain time dimension state: updated rnn state """ obs_dict = TensorUtils.to_sequence(obs_dict) action, state = self.forward( obs_dict, goal_dict, rnn_init_state=rnn_state, return_state=True) return action[:, 0], state def _to_string(self): """Info to pretty print.""" return "action_dim={}".format(self.ac_dim) class RNNGMMActorNetwork(RNNActorNetwork): """ An RNN GMM policy network that predicts sequences of action distributions from observation sequences. """ def __init__( self, obs_shapes, ac_dim, mlp_layer_dims, rnn_hidden_dim, rnn_num_layers, rnn_type="LSTM", # [LSTM, GRU] rnn_kwargs=None, num_modes=5, min_std=0.01, std_activation="softplus", low_noise_eval=True, use_tanh=False, goal_shapes=None, encoder_kwargs=None, ): """ Args: rnn_hidden_dim (int): RNN hidden dimension rnn_num_layers (int): number of RNN layers rnn_type (str): [LSTM, GRU] rnn_kwargs (dict): kwargs for the torch.nn.LSTM / GRU num_modes (int): number of GMM modes min_std (float): minimum std output from network std_activation (None or str): type of activation to use for std deviation. Options are: `'softplus'`: Softplus activation applied `'exp'`: Exp applied; this corresponds to network output being interpreted as log_std instead of std low_noise_eval (float): if True, model will sample from GMM with low std, so that one of the GMM modes will be sampled (approximately) use_tanh (bool): if True, use a tanh-Gaussian distribution encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ # parameters specific to GMM actor self.num_modes = num_modes self.min_std = min_std self.low_noise_eval = low_noise_eval self.use_tanh = use_tanh # Define activations to use self.activations = { "softplus": F.softplus, "exp": torch.exp, } assert std_activation in self.activations, \ "std_activation must be one of: {}; instead got: {}".format(self.activations.keys(), std_activation) self.std_activation = std_activation super(RNNGMMActorNetwork, self).__init__( obs_shapes=obs_shapes, ac_dim=ac_dim, mlp_layer_dims=mlp_layer_dims, rnn_hidden_dim=rnn_hidden_dim, rnn_num_layers=rnn_num_layers, rnn_type=rnn_type, rnn_kwargs=rnn_kwargs, goal_shapes=goal_shapes, encoder_kwargs=encoder_kwargs, ) def _get_output_shapes(self): """ Tells @MIMO_MLP superclass about the output dictionary that should be generated at the last layer. Network outputs parameters of GMM distribution. """ return OrderedDict( mean=(self.num_modes, self.ac_dim), scale=(self.num_modes, self.ac_dim), logits=(self.num_modes,), ) def forward_train(self, obs_dict, goal_dict=None, rnn_init_state=None, return_state=False): """ Return full GMM distribution, which is useful for computing quantities necessary at train-time, like log-likelihood, KL divergence, etc. Args: obs_dict (dict): batch of observations goal_dict (dict): if not None, batch of goal observations rnn_init_state: rnn hidden state, initialize to zero state if set to None return_state (bool): whether to return hidden state Returns: dists (Distribution): sequence of GMM distributions over the timesteps rnn_state: return rnn state at the end if return_state is set to True """ if self._is_goal_conditioned: assert goal_dict is not None # repeat the goal observation in time to match dimension with obs_dict mod = list(obs_dict.keys())[0] goal_dict = TensorUtils.unsqueeze_expand_at(goal_dict, size=obs_dict[mod].shape[1], dim=1) outputs = RNN_MIMO_MLP.forward( self, obs=obs_dict, goal=goal_dict, rnn_init_state=rnn_init_state, return_state=return_state) if return_state: outputs, state = outputs else: state = None means = outputs["mean"] scales = outputs["scale"] logits = outputs["logits"] # apply tanh squashing to mean if not using tanh-GMM to ensure means are in [-1, 1] if not self.use_tanh: means = torch.tanh(means) if self.low_noise_eval and (not self.training): # low-noise for all Gaussian dists scales = torch.ones_like(means) * 1e-4 else: # post-process the scale accordingly scales = self.activations[self.std_activation](scales) + self.min_std # mixture components - make sure that `batch_shape` for the distribution is equal # to (batch_size, timesteps, num_modes) since MixtureSameFamily expects this shape component_distribution = D.Normal(loc=means, scale=scales) component_distribution = D.Independent(component_distribution, 1) # shift action dim to event shape # unnormalized logits to categorical distribution for mixing the modes mixture_distribution = D.Categorical(logits=logits) dists = D.MixtureSameFamily( mixture_distribution=mixture_distribution, component_distribution=component_distribution, ) if self.use_tanh: # Wrap distribution with Tanh dists = TanhWrappedDistribution(base_dist=dists, scale=1.) if return_state: return dists, state else: return dists def forward(self, obs_dict, goal_dict=None, rnn_init_state=None, return_state=False): """ Samples actions from the policy distribution. Args: obs_dict (dict): batch of observations goal_dict (dict): if not None, batch of goal observations Returns: action (torch.Tensor): batch of actions from policy distribution """ out = self.forward_train(obs_dict=obs_dict, goal_dict=goal_dict, rnn_init_state=rnn_init_state, return_state=return_state) if return_state: ad, state = out return ad.sample(), state return out.sample() def forward_train_step(self, obs_dict, goal_dict=None, rnn_state=None): """ Unroll RNN over single timestep to get action GMM distribution, which is useful for computing quantities necessary at train-time, like log-likelihood, KL divergence, etc. Args: obs_dict (dict): batch of observations. Should not contain time dimension. goal_dict (dict): if not None, batch of goal observations rnn_state: rnn hidden state, initialize to zero state if set to None Returns: ad (Distribution): GMM action distributions state: updated rnn state """ obs_dict = TensorUtils.to_sequence(obs_dict) ad, state = self.forward_train( obs_dict, goal_dict, rnn_init_state=rnn_state, return_state=True) # to squeeze time dimension, make another action distribution assert ad.component_distribution.base_dist.loc.shape[1] == 1 assert ad.component_distribution.base_dist.scale.shape[1] == 1 assert ad.mixture_distribution.logits.shape[1] == 1 component_distribution = D.Normal( loc=ad.component_distribution.base_dist.loc.squeeze(1), scale=ad.component_distribution.base_dist.scale.squeeze(1), ) component_distribution = D.Independent(component_distribution, 1) mixture_distribution = D.Categorical(logits=ad.mixture_distribution.logits.squeeze(1)) ad = D.MixtureSameFamily( mixture_distribution=mixture_distribution, component_distribution=component_distribution, ) return ad, state def forward_step(self, obs_dict, goal_dict=None, rnn_state=None): """ Unroll RNN over single timestep to get sampled actions. Args: obs_dict (dict): batch of observations. Should not contain time dimension. goal_dict (dict): if not None, batch of goal observations rnn_state: rnn hidden state, initialize to zero state if set to None Returns: acts (torch.Tensor): batch of actions - does not contain time dimension state: updated rnn state """ obs_dict = TensorUtils.to_sequence(obs_dict) acts, state = self.forward( obs_dict, goal_dict, rnn_init_state=rnn_state, return_state=True) assert acts.shape[1] == 1 return acts[:, 0], state def _to_string(self): """Info to pretty print.""" msg = "action_dim={}, std_activation={}, low_noise_eval={}, num_nodes={}, min_std={}".format( self.ac_dim, self.std_activation, self.low_noise_eval, self.num_modes, self.min_std) return msg class TransformerActorNetwork(MIMO_Transformer): """ An Transformer policy network that predicts actions from observation sequences (assumed to be frame stacked from previous observations) and possible from previous actions as well (in an autoregressive manner). """ def __init__( self, obs_shapes, ac_dim, transformer_embed_dim, transformer_num_layers, transformer_num_heads, transformer_context_length, transformer_emb_dropout=0.1, transformer_attn_dropout=0.1, transformer_block_output_dropout=0.1, transformer_sinusoidal_embedding=False, transformer_activation="gelu", transformer_nn_parameter_for_timesteps=False, goal_shapes=None, encoder_kwargs=None, ): """ Args: obs_shapes (OrderedDict): a dictionary that maps modality to expected shapes for observations. ac_dim (int): dimension of action space. transformer_embed_dim (int): dimension for embeddings used by transformer transformer_num_layers (int): number of transformer blocks to stack transformer_num_heads (int): number of attention heads for each transformer block - must divide @transformer_embed_dim evenly. Self-attention is computed over this many partitions of the embedding dimension separately. transformer_context_length (int): expected length of input sequences transformer_embedding_dropout (float): dropout probability for embedding inputs in transformer transformer_attn_dropout (float): dropout probability for attention outputs for each transformer block transformer_block_output_dropout (float): dropout probability for final outputs for each transformer block goal_shapes (OrderedDict): a dictionary that maps modality to expected shapes for goal observations. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ self.ac_dim = ac_dim assert isinstance(obs_shapes, OrderedDict) self.obs_shapes = obs_shapes self.transformer_nn_parameter_for_timesteps = transformer_nn_parameter_for_timesteps # set up different observation groups for @RNN_MIMO_MLP observation_group_shapes = OrderedDict() observation_group_shapes["obs"] = OrderedDict(self.obs_shapes) self._is_goal_conditioned = False if goal_shapes is not None and len(goal_shapes) > 0: assert isinstance(goal_shapes, OrderedDict) self._is_goal_conditioned = True self.goal_shapes = OrderedDict(goal_shapes) observation_group_shapes["goal"] = OrderedDict(self.goal_shapes) else: self.goal_shapes = OrderedDict() output_shapes = self._get_output_shapes() super(TransformerActorNetwork, self).__init__( input_obs_group_shapes=observation_group_shapes, output_shapes=output_shapes, transformer_embed_dim=transformer_embed_dim, transformer_num_layers=transformer_num_layers, transformer_num_heads=transformer_num_heads, transformer_context_length=transformer_context_length, transformer_emb_dropout=transformer_emb_dropout, transformer_attn_dropout=transformer_attn_dropout, transformer_block_output_dropout=transformer_block_output_dropout, transformer_sinusoidal_embedding=transformer_sinusoidal_embedding, transformer_activation=transformer_activation, transformer_nn_parameter_for_timesteps=transformer_nn_parameter_for_timesteps, encoder_kwargs=encoder_kwargs, ) def _get_output_shapes(self): """ Allow subclasses to re-define outputs from @MIMO_Transformer, since we won't always directly predict actions, but may instead predict the parameters of a action distribution. """ output_shapes = OrderedDict(action=(self.ac_dim,)) return output_shapes def output_shape(self, input_shape): # note: @input_shape should be dictionary (key: mod) # infers temporal dimension from input shape mod = list(self.obs_shapes.keys())[0] T = input_shape[mod][0] TensorUtils.assert_size_at_dim(input_shape, size=T, dim=0, msg="TransformerActorNetwork: input_shape inconsistent in temporal dimension") return [T, self.ac_dim] def forward(self, obs_dict, actions=None, goal_dict=None): """ Forward a sequence of inputs through the Transformer. Args: obs_dict (dict): batch of observations - each tensor in the dictionary should have leading dimensions batch and time [B, T, ...] actions (torch.Tensor): batch of actions of shape [B, T, D] goal_dict (dict): if not None, batch of goal observations Returns: outputs (torch.Tensor or dict): contains predicted action sequence, or dictionary with predicted action sequence and predicted observation sequences """ if self._is_goal_conditioned: assert goal_dict is not None # repeat the goal observation in time to match dimension with obs_dict mod = list(obs_dict.keys())[0] goal_dict = TensorUtils.unsqueeze_expand_at(goal_dict, size=obs_dict[mod].shape[1], dim=1) forward_kwargs = dict(obs=obs_dict, goal=goal_dict) outputs = super(TransformerActorNetwork, self).forward(**forward_kwargs) # apply tanh squashing to ensure actions are in [-1, 1] outputs["action"] = torch.tanh(outputs["action"]) return outputs["action"] # only action sequences def _to_string(self): """Info to pretty print.""" return "action_dim={}".format(self.ac_dim) class TransformerGMMActorNetwork(TransformerActorNetwork): """ A Transformer GMM policy network that predicts sequences of action distributions from observation sequences (assumed to be frame stacked from previous observations). """ def __init__( self, obs_shapes, ac_dim, transformer_embed_dim, transformer_num_layers, transformer_num_heads, transformer_context_length, transformer_emb_dropout=0.1, transformer_attn_dropout=0.1, transformer_block_output_dropout=0.1, transformer_sinusoidal_embedding=False, transformer_activation="gelu", transformer_nn_parameter_for_timesteps=False, num_modes=5, min_std=0.01, std_activation="softplus", low_noise_eval=True, use_tanh=False, goal_shapes=None, encoder_kwargs=None, ): """ Args: obs_shapes (OrderedDict): a dictionary that maps modality to expected shapes for observations. ac_dim (int): dimension of action space. transformer_embed_dim (int): dimension for embeddings used by transformer transformer_num_layers (int): number of transformer blocks to stack transformer_num_heads (int): number of attention heads for each transformer block - must divide @transformer_embed_dim evenly. Self-attention is computed over this many partitions of the embedding dimension separately. transformer_context_length (int): expected length of input sequences transformer_embedding_dropout (float): dropout probability for embedding inputs in transformer transformer_attn_dropout (float): dropout probability for attention outputs for each transformer block transformer_block_output_dropout (float): dropout probability for final outputs for each transformer block num_modes (int): number of GMM modes min_std (float): minimum std output from network std_activation (None or str): type of activation to use for std deviation. Options are: `'softplus'`: Softplus activation applied `'exp'`: Exp applied; this corresponds to network output being interpreted as log_std instead of std low_noise_eval (float): if True, model will sample from GMM with low std, so that one of the GMM modes will be sampled (approximately) use_tanh (bool): if True, use a tanh-Gaussian distribution encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ # parameters specific to GMM actor self.num_modes = num_modes self.min_std = min_std self.low_noise_eval = low_noise_eval self.use_tanh = use_tanh # Define activations to use self.activations = { "softplus": F.softplus, "exp": torch.exp, } assert std_activation in self.activations, \ "std_activation must be one of: {}; instead got: {}".format(self.activations.keys(), std_activation) self.std_activation = std_activation super(TransformerGMMActorNetwork, self).__init__( obs_shapes=obs_shapes, ac_dim=ac_dim, transformer_embed_dim=transformer_embed_dim, transformer_num_layers=transformer_num_layers, transformer_num_heads=transformer_num_heads, transformer_context_length=transformer_context_length, transformer_emb_dropout=transformer_emb_dropout, transformer_attn_dropout=transformer_attn_dropout, transformer_block_output_dropout=transformer_block_output_dropout, transformer_sinusoidal_embedding=transformer_sinusoidal_embedding, transformer_activation=transformer_activation, transformer_nn_parameter_for_timesteps=transformer_nn_parameter_for_timesteps, encoder_kwargs=encoder_kwargs, goal_shapes=goal_shapes, ) def _get_output_shapes(self): """ Tells @MIMO_Transformer superclass about the output dictionary that should be generated at the last layer. Network outputs parameters of GMM distribution. """ return OrderedDict( mean=(self.num_modes, self.ac_dim), scale=(self.num_modes, self.ac_dim), logits=(self.num_modes,), ) def forward_train(self, obs_dict, actions=None, goal_dict=None, low_noise_eval=None): """ Return full GMM distribution, which is useful for computing quantities necessary at train-time, like log-likelihood, KL divergence, etc. Args: obs_dict (dict): batch of observations actions (torch.Tensor): batch of actions goal_dict (dict): if not None, batch of goal observations Returns: dists (Distribution): sequence of GMM distributions over the timesteps """ if self._is_goal_conditioned: assert goal_dict is not None # repeat the goal observation in time to match dimension with obs_dict mod = list(obs_dict.keys())[0] goal_dict = TensorUtils.unsqueeze_expand_at(goal_dict, size=obs_dict[mod].shape[1], dim=1) forward_kwargs = dict(obs=obs_dict, goal=goal_dict) outputs = MIMO_Transformer.forward(self, **forward_kwargs) means = outputs["mean"] scales = outputs["scale"] logits = outputs["logits"] # apply tanh squashing to mean if not using tanh-GMM to ensure means are in [-1, 1] if not self.use_tanh: means = torch.tanh(means) if low_noise_eval is None: low_noise_eval = self.low_noise_eval if low_noise_eval and (not self.training): # low-noise for all Gaussian dists scales = torch.ones_like(means) * 1e-4 else: # post-process the scale accordingly scales = self.activations[self.std_activation](scales) + self.min_std # mixture components - make sure that `batch_shape` for the distribution is equal # to (batch_size, timesteps, num_modes) since MixtureSameFamily expects this shape component_distribution = D.Normal(loc=means, scale=scales) component_distribution = D.Independent(component_distribution, 1) # shift action dim to event shape # unnormalized logits to categorical distribution for mixing the modes mixture_distribution = D.Categorical(logits=logits) dists = D.MixtureSameFamily( mixture_distribution=mixture_distribution, component_distribution=component_distribution, ) if self.use_tanh: # Wrap distribution with Tanh dists = TanhWrappedDistribution(base_dist=dists, scale=1.) return dists def forward(self, obs_dict, actions=None, goal_dict=None): """ Samples actions from the policy distribution. Args: obs_dict (dict): batch of observations actions (torch.Tensor): batch of actions goal_dict (dict): if not None, batch of goal observations Returns: action (torch.Tensor): batch of actions from policy distribution """ out = self.forward_train(obs_dict=obs_dict, actions=actions, goal_dict=goal_dict) return out.sample() def _to_string(self): """Info to pretty print.""" msg = "action_dim={}, std_activation={}, low_noise_eval={}, num_nodes={}, min_std={}".format( self.ac_dim, self.std_activation, self.low_noise_eval, self.num_modes, self.min_std) return msg class VAEActor(Module): """ A VAE that models a distribution of actions conditioned on observations. The VAE prior and decoder are used at test-time as the policy. """ def __init__( self, obs_shapes, ac_dim, encoder_layer_dims, decoder_layer_dims, latent_dim, device, decoder_is_conditioned=True, decoder_reconstruction_sum_across_elements=False, latent_clip=None, prior_learn=False, prior_is_conditioned=False, prior_layer_dims=(), prior_use_gmm=False, prior_gmm_num_modes=10, prior_gmm_learn_weights=False, prior_use_categorical=False, prior_categorical_dim=10, prior_categorical_gumbel_softmax_hard=False, goal_shapes=None, encoder_kwargs=None, ): """ Args: obs_shapes (OrderedDict): a dictionary that maps modality to expected shapes for observations. ac_dim (int): dimension of action space. goal_shapes (OrderedDict): a dictionary that maps modality to expected shapes for goal observations. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ super(VAEActor, self).__init__() self.obs_shapes = obs_shapes self.ac_dim = ac_dim action_shapes = OrderedDict(action=(self.ac_dim,)) # ensure VAE decoder will squash actions into [-1, 1] output_squash = ['action'] output_scales = OrderedDict(action=1.) self._vae = VAE( input_shapes=action_shapes, output_shapes=action_shapes, encoder_layer_dims=encoder_layer_dims, decoder_layer_dims=decoder_layer_dims, latent_dim=latent_dim, device=device, condition_shapes=self.obs_shapes, decoder_is_conditioned=decoder_is_conditioned, decoder_reconstruction_sum_across_elements=decoder_reconstruction_sum_across_elements, latent_clip=latent_clip, output_squash=output_squash, output_scales=output_scales, prior_learn=prior_learn, prior_is_conditioned=prior_is_conditioned, prior_layer_dims=prior_layer_dims, prior_use_gmm=prior_use_gmm, prior_gmm_num_modes=prior_gmm_num_modes, prior_gmm_learn_weights=prior_gmm_learn_weights, prior_use_categorical=prior_use_categorical, prior_categorical_dim=prior_categorical_dim, prior_categorical_gumbel_softmax_hard=prior_categorical_gumbel_softmax_hard, goal_shapes=goal_shapes, encoder_kwargs=encoder_kwargs, ) def encode(self, actions, obs_dict, goal_dict=None): """ Args: actions (torch.Tensor): a batch of actions obs_dict (dict): a dictionary that maps modalities to torch.Tensor batches. These should correspond to the observation modalities used for conditioning in either the decoder or the prior (or both). goal_dict (dict): a dictionary that maps modalities to torch.Tensor batches. These should correspond to goal modalities. Returns: posterior params (dict): dictionary with the following keys: mean (torch.Tensor): posterior encoder means logvar (torch.Tensor): posterior encoder logvars """ inputs = OrderedDict(action=actions) return self._vae.encode(inputs=inputs, conditions=obs_dict, goals=goal_dict) def decode(self, obs_dict=None, goal_dict=None, z=None, n=None): """ Thin wrapper around @VaeNets.VAE implementation. Args: obs_dict (dict): a dictionary that maps modalities to torch.Tensor batches. Only needs to be provided if @decoder_is_conditioned or @z is None (since the prior will require it to generate z). goal_dict (dict): a dictionary that maps modalities to torch.Tensor batches. These should correspond to goal modalities. z (torch.Tensor): if provided, these latents are used to generate reconstructions from the VAE, and the prior is not sampled. n (int): this argument is used to specify the number of samples to generate from the prior. Only required if @z is None - i.e. sampling takes place Returns: recons (dict): dictionary of reconstructed inputs (this will be a dictionary with a single "action" key) """ return self._vae.decode(conditions=obs_dict, goals=goal_dict, z=z, n=n) def sample_prior(self, obs_dict=None, goal_dict=None, n=None): """ Thin wrapper around @VaeNets.VAE implementation. Args: n (int): this argument is used to specify the number of samples to generate from the prior. obs_dict (dict): a dictionary that maps modalities to torch.Tensor batches. Only needs to be provided if @prior_is_conditioned. goal_dict (dict): a dictionary that maps modalities to torch.Tensor batches. These should correspond to goal modalities. Returns: z (torch.Tensor): latents sampled from the prior """ return self._vae.sample_prior(n=n, conditions=obs_dict, goals=goal_dict) def set_gumbel_temperature(self, temperature): """ Used by external algorithms to schedule Gumbel-Softmax temperature, which is used during reparametrization at train-time. Should only be used if @prior_use_categorical is True. """ self._vae.set_gumbel_temperature(temperature) def get_gumbel_temperature(self): """ Return current Gumbel-Softmax temperature. Should only be used if @prior_use_categorical is True. """ return self._vae.get_gumbel_temperature() def output_shape(self, input_shape=None): """ This implementation is required by the Module superclass, but is unused since we never chain this module to other ones. """ return [self.ac_dim] def forward_train(self, actions, obs_dict, goal_dict=None, freeze_encoder=False): """ A full pass through the VAE network used during training to construct KL and reconstruction losses. See @VAE class for more info. Args: actions (torch.Tensor): a batch of actions obs_dict (dict): a dictionary that maps modalities to torch.Tensor batches. These should correspond to the observation modalities used for conditioning in either the decoder or the prior (or both). goal_dict (dict): a dictionary that maps modalities to torch.Tensor batches. These should correspond to goal modalities. Returns: vae_outputs (dict): a dictionary that contains the following outputs. encoder_params (dict): parameters for the posterior distribution from the encoder forward pass encoder_z (torch.Tensor): latents sampled from the encoder posterior decoder_outputs (dict): action reconstructions from the decoder kl_loss (torch.Tensor): KL loss over the batch of data reconstruction_loss (torch.Tensor): reconstruction loss over the batch of data """ action_inputs = OrderedDict(action=actions) return self._vae.forward( inputs=action_inputs, outputs=action_inputs, conditions=obs_dict, goals=goal_dict, freeze_encoder=freeze_encoder) def forward(self, obs_dict, goal_dict=None, z=None): """ Samples actions from the policy distribution. Args: obs_dict (dict): batch of observations goal_dict (dict): if not None, batch of goal observations z (torch.Tensor): if not None, use the provided batch of latents instead of sampling from the prior Returns: action (torch.Tensor): batch of actions from policy distribution """ n = None if z is None: # prior will be sampled - so we must provide number of samples explicitly mod = list(obs_dict.keys())[0] n = obs_dict[mod].shape[0] return self.decode(obs_dict=obs_dict, goal_dict=goal_dict, z=z, n=n)["action"]