""" Contains torch Modules that help deal with inputs consisting of multiple modalities. This is extremely common when networks must deal with one or more observation dictionaries, where each input dictionary can have observation keys of a certain modality and shape. As an example, an observation could consist of a flat "robot0_eef_pos" observation key, and a 3-channel RGB "agentview_image" observation key. """ import sys import numpy as np import textwrap from copy import deepcopy from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import torch.distributions as D from robomimic.utils.python_utils import extract_class_init_kwargs_from_dict import robomimic.utils.tensor_utils as TensorUtils import robomimic.utils.obs_utils as ObsUtils from robomimic.models.base_nets import Module, Sequential, MLP, RNN_Base, ResNet18Conv, SpatialSoftmax, \ FeatureAggregator from robomimic.models.obs_core import VisualCore, Randomizer from robomimic.models.transformers import PositionalEncoding, GPT_Backbone def obs_encoder_factory( obs_shapes, feature_activation=nn.ReLU, encoder_kwargs=None, ): """ Utility function to create an @ObservationEncoder from kwargs specified in config. Args: obs_shapes (OrderedDict): a dictionary that maps observation key to expected shapes for observations. feature_activation: non-linearity to apply after each obs net - defaults to ReLU. Pass None to apply no activation. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ enc = ObservationEncoder(feature_activation=feature_activation) for k, obs_shape in obs_shapes.items(): obs_modality = ObsUtils.OBS_KEYS_TO_MODALITIES[k] enc_kwargs = deepcopy(ObsUtils.DEFAULT_ENCODER_KWARGS[obs_modality]) if encoder_kwargs is None else \ deepcopy(encoder_kwargs[obs_modality]) for obs_module, cls_mapping in zip(("core", "obs_randomizer"), (ObsUtils.OBS_ENCODER_CORES, ObsUtils.OBS_RANDOMIZERS)): # Sanity check for kwargs in case they don't exist / are None if enc_kwargs.get(f"{obs_module}_kwargs", None) is None: enc_kwargs[f"{obs_module}_kwargs"] = {} # Add in input shape info enc_kwargs[f"{obs_module}_kwargs"]["input_shape"] = obs_shape # If group class is specified, then make sure corresponding kwargs only contain relevant kwargs if enc_kwargs[f"{obs_module}_class"] is not None: enc_kwargs[f"{obs_module}_kwargs"] = extract_class_init_kwargs_from_dict( cls=cls_mapping[enc_kwargs[f"{obs_module}_class"]], dic=enc_kwargs[f"{obs_module}_kwargs"], copy=False, ) # Add in input shape info randomizer = None if enc_kwargs["obs_randomizer_class"] is None else \ ObsUtils.OBS_RANDOMIZERS[enc_kwargs["obs_randomizer_class"]](**enc_kwargs["obs_randomizer_kwargs"]) enc.register_obs_key( name=k, shape=obs_shape, net_class=enc_kwargs["core_class"], net_kwargs=enc_kwargs["core_kwargs"], randomizer=randomizer, ) enc.make() return enc class ObservationEncoder(Module): """ Module that processes inputs by observation key and then concatenates the processed observation keys together. Each key is processed with an encoder head network. Call @register_obs_key to register observation keys with the encoder and then finally call @make to create the encoder networks. """ def __init__(self, feature_activation=nn.ReLU): """ Args: feature_activation: non-linearity to apply after each obs net - defaults to ReLU. Pass None to apply no activation. """ super(ObservationEncoder, self).__init__() self.obs_shapes = OrderedDict() self.obs_nets_classes = OrderedDict() self.obs_nets_kwargs = OrderedDict() self.obs_share_mods = OrderedDict() self.obs_nets = nn.ModuleDict() self.obs_randomizers = nn.ModuleDict() self.feature_activation = feature_activation self._locked = False def register_obs_key( self, name, shape, net_class=None, net_kwargs=None, net=None, randomizer=None, share_net_from=None, ): """ Register an observation key that this encoder should be responsible for. Args: name (str): modality name shape (int tuple): shape of modality net_class (str): name of class in base_nets.py that should be used to process this observation key before concatenation. Pass None to flatten and concatenate the observation key directly. net_kwargs (dict): arguments to pass to @net_class net (Module instance): if provided, use this Module to process the observation key instead of creating a different net randomizer (Randomizer instance): if provided, use this Module to augment observation keys coming in to the encoder, and possibly augment the processed output as well share_net_from (str): if provided, use the same instance of @net_class as another observation key. This observation key must already exist in this encoder. Warning: Note that this does not share the observation key randomizer """ assert not self._locked, "ObservationEncoder: @register_obs_key called after @make" assert name not in self.obs_shapes, "ObservationEncoder: modality {} already exists".format(name) if net is not None: assert isinstance(net, Module), "ObservationEncoder: @net must be instance of Module class" assert (net_class is None) and (net_kwargs is None) and (share_net_from is None), \ "ObservationEncoder: @net provided - ignore other net creation options" if share_net_from is not None: # share processing with another modality assert (net_class is None) and (net_kwargs is None) assert share_net_from in self.obs_shapes net_kwargs = deepcopy(net_kwargs) if net_kwargs is not None else {} if randomizer is not None: assert isinstance(randomizer, Randomizer) if net_kwargs is not None: # update input shape to visual core net_kwargs["input_shape"] = randomizer.output_shape_in(shape) self.obs_shapes[name] = shape self.obs_nets_classes[name] = net_class self.obs_nets_kwargs[name] = net_kwargs self.obs_nets[name] = net self.obs_randomizers[name] = randomizer self.obs_share_mods[name] = share_net_from def make(self): """ Creates the encoder networks and locks the encoder so that more modalities cannot be added. """ assert not self._locked, "ObservationEncoder: @make called more than once" self._create_layers() self._locked = True def _create_layers(self): """ Creates all networks and layers required by this encoder using the registered modalities. """ assert not self._locked, "ObservationEncoder: layers have already been created" for k in self.obs_shapes: if self.obs_nets_classes[k] is not None: # create net to process this modality self.obs_nets[k] = ObsUtils.OBS_ENCODER_CORES[self.obs_nets_classes[k]](**self.obs_nets_kwargs[k]) elif self.obs_share_mods[k] is not None: # make sure net is shared with another modality self.obs_nets[k] = self.obs_nets[self.obs_share_mods[k]] self.activation = None if self.feature_activation is not None: self.activation = self.feature_activation() def forward(self, obs_dict): """ Processes modalities according to the ordering in @self.obs_shapes. For each modality, it is processed with a randomizer (if present), an encoder network (if present), and again with the randomizer (if present), flattened, and then concatenated with the other processed modalities. Args: obs_dict (OrderedDict): dictionary that maps modalities to torch.Tensor batches that agree with @self.obs_shapes. All modalities in @self.obs_shapes must be present, but additional modalities can also be present. Returns: feats (torch.Tensor): flat features of shape [B, D] """ assert self._locked, "ObservationEncoder: @make has not been called yet" # ensure all modalities that the encoder handles are present assert set(self.obs_shapes.keys()).issubset(obs_dict), "ObservationEncoder: {} does not contain all modalities {}".format( list(obs_dict.keys()), list(self.obs_shapes.keys()) ) # process modalities by order given by @self.obs_shapes feats = [] for k in self.obs_shapes: x = obs_dict[k] # maybe process encoder input with randomizer if self.obs_randomizers[k] is not None: x = self.obs_randomizers[k].forward_in(x) # maybe process with obs net if self.obs_nets[k] is not None: x = self.obs_nets[k](x) if self.activation is not None: x = self.activation(x) # maybe process encoder output with randomizer if self.obs_randomizers[k] is not None: x = self.obs_randomizers[k].forward_out(x) # flatten to [B, D] x = TensorUtils.flatten(x, begin_axis=1) feats.append(x) # concatenate all features together return torch.cat(feats, dim=-1) def output_shape(self, input_shape=None): """ Compute the output shape of the encoder. """ feat_dim = 0 for k in self.obs_shapes: feat_shape = self.obs_shapes[k] if self.obs_randomizers[k] is not None: feat_shape = self.obs_randomizers[k].output_shape_in(feat_shape) if self.obs_nets[k] is not None: feat_shape = self.obs_nets[k].output_shape(feat_shape) if self.obs_randomizers[k] is not None: feat_shape = self.obs_randomizers[k].output_shape_out(feat_shape) feat_dim += int(np.prod(feat_shape)) return [feat_dim] def __repr__(self): """ Pretty print the encoder. """ header = '{}'.format(str(self.__class__.__name__)) msg = '' for k in self.obs_shapes: msg += textwrap.indent('\nKey(\n', ' ' * 4) indent = ' ' * 8 msg += textwrap.indent("name={}\nshape={}\n".format(k, self.obs_shapes[k]), indent) msg += textwrap.indent("modality={}\n".format(ObsUtils.OBS_KEYS_TO_MODALITIES[k]), indent) msg += textwrap.indent("randomizer={}\n".format(self.obs_randomizers[k]), indent) msg += textwrap.indent("net={}\n".format(self.obs_nets[k]), indent) msg += textwrap.indent("sharing_from={}\n".format(self.obs_share_mods[k]), indent) msg += textwrap.indent(")", ' ' * 4) msg += textwrap.indent("\noutput_shape={}".format(self.output_shape()), ' ' * 4) msg = header + '(' + msg + '\n)' return msg class ObservationDecoder(Module): """ Module that can generate observation outputs by modality. Inputs are assumed to be flat (usually outputs from some hidden layer). Each observation output is generated with a linear layer from these flat inputs. Subclass this module in order to implement more complex schemes for generating each modality. """ def __init__( self, decode_shapes, input_feat_dim, ): """ Args: decode_shapes (OrderedDict): a dictionary that maps observation key to expected shape. This is used to generate output modalities from the input features. input_feat_dim (int): flat input dimension size """ super(ObservationDecoder, self).__init__() # important: sort observation keys to ensure consistent ordering of modalities assert isinstance(decode_shapes, OrderedDict) self.obs_shapes = OrderedDict() for k in decode_shapes: self.obs_shapes[k] = decode_shapes[k] self.input_feat_dim = input_feat_dim self._create_layers() def _create_layers(self): """ Create a linear layer to predict each modality. """ self.nets = nn.ModuleDict() for k in self.obs_shapes: layer_out_dim = int(np.prod(self.obs_shapes[k])) self.nets[k] = nn.Linear(self.input_feat_dim, layer_out_dim) def output_shape(self, input_shape=None): """ Returns output shape for this module, which is a dictionary instead of a list since outputs are dictionaries. """ return { k : list(self.obs_shapes[k]) for k in self.obs_shapes } def forward(self, feats): """ Predict each modality from input features, and reshape to each modality's shape. """ output = {} for k in self.obs_shapes: out = self.nets[k](feats) output[k] = out.reshape(-1, *self.obs_shapes[k]) return output def __repr__(self): """Pretty print network.""" header = '{}'.format(str(self.__class__.__name__)) msg = '' for k in self.obs_shapes: msg += textwrap.indent('\nKey(\n', ' ' * 4) indent = ' ' * 8 msg += textwrap.indent("name={}\nshape={}\n".format(k, self.obs_shapes[k]), indent) msg += textwrap.indent("modality={}\n".format(ObsUtils.OBS_KEYS_TO_MODALITIES[k]), indent) msg += textwrap.indent("net=({})\n".format(self.nets[k]), indent) msg += textwrap.indent(")", ' ' * 4) msg = header + '(' + msg + '\n)' return msg class ObservationGroupEncoder(Module): """ This class allows networks to encode multiple observation dictionaries into a single flat, concatenated vector representation. It does this by assigning each observation dictionary (observation group) an @ObservationEncoder object. The class takes a dictionary of dictionaries, @observation_group_shapes. Each key corresponds to a observation group (e.g. 'obs', 'subgoal', 'goal') and each OrderedDict should be a map between modalities and expected input shapes (e.g. { 'image' : (3, 120, 160) }). """ def __init__( self, observation_group_shapes, feature_activation=nn.ReLU, encoder_kwargs=None, ): """ Args: observation_group_shapes (OrderedDict): a dictionary of dictionaries. Each key in this dictionary should specify an observation group, and the value should be an OrderedDict that maps modalities to expected shapes. feature_activation: non-linearity to apply after each obs net - defaults to ReLU. Pass None to apply no activation. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ super(ObservationGroupEncoder, self).__init__() # type checking assert isinstance(observation_group_shapes, OrderedDict) assert np.all([isinstance(observation_group_shapes[k], OrderedDict) for k in observation_group_shapes]) self.observation_group_shapes = observation_group_shapes # create an observation encoder per observation group self.nets = nn.ModuleDict() for obs_group in self.observation_group_shapes: self.nets[obs_group] = obs_encoder_factory( obs_shapes=self.observation_group_shapes[obs_group], feature_activation=feature_activation, encoder_kwargs=encoder_kwargs, ) def forward(self, **inputs): """ Process each set of inputs in its own observation group. Args: inputs (dict): dictionary that maps observation groups to observation dictionaries of torch.Tensor batches that agree with @self.observation_group_shapes. All observation groups in @self.observation_group_shapes must be present, but additional observation groups can also be present. Note that these are specified as kwargs for ease of use with networks that name each observation stream in their forward calls. Returns: outputs (torch.Tensor): flat outputs of shape [B, D] """ # ensure all observation groups we need are present assert set(self.observation_group_shapes.keys()).issubset(inputs), "{} does not contain all observation groups {}".format( list(inputs.keys()), list(self.observation_group_shapes.keys()) ) outputs = [] # Deterministic order since self.observation_group_shapes is OrderedDict for obs_group in self.observation_group_shapes: # pass through encoder outputs.append( self.nets[obs_group].forward(inputs[obs_group]) ) return torch.cat(outputs, dim=-1) def output_shape(self): """ Compute the output shape of this encoder. """ feat_dim = 0 for obs_group in self.observation_group_shapes: # get feature dimension of these keys feat_dim += self.nets[obs_group].output_shape()[0] return [feat_dim] def __repr__(self): """Pretty print network.""" header = '{}'.format(str(self.__class__.__name__)) msg = '' for k in self.observation_group_shapes: msg += '\n' indent = ' ' * 4 msg += textwrap.indent("group={}\n{}".format(k, self.nets[k]), indent) msg = header + '(' + msg + '\n)' return msg class MIMO_MLP(Module): """ Extension to MLP to accept multiple observation dictionaries as input and to output dictionaries of tensors. Inputs are specified as a dictionary of observation dictionaries, with each key corresponding to an observation group. This module utilizes @ObservationGroupEncoder to process the multiple input dictionaries and @ObservationDecoder to generate tensor dictionaries. The default behavior for encoding the inputs is to process visual inputs with a learned CNN and concatenating the flat encodings with the other flat inputs. The default behavior for generating outputs is to use a linear layer branch to produce each modality separately (including visual outputs). """ def __init__( self, input_obs_group_shapes, output_shapes, layer_dims, layer_func=nn.Linear, activation=nn.ReLU, encoder_kwargs=None, ): """ Args: input_obs_group_shapes (OrderedDict): a dictionary of dictionaries. Each key in this dictionary should specify an observation group, and the value should be an OrderedDict that maps modalities to expected shapes. output_shapes (OrderedDict): a dictionary that maps modality to expected shapes for outputs. layer_dims ([int]): sequence of integers for the MLP hidden layer sizes layer_func: mapping per MLP layer - defaults to Linear activation: non-linearity per MLP layer - defaults to ReLU encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ super(MIMO_MLP, self).__init__() assert isinstance(input_obs_group_shapes, OrderedDict) assert np.all([isinstance(input_obs_group_shapes[k], OrderedDict) for k in input_obs_group_shapes]) assert isinstance(output_shapes, OrderedDict) self.input_obs_group_shapes = input_obs_group_shapes self.output_shapes = output_shapes self.nets = nn.ModuleDict() # Encoder for all observation groups. self.nets["encoder"] = ObservationGroupEncoder( observation_group_shapes=input_obs_group_shapes, encoder_kwargs=encoder_kwargs, ) # flat encoder output dimension mlp_input_dim = self.nets["encoder"].output_shape()[0] # intermediate MLP layers self.nets["mlp"] = MLP( input_dim=mlp_input_dim, output_dim=layer_dims[-1], layer_dims=layer_dims[:-1], layer_func=layer_func, activation=activation, output_activation=activation, # make sure non-linearity is applied before decoder ) # decoder for output modalities self.nets["decoder"] = ObservationDecoder( decode_shapes=self.output_shapes, input_feat_dim=layer_dims[-1], ) def output_shape(self, input_shape=None): """ Returns output shape for this module, which is a dictionary instead of a list since outputs are dictionaries. """ return { k : list(self.output_shapes[k]) for k in self.output_shapes } def forward(self, **inputs): """ Process each set of inputs in its own observation group. Args: inputs (dict): a dictionary of dictionaries with one dictionary per observation group. Each observation group's dictionary should map modality to torch.Tensor batches. Should be consistent with @self.input_obs_group_shapes. Returns: outputs (dict): dictionary of output torch.Tensors, that corresponds to @self.output_shapes """ enc_outputs = self.nets["encoder"](**inputs) mlp_out = self.nets["mlp"](enc_outputs) return self.nets["decoder"](mlp_out) def _to_string(self): """ Subclasses should override this method to print out info about network / policy. """ return '' def __repr__(self): """Pretty print network.""" header = '{}'.format(str(self.__class__.__name__)) msg = '' indent = ' ' * 4 if self._to_string() != '': msg += textwrap.indent("\n" + self._to_string() + "\n", indent) msg += textwrap.indent("\nencoder={}".format(self.nets["encoder"]), indent) msg += textwrap.indent("\n\nmlp={}".format(self.nets["mlp"]), indent) msg += textwrap.indent("\n\ndecoder={}".format(self.nets["decoder"]), indent) msg = header + '(' + msg + '\n)' return msg class RNN_MIMO_MLP(Module): """ A wrapper class for a multi-step RNN and a per-step MLP and a decoder. Structure: [encoder -> rnn -> mlp -> decoder] All temporal inputs are processed by a shared @ObservationGroupEncoder, followed by an RNN, and then a per-step multi-output MLP. """ def __init__( self, input_obs_group_shapes, output_shapes, mlp_layer_dims, rnn_hidden_dim, rnn_num_layers, rnn_type="LSTM", # [LSTM, GRU] rnn_kwargs=None, mlp_activation=nn.ReLU, mlp_layer_func=nn.Linear, per_step=True, encoder_kwargs=None, ): """ Args: input_obs_group_shapes (OrderedDict): a dictionary of dictionaries. Each key in this dictionary should specify an observation group, and the value should be an OrderedDict that maps modalities to expected shapes. output_shapes (OrderedDict): a dictionary that maps modality to expected shapes for outputs. rnn_hidden_dim (int): RNN hidden dimension rnn_num_layers (int): number of RNN layers rnn_type (str): [LSTM, GRU] rnn_kwargs (dict): kwargs for the rnn model per_step (bool): if True, apply the MLP and observation decoder into @output_shapes at every step of the RNN. Otherwise, apply them to the final hidden state of the RNN. encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be nested dictionary containing relevant per-modality information for encoder networks. Should be of form: obs_modality1: dict feature_dimension: int core_class: str core_kwargs: dict ... ... obs_randomizer_class: str obs_randomizer_kwargs: dict ... ... obs_modality2: dict ... """ super(RNN_MIMO_MLP, self).__init__() assert isinstance(input_obs_group_shapes, OrderedDict) assert np.all([isinstance(input_obs_group_shapes[k], OrderedDict) for k in input_obs_group_shapes]) assert isinstance(output_shapes, OrderedDict) self.input_obs_group_shapes = input_obs_group_shapes self.output_shapes = output_shapes self.per_step = per_step self.nets = nn.ModuleDict() # Encoder for all observation groups. self.nets["encoder"] = ObservationGroupEncoder( observation_group_shapes=input_obs_group_shapes, encoder_kwargs=encoder_kwargs, ) # flat encoder output dimension rnn_input_dim = self.nets["encoder"].output_shape()[0] # bidirectional RNNs mean that the output of RNN will be twice the hidden dimension rnn_is_bidirectional = rnn_kwargs.get("bidirectional", False) num_directions = int(rnn_is_bidirectional) + 1 # 2 if bidirectional, 1 otherwise rnn_output_dim = num_directions * rnn_hidden_dim per_step_net = None self._has_mlp = (len(mlp_layer_dims) > 0) if self._has_mlp: self.nets["mlp"] = MLP( input_dim=rnn_output_dim, output_dim=mlp_layer_dims[-1], layer_dims=mlp_layer_dims[:-1], output_activation=mlp_activation, layer_func=mlp_layer_func ) self.nets["decoder"] = ObservationDecoder( decode_shapes=self.output_shapes, input_feat_dim=mlp_layer_dims[-1], ) if self.per_step: per_step_net = Sequential(self.nets["mlp"], self.nets["decoder"]) else: self.nets["decoder"] = ObservationDecoder( decode_shapes=self.output_shapes, input_feat_dim=rnn_output_dim, ) if self.per_step: per_step_net = self.nets["decoder"] # core network self.nets["rnn"] = RNN_Base( input_dim=rnn_input_dim, rnn_hidden_dim=rnn_hidden_dim, rnn_num_layers=rnn_num_layers, rnn_type=rnn_type, per_step_net=per_step_net, rnn_kwargs=rnn_kwargs ) def get_rnn_init_state(self, batch_size, device): """ Get a default RNN state (zeros) Args: batch_size (int): batch size dimension device: device the hidden state should be sent to. Returns: hidden_state (torch.Tensor or tuple): returns hidden state tensor or tuple of hidden state tensors depending on the RNN type """ return self.nets["rnn"].get_rnn_init_state(batch_size, device=device) def output_shape(self, input_shape): """ Returns output shape for this module, which is a dictionary instead of a list since outputs are dictionaries. Args: input_shape (dict): dictionary of dictionaries, where each top-level key corresponds to an observation group, and the low-level dictionaries specify the shape for each modality in an observation dictionary """ # infers temporal dimension from input shape obs_group = list(self.input_obs_group_shapes.keys())[0] mod = list(self.input_obs_group_shapes[obs_group].keys())[0] T = input_shape[obs_group][mod][0] TensorUtils.assert_size_at_dim(input_shape, size=T, dim=0, msg="RNN_MIMO_MLP: input_shape inconsistent in temporal dimension") # returns a dictionary instead of list since outputs are dictionaries return { k : [T] + list(self.output_shapes[k]) for k in self.output_shapes } def forward(self, rnn_init_state=None, return_state=False, **inputs): """ Args: inputs (dict): a dictionary of dictionaries with one dictionary per observation group. Each observation group's dictionary should map modality to torch.Tensor batches. Should be consistent with @self.input_obs_group_shapes. First two leading dimensions should be batch and time [B, T, ...] for each tensor. rnn_init_state: rnn hidden state, initialize to zero state if set to None return_state (bool): whether to return hidden state Returns: outputs (dict): dictionary of output torch.Tensors, that corresponds to @self.output_shapes. Leading dimensions will be batch and time [B, T, ...] for each tensor. rnn_state (torch.Tensor or tuple): return the new rnn state (if @return_state) """ for obs_group in self.input_obs_group_shapes: for k in self.input_obs_group_shapes[obs_group]: # first two dimensions should be [B, T] for inputs assert inputs[obs_group][k].ndim - 2 == len(self.input_obs_group_shapes[obs_group][k]) # use encoder to extract flat rnn inputs rnn_inputs = TensorUtils.time_distributed(inputs, self.nets["encoder"], inputs_as_kwargs=True) assert rnn_inputs.ndim == 3 # [B, T, D] if self.per_step: return self.nets["rnn"].forward(inputs=rnn_inputs, rnn_init_state=rnn_init_state, return_state=return_state) # apply MLP + decoder to last RNN output outputs = self.nets["rnn"].forward(inputs=rnn_inputs, rnn_init_state=rnn_init_state, return_state=return_state) if return_state: outputs, rnn_state = outputs assert outputs.ndim == 3 # [B, T, D] if self._has_mlp: outputs = self.nets["decoder"](self.nets["mlp"](outputs[:, -1])) else: outputs = self.nets["decoder"](outputs[:, -1]) if return_state: return outputs, rnn_state return outputs def forward_step(self, rnn_state, **inputs): """ Unroll network over a single timestep. Args: inputs (dict): expects same modalities as @self.input_shapes, with additional batch dimension (but NOT time), since this is a single time step. rnn_state (torch.Tensor): rnn hidden state Returns: outputs (dict): dictionary of output torch.Tensors, that corresponds to @self.output_shapes. Does not contain time dimension. rnn_state: return the new rnn state """ # ensure that the only extra dimension is batch dim, not temporal dim assert np.all([inputs[k].ndim - 1 == len(self.input_shapes[k]) for k in self.input_shapes]) inputs = TensorUtils.to_sequence(inputs) outputs, rnn_state = self.forward( inputs, rnn_init_state=rnn_state, return_state=True, ) if self.per_step: # if outputs are not per-step, the time dimension is already reduced outputs = outputs[:, 0] return outputs, rnn_state def _to_string(self): """ Subclasses should override this method to print out info about network / policy. """ return '' def __repr__(self): """Pretty print network.""" header = '{}'.format(str(self.__class__.__name__)) msg = '' indent = ' ' * 4 msg += textwrap.indent("\n" + self._to_string(), indent) msg += textwrap.indent("\n\nencoder={}".format(self.nets["encoder"]), indent) msg += textwrap.indent("\n\nrnn={}".format(self.nets["rnn"]), indent) msg = header + '(' + msg + '\n)' return msg class MIMO_Transformer(Module): """ Extension to Transformer (based on GPT architecture) to accept multiple observation dictionaries as input and to output dictionaries of tensors. Inputs are specified as a dictionary of observation dictionaries, with each key corresponding to an observation group. This module utilizes @ObservationGroupEncoder to process the multiple input dictionaries and @ObservationDecoder to generate tensor dictionaries. The default behavior for encoding the inputs is to process visual inputs with a learned CNN and concatenating the flat encodings with the other flat inputs. The default behavior for generating outputs is to use a linear layer branch to produce each modality separately (including visual outputs). """ def __init__( self, input_obs_group_shapes, output_shapes, transformer_embed_dim, transformer_num_layers, transformer_num_heads, transformer_context_length, transformer_emb_dropout=0.1, transformer_attn_dropout=0.1, transformer_block_output_dropout=0.1, transformer_sinusoidal_embedding=False, transformer_activation="gelu", transformer_nn_parameter_for_timesteps=False, encoder_kwargs=None, ): """ Args: input_obs_group_shapes (OrderedDict): a dictionary of dictionaries. Each key in this dictionary should specify an observation group, and the value should be an OrderedDict that maps modalities to expected shapes. output_shapes (OrderedDict): a dictionary that maps modality to expected shapes for outputs. transformer_embed_dim (int): dimension for embeddings used by transformer transformer_num_layers (int): number of transformer blocks to stack transformer_num_heads (int): number of attention heads for each transformer block - must divide @transformer_embed_dim evenly. Self-attention is computed over this many partitions of the embedding dimension separately. transformer_context_length (int): expected length of input sequences transformer_activation: non-linearity for input and output layers used in transformer transformer_emb_dropout (float): dropout probability for embedding inputs in transformer transformer_attn_dropout (float): dropout probability for attention outputs for each transformer block transformer_block_output_dropout (float): dropout probability for final outputs for each transformer block encoder_kwargs (dict): observation encoder config """ super(MIMO_Transformer, self).__init__() assert isinstance(input_obs_group_shapes, OrderedDict) assert np.all([isinstance(input_obs_group_shapes[k], OrderedDict) for k in input_obs_group_shapes]) assert isinstance(output_shapes, OrderedDict) self.input_obs_group_shapes = input_obs_group_shapes self.output_shapes = output_shapes self.nets = nn.ModuleDict() self.params = nn.ParameterDict() # Encoder for all observation groups. self.nets["encoder"] = ObservationGroupEncoder( observation_group_shapes=input_obs_group_shapes, encoder_kwargs=encoder_kwargs, feature_activation=None, ) # flat encoder output dimension transformer_input_dim = self.nets["encoder"].output_shape()[0] self.nets["embed_encoder"] = nn.Linear( transformer_input_dim, transformer_embed_dim ) max_timestep = transformer_context_length if transformer_sinusoidal_embedding: self.nets["embed_timestep"] = PositionalEncoding(transformer_embed_dim) elif transformer_nn_parameter_for_timesteps: assert ( not transformer_sinusoidal_embedding ), "nn.Parameter only works with learned embeddings" self.params["embed_timestep"] = nn.Parameter( torch.zeros(1, max_timestep, transformer_embed_dim) ) else: self.nets["embed_timestep"] = nn.Embedding(max_timestep, transformer_embed_dim) # layer norm for embeddings self.nets["embed_ln"] = nn.LayerNorm(transformer_embed_dim) # dropout for input embeddings self.nets["embed_drop"] = nn.Dropout(transformer_emb_dropout) # GPT transformer self.nets["transformer"] = GPT_Backbone( embed_dim=transformer_embed_dim, num_layers=transformer_num_layers, num_heads=transformer_num_heads, context_length=transformer_context_length, attn_dropout=transformer_attn_dropout, block_output_dropout=transformer_block_output_dropout, activation=transformer_activation, ) # decoder for output modalities self.nets["decoder"] = ObservationDecoder( decode_shapes=self.output_shapes, input_feat_dim=transformer_embed_dim, ) self.transformer_context_length = transformer_context_length self.transformer_embed_dim = transformer_embed_dim self.transformer_sinusoidal_embedding = transformer_sinusoidal_embedding self.transformer_nn_parameter_for_timesteps = transformer_nn_parameter_for_timesteps def output_shape(self, input_shape=None): """ Returns output shape for this module, which is a dictionary instead of a list since outputs are dictionaries. """ return { k : list(self.output_shapes[k]) for k in self.output_shapes } def embed_timesteps(self, embeddings): """ Computes timestep-based embeddings (aka positional embeddings) to add to embeddings. Args: embeddings (torch.Tensor): embeddings prior to positional embeddings are computed Returns: time_embeddings (torch.Tensor): positional embeddings to add to embeddings """ timesteps = ( torch.arange( 0, embeddings.shape[1], dtype=embeddings.dtype, device=embeddings.device, ) .unsqueeze(0) .repeat(embeddings.shape[0], 1) ) assert (timesteps >= 0.0).all(), "timesteps must be positive!" if self.transformer_sinusoidal_embedding: assert torch.is_floating_point(timesteps), timesteps.dtype else: timesteps = timesteps.long() if self.transformer_nn_parameter_for_timesteps: time_embeddings = self.params["embed_timestep"] else: time_embeddings = self.nets["embed_timestep"]( timesteps ) # these are NOT fed into transformer, only added to the inputs. # compute how many modalities were combined into embeddings, replicate time embeddings that many times num_replicates = embeddings.shape[-1] // self.transformer_embed_dim time_embeddings = torch.cat([time_embeddings for _ in range(num_replicates)], -1) assert ( embeddings.shape == time_embeddings.shape ), f"{embeddings.shape}, {time_embeddings.shape}" return time_embeddings def input_embedding( self, inputs, ): """ Process encoded observations into embeddings to pass to transformer, Adds timestep-based embeddings (aka positional embeddings) to inputs. Args: inputs (torch.Tensor): outputs from observation encoder Returns: embeddings (torch.Tensor): input embeddings to pass to transformer backbone. """ embeddings = self.nets["embed_encoder"](inputs) time_embeddings = self.embed_timesteps(embeddings) embeddings = embeddings + time_embeddings embeddings = self.nets["embed_ln"](embeddings) embeddings = self.nets["embed_drop"](embeddings) return embeddings def forward(self, **inputs): """ Process each set of inputs in its own observation group. Args: inputs (dict): a dictionary of dictionaries with one dictionary per observation group. Each observation group's dictionary should map modality to torch.Tensor batches. Should be consistent with @self.input_obs_group_shapes. First two leading dimensions should be batch and time [B, T, ...] for each tensor. Returns: outputs (dict): dictionary of output torch.Tensors, that corresponds to @self.output_shapes. Leading dimensions will be batch and time [B, T, ...] for each tensor. """ for obs_group in self.input_obs_group_shapes: for k in self.input_obs_group_shapes[obs_group]: # first two dimensions should be [B, T] for inputs if inputs[obs_group][k] is None: continue assert inputs[obs_group][k].ndim - 2 == len(self.input_obs_group_shapes[obs_group][k]) inputs = inputs.copy() transformer_encoder_outputs = None transformer_inputs = TensorUtils.time_distributed( inputs, self.nets["encoder"], inputs_as_kwargs=True ) assert transformer_inputs.ndim == 3 # [B, T, D] if transformer_encoder_outputs is None: transformer_embeddings = self.input_embedding(transformer_inputs) # pass encoded sequences through transformer transformer_encoder_outputs = self.nets["transformer"].forward(transformer_embeddings) transformer_outputs = transformer_encoder_outputs # apply decoder to each timestep of sequence to get a dictionary of outputs transformer_outputs = TensorUtils.time_distributed( transformer_outputs, self.nets["decoder"] ) transformer_outputs["transformer_encoder_outputs"] = transformer_encoder_outputs return transformer_outputs def _to_string(self): """ Subclasses should override this method to print out info about network / policy. """ return '' def __repr__(self): """Pretty print network.""" header = '{}'.format(str(self.__class__.__name__)) msg = '' indent = ' ' * 4 if self._to_string() != '': msg += textwrap.indent("\n" + self._to_string() + "\n", indent) msg += textwrap.indent("\nencoder={}".format(self.nets["encoder"]), indent) msg += textwrap.indent("\n\ntransformer={}".format(self.nets["transformer"]), indent) msg += textwrap.indent("\n\ndecoder={}".format(self.nets["decoder"]), indent) msg = header + '(' + msg + '\n)' return msg