Spaces:
Running on Zero
Running on Zero
| """ | |
| Subgoal prediction models, used in HBC / IRIS. | |
| """ | |
| import numpy as np | |
| from collections import OrderedDict | |
| from copy import deepcopy | |
| import torch | |
| import torch.nn as nn | |
| import robomimic.models.obs_nets as ObsNets | |
| import robomimic.models.vae_nets as VAENets | |
| import robomimic.utils.tensor_utils as TensorUtils | |
| import robomimic.utils.torch_utils as TorchUtils | |
| import robomimic.utils.obs_utils as ObsUtils | |
| from robomimic.algo import register_algo_factory_func, PlannerAlgo, ValueAlgo | |
| def algo_config_to_class(algo_config): | |
| """ | |
| Maps algo config to the GL algo class to instantiate, along with additional algo kwargs. | |
| Args: | |
| algo_config (Config instance): algo config | |
| Returns: | |
| algo_class: subclass of Algo | |
| algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm | |
| """ | |
| if algo_config.vae.enabled: | |
| return GL_VAE, {} | |
| return GL, {} | |
| class GL(PlannerAlgo): | |
| """ | |
| Implements goal prediction component for HBC and IRIS. | |
| """ | |
| def __init__( | |
| self, | |
| algo_config, | |
| obs_config, | |
| global_config, | |
| obs_key_shapes, | |
| ac_dim, | |
| device | |
| ): | |
| """ | |
| Args: | |
| algo_config (Config object): instance of Config corresponding to the algo section | |
| of the config | |
| obs_config (Config object): instance of Config corresponding to the observation | |
| section of the config | |
| global_config (Config object): global training config | |
| obs_key_shapes (OrderedDict): dictionary that maps observation keys to shapes | |
| ac_dim (int): dimension of action space | |
| device (torch.Device): where the algo should live (i.e. cpu, gpu) | |
| """ | |
| self._subgoal_horizon = algo_config.subgoal_horizon | |
| super(GL, self).__init__( | |
| algo_config=algo_config, | |
| obs_config=obs_config, | |
| global_config=global_config, | |
| obs_key_shapes=obs_key_shapes, | |
| ac_dim=ac_dim, | |
| device=device | |
| ) | |
| def _create_networks(self): | |
| """ | |
| Creates networks and places them into @self.nets. | |
| """ | |
| self.nets = nn.ModuleDict() | |
| obs_group_shapes = OrderedDict() | |
| obs_group_shapes["obs"] = OrderedDict(self.obs_shapes) | |
| if len(self.goal_shapes) > 0: | |
| obs_group_shapes["goal"] = OrderedDict(self.goal_shapes) | |
| # deterministic goal prediction network | |
| self.nets["goal_network"] = ObsNets.MIMO_MLP( | |
| input_obs_group_shapes=obs_group_shapes, | |
| output_shapes=self.subgoal_shapes, | |
| layer_dims=self.algo_config.ae.planner_layer_dims, | |
| encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder), | |
| ) | |
| self.nets = self.nets.float().to(self.device) | |
| def process_batch_for_training(self, batch): | |
| """ | |
| Processes input batch from a data loader to filter out | |
| relevant information and prepare the batch for training. | |
| Args: | |
| batch (dict): dictionary with torch.Tensors sampled | |
| from a data loader | |
| Returns: | |
| input_batch (dict): processed and filtered batch that | |
| will be used for training | |
| """ | |
| input_batch = dict() | |
| # remove temporal batches for all except scalar signals (to be compatible with model outputs) | |
| input_batch["obs"] = { k: batch["obs"][k][:, 0, :] for k in batch["obs"] } | |
| # extract multi-horizon subgoal target | |
| input_batch["subgoals"] = {k: batch["next_obs"][k][:, self._subgoal_horizon - 1, :] for k in batch["next_obs"]} | |
| input_batch["target_subgoals"] = input_batch["subgoals"] | |
| input_batch["goal_obs"] = batch.get("goal_obs", None) # goals may not be present | |
| # we move to device first before float conversion because image observation modalities will be uint8 - | |
| # this minimizes the amount of data transferred to GPU | |
| return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device)) | |
| def get_actor_goal_for_training_from_processed_batch(self, processed_batch, **kwargs): | |
| """ | |
| Retrieve subgoals from processed batch to use for training the actor. Subclasses | |
| can modify this function to change the subgoals. | |
| Args: | |
| processed_batch (dict): processed batch from @process_batch_for_training | |
| Returns: | |
| actor_subgoals (dict): subgoal observations to condition actor on | |
| """ | |
| return processed_batch["target_subgoals"] | |
| def train_on_batch(self, batch, epoch, validate=False): | |
| """ | |
| Training on a single batch of data. | |
| Args: | |
| batch (dict): dictionary with torch.Tensors sampled | |
| from a data loader and filtered by @process_batch_for_training | |
| epoch (int): epoch number - required by some Algos that need | |
| to perform staged training and early stopping | |
| validate (bool): if True, don't perform any learning updates. | |
| Returns: | |
| info (dict): dictionary of relevant inputs, outputs, and losses | |
| that might be relevant for logging | |
| """ | |
| with TorchUtils.maybe_no_grad(no_grad=validate): | |
| info = super(GL, self).train_on_batch(batch, epoch, validate=validate) | |
| # predict subgoal observations with goal network | |
| pred_subgoals = self.nets["goal_network"](obs=batch["obs"], goal=batch["goal_obs"]) | |
| # compute loss as L2 error for each observation key | |
| losses = OrderedDict() | |
| target_subgoals = batch["target_subgoals"] # targets for network prediction | |
| goal_loss = 0. | |
| for k in pred_subgoals: | |
| assert pred_subgoals[k].shape == target_subgoals[k].shape, "mismatch in predicted and target subgoals!" | |
| mode_loss = nn.MSELoss()(pred_subgoals[k], target_subgoals[k]) | |
| goal_loss += mode_loss | |
| losses["goal_{}_loss".format(k)] = mode_loss | |
| losses["goal_loss"] = goal_loss | |
| info.update(TensorUtils.detach(losses)) | |
| if not validate: | |
| # gradient step | |
| goal_grad_norms = TorchUtils.backprop_for_loss( | |
| net=self.nets["goal_network"], | |
| optim=self.optimizers["goal_network"], | |
| loss=losses["goal_loss"], | |
| ) | |
| info["goal_grad_norms"] = goal_grad_norms | |
| return info | |
| def log_info(self, info): | |
| """ | |
| Process info dictionary from @train_on_batch to summarize | |
| information to pass to tensorboard for logging. | |
| Args: | |
| info (dict): dictionary of info | |
| Returns: | |
| loss_log (dict): name -> summary statistic | |
| """ | |
| loss_log = super(GL, self).log_info(info) | |
| loss_log["Loss"] = info["goal_loss"].item() | |
| for k in info: | |
| if k.endswith("_loss"): | |
| loss_log[k] = info[k].item() | |
| if "goal_grad_norms" in info: | |
| loss_log["Grad_Norms"] = info["goal_grad_norms"] | |
| return loss_log | |
| def get_subgoal_predictions(self, obs_dict, goal_dict=None): | |
| """ | |
| Takes a batch of observations and predicts a batch of subgoals. | |
| Args: | |
| obs_dict (dict): current observation | |
| goal_dict (dict): (optional) goal | |
| Returns: | |
| subgoal prediction (dict): name -> Tensor [batch_size, ...] | |
| """ | |
| return self.nets["goal_network"](obs=obs_dict, goal=goal_dict) | |
| def sample_subgoals(self, obs_dict, goal_dict=None, num_samples=1): | |
| """ | |
| Sample @num_samples subgoals from the network per observation. | |
| Since this class implements a deterministic subgoal prediction, | |
| this function returns identical subgoals for each input observation. | |
| Args: | |
| obs_dict (dict): current observation | |
| goal_dict (dict): (optional) goal | |
| Returns: | |
| subgoals (dict): name -> Tensor [batch_size, num_samples, ...] | |
| """ | |
| # stack observations to get all samples in one forward pass | |
| obs_tiled = ObsUtils.repeat_and_stack_observation(obs_dict, n=num_samples) | |
| goal_tiled = None | |
| if goal_dict is not None: | |
| goal_tiled = ObsUtils.repeat_and_stack_observation(goal_dict, n=num_samples) | |
| # [batch_size * num_samples, ...] | |
| goals = self.get_subgoal_predictions(obs_dict=obs_tiled, goal_dict=goal_tiled) | |
| # reshape to [batch_size, num_samples, ...] | |
| return TensorUtils.reshape_dimensions(goals, begin_axis=0, end_axis=0, target_dims=(-1, num_samples)) | |
| def get_action(self, obs_dict, goal_dict=None): | |
| """ | |
| Get policy action outputs. Assumes one input observation (first dimension should be 1). | |
| Args: | |
| obs_dict (dict): current observation | |
| goal_dict (dict): (optional) goal | |
| Returns: | |
| action (torch.Tensor): action tensor | |
| """ | |
| raise Exception("Rollouts are not supported by GL") | |
| class GL_VAE(GL): | |
| """ | |
| Implements goal prediction via VAE. | |
| """ | |
| def _create_networks(self): | |
| """ | |
| Creates networks and places them into @self.nets. | |
| """ | |
| self.nets = nn.ModuleDict() | |
| self.nets["goal_network"] = VAENets.VAE( | |
| input_shapes=self.subgoal_shapes, | |
| output_shapes=self.subgoal_shapes, | |
| condition_shapes=self.obs_shapes, | |
| goal_shapes=self.goal_shapes, | |
| device=self.device, | |
| encoder_kwargs=ObsUtils.obs_encoder_kwargs_from_config(self.obs_config.encoder), | |
| **VAENets.vae_args_from_config(self.algo_config.vae), | |
| ) | |
| self.nets = self.nets.float().to(self.device) | |
| def get_actor_goal_for_training_from_processed_batch( | |
| self, | |
| processed_batch, | |
| use_latent_subgoals=False, | |
| use_prior_correction=False, | |
| num_prior_samples=100, | |
| **kwargs, | |
| ): | |
| """ | |
| Modify from superclass to support a @use_latent_subgoals option. | |
| The VAE can optionally return latent subgoals by passing the subgoal | |
| observations in the batch through the encoder. | |
| Args: | |
| processed_batch (dict): processed batch from @process_batch_for_training | |
| use_latent_subgoals (bool): if True, condition the actor on latent subgoals | |
| by using the VAE encoder to encode subgoal observations at train-time, | |
| and using the VAE prior to generate latent subgoals at test-time | |
| use_prior_correction (bool): if True, use a "prior correction" trick to | |
| choose a latent subgoal sampled from the prior that is close to the | |
| latent from the VAE encoder (posterior). This can help with issues at | |
| test-time where the encoder latent distribution might not match | |
| the prior latent distribution. | |
| num_prior_samples (int): number of VAE prior samples to take and choose among, | |
| if @use_prior_correction is true | |
| Returns: | |
| actor_subgoals (dict): subgoal observations to condition actor on | |
| """ | |
| if not use_latent_subgoals: | |
| return processed_batch["target_subgoals"] | |
| # batch variables | |
| obs = processed_batch["obs"] | |
| subgoals = processed_batch["subgoals"] # full subgoal observations | |
| target_subgoals = processed_batch["target_subgoals"] # targets for network prediction | |
| goal_obs = processed_batch["goal_obs"] | |
| with torch.no_grad(): | |
| # run VAE forward pass to get samples from posterior for the current observation and subgoal | |
| vae_outputs = self.nets["goal_network"]( | |
| inputs=subgoals, # encoder takes full subgoals | |
| outputs=target_subgoals, # reconstruct target subgoals | |
| goals=goal_obs, | |
| conditions=obs, # condition on observations | |
| ) | |
| posterior_z = vae_outputs["encoder_z"] | |
| latent_subgoals = posterior_z | |
| if use_prior_correction: | |
| # instead of treating posterior samples as latent subgoals, sample latents from | |
| # the prior and choose the closest one as the latent subgoal | |
| random_key = list(obs.keys())[0] | |
| batch_size = obs[random_key].shape[0] | |
| # for each batch member, get @num_prior_samples samples from the prior | |
| obs_tiled = ObsUtils.repeat_and_stack_observation(obs, n=num_prior_samples) | |
| goal_tiled = None | |
| if len(self.goal_shapes) > 0: | |
| goal_tiled = ObsUtils.repeat_and_stack_observation(goal_obs, n=num_prior_samples) | |
| prior_z_samples = self.nets["goal_network"].sample_prior( | |
| conditions=obs_tiled, | |
| goals=goal_tiled, | |
| ) | |
| # choose prior samples that are closest to the sampled posterior latents | |
| # note: every posterior sample in the batch has @num_prior_samples corresponding prior samples | |
| # reshape prior samples to (batch_size, num_samples, latent_dim) | |
| prior_z_samples = prior_z_samples.reshape(batch_size, num_prior_samples, -1) | |
| # reshape posterior latents to (batch_size, 1, latent_dim) | |
| posterior_z_expanded = posterior_z.unsqueeze(1) | |
| # compute distances with broadcasting so that each posterior sample | |
| # has distances to all of its prior samples | |
| distances = (prior_z_samples - posterior_z_expanded).pow(2).sum(dim=2) | |
| # then gather the closest prior sample for each posterior sample | |
| neighbors = torch.argmin(distances, dim=1) | |
| latent_subgoals = prior_z_samples[torch.arange(batch_size).long(), neighbors] | |
| return { "latent_subgoal" : latent_subgoals } | |
| def train_on_batch(self, batch, epoch, validate=False): | |
| """ | |
| Training on a single batch of data. | |
| Args: | |
| batch (dict): dictionary with torch.Tensors sampled | |
| from a data loader and filtered by @process_batch_for_training | |
| epoch (int): epoch number - required by some Algos that need | |
| to perform staged training and early stopping | |
| validate (bool): if True, don't perform any learning updates. | |
| Returns: | |
| info (dict): dictionary of relevant inputs, outputs, and losses | |
| that might be relevant for logging | |
| """ | |
| with TorchUtils.maybe_no_grad(no_grad=validate): | |
| info = super(GL, self).train_on_batch(batch, epoch, validate=validate) | |
| if self.algo_config.vae.prior.use_categorical: | |
| temperature = self.algo_config.vae.prior.categorical_init_temp - epoch * self.algo_config.vae.prior.categorical_temp_anneal_step | |
| temperature = max(temperature, self.algo_config.vae.prior.categorical_min_temp) | |
| self.nets["goal_network"].set_gumbel_temperature(temperature) | |
| # batch variables | |
| obs = batch["obs"] | |
| subgoals = batch["subgoals"] # full subgoal observations | |
| target_subgoals = batch["target_subgoals"] # targets for network prediction | |
| goal_obs = batch["goal_obs"] | |
| vae_outputs = self.nets["goal_network"]( | |
| inputs=subgoals, # encoder takes full subgoals | |
| outputs=target_subgoals, # reconstruct target subgoals | |
| goals=goal_obs, | |
| conditions=obs, # condition on observations | |
| ) | |
| recons_loss = vae_outputs["reconstruction_loss"] | |
| kl_loss = vae_outputs["kl_loss"] | |
| goal_loss = recons_loss + self.algo_config.vae.kl_weight * kl_loss | |
| info["recons_loss"] = recons_loss | |
| info["kl_loss"] = kl_loss | |
| info["goal_loss"] = goal_loss | |
| if not self.algo_config.vae.prior.use_categorical: | |
| with torch.no_grad(): | |
| info["encoder_variance"] = torch.exp(vae_outputs["encoder_params"]["logvar"]) | |
| # VAE gradient step | |
| if not validate: | |
| goal_grad_norms = TorchUtils.backprop_for_loss( | |
| net=self.nets["goal_network"], | |
| optim=self.optimizers["goal_network"], | |
| loss=goal_loss, | |
| ) | |
| info["goal_grad_norms"] = goal_grad_norms | |
| return info | |
| def log_info(self, info): | |
| """ | |
| Process info dictionary from @train_on_batch to summarize | |
| information to pass to tensorboard for logging. | |
| Args: | |
| info (dict): dictionary of info | |
| Returns: | |
| loss_log (dict): name -> summary statistic | |
| """ | |
| loss_log = super(GL_VAE, self).log_info(info) | |
| loss_log["Reconstruction_Loss"] = info["recons_loss"].item() | |
| loss_log["KL_Loss"] = info["kl_loss"].item() | |
| if self.algo_config.vae.prior.use_categorical: | |
| loss_log["Gumbel_Temperature"] = self.nets["goal_network"].get_gumbel_temperature() | |
| else: | |
| loss_log["Encoder_Variance"] = info["encoder_variance"].mean().item() | |
| return loss_log | |
| def get_subgoal_predictions(self, obs_dict, goal_dict=None): | |
| """ | |
| Takes a batch of observations and predicts a batch of subgoals. | |
| Args: | |
| obs_dict (dict): current observation | |
| goal_dict (dict): (optional) goal | |
| Returns: | |
| subgoal prediction (dict): name -> Tensor [batch_size, ...] | |
| """ | |
| if self.global_config.algo.latent_subgoal.enabled: | |
| # latent subgoals from sampling prior | |
| latent_subgoals = self.nets["goal_network"].sample_prior( | |
| conditions=obs_dict, | |
| goals=goal_dict, | |
| ) | |
| return OrderedDict(latent_subgoal=latent_subgoals) | |
| # sample a single goal from the VAE | |
| goals = self.sample_subgoals(obs_dict=obs_dict, goal_dict=goal_dict, num_samples=1) | |
| return { k : goals[k][:, 0, ...] for k in goals } | |
| def sample_subgoals(self, obs_dict, goal_dict=None, num_samples=1): | |
| """ | |
| Sample @num_samples subgoals from the VAE per observation. | |
| Args: | |
| obs_dict (dict): current observation | |
| goal_dict (dict): (optional) goal | |
| Returns: | |
| subgoals (dict): name -> Tensor [batch_size, num_samples, ...] | |
| """ | |
| # stack observations to get all samples in one forward pass | |
| obs_tiled = ObsUtils.repeat_and_stack_observation(obs_dict, n=num_samples) | |
| goal_tiled = None | |
| if goal_dict is not None: | |
| goal_tiled = ObsUtils.repeat_and_stack_observation(goal_dict, n=num_samples) | |
| # VAE decode expects number of samples explicitly | |
| mod = list(obs_tiled.keys())[0] | |
| n = obs_tiled[mod].shape[0] | |
| # [batch_size * num_samples, ...] | |
| goals = self.nets["goal_network"].decode(n=n, conditions=obs_tiled, goals=goal_tiled) | |
| # reshape to [batch_size, num_samples, ...] | |
| return TensorUtils.reshape_dimensions(goals, begin_axis=0, end_axis=0, target_dims=(-1, num_samples)) | |
| class ValuePlanner(PlannerAlgo, ValueAlgo): | |
| """ | |
| Base class for all algorithms that are used for planning subgoals | |
| based on (1) a @PlannerAlgo that is used to sample candidate subgoals | |
| and (2) a @ValueAlgo that is used to select one of the subgoals. | |
| """ | |
| def __init__( | |
| self, | |
| planner_algo_class, | |
| value_algo_class, | |
| algo_config, | |
| obs_config, | |
| global_config, | |
| obs_key_shapes, | |
| ac_dim, | |
| device, | |
| ): | |
| """ | |
| Args: | |
| planner_algo_class (Algo class): algo class for the planner | |
| value_algo_class (Algo class): algo class for the value network | |
| algo_config (Config object): instance of Config corresponding to the algo section | |
| of the config | |
| obs_config (Config object): instance of Config corresponding to the observation | |
| section of the config | |
| global_config (Config object); global config | |
| obs_key_shapes (OrderedDict): dictionary that maps input/output observation keys to shapes | |
| ac_dim (int): action dimension | |
| device: torch device | |
| """ | |
| self.algo_config = algo_config | |
| self.obs_config = obs_config | |
| self.global_config = global_config | |
| self.ac_dim = ac_dim | |
| self.device = device | |
| self.planner = planner_algo_class( | |
| algo_config=algo_config.planner, | |
| obs_config=obs_config.planner, | |
| global_config=global_config, | |
| obs_key_shapes=obs_key_shapes, | |
| ac_dim=ac_dim, | |
| device=device | |
| ) | |
| self.value_net = value_algo_class( | |
| algo_config=algo_config.value, | |
| obs_config=obs_config.value, | |
| global_config=global_config, | |
| obs_key_shapes=obs_key_shapes, | |
| ac_dim=ac_dim, | |
| device=device | |
| ) | |
| self.subgoal_shapes = self.planner.subgoal_shapes | |
| def process_batch_for_training(self, batch): | |
| """ | |
| Processes input batch from a data loader to filter out | |
| relevant information and prepare the batch for training. | |
| Args: | |
| batch (dict): dictionary with torch.Tensors sampled | |
| from a data loader | |
| Returns: | |
| input_batch (dict): processed and filtered batch that | |
| will be used for training | |
| """ | |
| input_batch = dict() | |
| input_batch["planner"] = self.planner.process_batch_for_training(batch) | |
| input_batch["value_net"] = self.value_net.process_batch_for_training(batch) | |
| # we move to device first before float conversion because image observation modalities will be uint8 - | |
| # this minimizes the amount of data transferred to GPU | |
| return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device)) | |
| def train_on_batch(self, batch, epoch, validate=False): | |
| """ | |
| Training on a single batch of data. | |
| Args: | |
| batch (dict): dictionary with torch.Tensors sampled | |
| from a data loader and filtered by @process_batch_for_training | |
| epoch (int): epoch number - required by some Algos that need | |
| to perform staged training and early stopping | |
| validate (bool): if True, don't perform any learning updates. | |
| Returns: | |
| info (dict): dictionary of relevant inputs, outputs, and losses | |
| that might be relevant for logging | |
| """ | |
| if validate: | |
| assert not self.planner.nets.training | |
| assert not self.value_net.nets.training | |
| info = dict(planner=dict(), value_net=dict()) | |
| # train planner | |
| info["planner"].update(self.planner.train_on_batch(batch["planner"], epoch, validate=validate)) | |
| # train value network | |
| info["value_net"].update(self.value_net.train_on_batch(batch["value_net"], epoch, validate=validate)) | |
| return info | |
| def log_info(self, info): | |
| """ | |
| Process info dictionary from @train_on_batch to summarize | |
| information to pass to tensorboard for logging. | |
| Args: | |
| info (dict): dictionary of info | |
| Returns: | |
| loss_log (dict): name -> summary statistic | |
| """ | |
| loss = 0. | |
| # planner | |
| planner_log = self.planner.log_info(info["planner"]) | |
| planner_log = dict(("Planner/" + k, v) for k, v in planner_log.items()) | |
| loss += planner_log["Planner/Loss"] | |
| # value network | |
| value_net_log = self.value_net.log_info(info["value_net"]) | |
| value_net_log = dict(("ValueNetwork/" + k, v) for k, v in value_net_log.items()) | |
| loss += value_net_log["ValueNetwork/Loss"] | |
| planner_log.update(value_net_log) | |
| planner_log["Loss"] = loss | |
| return planner_log | |
| def on_epoch_end(self, epoch): | |
| """ | |
| Called at the end of each epoch. | |
| """ | |
| self.planner.on_epoch_end(epoch) | |
| self.value_net.on_epoch_end(epoch) | |
| def set_eval(self): | |
| """ | |
| Prepare networks for evaluation. | |
| """ | |
| self.planner.set_eval() | |
| self.value_net.set_eval() | |
| def set_train(self): | |
| """ | |
| Prepare networks for training. | |
| """ | |
| self.planner.set_train() | |
| self.value_net.set_train() | |
| def serialize(self): | |
| """ | |
| Get dictionary of current model parameters. | |
| """ | |
| return dict( | |
| planner=self.planner.serialize(), | |
| value_net=self.value_net.serialize(), | |
| ) | |
| def deserialize(self, model_dict): | |
| """ | |
| Load model from a checkpoint. | |
| Args: | |
| model_dict (dict): a dictionary saved by self.serialize() that contains | |
| the same keys as @self.network_classes | |
| """ | |
| self.planner.deserialize(model_dict["planner"]) | |
| self.value_net.deserialize(model_dict["value_net"]) | |
| def reset(self): | |
| """ | |
| Reset algo state to prepare for environment rollouts. | |
| """ | |
| self.planner.reset() | |
| self.value_net.reset() | |
| def __repr__(self): | |
| """ | |
| Pretty print algorithm and network description. | |
| """ | |
| msg = str(self.__class__.__name__) | |
| import textwrap | |
| return msg + "Planner:\n" + textwrap.indent(self.planner.__repr__(), ' ') + \ | |
| "\n\nValue Network:\n" + textwrap.indent(self.value_net.__repr__(), ' ') | |
| def get_subgoal_predictions(self, obs_dict, goal_dict=None): | |
| """ | |
| Takes a batch of observations and predicts a batch of subgoals. | |
| Args: | |
| obs_dict (dict): current observation | |
| goal_dict (dict): (optional) goal | |
| Returns: | |
| subgoal prediction (dict): name -> Tensor [batch_size, ...] | |
| """ | |
| num_samples = self.algo_config.num_samples | |
| # sample subgoals from the planner (shape: [batch_size, num_samples, ...]) | |
| subgoals = self.sample_subgoals(obs_dict=obs_dict, goal_dict=goal_dict, num_samples=num_samples) | |
| # stack subgoals to get all values in one forward pass (shape [batch_size * num_samples, ...]) | |
| k = list(obs_dict.keys())[0] | |
| bsize = obs_dict[k].shape[0] | |
| subgoals_tiled = TensorUtils.reshape_dimensions(subgoals, begin_axis=0, end_axis=1, target_dims=(bsize * num_samples,)) | |
| # also repeat goals if necessary | |
| goal_tiled = None | |
| if len(self.planner.goal_shapes) > 0: | |
| goal_tiled = ObsUtils.repeat_and_stack_observation(goal_dict, n=num_samples) | |
| # evaluate the value of each subgoal | |
| subgoal_values = self.value_net.get_state_value(obs_dict=subgoals_tiled, goal_dict=goal_tiled).reshape(-1, num_samples) | |
| # pick the best subgoal | |
| best_index = torch.argmax(subgoal_values, dim=1) | |
| best_subgoal = {k: subgoals[k][torch.arange(bsize), best_index] for k in subgoals} | |
| return best_subgoal | |
| def sample_subgoals(self, obs_dict, goal_dict, num_samples=1): | |
| """ | |
| Sample @num_samples subgoals from the planner algo per observation. | |
| Args: | |
| obs_dict (dict): current observation | |
| goal_dict (dict): (optional) goal | |
| Returns: | |
| subgoals (dict): name -> Tensor [batch_size, num_samples, ...] | |
| """ | |
| return self.planner.sample_subgoals(obs_dict=obs_dict, goal_dict=goal_dict, num_samples=num_samples) | |
| def get_state_value(self, obs_dict, goal_dict=None): | |
| """ | |
| Get state value outputs. | |
| Args: | |
| obs_dict (dict): current observation | |
| goal_dict (dict): (optional) goal | |
| Returns: | |
| value (torch.Tensor): value tensor | |
| """ | |
| return self.value_net.get_state_value(obs_dict=obs_dict, goal_dict=goal_dict) | |
| def get_state_action_value(self, obs_dict, actions, goal_dict=None): | |
| """ | |
| Get state-action value outputs. | |
| Args: | |
| obs_dict (dict): current observation | |
| actions (torch.Tensor): action | |
| goal_dict (dict): (optional) goal | |
| Returns: | |
| value (torch.Tensor): value tensor | |
| """ | |
| return self.value_net.get_state_action_value(obs_dict=obs_dict, actions=actions, goal_dict=goal_dict) | |