Spaces:
Running on Zero
Running on Zero
File size: 6,993 Bytes
96da58e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | """
Implementation of IRIS (https://arxiv.org/abs/1911.05321).
"""
import numpy as np
from collections import OrderedDict
from copy import deepcopy
import torch
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.config.config import Config
from robomimic.algo import register_algo_factory_func, algo_name_to_factory_func, HBC, ValuePlanner, ValueAlgo, GL_VAE
@register_algo_factory_func("iris")
def algo_config_to_class(algo_config):
"""
Maps algo config to the IRIS algo class to instantiate, along with additional algo kwargs.
Args:
algo_config (Config instance): algo config
Returns:
algo_class: subclass of Algo
algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
"""
pol_cls, _ = algo_name_to_factory_func("bc")(algo_config.actor)
plan_cls, _ = algo_name_to_factory_func("gl")(algo_config.value_planner.planner)
value_cls, _ = algo_name_to_factory_func("bcq")(algo_config.value_planner.value)
return IRIS, dict(policy_algo_class=pol_cls, planner_algo_class=plan_cls, value_algo_class=value_cls)
class IRIS(HBC, ValueAlgo):
"""
Implementation of IRIS (https://arxiv.org/abs/1911.05321).
"""
def __init__(
self,
planner_algo_class,
value_algo_class,
policy_algo_class,
algo_config,
obs_config,
global_config,
obs_key_shapes,
ac_dim,
device,
):
"""
Args:
planner_algo_class (Algo class): algo class for the planner
policy_algo_class (Algo class): algo class for the policy
algo_config (Config object): instance of Config corresponding to the algo section
of the config
obs_config (Config object): instance of Config corresponding to the observation
section of the config
global_config (Config object): global training config
obs_key_shapes (OrderedDict): dictionary that maps input/output observation keys to shapes
ac_dim (int): action dimension
device: torch device
"""
self.algo_config = algo_config
self.obs_config = obs_config
self.global_config = global_config
self.ac_dim = ac_dim
self.device = device
self._subgoal_step_count = 0 # current step count for deciding when to update subgoal
self._current_subgoal = None # latest subgoal
self._subgoal_update_interval = self.algo_config.subgoal_update_interval # subgoal update frequency
self._subgoal_horizon = self.algo_config.value_planner.planner.subgoal_horizon
self._actor_horizon = self.algo_config.actor.rnn.horizon
self._algo_mode = self.algo_config.mode
assert self._algo_mode in ["separate", "cascade"]
self.planner = ValuePlanner(
planner_algo_class=planner_algo_class,
value_algo_class=value_algo_class,
algo_config=algo_config.value_planner,
obs_config=obs_config.value_planner,
global_config=global_config,
obs_key_shapes=obs_key_shapes,
ac_dim=ac_dim,
device=device
)
self.actor_goal_shapes = self.planner.subgoal_shapes
assert not algo_config.latent_subgoal.enabled, "IRIS does not support latent subgoals"
# only for the actor: override goal modalities and shapes to match the subgoal set by the planner
actor_obs_key_shapes = deepcopy(obs_key_shapes)
# make sure we are not modifying existing observation key shapes
for k in self.actor_goal_shapes:
if k in actor_obs_key_shapes:
assert actor_obs_key_shapes[k] == self.actor_goal_shapes[k]
actor_obs_key_shapes.update(self.actor_goal_shapes)
goal_modalities = {obs_modality: [] for obs_modality in ObsUtils.OBS_MODALITY_CLASSES.keys()}
for k in self.actor_goal_shapes.keys():
goal_modalities[ObsUtils.OBS_KEYS_TO_MODALITIES[k]].append(k)
actor_obs_config = deepcopy(obs_config.actor)
with actor_obs_config.unlocked():
actor_obs_config["goal"] = Config(**goal_modalities)
self.actor = policy_algo_class(
algo_config=algo_config.actor,
obs_config=actor_obs_config,
global_config=global_config,
obs_key_shapes=actor_obs_key_shapes,
ac_dim=ac_dim,
device=device
)
def process_batch_for_training(self, batch):
"""
Processes input batch from a data loader to filter out
relevant information and prepare the batch for training.
Args:
batch (dict): dictionary with torch.Tensors sampled
from a data loader
Returns:
input_batch (dict): processed and filtered batch that
will be used for training
"""
input_batch = dict()
input_batch["planner"] = self.planner.process_batch_for_training(batch)
input_batch["actor"] = self.actor.process_batch_for_training(batch)
if self.algo_config.actor_use_random_subgoals:
# optionally use randomly sampled step between [1, seq_length] as policy goal
policy_subgoal_indices = torch.randint(
low=0, high=self.global_config.train.seq_length, size=(batch["actions"].shape[0],))
goal_obs = TensorUtils.gather_sequence(batch["next_obs"], policy_subgoal_indices)
goal_obs = TensorUtils.to_float(TensorUtils.to_device(goal_obs, self.device))
input_batch["actor"]["goal_obs"] = goal_obs
else:
# otherwise, use planner subgoal target as goal for the policy
input_batch["actor"]["goal_obs"] = input_batch["planner"]["planner"]["target_subgoals"]
# we move to device first before float conversion because image observation modalities will be uint8 -
# this minimizes the amount of data transferred to GPU
return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))
def get_state_value(self, obs_dict, goal_dict=None):
"""
Get state value outputs.
Args:
obs_dict (dict): current observation
goal_dict (dict): (optional) goal
Returns:
value (torch.Tensor): value tensor
"""
return self.planner.get_state_value(obs_dict=obs_dict, goal_dict=goal_dict)
def get_state_action_value(self, obs_dict, actions, goal_dict=None):
"""
Get state-action value outputs.
Args:
obs_dict (dict): current observation
actions (torch.Tensor): action
goal_dict (dict): (optional) goal
Returns:
value (torch.Tensor): value tensor
"""
return self.planner.get_state_action_value(obs_dict=obs_dict, actions=actions, goal_dict=goal_dict)
|