xfu314's picture
Add phantom project with submodules and dependencies
96da58e
"""
Subgoal prediction models, used in HBC / IRIS.
"""
import numpy as np
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn
import robomimic.models.obs_nets as ObsNets
import robomimic.models.vae_nets as VAENets
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.algo import register_algo_factory_func, PlannerAlgo, ValueAlgo
@register_algo_factory_func("gl")
def algo_config_to_class(algo_config):
"""
Maps algo config to the GL algo class to instantiate, along with additional algo kwargs.
Args:
algo_config (Config instance): algo config
Returns:
algo_class: subclass of Algo
algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
"""
if algo_config.vae.enabled:
return GL_VAE, {}
return GL, {}
class GL(PlannerAlgo):
"""
Implements goal prediction component for HBC and IRIS.
"""
def __init__(
self,
algo_config,
obs_config,
global_config,
obs_key_shapes,
ac_dim,
device
):
"""
Args:
algo_config (Config object): instance of Config corresponding to the algo section
of the config
obs_config (Config object): instance of Config corresponding to the observation
section of the config
global_config (Config object): global training config
obs_key_shapes (OrderedDict): dictionary that maps observation keys to shapes
ac_dim (int): dimension of action space
device (torch.Device): where the algo should live (i.e. cpu, gpu)
"""
self._subgoal_horizon = algo_config.subgoal_horizon
super(GL, self).__init__(
algo_config=algo_config,
obs_config=obs_config,
global_config=global_config,
obs_key_shapes=obs_key_shapes,
ac_dim=ac_dim,
device=device
)
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
self.nets = nn.ModuleDict()
obs_group_shapes = OrderedDict()
obs_group_shapes["obs"] = OrderedDict(self.obs_shapes)
if len(self.goal_shapes) > 0:
obs_group_shapes["goal"] = OrderedDict(self.goal_shapes)
# deterministic goal prediction network
self.nets["goal_network"] = ObsNets.MIMO_MLP(
input_obs_group_shapes=obs_group_shapes,
output_shapes=self.subgoal_shapes,
layer_dims=self.algo_config.ae.planner_layer_dims,
encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
)
self.nets = self.nets.float().to(self.device)
def process_batch_for_training(self, batch):
"""
Processes input batch from a data loader to filter out
relevant information and prepare the batch for training.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader
Returns:
input_batch (dict): processed and filtered batch that
will be used for training
"""
input_batch = dict()
# remove temporal batches for all except scalar signals (to be compatible with model outputs)
input_batch["obs"] = { k: batch["obs"][k][:, 0, :] for k in batch["obs"] }
# extract multi-horizon subgoal target
input_batch["subgoals"] = {k: batch["next_obs"][k][:, self._subgoal_horizon - 1, :] for k in batch["next_obs"]}
input_batch["target_subgoals"] = input_batch["subgoals"]
input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present
# we move to device first before float conversion because image observation modalities will be uint8 -
# this minimizes the amount of data transferred to GPU
return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
def get_actor_goal_for_training_from_processed_batch(self, processed_batch, **kwargs):
"""
Retrieve subgoals from processed batch to use for training the actor. Subclasses
can modify this function to change the subgoals.
Args:
processed_batch (dict): processed batch from @process_batch_for_training
Returns:
actor_subgoals (dict): subgoal observations to condition actor on
"""
return processed_batch["target_subgoals"]
def train_on_batch(self, batch, epoch, validate=False):
"""
Training on a single batch of data.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
epoch (int): epoch number - required by some Algos that need
to perform staged training and early stopping
validate (bool): if True, don't perform any learning updates.
Returns:
info (dict): dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
"""
with TorchUtils.maybe_no_grad(no_grad=validate):
info = super(GL, self).train_on_batch(batch, epoch, validate=validate)
# predict subgoal observations with goal network
pred_subgoals = self.nets["goal_network"](obs=batch["obs"], goal=batch["goal_obs"])
# compute loss as L2 error for each observation key
losses = OrderedDict()
target_subgoals = batch["target_subgoals"] # targets for network prediction
goal_loss = 0.
for k in pred_subgoals:
assert pred_subgoals[k].shape == target_subgoals[k].shape, "mismatch in predicted and target subgoals!"
mode_loss = nn.MSELoss()(pred_subgoals[k], target_subgoals[k])
goal_loss += mode_loss
losses["goal_{}_loss".format(k)] = mode_loss
losses["goal_loss"] = goal_loss
info.update(TensorUtils.detach(losses))
if not validate:
# gradient step
goal_grad_norms = TorchUtils.backprop_for_loss(
net=self.nets["goal_network"],
optim=self.optimizers["goal_network"],
loss=losses["goal_loss"],
)
info["goal_grad_norms"] = goal_grad_norms
return info
def log_info(self, info):
"""
Process info dictionary from @train_on_batch to summarize
information to pass to tensorboard for logging.
Args:
info (dict): dictionary of info
Returns:
loss_log (dict): name -> summary statistic
"""
loss_log = super(GL, self).log_info(info)
loss_log["Loss"] = info["goal_loss"].item()
for k in info:
if k.endswith("_loss"):
loss_log[k] = info[k].item()
if "goal_grad_norms" in info:
loss_log["Grad_Norms"] = info["goal_grad_norms"]
return loss_log
def get_subgoal_predictions(self, obs_dict, goal_dict=None):
"""
Takes a batch of observations and predicts a batch of subgoals.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
subgoal prediction (dict): name -> Tensor [batch_size, ...]
"""
return self.nets["goal_network"](obs=obs_dict, goal=goal_dict)
def sample_subgoals(self, obs_dict, goal_dict=None, num_samples=1):
"""
Sample @num_samples subgoals from the network per observation.
Since this class implements a deterministic subgoal prediction,
this function returns identical subgoals for each input observation.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
subgoals (dict): name -> Tensor [batch_size, num_samples, ...]
"""
# stack observations to get all samples in one forward pass
obs_tiled = ObsUtils.repeat_and_stack_observation(obs_dict, n=num_samples)
goal_tiled = None
if goal_dict is not None:
goal_tiled = ObsUtils.repeat_and_stack_observation(goal_dict, n=num_samples)
# [batch_size * num_samples, ...]
goals = self.get_subgoal_predictions(obs_dict=obs_tiled, goal_dict=goal_tiled)
# reshape to [batch_size, num_samples, ...]
return TensorUtils.reshape_dimensions(goals, begin_axis=0, end_axis=0, target_dims=(-1, num_samples))
def get_action(self, obs_dict, goal_dict=None):
"""
Get policy action outputs. Assumes one input observation (first dimension should be 1).
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
action (torch.Tensor): action tensor
"""
raise Exception("Rollouts are not supported by GL")
class GL_VAE(GL):
"""
Implements goal prediction via VAE.
"""
def _create_networks(self):
"""
Creates networks and places them into @self.nets.
"""
self.nets = nn.ModuleDict()
self.nets["goal_network"] = VAENets.VAE(
input_shapes=self.subgoal_shapes,
output_shapes=self.subgoal_shapes,
condition_shapes=self.obs_shapes,
goal_shapes=self.goal_shapes,
device=self.device,
encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder),
**VAENets.vae_args_from_config(self.algo_config.vae),
)
self.nets = self.nets.float().to(self.device)
def get_actor_goal_for_training_from_processed_batch(
self,
processed_batch,
use_latent_subgoals=False,
use_prior_correction=False,
num_prior_samples=100,
**kwargs,
):
"""
Modify from superclass to support a @use_latent_subgoals option.
The VAE can optionally return latent subgoals by passing the subgoal
observations in the batch through the encoder.
Args:
processed_batch (dict): processed batch from @process_batch_for_training
use_latent_subgoals (bool): if True, condition the actor on latent subgoals
by using the VAE encoder to encode subgoal observations at train-time,
and using the VAE prior to generate latent subgoals at test-time
use_prior_correction (bool): if True, use a "prior correction" trick to
choose a latent subgoal sampled from the prior that is close to the
latent from the VAE encoder (posterior). This can help with issues at
test-time where the encoder latent distribution might not match
the prior latent distribution.
num_prior_samples (int): number of VAE prior samples to take and choose among,
if @use_prior_correction is true
Returns:
actor_subgoals (dict): subgoal observations to condition actor on
"""
if not use_latent_subgoals:
return processed_batch["target_subgoals"]
# batch variables
obs = processed_batch["obs"]
subgoals = processed_batch["subgoals"] # full subgoal observations
target_subgoals = processed_batch["target_subgoals"] # targets for network prediction
goal_obs = processed_batch["goal_obs"]
with torch.no_grad():
# run VAE forward pass to get samples from posterior for the current observation and subgoal
vae_outputs = self.nets["goal_network"](
inputs=subgoals, # encoder takes full subgoals
outputs=target_subgoals, # reconstruct target subgoals
goals=goal_obs,
conditions=obs, # condition on observations
)
posterior_z = vae_outputs["encoder_z"]
latent_subgoals = posterior_z
if use_prior_correction:
# instead of treating posterior samples as latent subgoals, sample latents from
# the prior and choose the closest one as the latent subgoal
random_key = list(obs.keys())[0]
batch_size = obs[random_key].shape[0]
# for each batch member, get @num_prior_samples samples from the prior
obs_tiled = ObsUtils.repeat_and_stack_observation(obs, n=num_prior_samples)
goal_tiled = None
if len(self.goal_shapes) > 0:
goal_tiled = ObsUtils.repeat_and_stack_observation(goal_obs, n=num_prior_samples)
prior_z_samples = self.nets["goal_network"].sample_prior(
conditions=obs_tiled,
goals=goal_tiled,
)
# choose prior samples that are closest to the sampled posterior latents
# note: every posterior sample in the batch has @num_prior_samples corresponding prior samples
# reshape prior samples to (batch_size, num_samples, latent_dim)
prior_z_samples = prior_z_samples.reshape(batch_size, num_prior_samples, -1)
# reshape posterior latents to (batch_size, 1, latent_dim)
posterior_z_expanded = posterior_z.unsqueeze(1)
# compute distances with broadcasting so that each posterior sample
# has distances to all of its prior samples
distances = (prior_z_samples - posterior_z_expanded).pow(2).sum(dim=2)
# then gather the closest prior sample for each posterior sample
neighbors = torch.argmin(distances, dim=1)
latent_subgoals = prior_z_samples[torch.arange(batch_size).long(), neighbors]
return { "latent_subgoal" : latent_subgoals }
def train_on_batch(self, batch, epoch, validate=False):
"""
Training on a single batch of data.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
epoch (int): epoch number - required by some Algos that need
to perform staged training and early stopping
validate (bool): if True, don't perform any learning updates.
Returns:
info (dict): dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
"""
with TorchUtils.maybe_no_grad(no_grad=validate):
info = super(GL, self).train_on_batch(batch, epoch, validate=validate)
if self.algo_config.vae.prior.use_categorical:
temperature = self.algo_config.vae.prior.categorical_init_temp - epoch * self.algo_config.vae.prior.categorical_temp_anneal_step
temperature = max(temperature, self.algo_config.vae.prior.categorical_min_temp)
self.nets["goal_network"].set_gumbel_temperature(temperature)
# batch variables
obs = batch["obs"]
subgoals = batch["subgoals"] # full subgoal observations
target_subgoals = batch["target_subgoals"] # targets for network prediction
goal_obs = batch["goal_obs"]
vae_outputs = self.nets["goal_network"](
inputs=subgoals, # encoder takes full subgoals
outputs=target_subgoals, # reconstruct target subgoals
goals=goal_obs,
conditions=obs, # condition on observations
)
recons_loss = vae_outputs["reconstruction_loss"]
kl_loss = vae_outputs["kl_loss"]
goal_loss = recons_loss + self.algo_config.vae.kl_weight * kl_loss
info["recons_loss"] = recons_loss
info["kl_loss"] = kl_loss
info["goal_loss"] = goal_loss
if not self.algo_config.vae.prior.use_categorical:
with torch.no_grad():
info["encoder_variance"] = torch.exp(vae_outputs["encoder_params"]["logvar"])
# VAE gradient step
if not validate:
goal_grad_norms = TorchUtils.backprop_for_loss(
net=self.nets["goal_network"],
optim=self.optimizers["goal_network"],
loss=goal_loss,
)
info["goal_grad_norms"] = goal_grad_norms
return info
def log_info(self, info):
"""
Process info dictionary from @train_on_batch to summarize
information to pass to tensorboard for logging.
Args:
info (dict): dictionary of info
Returns:
loss_log (dict): name -> summary statistic
"""
loss_log = super(GL_VAE, self).log_info(info)
loss_log["Reconstruction_Loss"] = info["recons_loss"].item()
loss_log["KL_Loss"] = info["kl_loss"].item()
if self.algo_config.vae.prior.use_categorical:
loss_log["Gumbel_Temperature"] = self.nets["goal_network"].get_gumbel_temperature()
else:
loss_log["Encoder_Variance"] = info["encoder_variance"].mean().item()
return loss_log
def get_subgoal_predictions(self, obs_dict, goal_dict=None):
"""
Takes a batch of observations and predicts a batch of subgoals.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
subgoal prediction (dict): name -> Tensor [batch_size, ...]
"""
if self.global_config.algo.latent_subgoal.enabled:
# latent subgoals from sampling prior
latent_subgoals = self.nets["goal_network"].sample_prior(
conditions=obs_dict,
goals=goal_dict,
)
return OrderedDict(latent_subgoal=latent_subgoals)
# sample a single goal from the VAE
goals = self.sample_subgoals(obs_dict=obs_dict, goal_dict=goal_dict, num_samples=1)
return { k : goals[k][:, 0, ...] for k in goals }
def sample_subgoals(self, obs_dict, goal_dict=None, num_samples=1):
"""
Sample @num_samples subgoals from the VAE per observation.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
subgoals (dict): name -> Tensor [batch_size, num_samples, ...]
"""
# stack observations to get all samples in one forward pass
obs_tiled = ObsUtils.repeat_and_stack_observation(obs_dict, n=num_samples)
goal_tiled = None
if goal_dict is not None:
goal_tiled = ObsUtils.repeat_and_stack_observation(goal_dict, n=num_samples)
# VAE decode expects number of samples explicitly
mod = list(obs_tiled.keys())[0]
n = obs_tiled[mod].shape[0]
# [batch_size * num_samples, ...]
goals = self.nets["goal_network"].decode(n=n, conditions=obs_tiled, goals=goal_tiled)
# reshape to [batch_size, num_samples, ...]
return TensorUtils.reshape_dimensions(goals, begin_axis=0, end_axis=0, target_dims=(-1, num_samples))
class ValuePlanner(PlannerAlgo, ValueAlgo):
"""
Base class for all algorithms that are used for planning subgoals
based on (1) a @PlannerAlgo that is used to sample candidate subgoals
and (2) a @ValueAlgo that is used to select one of the subgoals.
"""
def __init__(
self,
planner_algo_class,
value_algo_class,
algo_config,
obs_config,
global_config,
obs_key_shapes,
ac_dim,
device,
):
"""
Args:
planner_algo_class (Algo class): algo class for the planner
value_algo_class (Algo class): algo class for the value network
algo_config (Config object): instance of Config corresponding to the algo section
of the config
obs_config (Config object): instance of Config corresponding to the observation
section of the config
global_config (Config object); global config
obs_key_shapes (OrderedDict): dictionary that maps input/output observation keys to shapes
ac_dim (int): action dimension
device: torch device
"""
self.algo_config = algo_config
self.obs_config = obs_config
self.global_config = global_config
self.ac_dim = ac_dim
self.device = device
self.planner = planner_algo_class(
algo_config=algo_config.planner,
obs_config=obs_config.planner,
global_config=global_config,
obs_key_shapes=obs_key_shapes,
ac_dim=ac_dim,
device=device
)
self.value_net = value_algo_class(
algo_config=algo_config.value,
obs_config=obs_config.value,
global_config=global_config,
obs_key_shapes=obs_key_shapes,
ac_dim=ac_dim,
device=device
)
self.subgoal_shapes = self.planner.subgoal_shapes
def process_batch_for_training(self, batch):
"""
Processes input batch from a data loader to filter out
relevant information and prepare the batch for training.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader
Returns:
input_batch (dict): processed and filtered batch that
will be used for training
"""
input_batch = dict()
input_batch["planner"] = self.planner.process_batch_for_training(batch)
input_batch["value_net"] = self.value_net.process_batch_for_training(batch)
# we move to device first before float conversion because image observation modalities will be uint8 -
# this minimizes the amount of data transferred to GPU
return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
def train_on_batch(self, batch, epoch, validate=False):
"""
Training on a single batch of data.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader and filtered by @process_batch_for_training
epoch (int): epoch number - required by some Algos that need
to perform staged training and early stopping
validate (bool): if True, don't perform any learning updates.
Returns:
info (dict): dictionary of relevant inputs, outputs, and losses
that might be relevant for logging
"""
if validate:
assert not self.planner.nets.training
assert not self.value_net.nets.training
info = dict(planner=dict(), value_net=dict())
# train planner
info["planner"].update(self.planner.train_on_batch(batch["planner"], epoch, validate=validate))
# train value network
info["value_net"].update(self.value_net.train_on_batch(batch["value_net"], epoch, validate=validate))
return info
def log_info(self, info):
"""
Process info dictionary from @train_on_batch to summarize
information to pass to tensorboard for logging.
Args:
info (dict): dictionary of info
Returns:
loss_log (dict): name -> summary statistic
"""
loss = 0.
# planner
planner_log = self.planner.log_info(info["planner"])
planner_log = dict(("Planner/" + k, v) for k, v in planner_log.items())
loss += planner_log["Planner/Loss"]
# value network
value_net_log = self.value_net.log_info(info["value_net"])
value_net_log = dict(("ValueNetwork/" + k, v) for k, v in value_net_log.items())
loss += value_net_log["ValueNetwork/Loss"]
planner_log.update(value_net_log)
planner_log["Loss"] = loss
return planner_log
def on_epoch_end(self, epoch):
"""
Called at the end of each epoch.
"""
self.planner.on_epoch_end(epoch)
self.value_net.on_epoch_end(epoch)
def set_eval(self):
"""
Prepare networks for evaluation.
"""
self.planner.set_eval()
self.value_net.set_eval()
def set_train(self):
"""
Prepare networks for training.
"""
self.planner.set_train()
self.value_net.set_train()
def serialize(self):
"""
Get dictionary of current model parameters.
"""
return dict(
planner=self.planner.serialize(),
value_net=self.value_net.serialize(),
)
def deserialize(self, model_dict):
"""
Load model from a checkpoint.
Args:
model_dict (dict): a dictionary saved by self.serialize() that contains
the same keys as @self.network_classes
"""
self.planner.deserialize(model_dict["planner"])
self.value_net.deserialize(model_dict["value_net"])
def reset(self):
"""
Reset algo state to prepare for environment rollouts.
"""
self.planner.reset()
self.value_net.reset()
def __repr__(self):
"""
Pretty print algorithm and network description.
"""
msg = str(self.__class__.__name__)
import textwrap
return msg + "Planner:\n" + textwrap.indent(self.planner.__repr__(), ' ') + \
"\n\nValue Network:\n" + textwrap.indent(self.value_net.__repr__(), ' ')
def get_subgoal_predictions(self, obs_dict, goal_dict=None):
"""
Takes a batch of observations and predicts a batch of subgoals.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
subgoal prediction (dict): name -> Tensor [batch_size, ...]
"""
num_samples = self.algo_config.num_samples
# sample subgoals from the planner (shape: [batch_size, num_samples, ...])
subgoals = self.sample_subgoals(obs_dict=obs_dict, goal_dict=goal_dict, num_samples=num_samples)
# stack subgoals to get all values in one forward pass (shape [batch_size * num_samples, ...])
k = list(obs_dict.keys())[0]
bsize = obs_dict[k].shape[0]
subgoals_tiled = TensorUtils.reshape_dimensions(subgoals, begin_axis=0, end_axis=1, target_dims=(bsize * num_samples,))
# also repeat goals if necessary
goal_tiled = None
if len(self.planner.goal_shapes) > 0:
goal_tiled = ObsUtils.repeat_and_stack_observation(goal_dict, n=num_samples)
# evaluate the value of each subgoal
subgoal_values = self.value_net.get_state_value(obs_dict=subgoals_tiled, goal_dict=goal_tiled).reshape(-1, num_samples)
# pick the best subgoal
best_index = torch.argmax(subgoal_values, dim=1)
best_subgoal = {k: subgoals[k][torch.arange(bsize), best_index] for k in subgoals}
return best_subgoal
def sample_subgoals(self, obs_dict, goal_dict, num_samples=1):
"""
Sample @num_samples subgoals from the planner algo per observation.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
subgoals (dict): name -> Tensor [batch_size, num_samples, ...]
"""
return self.planner.sample_subgoals(obs_dict=obs_dict, goal_dict=goal_dict, num_samples=num_samples)
def get_state_value(self, obs_dict, goal_dict=None):
"""
Get state value outputs.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
value (torch.Tensor): value tensor
"""
return self.value_net.get_state_value(obs_dict=obs_dict, goal_dict=goal_dict)
def get_state_action_value(self, obs_dict, actions, goal_dict=None):
"""
Get state-action value outputs.
Args:
obs_dict (dict): current observation
actions (torch.Tensor): action
goal_dict (dict): (optional) goal
Returns:
value (torch.Tensor): value tensor
"""
return self.value_net.get_state_action_value(obs_dict=obs_dict, actions=actions, goal_dict=goal_dict)