""" Contains torch Modules for value networks. These networks take an observation dictionary as input (and possibly additional conditioning, such as subgoal or goal dictionaries) and produce value or action-value estimates or distributions. """ import numpy as np from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import torch.distributions as D import robomimic.utils.tensor_utils as TensorUtils from robomimic.models.obs_nets import MIMO_MLP from robomimic.models.distributions import DiscreteValueDistribution class ValueNetwork(MIMO_MLP): """ A basic value network that predicts values from observations. Can optionally be goal conditioned on future observations. """ def __init__( self, obs_shapes, mlp_layer_dims, value_bounds=None, goal_shapes=None, encoder_kwargs=None, ): """ Args: obs_shapes (OrderedDict): a dictionary that maps observation keys to expected shapes for observations. mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes. value_bounds (tuple): a 2-tuple corresponding to the lowest and highest possible return that the network should be possible of generating. The network will rescale outputs using a tanh layer to lie within these bounds. If None, no tanh re-scaling is done. goal_shapes (OrderedDict): a dictionary that maps observation keys to expected shapes for goal observations. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-observation key information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ self.value_bounds = value_bounds if self.value_bounds is not None: # convert [lb, ub] to a scale and offset for the tanh output, which is in [-1, 1] self._value_scale = (float(self.value_bounds[1]) - float(self.value_bounds[0])) / 2. self._value_offset = (float(self.value_bounds[1]) + float(self.value_bounds[0])) / 2. assert isinstance(obs_shapes, OrderedDict) self.obs_shapes = obs_shapes # set up different observation groups for @MIMO_MLP observation_group_shapes = OrderedDict() observation_group_shapes["obs"] = OrderedDict(self.obs_shapes) self._is_goal_conditioned = False if goal_shapes is not None and len(goal_shapes) > 0: assert isinstance(goal_shapes, OrderedDict) self._is_goal_conditioned = True self.goal_shapes = OrderedDict(goal_shapes) observation_group_shapes["goal"] = OrderedDict(self.goal_shapes) else: self.goal_shapes = OrderedDict() output_shapes = self._get_output_shapes() super(ValueNetwork, self).__init__( input_obs_group_shapes=observation_group_shapes, output_shapes=output_shapes, layer_dims=mlp_layer_dims, encoder_kwargs=encoder_kwargs, ) def _get_output_shapes(self): """ Allow subclasses to re-define outputs from @MIMO_MLP, since we won't always directly predict values, but may instead predict the parameters of a value distribution. """ return OrderedDict(value=(1,)) def output_shape(self, input_shape=None): """ Function to compute output shape from inputs to this module. Args: input_shape (iterable of int): shape of input. Does not include batch dimension. Some modules may not need this argument, if their output does not depend on the size of the input, or if they assume fixed size input. Returns: out_shape ([int]): list of integers corresponding to output shape """ return [1] def forward(self, obs_dict, goal_dict=None): """ Forward through value network, and then optionally use tanh scaling. """ values = super(ValueNetwork, self).forward(obs=obs_dict, goal=goal_dict)["value"] if self.value_bounds is not None: values = self._value_offset + self._value_scale * torch.tanh(values) return values def _to_string(self): return "value_bounds={}".format(self.value_bounds) class ActionValueNetwork(ValueNetwork): """ A basic Q (action-value) network that predicts values from observations and actions. Can optionally be goal conditioned on future observations. """ def __init__( self, obs_shapes, ac_dim, mlp_layer_dims, value_bounds=None, goal_shapes=None, encoder_kwargs=None, ): """ Args: obs_shapes (OrderedDict): a dictionary that maps observation keys to expected shapes for observations. ac_dim (int): dimension of action space. mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes. value_bounds (tuple): a 2-tuple corresponding to the lowest and highest possible return that the network should be possible of generating. The network will rescale outputs using a tanh layer to lie within these bounds. If None, no tanh re-scaling is done. goal_shapes (OrderedDict): a dictionary that maps observation keys to expected shapes for goal observations. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-observation key information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ # add in action as a modality new_obs_shapes = OrderedDict(obs_shapes) new_obs_shapes["action"] = (ac_dim,) self.ac_dim = ac_dim # pass to super class to instantiate network super(ActionValueNetwork, self).__init__( obs_shapes=new_obs_shapes, mlp_layer_dims=mlp_layer_dims, value_bounds=value_bounds, goal_shapes=goal_shapes, encoder_kwargs=encoder_kwargs, ) def forward(self, obs_dict, acts, goal_dict=None): """ Modify forward from super class to include actions in inputs. """ inputs = dict(obs_dict) inputs["action"] = acts return super(ActionValueNetwork, self).forward(inputs, goal_dict) def _to_string(self): return "action_dim={}\nvalue_bounds={}".format(self.ac_dim, self.value_bounds) class DistributionalActionValueNetwork(ActionValueNetwork): """ Distributional Q (action-value) network that outputs a categorical distribution over a discrete grid of value atoms. See https://arxiv.org/pdf/1707.06887.pdf for more details. """ def __init__( self, obs_shapes, ac_dim, mlp_layer_dims, value_bounds, num_atoms, goal_shapes=None, encoder_kwargs=None, ): """ Args: obs_shapes (OrderedDict): a dictionary that maps modality to expected shapes for observations. ac_dim (int): dimension of action space. mlp_layer_dims ([int]): sequence of integers for the MLP hidden layers sizes. value_bounds (tuple): a 2-tuple corresponding to the lowest and highest possible return that the network should be possible of generating. This defines the support of the value distribution. num_atoms (int): number of value atoms to use for the categorical distribution - which is the representation of the value distribution. goal_shapes (OrderedDict): a dictionary that maps modality to expected shapes for goal observations. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ # parameters specific to DistributionalActionValueNetwork self.num_atoms = num_atoms self._atoms = np.linspace(value_bounds[0], value_bounds[1], num_atoms) # pass to super class to instantiate network super(DistributionalActionValueNetwork, self).__init__( obs_shapes=obs_shapes, ac_dim=ac_dim, mlp_layer_dims=mlp_layer_dims, value_bounds=value_bounds, goal_shapes=goal_shapes, encoder_kwargs=encoder_kwargs, ) def _get_output_shapes(self): """ Network outputs log probabilities for categorical distribution over discrete value grid. """ return OrderedDict(log_probs=(self.num_atoms,)) def forward_train(self, obs_dict, acts, goal_dict=None): """ Return full critic categorical distribution. Args: obs_dict (dict): batch of observations acts (torch.Tensor): batch of actions goal_dict (dict): if not None, batch of goal observations Returns: value_distribution (DiscreteValueDistribution instance) """ # add in actions inputs = dict(obs_dict) inputs["action"] = acts # network returns unnormalized log probabilities (logits) for each of the value atoms logits = MIMO_MLP.forward(self, obs=inputs, goal=goal_dict)["log_probs"] # turn these logits into a categorical distribution over the value atoms. # (unsqueeze to make sure atoms are compatible with batch operations) value_atoms = torch.Tensor(self._atoms).unsqueeze(0).to(logits.device) return DiscreteValueDistribution(values=value_atoms, logits=logits) def forward(self, obs_dict, acts, goal_dict=None): """ Return mean of critic categorical distribution. Useful for obtaining point estimates of critic values. Args: obs_dict (dict): batch of observations acts (torch.Tensor): batch of actions goal_dict (dict): if not None, batch of goal observations Returns: mean_value (torch.Tensor): expectation of value distribution """ vd = self.forward_train(obs_dict=obs_dict, acts=acts, goal_dict=goal_dict) return vd.mean() def _to_string(self): return "action_dim={}\nvalue_bounds={}\nnum_atoms={}".format(self.ac_dim, self.value_bounds, self.num_atoms)