xfu314's picture
Add phantom project with submodules and dependencies
96da58e
"""
Implementation of Behavioral Cloning (BC).
"""
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
import robomimic.models.base_nets as BaseNets
import robomimic.models.obs_nets as ObsNets
import robomimic.models.policy_nets as PolicyNets
import robomimic.models.vae_nets as VAENets
import robomimic.utils.loss_utils as LossUtils
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.algo import register_algo_factory_func, PolicyAlgo
@register_algo_factory_func("bc")
def algo_config_to_class(algo_config):
"""
Maps algo config to the BC algo class to instantiate, along with additional algo kwargs.
Args:
algo_config (Config instance): algo config
Returns:
algo_class: subclass of Algo
algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
"""
# note: we need the check below because some configs import BCConfig and exclude
# some of these options
gaussian_enabled = ("gaussian" in algo_config and algo_config.gaussian.enabled)
gmm_enabled = ("gmm" in algo_config and algo_config.gmm.enabled)
vae_enabled = ("vae" in algo_config and algo_config.vae.enabled)
rnn_enabled = algo_config.rnn.enabled
# support legacy configs that do not have "transformer" item
transformer_enabled = ("transformer" in algo_config) and algo_config.transformer.enabled
if gaussian_enabled:
if rnn_enabled:
raise NotImplementedError
elif transformer_enabled:
raise NotImplementedError
else:
algo_class, algo_kwargs = BC_Gaussian, {}
elif gmm_enabled:
if rnn_enabled:
algo_class, algo_kwargs = BC_RNN_GMM, {}
elif transformer_enabled:
algo_class, algo_kwargs = BC_Transformer_GMM, {}
else:
algo_class, algo_kwargs = BC_GMM, {}
elif vae_enabled:
if rnn_enabled:
raise NotImplementedError
elif transformer_enabled:
raise NotImplementedError
else:
algo_class, algo_kwargs = BC_VAE, {}
else:
if rnn_enabled:
algo_class, algo_kwargs = BC_RNN, {}
elif transformer_enabled:
algo_class, algo_kwargs = BC_Transformer, {}
else:
algo_class, algo_kwargs = BC, {}
return algo_class, algo_kwargs
class BC(PolicyAlgo):
"""
Normal BC training.
"""
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
self.nets = nn.ModuleDict()
self.nets["policy"] = PolicyNets.ActorNetwork(
obs_shapes=self.obs_shapes,
goal_shapes=self.goal_shapes,
ac_dim=self.ac_dim,
mlp_layer_dims=self.algo_config.actor_layer_dims,
encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
)
self.nets = self.nets.float().to(self.device)
def process_batch_for_training(self, batch):
"""
Processes input batch from a data loader to filter out
relevant information and prepare the batch for training.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader
Returns:
input_batch (dict): processed and filtered batch that
will be used for training
"""
input_batch = dict()
input_batch["obs"] = {k: batch["obs"][k][:, 0, :] for k in batch["obs"]}
input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
input_batch["actions"] = batch["actions"][:, 0, :]
# we move to device first before float conversion because image observation modalities will be uint8 -
# this minimizes the amount of data transferred to GPU
return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
def train_on_batch(self, batch, epoch, validate=False):
"""
Training on a single batch of data.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
epoch (int): epoch number - required by some Algos that need
to perform staged training and early stopping
validate (bool): if True, don't perform any learning updates.
Returns:
info (dict): dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
"""
with TorchUtils.maybe_no_grad(no_grad=validate):
info = super(BC, self).train_on_batch(batch, epoch, validate=validate)
predictions = self._forward_training(batch)
losses = self._compute_losses(predictions, batch)
info["predictions"] = TensorUtils.detach(predictions)
info["losses"] = TensorUtils.detach(losses)
if not validate:
step_info = self._train_step(losses)
info.update(step_info)
return info
def _forward_training(self, batch):
"""
Internal helper function for BC algo class. Compute forward pass
and return network outputs in @predictions dict.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
Returns:
predictions (dict): dictionary containing network outputs
"""
predictions = OrderedDict()
actions = self.nets["policy"](obs_dict=batch["obs"], goal_dict=batch["goal_obs"])
predictions["actions"] = actions
return predictions
def _compute_losses(self, predictions, batch):
"""
Internal helper function for BC algo class. Compute losses based on
network outputs in @predictions dict, using reference labels in @batch.
Args:
predictions (dict): dictionary containing network outputs, from @_forward_training
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
Returns:
losses (dict): dictionary of losses computed over the batch
"""
losses = OrderedDict()
a_target = batch["actions"]
actions = predictions["actions"]
losses["l2_loss"] = nn.MSELoss()(actions, a_target)
losses["l1_loss"] = nn.SmoothL1Loss()(actions, a_target)
# cosine direction loss on eef delta position
losses["cos_loss"] = LossUtils.cosine_loss(actions[..., :3], a_target[..., :3])
action_losses = [
self.algo_config.loss.l2_weight * losses["l2_loss"],
self.algo_config.loss.l1_weight * losses["l1_loss"],
self.algo_config.loss.cos_weight * losses["cos_loss"],
]
action_loss = sum(action_losses)
losses["action_loss"] = action_loss
return losses
def _train_step(self, losses):
"""
Internal helper function for BC algo class. Perform backpropagation on the
loss tensors in @losses to update networks.
Args:
losses (dict): dictionary of losses computed over the batch, from @_compute_losses
"""
# gradient step
info = OrderedDict()
policy_grad_norms = TorchUtils.backprop_for_loss(
net=self.nets["policy"],
optim=self.optimizers["policy"],
loss=losses["action_loss"],
)
info["policy_grad_norms"] = policy_grad_norms
return info
def log_info(self, info):
"""
Process info dictionary from @train_on_batch to summarize
information to pass to tensorboard for logging.
Args:
info (dict): dictionary of info
Returns:
loss_log (dict): name -> summary statistic
"""
log = super(BC, self).log_info(info)
log["Loss"] = info["losses"]["action_loss"].item()
if "l2_loss" in info["losses"]:
log["L2_Loss"] = info["losses"]["l2_loss"].item()
if "l1_loss" in info["losses"]:
log["L1_Loss"] = info["losses"]["l1_loss"].item()
if "cos_loss" in info["losses"]:
log["Cosine_Loss"] = info["losses"]["cos_loss"].item()
if "policy_grad_norms" in info:
log["Policy_Grad_Norms"] = info["policy_grad_norms"]
return log
def get_action(self, obs_dict, goal_dict=None):
"""
Get policy action outputs.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
action (torch.Tensor): action tensor
"""
assert not self.nets.training
return self.nets["policy"](obs_dict, goal_dict=goal_dict)
class BC_Gaussian(BC):
"""
BC training with a Gaussian policy.
"""
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
assert self.algo_config.gaussian.enabled
self.nets = nn.ModuleDict()
self.nets["policy"] = PolicyNets.GaussianActorNetwork(
obs_shapes=self.obs_shapes,
goal_shapes=self.goal_shapes,
ac_dim=self.ac_dim,
mlp_layer_dims=self.algo_config.actor_layer_dims,
fixed_std=self.algo_config.gaussian.fixed_std,
init_std=self.algo_config.gaussian.init_std,
std_limits=(self.algo_config.gaussian.min_std, 7.5),
std_activation=self.algo_config.gaussian.std_activation,
low_noise_eval=self.algo_config.gaussian.low_noise_eval,
encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
)
self.nets = self.nets.float().to(self.device)
def _forward_training(self, batch):
"""
Internal helper function for BC algo class. Compute forward pass
and return network outputs in @predictions dict.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
Returns:
predictions (dict): dictionary containing network outputs
"""
dists = self.nets["policy"].forward_train(
obs_dict=batch["obs"],
goal_dict=batch["goal_obs"],
)
# make sure that this is a batch of multivariate action distributions, so that
# the log probability computation will be correct
assert len(dists.batch_shape) == 1
log_probs = dists.log_prob(batch["actions"])
predictions = OrderedDict(
log_probs=log_probs,
)
return predictions
def _compute_losses(self, predictions, batch):
"""
Internal helper function for BC algo class. Compute losses based on
network outputs in @predictions dict, using reference labels in @batch.
Args:
predictions (dict): dictionary containing network outputs, from @_forward_training
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
Returns:
losses (dict): dictionary of losses computed over the batch
"""
# loss is just negative log-likelihood of action targets
action_loss = -predictions["log_probs"].mean()
return OrderedDict(
log_probs=-action_loss,
action_loss=action_loss,
)
def log_info(self, info):
"""
Process info dictionary from @train_on_batch to summarize
information to pass to tensorboard for logging.
Args:
info (dict): dictionary of info
Returns:
loss_log (dict): name -> summary statistic
"""
log = PolicyAlgo.log_info(self, info)
log["Loss"] = info["losses"]["action_loss"].item()
log["Log_Likelihood"] = info["losses"]["log_probs"].item()
if "policy_grad_norms" in info:
log["Policy_Grad_Norms"] = info["policy_grad_norms"]
return log
class BC_GMM(BC_Gaussian):
"""
BC training with a Gaussian Mixture Model policy.
"""
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
assert self.algo_config.gmm.enabled
self.nets = nn.ModuleDict()
self.nets["policy"] = PolicyNets.GMMActorNetwork(
obs_shapes=self.obs_shapes,
goal_shapes=self.goal_shapes,
ac_dim=self.ac_dim,
mlp_layer_dims=self.algo_config.actor_layer_dims,
num_modes=self.algo_config.gmm.num_modes,
min_std=self.algo_config.gmm.min_std,
std_activation=self.algo_config.gmm.std_activation,
low_noise_eval=self.algo_config.gmm.low_noise_eval,
encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
)
self.nets = self.nets.float().to(self.device)
class BC_VAE(BC):
"""
BC training with a VAE policy.
"""
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
self.nets = nn.ModuleDict()
self.nets["policy"] = PolicyNets.VAEActor(
obs_shapes=self.obs_shapes,
goal_shapes=self.goal_shapes,
ac_dim=self.ac_dim,
device=self.device,
encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
**VAENets.vae_args_from_config(self.algo_config.vae),
)
self.nets = self.nets.float().to(self.device)
def train_on_batch(self, batch, epoch, validate=False):
"""
Update from superclass to set categorical temperature, for categorical VAEs.
"""
if self.algo_config.vae.prior.use_categorical:
temperature = self.algo_config.vae.prior.categorical_init_temp - epoch * self.algo_config.vae.prior.categorical_temp_anneal_step
temperature = max(temperature, self.algo_config.vae.prior.categorical_min_temp)
self.nets["policy"].set_gumbel_temperature(temperature)
return super(BC_VAE, self).train_on_batch(batch, epoch, validate=validate)
def _forward_training(self, batch):
"""
Internal helper function for BC algo class. Compute forward pass
and return network outputs in @predictions dict.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
Returns:
predictions (dict): dictionary containing network outputs
"""
vae_inputs = dict(
actions=batch["actions"],
obs_dict=batch["obs"],
goal_dict=batch["goal_obs"],
freeze_encoder=batch.get("freeze_encoder", False),
)
vae_outputs = self.nets["policy"].forward_train(**vae_inputs)
predictions = OrderedDict(
actions=vae_outputs["decoder_outputs"],
kl_loss=vae_outputs["kl_loss"],
reconstruction_loss=vae_outputs["reconstruction_loss"],
encoder_z=vae_outputs["encoder_z"],
)
if not self.algo_config.vae.prior.use_categorical:
with torch.no_grad():
encoder_variance = torch.exp(vae_outputs["encoder_params"]["logvar"])
predictions["encoder_variance"] = encoder_variance
return predictions
def _compute_losses(self, predictions, batch):
"""
Internal helper function for BC algo class. Compute losses based on
network outputs in @predictions dict, using reference labels in @batch.
Args:
predictions (dict): dictionary containing network outputs, from @_forward_training
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
Returns:
losses (dict): dictionary of losses computed over the batch
"""
# total loss is sum of reconstruction and KL, weighted by beta
kl_loss = predictions["kl_loss"]
recons_loss = predictions["reconstruction_loss"]
action_loss = recons_loss + self.algo_config.vae.kl_weight * kl_loss
return OrderedDict(
recons_loss=recons_loss,
kl_loss=kl_loss,
action_loss=action_loss,
)
def log_info(self, info):
"""
Process info dictionary from @train_on_batch to summarize
information to pass to tensorboard for logging.
Args:
info (dict): dictionary of info
Returns:
loss_log (dict): name -> summary statistic
"""
log = PolicyAlgo.log_info(self, info)
log["Loss"] = info["losses"]["action_loss"].item()
log["KL_Loss"] = info["losses"]["kl_loss"].item()
log["Reconstruction_Loss"] = info["losses"]["recons_loss"].item()
if self.algo_config.vae.prior.use_categorical:
log["Gumbel_Temperature"] = self.nets["policy"].get_gumbel_temperature()
else:
log["Encoder_Variance"] = info["predictions"]["encoder_variance"].mean().item()
if "policy_grad_norms" in info:
log["Policy_Grad_Norms"] = info["policy_grad_norms"]
return log
class BC_RNN(BC):
"""
BC training with an RNN policy.
"""
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
self.nets = nn.ModuleDict()
self.nets["policy"] = PolicyNets.RNNActorNetwork(
obs_shapes=self.obs_shapes,
goal_shapes=self.goal_shapes,
ac_dim=self.ac_dim,
mlp_layer_dims=self.algo_config.actor_layer_dims,
encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
**BaseNets.rnn_args_from_config(self.algo_config.rnn),
)
self._rnn_hidden_state = None
self._rnn_horizon = self.algo_config.rnn.horizon
self._rnn_counter = 0
self._rnn_is_open_loop = self.algo_config.rnn.get("open_loop", False)
self.nets = self.nets.float().to(self.device)
def process_batch_for_training(self, batch):
"""
Processes input batch from a data loader to filter out
relevant information and prepare the batch for training.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader
Returns:
input_batch (dict): processed and filtered batch that
will be used for training
"""
input_batch = dict()
input_batch["obs"] = batch["obs"]
input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
input_batch["actions"] = batch["actions"]
if self._rnn_is_open_loop:
# replace the observation sequence with one that only consists of the first observation.
# This way, all actions are predicted "open-loop" after the first observation, based
# on the rnn hidden state.
n_steps = batch["actions"].shape[1]
obs_seq_start = TensorUtils.index_at_time(batch["obs"], ind=0)
input_batch["obs"] = TensorUtils.unsqueeze_expand_at(obs_seq_start, size=n_steps, dim=1)
# we move to device first before float conversion because image observation modalities will be uint8 -
# this minimizes the amount of data transferred to GPU
return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
def get_action(self, obs_dict, goal_dict=None):
"""
Get policy action outputs.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
action (torch.Tensor): action tensor
"""
assert not self.nets.training
if self._rnn_hidden_state is None or self._rnn_counter % self._rnn_horizon == 0:
batch_size = list(obs_dict.values())[0].shape[0]
self._rnn_hidden_state = self.nets["policy"].get_rnn_init_state(batch_size=batch_size, device=self.device)
if self._rnn_is_open_loop:
# remember the initial observation, and use it instead of the current observation
# for open-loop action sequence prediction
self._open_loop_obs = TensorUtils.clone(TensorUtils.detach(obs_dict))
obs_to_use = obs_dict
if self._rnn_is_open_loop:
# replace current obs with last recorded obs
obs_to_use = self._open_loop_obs
self._rnn_counter += 1
action, self._rnn_hidden_state = self.nets["policy"].forward_step(
obs_to_use, goal_dict=goal_dict, rnn_state=self._rnn_hidden_state)
return action
def reset(self):
"""
Reset algo state to prepare for environment rollouts.
"""
self._rnn_hidden_state = None
self._rnn_counter = 0
class BC_RNN_GMM(BC_RNN):
"""
BC training with an RNN GMM policy.
"""
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
assert self.algo_config.gmm.enabled
assert self.algo_config.rnn.enabled
self.nets = nn.ModuleDict()
self.nets["policy"] = PolicyNets.RNNGMMActorNetwork(
obs_shapes=self.obs_shapes,
goal_shapes=self.goal_shapes,
ac_dim=self.ac_dim,
mlp_layer_dims=self.algo_config.actor_layer_dims,
num_modes=self.algo_config.gmm.num_modes,
min_std=self.algo_config.gmm.min_std,
std_activation=self.algo_config.gmm.std_activation,
low_noise_eval=self.algo_config.gmm.low_noise_eval,
encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
**BaseNets.rnn_args_from_config(self.algo_config.rnn),
)
self._rnn_hidden_state = None
self._rnn_horizon = self.algo_config.rnn.horizon
self._rnn_counter = 0
self._rnn_is_open_loop = self.algo_config.rnn.get("open_loop", False)
self.nets = self.nets.float().to(self.device)
def _forward_training(self, batch):
"""
Internal helper function for BC algo class. Compute forward pass
and return network outputs in @predictions dict.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
Returns:
predictions (dict): dictionary containing network outputs
"""
dists = self.nets["policy"].forward_train(
obs_dict=batch["obs"],
goal_dict=batch["goal_obs"],
)
# make sure that this is a batch of multivariate action distributions, so that
# the log probability computation will be correct
assert len(dists.batch_shape) == 2 # [B, T]
log_probs = dists.log_prob(batch["actions"])
predictions = OrderedDict(
log_probs=log_probs,
)
return predictions
def _compute_losses(self, predictions, batch):
"""
Internal helper function for BC algo class. Compute losses based on
network outputs in @predictions dict, using reference labels in @batch.
Args:
predictions (dict): dictionary containing network outputs, from @_forward_training
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
Returns:
losses (dict): dictionary of losses computed over the batch
"""
# loss is just negative log-likelihood of action targets
action_loss = -predictions["log_probs"].mean()
return OrderedDict(
log_probs=-action_loss,
action_loss=action_loss,
)
def log_info(self, info):
"""
Process info dictionary from @train_on_batch to summarize
information to pass to tensorboard for logging.
Args:
info (dict): dictionary of info
Returns:
loss_log (dict): name -> summary statistic
"""
log = PolicyAlgo.log_info(self, info)
log["Loss"] = info["losses"]["action_loss"].item()
log["Log_Likelihood"] = info["losses"]["log_probs"].item()
if "policy_grad_norms" in info:
log["Policy_Grad_Norms"] = info["policy_grad_norms"]
return log
class BC_Transformer(BC):
"""
BC training with a Transformer policy.
"""
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
assert self.algo_config.transformer.enabled
self.nets = nn.ModuleDict()
self.nets["policy"] = PolicyNets.TransformerActorNetwork(
obs_shapes=self.obs_shapes,
goal_shapes=self.goal_shapes,
ac_dim=self.ac_dim,
encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
**BaseNets.transformer_args_from_config(self.algo_config.transformer),
)
self._set_params_from_config()
self.nets = self.nets.float().to(self.device)
def _set_params_from_config(self):
"""
Read specific config variables we need for training / eval.
Called by @_create_networks method
"""
self.context_length = self.algo_config.transformer.context_length
self.supervise_all_steps = self.algo_config.transformer.supervise_all_steps
def process_batch_for_training(self, batch):
"""
Processes input batch from a data loader to filter out
relevant information and prepare the batch for training.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader
Returns:
input_batch (dict): processed and filtered batch that
will be used for training
"""
input_batch = dict()
h = self.context_length
input_batch["obs"] = {k: batch["obs"][k][:, :h, :] for k in batch["obs"]}
input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
if self.supervise_all_steps:
# supervision on entire sequence (instead of just current timestep)
input_batch["actions"] = batch["actions"][:, :h, :]
else:
# just use current timestep
input_batch["actions"] = batch["actions"][:, h-1, :]
input_batch = TensorUtils.to_device(TensorUtils.to_float(input_batch), self.device)
return input_batch
def _forward_training(self, batch, epoch=None):
"""
Internal helper function for BC_Transformer algo class. Compute forward pass
and return network outputs in @predictions dict.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
Returns:
predictions (dict): dictionary containing network outputs
"""
# ensure that transformer context length is consistent with temporal dimension of observations
TensorUtils.assert_size_at_dim(
batch["obs"],
size=(self.context_length),
dim=1,
msg="Error: expect temporal dimension of obs batch to match transformer context length {}".format(self.context_length),
)
predictions = OrderedDict()
predictions["actions"] = self.nets["policy"](obs_dict=batch["obs"], actions=None, goal_dict=batch["goal_obs"])
if not self.supervise_all_steps:
# only supervise final timestep
predictions["actions"] = predictions["actions"][:, -1, :]
return predictions
def get_action(self, obs_dict, goal_dict=None):
"""
Get policy action outputs.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
action (torch.Tensor): action tensor
"""
assert not self.nets.training
return self.nets["policy"](obs_dict, actions=None, goal_dict=goal_dict)[:, -1, :]
class BC_Transformer_GMM(BC_Transformer):
"""
BC training with a Transformer GMM policy.
"""
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
assert self.algo_config.gmm.enabled
assert self.algo_config.transformer.enabled
self.nets = nn.ModuleDict()
self.nets["policy"] = PolicyNets.TransformerGMMActorNetwork(
obs_shapes=self.obs_shapes,
goal_shapes=self.goal_shapes,
ac_dim=self.ac_dim,
num_modes=self.algo_config.gmm.num_modes,
min_std=self.algo_config.gmm.min_std,
std_activation=self.algo_config.gmm.std_activation,
low_noise_eval=self.algo_config.gmm.low_noise_eval,
encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
**BaseNets.transformer_args_from_config(self.algo_config.transformer),
)
self._set_params_from_config()
self.nets = self.nets.float().to(self.device)
def _forward_training(self, batch, epoch=None):
"""
Modify from super class to support GMM training.
"""
# ensure that transformer context length is consistent with temporal dimension of observations
TensorUtils.assert_size_at_dim(
batch["obs"],
size=(self.context_length),
dim=1,
msg="Error: expect temporal dimension of obs batch to match transformer context length {}".format(self.context_length),
)
dists = self.nets["policy"].forward_train(
obs_dict=batch["obs"],
actions=None,
goal_dict=batch["goal_obs"],
low_noise_eval=False,
)
# make sure that this is a batch of multivariate action distributions, so that
# the log probability computation will be correct
assert len(dists.batch_shape) == 2 # [B, T]
if not self.supervise_all_steps:
# only use final timestep prediction by making a new distribution with only final timestep.
# This essentially does `dists = dists[:, -1]`
component_distribution = D.Normal(
loc=dists.component_distribution.base_dist.loc[:, -1],
scale=dists.component_distribution.base_dist.scale[:, -1],
)
component_distribution = D.Independent(component_distribution, 1)
mixture_distribution = D.Categorical(logits=dists.mixture_distribution.logits[:, -1])
dists = D.MixtureSameFamily(
mixture_distribution=mixture_distribution,
component_distribution=component_distribution,
)
log_probs = dists.log_prob(batch["actions"])
predictions = OrderedDict(
log_probs=log_probs,
)
return predictions
def _compute_losses(self, predictions, batch):
"""
Internal helper function for BC_Transformer_GMM algo class. Compute losses based on
network outputs in @predictions dict, using reference labels in @batch.
Args:
predictions (dict): dictionary containing network outputs, from @_forward_training
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
Returns:
losses (dict): dictionary of losses computed over the batch
"""
# loss is just negative log-likelihood of action targets
action_loss = -predictions["log_probs"].mean()
return OrderedDict(
log_probs=-action_loss,
action_loss=action_loss,
)
def log_info(self, info):
"""
Process info dictionary from @train_on_batch to summarize
information to pass to tensorboard for logging.
Args:
info (dict): dictionary of info
Returns:
loss_log (dict): name -> summary statistic
"""
log = PolicyAlgo.log_info(self, info)
log["Loss"] = info["losses"]["action_loss"].item()
log["Log_Likelihood"] = info["losses"]["log_probs"].item()
if "policy_grad_norms" in info:
log["Policy_Grad_Norms"] = info["policy_grad_norms"]
return log