File size: 6,719 Bytes
96da58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""
Test script for HBC algorithm. Each test trains a variant of HBC
for a handful of gradient steps and tries one rollout with 
the model. Excludes stdout output by default (pass --verbose
to see stdout output).
"""
import argparse
from collections import OrderedDict

import robomimic
import robomimic.utils.test_utils as TestUtils
from robomimic.utils.log_utils import silence_stdout
from robomimic.utils.torch_utils import dummy_context_mgr


def get_algo_base_config():
    """
    Base config for testing BCQ algorithms.
    """

    # config with basic settings for quick training run
    config = TestUtils.get_base_config(algo_name="hbc")

    # low-level obs (note that we define it here because @observation structure might vary per algorithm, 
    # for example HBC)
    config.observation.planner.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
    config.observation.planner.modalities.obs.rgb = []

    config.observation.planner.modalities.subgoal.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
    config.observation.planner.modalities.subgoal.rgb = []

    config.observation.actor.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"]
    config.observation.actor.modalities.obs.rgb = []

    # by default, planner is deterministic prediction
    config.algo.planner.vae.enabled = False

    return config


# mapping from test name to config modifier functions
MODIFIERS = OrderedDict()
def register_mod(test_name):
    def decorator(config_modifier):
        MODIFIERS[test_name] = config_modifier
    return decorator


@register_mod("hbc")
def hbc_modifier(config):
    # no-op
    return config


@register_mod("hbc-vae, N(0, 1) prior")
def hbc_vae_modifier_1(config):
    config.algo.planner.vae.enabled = True
    config.algo.planner.vae.prior.learn = False
    config.algo.planner.vae.prior.is_conditioned = False
    return config


@register_mod("hbc-vae, Gaussian prior (obs-independent)")
def hbc_vae_modifier_2(config):
    # learn parameters of Gaussian prior (obs-independent)
    config.algo.planner.vae.enabled = True
    config.algo.planner.vae.prior.learn = True
    config.algo.planner.vae.prior.is_conditioned = False
    config.algo.planner.vae.prior.use_gmm = False
    config.algo.planner.vae.prior.use_categorical = False
    return config


@register_mod("hbc-vae, Gaussian prior (obs-dependent)")
def hbc_vae_modifier_3(config):
    # learn parameters of Gaussian prior (obs-dependent)
    config.algo.planner.vae.enabled = True
    config.algo.planner.vae.prior.learn = True
    config.algo.planner.vae.prior.is_conditioned = True
    config.algo.planner.vae.prior.use_gmm = False
    config.algo.planner.vae.prior.use_categorical = False
    return config


@register_mod("hbc-vae, GMM prior (obs-independent, weights-fixed)")
def hbc_vae_modifier_4(config):
    # learn parameters of GMM prior (obs-independent, weights-fixed)
    config.algo.planner.vae.enabled = True
    config.algo.planner.vae.prior.learn = True
    config.algo.planner.vae.prior.is_conditioned = False
    config.algo.planner.vae.prior.use_gmm = True
    config.algo.planner.vae.prior.gmm_learn_weights = False
    config.algo.planner.vae.prior.use_categorical = False
    return config


@register_mod("hbc-vae, GMM prior (obs-independent, weights-learned)")
def hbc_vae_modifier_5(config):
    # learn parameters of GMM prior (obs-independent, weights-learned)
    config.algo.planner.vae.enabled = True
    config.algo.planner.vae.prior.learn = True
    config.algo.planner.vae.prior.is_conditioned = False
    config.algo.planner.vae.prior.use_gmm = True
    config.algo.planner.vae.prior.gmm_learn_weights = True
    config.algo.planner.vae.prior.use_categorical = False
    return config


@register_mod("hbc-vae, GMM prior (obs-dependent, weights-fixed)")
def hbc_vae_modifier_6(config):
    # learn parameters of GMM prior (obs-dependent, weights-fixed)
    config.algo.planner.vae.enabled = True
    config.algo.planner.vae.prior.learn = True
    config.algo.planner.vae.prior.is_conditioned = True
    config.algo.planner.vae.prior.use_gmm = True
    config.algo.planner.vae.prior.gmm_learn_weights = False
    config.algo.planner.vae.prior.use_categorical = False
    return config


@register_mod("hbc-vae, GMM prior (obs-dependent, weights-learned)")
def hbc_vae_modifier_7(config):
    # learn parameters of GMM prior (obs-dependent, weights-learned)
    config.algo.planner.vae.enabled = True
    config.algo.planner.vae.prior.learn = True
    config.algo.planner.vae.prior.is_conditioned = True
    config.algo.planner.vae.prior.use_gmm = True
    config.algo.planner.vae.prior.gmm_learn_weights = True
    config.algo.planner.vae.prior.use_categorical = False
    return config


@register_mod("hbc-vae, uniform categorical prior")
def hbc_vae_modifier_8(config):
    # uniform categorical prior
    config.algo.planner.vae.enabled = True
    config.algo.planner.vae.prior.learn = False
    config.algo.planner.vae.prior.is_conditioned = False
    config.algo.planner.vae.prior.use_gmm = False
    config.algo.planner.vae.prior.use_categorical = True
    return config


@register_mod("hbc-vae, categorical prior (obs-independent)")
def hbc_vae_modifier_9(config):
    # learn parameters of categorical prior (obs-independent)
    config.algo.planner.vae.enabled = True
    config.algo.planner.vae.prior.learn = True
    config.algo.planner.vae.prior.is_conditioned = False
    config.algo.planner.vae.prior.use_gmm = False
    config.algo.planner.vae.prior.use_categorical = True
    return config


@register_mod("hbc-vae, categorical prior (obs-dependent)")
def hbc_vae_modifier_10(config):
    # learn parameters of categorical prior (obs-dependent)
    config.algo.planner.vae.enabled = True
    config.algo.planner.vae.prior.learn = True
    config.algo.planner.vae.prior.is_conditioned = True
    config.algo.planner.vae.prior.use_gmm = False
    config.algo.planner.vae.prior.use_categorical = True
    return config


def test_hbc(silence=True):
    for test_name in MODIFIERS:
        context = silence_stdout() if silence else dummy_context_mgr()
        with context:
            base_config = get_algo_base_config()
            res_str = TestUtils.test_run(base_config=base_config, config_modifier=MODIFIERS[test_name])
        print("{}: {}".format(test_name, res_str))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--verbose",
        action='store_true',
        help="don't suppress stdout during tests",
    )
    args = parser.parse_args()

    test_hbc(silence=(not args.verbose))