xfu314's picture
Add phantom project with submodules and dependencies
96da58e
"""
Implementation of IRIS (https://arxiv.org/abs/1911.05321).
"""
import numpy as np
from collections import OrderedDict
from copy import deepcopy
import torch
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.config.config import Config
from robomimic.algo import register_algo_factory_func, algo_name_to_factory_func, HBC, ValuePlanner, ValueAlgo, GL_VAE
@register_algo_factory_func("iris")
def algo_config_to_class(algo_config):
"""
Maps algo config to the IRIS algo class to instantiate, along with additional algo kwargs.
Args:
algo_config (Config instance): algo config
Returns:
algo_class: subclass of Algo
algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
"""
pol_cls, _ = algo_name_to_factory_func("bc")(algo_config.actor)
plan_cls, _ = algo_name_to_factory_func("gl")(algo_config.value_planner.planner)
value_cls, _ = algo_name_to_factory_func("bcq")(algo_config.value_planner.value)
return IRIS, dict(policy_algo_class=pol_cls, planner_algo_class=plan_cls, value_algo_class=value_cls)
class IRIS(HBC, ValueAlgo):
"""
Implementation of IRIS (https://arxiv.org/abs/1911.05321).
"""
def __init__(
self,
planner_algo_class,
value_algo_class,
policy_algo_class,
algo_config,
obs_config,
global_config,
obs_key_shapes,
ac_dim,
device,
):
"""
Args:
planner_algo_class (Algo class): algo class for the planner
policy_algo_class (Algo class): algo class for the policy
algo_config (Config object): instance of Config corresponding to the algo section
of the config
obs_config (Config object): instance of Config corresponding to the observation
section of the config
global_config (Config object): global training config
obs_key_shapes (OrderedDict): dictionary that maps input/output observation keys to shapes
ac_dim (int): action dimension
device: torch device
"""
self.algo_config = algo_config
self.obs_config = obs_config
self.global_config = global_config
self.ac_dim = ac_dim
self.device = device
self._subgoal_step_count = 0 # current step count for deciding when to update subgoal
self._current_subgoal = None # latest subgoal
self._subgoal_update_interval = self.algo_config.subgoal_update_interval # subgoal update frequency
self._subgoal_horizon = self.algo_config.value_planner.planner.subgoal_horizon
self._actor_horizon = self.algo_config.actor.rnn.horizon
self._algo_mode = self.algo_config.mode
assert self._algo_mode in ["separate", "cascade"]
self.planner = ValuePlanner(
planner_algo_class=planner_algo_class,
value_algo_class=value_algo_class,
algo_config=algo_config.value_planner,
obs_config=obs_config.value_planner,
global_config=global_config,
obs_key_shapes=obs_key_shapes,
ac_dim=ac_dim,
device=device
)
self.actor_goal_shapes = self.planner.subgoal_shapes
assert not algo_config.latent_subgoal.enabled, "IRIS does not support latent subgoals"
# only for the actor: override goal modalities and shapes to match the subgoal set by the planner
actor_obs_key_shapes = deepcopy(obs_key_shapes)
# make sure we are not modifying existing observation key shapes
for k in self.actor_goal_shapes:
if k in actor_obs_key_shapes:
assert actor_obs_key_shapes[k] == self.actor_goal_shapes[k]
actor_obs_key_shapes.update(self.actor_goal_shapes)
goal_modalities = {obs_modality: [] for obs_modality in ObsUtils.OBS_MODALITY_CLASSES.keys()}
for k in self.actor_goal_shapes.keys():
goal_modalities[ObsUtils.OBS_KEYS_TO_MODALITIES[k]].append(k)
actor_obs_config = deepcopy(obs_config.actor)
with actor_obs_config.unlocked():
actor_obs_config["goal"] = Config(**goal_modalities)
self.actor = policy_algo_class(
algo_config=algo_config.actor,
obs_config=actor_obs_config,
global_config=global_config,
obs_key_shapes=actor_obs_key_shapes,
ac_dim=ac_dim,
device=device
)
def process_batch_for_training(self, batch):
"""
Processes input batch from a data loader to filter out
relevant information and prepare the batch for training.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader
Returns:
input_batch (dict): processed and filtered batch that
will be used for training
"""
input_batch = dict()
input_batch["planner"] = self.planner.process_batch_for_training(batch)
input_batch["actor"] = self.actor.process_batch_for_training(batch)
if self.algo_config.actor_use_random_subgoals:
# optionally use randomly sampled step between [1, seq_length] as policy goal
policy_subgoal_indices = torch.randint(
low=0, high=self.global_config.train.seq_length, size=(batch["actions"].shape[0],))
goal_obs = TensorUtils.gather_sequence(batch["next_obs"], policy_subgoal_indices)
goal_obs = TensorUtils.to_float(TensorUtils.to_device(goal_obs, self.device))
input_batch["actor"]["goal_obs"] = goal_obs
else:
# otherwise, use planner subgoal target as goal for the policy
input_batch["actor"]["goal_obs"] = input_batch["planner"]["planner"]["target_subgoals"]
# we move to device first before float conversion because image observation modalities will be uint8 -
# this minimizes the amount of data transferred to GPU
return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
def get_state_value(self, obs_dict, goal_dict=None):
"""
Get state value outputs.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
value (torch.Tensor): value tensor
"""
return self.planner.get_state_value(obs_dict=obs_dict, goal_dict=goal_dict)
def get_state_action_value(self, obs_dict, actions, goal_dict=None):
"""
Get state-action value outputs.
Args:
obs_dict (dict): current observation
actions (torch.Tensor): action
goal_dict (dict): (optional) goal
Returns:
value (torch.Tensor): value tensor
"""
return self.planner.get_state_action_value(obs_dict=obs_dict, actions=actions, goal_dict=goal_dict)