File size: 6,993 Bytes
96da58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""
Implementation of IRIS (https://arxiv.org/abs/1911.05321).
"""
import numpy as np
from collections import OrderedDict
from copy import deepcopy

import torch

import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.config.config import Config
from robomimic.algo import register_algo_factory_func, algo_name_to_factory_func, HBC, ValuePlanner, ValueAlgo, GL_VAE


@register_algo_factory_func("iris")
def algo_config_to_class(algo_config):
    """
    Maps algo config to the IRIS algo class to instantiate, along with additional algo kwargs.

    Args:
        algo_config (Config instance): algo config

    Returns:
        algo_class: subclass of Algo
        algo_kwargs (dict): dictionary of additional kwargs to pass to algorithm
    """
    pol_cls, _ = algo_name_to_factory_func("bc")(algo_config.actor)
    plan_cls, _ = algo_name_to_factory_func("gl")(algo_config.value_planner.planner)
    value_cls, _ = algo_name_to_factory_func("bcq")(algo_config.value_planner.value)
    return IRIS, dict(policy_algo_class=pol_cls, planner_algo_class=plan_cls, value_algo_class=value_cls)


class IRIS(HBC, ValueAlgo):
    """
    Implementation of IRIS (https://arxiv.org/abs/1911.05321).
    """
    def __init__(
        self,
        planner_algo_class,
        value_algo_class,
        policy_algo_class,
        algo_config,
        obs_config,
        global_config,
        obs_key_shapes,
        ac_dim,
        device,
    ):
        """
        Args:
            planner_algo_class (Algo class): algo class for the planner

            policy_algo_class (Algo class): algo class for the policy

            algo_config (Config object): instance of Config corresponding to the algo section
                of the config

            obs_config (Config object): instance of Config corresponding to the observation
                section of the config

            global_config (Config object): global training config

            obs_key_shapes (OrderedDict): dictionary that maps input/output observation keys to shapes

            ac_dim (int): action dimension

            device: torch device
        """
        self.algo_config = algo_config
        self.obs_config = obs_config
        self.global_config = global_config

        self.ac_dim = ac_dim
        self.device = device

        self._subgoal_step_count = 0  # current step count for deciding when to update subgoal
        self._current_subgoal = None  # latest subgoal
        self._subgoal_update_interval = self.algo_config.subgoal_update_interval  # subgoal update frequency
        self._subgoal_horizon = self.algo_config.value_planner.planner.subgoal_horizon
        self._actor_horizon = self.algo_config.actor.rnn.horizon

        self._algo_mode = self.algo_config.mode
        assert self._algo_mode in ["separate", "cascade"]

        self.planner = ValuePlanner(
            planner_algo_class=planner_algo_class,
            value_algo_class=value_algo_class,
            algo_config=algo_config.value_planner,
            obs_config=obs_config.value_planner,
            global_config=global_config,
            obs_key_shapes=obs_key_shapes,
            ac_dim=ac_dim,
            device=device
        )

        self.actor_goal_shapes = self.planner.subgoal_shapes
        assert not algo_config.latent_subgoal.enabled, "IRIS does not support latent subgoals"

        # only for the actor: override goal modalities and shapes to match the subgoal set by the planner
        actor_obs_key_shapes = deepcopy(obs_key_shapes)
        # make sure we are not modifying existing observation key shapes
        for k in self.actor_goal_shapes:
            if k in actor_obs_key_shapes:
                assert actor_obs_key_shapes[k] == self.actor_goal_shapes[k]
        actor_obs_key_shapes.update(self.actor_goal_shapes)

        goal_modalities = {obs_modality: [] for obs_modality in ObsUtils.OBS_MODALITY_CLASSES.keys()}
        for k in self.actor_goal_shapes.keys():
            goal_modalities[ObsUtils.OBS_KEYS_TO_MODALITIES[k]].append(k)

        actor_obs_config = deepcopy(obs_config.actor)
        with actor_obs_config.unlocked():
            actor_obs_config["goal"] = Config(**goal_modalities)

        self.actor = policy_algo_class(
            algo_config=algo_config.actor,
            obs_config=actor_obs_config,
            global_config=global_config,
            obs_key_shapes=actor_obs_key_shapes,
            ac_dim=ac_dim,
            device=device
        )

    def process_batch_for_training(self, batch):
        """
        Processes input batch from a data loader to filter out
        relevant information and prepare the batch for training.

        Args:
            batch (dict): dictionary with torch.Tensors sampled
                from a data loader

        Returns:
            input_batch (dict): processed and filtered batch that
                will be used for training 
        """
        input_batch = dict()

        input_batch["planner"] = self.planner.process_batch_for_training(batch)
        input_batch["actor"] = self.actor.process_batch_for_training(batch)

        if self.algo_config.actor_use_random_subgoals:
            # optionally use randomly sampled step between [1, seq_length] as policy goal
            policy_subgoal_indices = torch.randint(
                low=0, high=self.global_config.train.seq_length, size=(batch["actions"].shape[0],))
            goal_obs = TensorUtils.gather_sequence(batch["next_obs"], policy_subgoal_indices)
            goal_obs = TensorUtils.to_float(TensorUtils.to_device(goal_obs, self.device))
            input_batch["actor"]["goal_obs"] = goal_obs
        else:
            # otherwise, use planner subgoal target as goal for the policy
            input_batch["actor"]["goal_obs"] = input_batch["planner"]["planner"]["target_subgoals"]

        # we move to device first before float conversion because image observation modalities will be uint8 -
        # this minimizes the amount of data transferred to GPU
        return TensorUtils.to_float(TensorUtils.to_device(input_batch, self.device))

    def get_state_value(self, obs_dict, goal_dict=None):
        """
        Get state value outputs.

        Args:
            obs_dict (dict): current observation
            goal_dict (dict): (optional) goal

        Returns:
            value (torch.Tensor): value tensor
        """
        return self.planner.get_state_value(obs_dict=obs_dict, goal_dict=goal_dict)

    def get_state_action_value(self, obs_dict, actions, goal_dict=None):
        """
        Get state-action value outputs.

        Args:
            obs_dict (dict): current observation
            actions (torch.Tensor): action
            goal_dict (dict): (optional) goal

        Returns:
            value (torch.Tensor): value tensor
        """
        return self.planner.get_state_action_value(obs_dict=obs_dict, actions=actions, goal_dict=goal_dict)