""" Implementation of Behavioral Cloning (BC). """ from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import torch.distributions as D import robomimic.models.base_nets as BaseNets import robomimic.models.obs_nets as ObsNets import robomimic.models.policy_nets as PolicyNets import robomimic.models.vae_nets as VAENets import robomimic.utils.loss_utils as LossUtils import robomimic.utils.tensor_utils as TensorUtils import robomimic.utils.torch_utils as TorchUtils import robomimic.utils.obs_utils as ObsUtils from robomimic.algo import register_algo_factory_func, PolicyAlgo @register_algo_factory_func("bc") def algo_config_to_class(algo_config): """ Maps algo config to the BC algo class to instantiate, along with additional algo kwargs. Args: algo_config (Config instance): algo config Returns: algo_class: subclass of Algo algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm """ # note: we need the check below because some configs import BCConfig and exclude # some of these options gaussian_enabled = ("gaussian" in algo_config and algo_config.gaussian.enabled) gmm_enabled = ("gmm" in algo_config and algo_config.gmm.enabled) vae_enabled = ("vae" in algo_config and algo_config.vae.enabled) rnn_enabled = algo_config.rnn.enabled # support legacy configs that do not have "transformer" item transformer_enabled = ("transformer" in algo_config) and algo_config.transformer.enabled if gaussian_enabled: if rnn_enabled: raise NotImplementedError elif transformer_enabled: raise NotImplementedError else: algo_class, algo_kwargs = BC_Gaussian, {} elif gmm_enabled: if rnn_enabled: algo_class, algo_kwargs = BC_RNN_GMM, {} elif transformer_enabled: algo_class, algo_kwargs = BC_Transformer_GMM, {} else: algo_class, algo_kwargs = BC_GMM, {} elif vae_enabled: if rnn_enabled: raise NotImplementedError elif transformer_enabled: raise NotImplementedError else: algo_class, algo_kwargs = BC_VAE, {} else: if rnn_enabled: algo_class, algo_kwargs = BC_RNN, {} elif transformer_enabled: algo_class, algo_kwargs = BC_Transformer, {} else: algo_class, algo_kwargs = BC, {} return algo_class, algo_kwargs class BC(PolicyAlgo): """ Normal BC training. """ def _create_networks(self): """ Creates networks and places them into @self.nets. """ self.nets = nn.ModuleDict() self.nets["policy"] = PolicyNets.ActorNetwork( obs_shapes=self.obs_shapes, goal_shapes=self.goal_shapes, ac_dim=self.ac_dim, mlp_layer_dims=self.algo_config.actor_layer_dims, encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder), ) self.nets = self.nets.float().to(self.device) def process_batch_for_training(self, batch): """ Processes input batch from a data loader to filter out relevant information and prepare the batch for training. Args: batch (dict): dictionary with torch.Tensors sampled from a data loader Returns: input_batch (dict): processed and filtered batch that will be used for training """ input_batch = dict() input_batch["obs"] = {k: batch["obs"][k][:, 0, :] for k in batch["obs"]} input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present input_batch["actions"] = batch["actions"][:, 0, :] # we move to device first before float conversion because image observation modalities will be uint8 - # this minimizes the amount of data transferred to GPU return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device)) def train_on_batch(self, batch, epoch, validate=False): """ Training on a single batch of data. Args: batch (dict): dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training epoch (int): epoch number - required by some Algos that need to perform staged training and early stopping validate (bool): if True, don't perform any learning updates. Returns: info (dict): dictionary of relevant inputs, outputs, and losses that might be relevant for logging """ with TorchUtils.maybe_no_grad(no_grad=validate): info = super(BC, self).train_on_batch(batch, epoch, validate=validate) predictions = self._forward_training(batch) losses = self._compute_losses(predictions, batch) info["predictions"] = TensorUtils.detach(predictions) info["losses"] = TensorUtils.detach(losses) if not validate: step_info = self._train_step(losses) info.update(step_info) return info def _forward_training(self, batch): """ Internal helper function for BC algo class. Compute forward pass and return network outputs in @predictions dict. Args: batch (dict): dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training Returns: predictions (dict): dictionary containing network outputs """ predictions = OrderedDict() actions = self.nets["policy"](obs_dict=batch["obs"], goal_dict=batch["goal_obs"]) predictions["actions"] = actions return predictions def _compute_losses(self, predictions, batch): """ Internal helper function for BC algo class. Compute losses based on network outputs in @predictions dict, using reference labels in @batch. Args: predictions (dict): dictionary containing network outputs, from @_forward_training batch (dict): dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training Returns: losses (dict): dictionary of losses computed over the batch """ losses = OrderedDict() a_target = batch["actions"] actions = predictions["actions"] losses["l2_loss"] = nn.MSELoss()(actions, a_target) losses["l1_loss"] = nn.SmoothL1Loss()(actions, a_target) # cosine direction loss on eef delta position losses["cos_loss"] = LossUtils.cosine_loss(actions[..., :3], a_target[..., :3]) action_losses = [ self.algo_config.loss.l2_weight * losses["l2_loss"], self.algo_config.loss.l1_weight * losses["l1_loss"], self.algo_config.loss.cos_weight * losses["cos_loss"], ] action_loss = sum(action_losses) losses["action_loss"] = action_loss return losses def _train_step(self, losses): """ Internal helper function for BC algo class. Perform backpropagation on the loss tensors in @losses to update networks. Args: losses (dict): dictionary of losses computed over the batch, from @_compute_losses """ # gradient step info = OrderedDict() policy_grad_norms = TorchUtils.backprop_for_loss( net=self.nets["policy"], optim=self.optimizers["policy"], loss=losses["action_loss"], ) info["policy_grad_norms"] = policy_grad_norms return info def log_info(self, info): """ Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging. Args: info (dict): dictionary of info Returns: loss_log (dict): name -> summary statistic """ log = super(BC, self).log_info(info) log["Loss"] = info["losses"]["action_loss"].item() if "l2_loss" in info["losses"]: log["L2_Loss"] = info["losses"]["l2_loss"].item() if "l1_loss" in info["losses"]: log["L1_Loss"] = info["losses"]["l1_loss"].item() if "cos_loss" in info["losses"]: log["Cosine_Loss"] = info["losses"]["cos_loss"].item() if "policy_grad_norms" in info: log["Policy_Grad_Norms"] = info["policy_grad_norms"] return log def get_action(self, obs_dict, goal_dict=None): """ Get policy action outputs. Args: obs_dict (dict): current observation goal_dict (dict): (optional) goal Returns: action (torch.Tensor): action tensor """ assert not self.nets.training return self.nets["policy"](obs_dict, goal_dict=goal_dict) class BC_Gaussian(BC): """ BC training with a Gaussian policy. """ def _create_networks(self): """ Creates networks and places them into @self.nets. """ assert self.algo_config.gaussian.enabled self.nets = nn.ModuleDict() self.nets["policy"] = PolicyNets.GaussianActorNetwork( obs_shapes=self.obs_shapes, goal_shapes=self.goal_shapes, ac_dim=self.ac_dim, mlp_layer_dims=self.algo_config.actor_layer_dims, fixed_std=self.algo_config.gaussian.fixed_std, init_std=self.algo_config.gaussian.init_std, std_limits=(self.algo_config.gaussian.min_std, 7.5), std_activation=self.algo_config.gaussian.std_activation, low_noise_eval=self.algo_config.gaussian.low_noise_eval, encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder), ) self.nets = self.nets.float().to(self.device) def _forward_training(self, batch): """ Internal helper function for BC algo class. Compute forward pass and return network outputs in @predictions dict. Args: batch (dict): dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training Returns: predictions (dict): dictionary containing network outputs """ dists = self.nets["policy"].forward_train( obs_dict=batch["obs"], goal_dict=batch["goal_obs"], ) # make sure that this is a batch of multivariate action distributions, so that # the log probability computation will be correct assert len(dists.batch_shape) == 1 log_probs = dists.log_prob(batch["actions"]) predictions = OrderedDict( log_probs=log_probs, ) return predictions def _compute_losses(self, predictions, batch): """ Internal helper function for BC algo class. Compute losses based on network outputs in @predictions dict, using reference labels in @batch. Args: predictions (dict): dictionary containing network outputs, from @_forward_training batch (dict): dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training Returns: losses (dict): dictionary of losses computed over the batch """ # loss is just negative log-likelihood of action targets action_loss = -predictions["log_probs"].mean() return OrderedDict( log_probs=-action_loss, action_loss=action_loss, ) def log_info(self, info): """ Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging. Args: info (dict): dictionary of info Returns: loss_log (dict): name -> summary statistic """ log = PolicyAlgo.log_info(self, info) log["Loss"] = info["losses"]["action_loss"].item() log["Log_Likelihood"] = info["losses"]["log_probs"].item() if "policy_grad_norms" in info: log["Policy_Grad_Norms"] = info["policy_grad_norms"] return log class BC_GMM(BC_Gaussian): """ BC training with a Gaussian Mixture Model policy. """ def _create_networks(self): """ Creates networks and places them into @self.nets. """ assert self.algo_config.gmm.enabled self.nets = nn.ModuleDict() self.nets["policy"] = PolicyNets.GMMActorNetwork( obs_shapes=self.obs_shapes, goal_shapes=self.goal_shapes, ac_dim=self.ac_dim, mlp_layer_dims=self.algo_config.actor_layer_dims, num_modes=self.algo_config.gmm.num_modes, min_std=self.algo_config.gmm.min_std, std_activation=self.algo_config.gmm.std_activation, low_noise_eval=self.algo_config.gmm.low_noise_eval, encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder), ) self.nets = self.nets.float().to(self.device) class BC_VAE(BC): """ BC training with a VAE policy. """ def _create_networks(self): """ Creates networks and places them into @self.nets. """ self.nets = nn.ModuleDict() self.nets["policy"] = PolicyNets.VAEActor( obs_shapes=self.obs_shapes, goal_shapes=self.goal_shapes, ac_dim=self.ac_dim, device=self.device, encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder), **VAENets.vae_args_from_config(self.algo_config.vae), ) self.nets = self.nets.float().to(self.device) def train_on_batch(self, batch, epoch, validate=False): """ Update from superclass to set categorical temperature, for categorical VAEs. """ if self.algo_config.vae.prior.use_categorical: temperature = self.algo_config.vae.prior.categorical_init_temp - epoch * self.algo_config.vae.prior.categorical_temp_anneal_step temperature = max(temperature, self.algo_config.vae.prior.categorical_min_temp) self.nets["policy"].set_gumbel_temperature(temperature) return super(BC_VAE, self).train_on_batch(batch, epoch, validate=validate) def _forward_training(self, batch): """ Internal helper function for BC algo class. Compute forward pass and return network outputs in @predictions dict. Args: batch (dict): dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training Returns: predictions (dict): dictionary containing network outputs """ vae_inputs = dict( actions=batch["actions"], obs_dict=batch["obs"], goal_dict=batch["goal_obs"], freeze_encoder=batch.get("freeze_encoder", False), ) vae_outputs = self.nets["policy"].forward_train(**vae_inputs) predictions = OrderedDict( actions=vae_outputs["decoder_outputs"], kl_loss=vae_outputs["kl_loss"], reconstruction_loss=vae_outputs["reconstruction_loss"], encoder_z=vae_outputs["encoder_z"], ) if not self.algo_config.vae.prior.use_categorical: with torch.no_grad(): encoder_variance = torch.exp(vae_outputs["encoder_params"]["logvar"]) predictions["encoder_variance"] = encoder_variance return predictions def _compute_losses(self, predictions, batch): """ Internal helper function for BC algo class. Compute losses based on network outputs in @predictions dict, using reference labels in @batch. Args: predictions (dict): dictionary containing network outputs, from @_forward_training batch (dict): dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training Returns: losses (dict): dictionary of losses computed over the batch """ # total loss is sum of reconstruction and KL, weighted by beta kl_loss = predictions["kl_loss"] recons_loss = predictions["reconstruction_loss"] action_loss = recons_loss + self.algo_config.vae.kl_weight * kl_loss return OrderedDict( recons_loss=recons_loss, kl_loss=kl_loss, action_loss=action_loss, ) def log_info(self, info): """ Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging. Args: info (dict): dictionary of info Returns: loss_log (dict): name -> summary statistic """ log = PolicyAlgo.log_info(self, info) log["Loss"] = info["losses"]["action_loss"].item() log["KL_Loss"] = info["losses"]["kl_loss"].item() log["Reconstruction_Loss"] = info["losses"]["recons_loss"].item() if self.algo_config.vae.prior.use_categorical: log["Gumbel_Temperature"] = self.nets["policy"].get_gumbel_temperature() else: log["Encoder_Variance"] = info["predictions"]["encoder_variance"].mean().item() if "policy_grad_norms" in info: log["Policy_Grad_Norms"] = info["policy_grad_norms"] return log class BC_RNN(BC): """ BC training with an RNN policy. """ def _create_networks(self): """ Creates networks and places them into @self.nets. """ self.nets = nn.ModuleDict() self.nets["policy"] = PolicyNets.RNNActorNetwork( obs_shapes=self.obs_shapes, goal_shapes=self.goal_shapes, ac_dim=self.ac_dim, mlp_layer_dims=self.algo_config.actor_layer_dims, encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder), **BaseNets.rnn_args_from_config(self.algo_config.rnn), ) self._rnn_hidden_state = None self._rnn_horizon = self.algo_config.rnn.horizon self._rnn_counter = 0 self._rnn_is_open_loop = self.algo_config.rnn.get("open_loop", False) self.nets = self.nets.float().to(self.device) def process_batch_for_training(self, batch): """ Processes input batch from a data loader to filter out relevant information and prepare the batch for training. Args: batch (dict): dictionary with torch.Tensors sampled from a data loader Returns: input_batch (dict): processed and filtered batch that will be used for training """ input_batch = dict() input_batch["obs"] = batch["obs"] input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present input_batch["actions"] = batch["actions"] if self._rnn_is_open_loop: # replace the observation sequence with one that only consists of the first observation. # This way, all actions are predicted "open-loop" after the first observation, based # on the rnn hidden state. n_steps = batch["actions"].shape[1] obs_seq_start = TensorUtils.index_at_time(batch["obs"], ind=0) input_batch["obs"] = TensorUtils.unsqueeze_expand_at(obs_seq_start, size=n_steps, dim=1) # we move to device first before float conversion because image observation modalities will be uint8 - # this minimizes the amount of data transferred to GPU return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device)) def get_action(self, obs_dict, goal_dict=None): """ Get policy action outputs. Args: obs_dict (dict): current observation goal_dict (dict): (optional) goal Returns: action (torch.Tensor): action tensor """ assert not self.nets.training if self._rnn_hidden_state is None or self._rnn_counter % self._rnn_horizon == 0: batch_size = list(obs_dict.values())[0].shape[0] self._rnn_hidden_state = self.nets["policy"].get_rnn_init_state(batch_size=batch_size, device=self.device) if self._rnn_is_open_loop: # remember the initial observation, and use it instead of the current observation # for open-loop action sequence prediction self._open_loop_obs = TensorUtils.clone(TensorUtils.detach(obs_dict)) obs_to_use = obs_dict if self._rnn_is_open_loop: # replace current obs with last recorded obs obs_to_use = self._open_loop_obs self._rnn_counter += 1 action, self._rnn_hidden_state = self.nets["policy"].forward_step( obs_to_use, goal_dict=goal_dict, rnn_state=self._rnn_hidden_state) return action def reset(self): """ Reset algo state to prepare for environment rollouts. """ self._rnn_hidden_state = None self._rnn_counter = 0 class BC_RNN_GMM(BC_RNN): """ BC training with an RNN GMM policy. """ def _create_networks(self): """ Creates networks and places them into @self.nets. """ assert self.algo_config.gmm.enabled assert self.algo_config.rnn.enabled self.nets = nn.ModuleDict() self.nets["policy"] = PolicyNets.RNNGMMActorNetwork( obs_shapes=self.obs_shapes, goal_shapes=self.goal_shapes, ac_dim=self.ac_dim, mlp_layer_dims=self.algo_config.actor_layer_dims, num_modes=self.algo_config.gmm.num_modes, min_std=self.algo_config.gmm.min_std, std_activation=self.algo_config.gmm.std_activation, low_noise_eval=self.algo_config.gmm.low_noise_eval, encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder), **BaseNets.rnn_args_from_config(self.algo_config.rnn), ) self._rnn_hidden_state = None self._rnn_horizon = self.algo_config.rnn.horizon self._rnn_counter = 0 self._rnn_is_open_loop = self.algo_config.rnn.get("open_loop", False) self.nets = self.nets.float().to(self.device) def _forward_training(self, batch): """ Internal helper function for BC algo class. Compute forward pass and return network outputs in @predictions dict. Args: batch (dict): dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training Returns: predictions (dict): dictionary containing network outputs """ dists = self.nets["policy"].forward_train( obs_dict=batch["obs"], goal_dict=batch["goal_obs"], ) # make sure that this is a batch of multivariate action distributions, so that # the log probability computation will be correct assert len(dists.batch_shape) == 2 # [B, T] log_probs = dists.log_prob(batch["actions"]) predictions = OrderedDict( log_probs=log_probs, ) return predictions def _compute_losses(self, predictions, batch): """ Internal helper function for BC algo class. Compute losses based on network outputs in @predictions dict, using reference labels in @batch. Args: predictions (dict): dictionary containing network outputs, from @_forward_training batch (dict): dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training Returns: losses (dict): dictionary of losses computed over the batch """ # loss is just negative log-likelihood of action targets action_loss = -predictions["log_probs"].mean() return OrderedDict( log_probs=-action_loss, action_loss=action_loss, ) def log_info(self, info): """ Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging. Args: info (dict): dictionary of info Returns: loss_log (dict): name -> summary statistic """ log = PolicyAlgo.log_info(self, info) log["Loss"] = info["losses"]["action_loss"].item() log["Log_Likelihood"] = info["losses"]["log_probs"].item() if "policy_grad_norms" in info: log["Policy_Grad_Norms"] = info["policy_grad_norms"] return log class BC_Transformer(BC): """ BC training with a Transformer policy. """ def _create_networks(self): """ Creates networks and places them into @self.nets. """ assert self.algo_config.transformer.enabled self.nets = nn.ModuleDict() self.nets["policy"] = PolicyNets.TransformerActorNetwork( obs_shapes=self.obs_shapes, goal_shapes=self.goal_shapes, ac_dim=self.ac_dim, encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder), **BaseNets.transformer_args_from_config(self.algo_config.transformer), ) self._set_params_from_config() self.nets = self.nets.float().to(self.device) def _set_params_from_config(self): """ Read specific config variables we need for training / eval. Called by @_create_networks method """ self.context_length = self.algo_config.transformer.context_length self.supervise_all_steps = self.algo_config.transformer.supervise_all_steps def process_batch_for_training(self, batch): """ Processes input batch from a data loader to filter out relevant information and prepare the batch for training. Args: batch (dict): dictionary with torch.Tensors sampled from a data loader Returns: input_batch (dict): processed and filtered batch that will be used for training """ input_batch = dict() h = self.context_length input_batch["obs"] = {k: batch["obs"][k][:, :h, :] for k in batch["obs"]} input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present if self.supervise_all_steps: # supervision on entire sequence (instead of just current timestep) input_batch["actions"] = batch["actions"][:, :h, :] else: # just use current timestep input_batch["actions"] = batch["actions"][:, h-1, :] input_batch = TensorUtils.to_device(TensorUtils.to_float(input_batch), self.device) return input_batch def _forward_training(self, batch, epoch=None): """ Internal helper function for BC_Transformer algo class. Compute forward pass and return network outputs in @predictions dict. Args: batch (dict): dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training Returns: predictions (dict): dictionary containing network outputs """ # ensure that transformer context length is consistent with temporal dimension of observations TensorUtils.assert_size_at_dim( batch["obs"], size=(self.context_length), dim=1, msg="Error: expect temporal dimension of obs batch to match transformer context length {}".format(self.context_length), ) predictions = OrderedDict() predictions["actions"] = self.nets["policy"](obs_dict=batch["obs"], actions=None, goal_dict=batch["goal_obs"]) if not self.supervise_all_steps: # only supervise final timestep predictions["actions"] = predictions["actions"][:, -1, :] return predictions def get_action(self, obs_dict, goal_dict=None): """ Get policy action outputs. Args: obs_dict (dict): current observation goal_dict (dict): (optional) goal Returns: action (torch.Tensor): action tensor """ assert not self.nets.training return self.nets["policy"](obs_dict, actions=None, goal_dict=goal_dict)[:, -1, :] class BC_Transformer_GMM(BC_Transformer): """ BC training with a Transformer GMM policy. """ def _create_networks(self): """ Creates networks and places them into @self.nets. """ assert self.algo_config.gmm.enabled assert self.algo_config.transformer.enabled self.nets = nn.ModuleDict() self.nets["policy"] = PolicyNets.TransformerGMMActorNetwork( obs_shapes=self.obs_shapes, goal_shapes=self.goal_shapes, ac_dim=self.ac_dim, num_modes=self.algo_config.gmm.num_modes, min_std=self.algo_config.gmm.min_std, std_activation=self.algo_config.gmm.std_activation, low_noise_eval=self.algo_config.gmm.low_noise_eval, encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder), **BaseNets.transformer_args_from_config(self.algo_config.transformer), ) self._set_params_from_config() self.nets = self.nets.float().to(self.device) def _forward_training(self, batch, epoch=None): """ Modify from super class to support GMM training. """ # ensure that transformer context length is consistent with temporal dimension of observations TensorUtils.assert_size_at_dim( batch["obs"], size=(self.context_length), dim=1, msg="Error: expect temporal dimension of obs batch to match transformer context length {}".format(self.context_length), ) dists = self.nets["policy"].forward_train( obs_dict=batch["obs"], actions=None, goal_dict=batch["goal_obs"], low_noise_eval=False, ) # make sure that this is a batch of multivariate action distributions, so that # the log probability computation will be correct assert len(dists.batch_shape) == 2 # [B, T] if not self.supervise_all_steps: # only use final timestep prediction by making a new distribution with only final timestep. # This essentially does `dists = dists[:, -1]` component_distribution = D.Normal( loc=dists.component_distribution.base_dist.loc[:, -1], scale=dists.component_distribution.base_dist.scale[:, -1], ) component_distribution = D.Independent(component_distribution, 1) mixture_distribution = D.Categorical(logits=dists.mixture_distribution.logits[:, -1]) dists = D.MixtureSameFamily( mixture_distribution=mixture_distribution, component_distribution=component_distribution, ) log_probs = dists.log_prob(batch["actions"]) predictions = OrderedDict( log_probs=log_probs, ) return predictions def _compute_losses(self, predictions, batch): """ Internal helper function for BC_Transformer_GMM algo class. Compute losses based on network outputs in @predictions dict, using reference labels in @batch. Args: predictions (dict): dictionary containing network outputs, from @_forward_training batch (dict): dictionary with torch.Tensors sampled from a data loader and filtered by @process_batch_for_training Returns: losses (dict): dictionary of losses computed over the batch """ # loss is just negative log-likelihood of action targets action_loss = -predictions["log_probs"].mean() return OrderedDict( log_probs=-action_loss, action_loss=action_loss, ) def log_info(self, info): """ Process info dictionary from @train_on_batch to summarize information to pass to tensorboard for logging. Args: info (dict): dictionary of info Returns: loss_log (dict): name -> summary statistic """ log = PolicyAlgo.log_info(self, info) log["Loss"] = info["losses"]["action_loss"].item() log["Log_Likelihood"] = info["losses"]["log_probs"].item() if "policy_grad_norms" in info: log["Policy_Grad_Norms"] = info["policy_grad_norms"] return log