Spaces:
Sleeping
Sleeping
| """ | |
| Implementation of IRIS (https://arxiv.org/abs/1911.05321). | |
| """ | |
| import numpy as np | |
| from collections import OrderedDict | |
| from copy import deepcopy | |
| import torch | |
| import robomimic.utils.tensor_utils as TensorUtils | |
| import robomimic.utils.obs_utils as ObsUtils | |
| from robomimic.config.config import Config | |
| from robomimic.algo import register_algo_factory_func, algo_name_to_factory_func, HBC, ValuePlanner, ValueAlgo, GL_VAE | |
| def algo_config_to_class(algo_config): | |
| """ | |
| Maps algo config to the IRIS algo class to instantiate, along with additional algo kwargs. | |
| Args: | |
| algo_config (Config instance): algo config | |
| Returns: | |
| algo_class: subclass of Algo | |
| algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm | |
| """ | |
| pol_cls, _ = algo_name_to_factory_func("bc")(algo_config.actor) | |
| plan_cls, _ = algo_name_to_factory_func("gl")(algo_config.value_planner.planner) | |
| value_cls, _ = algo_name_to_factory_func("bcq")(algo_config.value_planner.value) | |
| return IRIS, dict(policy_algo_class=pol_cls, planner_algo_class=plan_cls, value_algo_class=value_cls) | |
| class IRIS(HBC, ValueAlgo): | |
| """ | |
| Implementation of IRIS (https://arxiv.org/abs/1911.05321). | |
| """ | |
| def __init__( | |
| self, | |
| planner_algo_class, | |
| value_algo_class, | |
| policy_algo_class, | |
| algo_config, | |
| obs_config, | |
| global_config, | |
| obs_key_shapes, | |
| ac_dim, | |
| device, | |
| ): | |
| """ | |
| Args: | |
| planner_algo_class (Algo class): algo class for the planner | |
| policy_algo_class (Algo class): algo class for the policy | |
| algo_config (Config object): instance of Config corresponding to the algo section | |
| of the config | |
| obs_config (Config object): instance of Config corresponding to the observation | |
| section of the config | |
| global_config (Config object): global training config | |
| obs_key_shapes (OrderedDict): dictionary that maps input/output observation keys to shapes | |
| ac_dim (int): action dimension | |
| device: torch device | |
| """ | |
| self.algo_config = algo_config | |
| self.obs_config = obs_config | |
| self.global_config = global_config | |
| self.ac_dim = ac_dim | |
| self.device = device | |
| self._subgoal_step_count = 0 # current step count for deciding when to update subgoal | |
| self._current_subgoal = None # latest subgoal | |
| self._subgoal_update_interval = self.algo_config.subgoal_update_interval # subgoal update frequency | |
| self._subgoal_horizon = self.algo_config.value_planner.planner.subgoal_horizon | |
| self._actor_horizon = self.algo_config.actor.rnn.horizon | |
| self._algo_mode = self.algo_config.mode | |
| assert self._algo_mode in ["separate", "cascade"] | |
| self.planner = ValuePlanner( | |
| planner_algo_class=planner_algo_class, | |
| value_algo_class=value_algo_class, | |
| algo_config=algo_config.value_planner, | |
| obs_config=obs_config.value_planner, | |
| global_config=global_config, | |
| obs_key_shapes=obs_key_shapes, | |
| ac_dim=ac_dim, | |
| device=device | |
| ) | |
| self.actor_goal_shapes = self.planner.subgoal_shapes | |
| assert not algo_config.latent_subgoal.enabled, "IRIS does not support latent subgoals" | |
| # only for the actor: override goal modalities and shapes to match the subgoal set by the planner | |
| actor_obs_key_shapes = deepcopy(obs_key_shapes) | |
| # make sure we are not modifying existing observation key shapes | |
| for k in self.actor_goal_shapes: | |
| if k in actor_obs_key_shapes: | |
| assert actor_obs_key_shapes[k] == self.actor_goal_shapes[k] | |
| actor_obs_key_shapes.update(self.actor_goal_shapes) | |
| goal_modalities = {obs_modality: [] for obs_modality in ObsUtils.OBS_MODALITY_CLASSES.keys()} | |
| for k in self.actor_goal_shapes.keys(): | |
| goal_modalities[ObsUtils.OBS_KEYS_TO_MODALITIES[k]].append(k) | |
| actor_obs_config = deepcopy(obs_config.actor) | |
| with actor_obs_config.unlocked(): | |
| actor_obs_config["goal"] = Config(**goal_modalities) | |
| self.actor = policy_algo_class( | |
| algo_config=algo_config.actor, | |
| obs_config=actor_obs_config, | |
| global_config=global_config, | |
| obs_key_shapes=actor_obs_key_shapes, | |
| ac_dim=ac_dim, | |
| device=device | |
| ) | |
| def process_batch_for_training(self, batch): | |
| """ | |
| Processes input batch from a data loader to filter out | |
| relevant information and prepare the batch for training. | |
| Args: | |
| batch (dict): dictionary with torch.Tensors sampled | |
| from a data loader | |
| Returns: | |
| input_batch (dict): processed and filtered batch that | |
| will be used for training | |
| """ | |
| input_batch = dict() | |
| input_batch["planner"] = self.planner.process_batch_for_training(batch) | |
| input_batch["actor"] = self.actor.process_batch_for_training(batch) | |
| if self.algo_config.actor_use_random_subgoals: | |
| # optionally use randomly sampled step between [1, seq_length] as policy goal | |
| policy_subgoal_indices = torch.randint( | |
| low=0, high=self.global_config.train.seq_length, size=(batch["actions"].shape[0],)) | |
| goal_obs = TensorUtils.gather_sequence(batch["next_obs"], policy_subgoal_indices) | |
| goal_obs = TensorUtils.to_float(TensorUtils.to_device(goal_obs, self.device)) | |
| input_batch["actor"]["goal_obs"] = goal_obs | |
| else: | |
| # otherwise, use planner subgoal target as goal for the policy | |
| input_batch["actor"]["goal_obs"] = input_batch["planner"]["planner"]["target_subgoals"] | |
| # we move to device first before float conversion because image observation modalities will be uint8 - | |
| # this minimizes the amount of data transferred to GPU | |
| return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device)) | |
| def get_state_value(self, obs_dict, goal_dict=None): | |
| """ | |
| Get state value outputs. | |
| Args: | |
| obs_dict (dict): current observation | |
| goal_dict (dict): (optional) goal | |
| Returns: | |
| value (torch.Tensor): value tensor | |
| """ | |
| return self.planner.get_state_value(obs_dict=obs_dict, goal_dict=goal_dict) | |
| def get_state_action_value(self, obs_dict, actions, goal_dict=None): | |
| """ | |
| Get state-action value outputs. | |
| Args: | |
| obs_dict (dict): current observation | |
| actions (torch.Tensor): action | |
| goal_dict (dict): (optional) goal | |
| Returns: | |
| value (torch.Tensor): value tensor | |
| """ | |
| return self.planner.get_state_action_value(obs_dict=obs_dict, actions=actions, goal_dict=goal_dict) | |