Spaces:
Sleeping
Sleeping
| """ | |
| Test script for IRIS algorithms. Each test trains a variant of IRIS | |
| for a handful of gradient steps and tries one rollout with | |
| the model. Excludes stdout output by default (pass --verbose | |
| to see stdout output). | |
| """ | |
| import argparse | |
| from collections import OrderedDict | |
| import robomimic | |
| import robomimic.utils.test_utils as TestUtils | |
| from robomimic.utils.log_utils import silence_stdout | |
| from robomimic.utils.torch_utils import dummy_context_mgr | |
| def get_algo_base_config(): | |
| """ | |
| Base config for testing BCQ algorithms. | |
| """ | |
| # config with basic settings for quick training run | |
| config = TestUtils.get_base_config(algo_name="iris") | |
| # low-level obs (note that we define it here because @observation structure might vary per algorithm, | |
| # for example iris) | |
| config.observation.value_planner.planner.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"] | |
| config.observation.value_planner.planner.modalities.obs.rgb = [] | |
| config.observation.value_planner.planner.modalities.subgoal.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"] | |
| config.observation.value_planner.planner.modalities.subgoal.rgb = [] | |
| config.observation.value_planner.value.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"] | |
| config.observation.value_planner.value.modalities.obs.rgb = [] | |
| config.observation.actor.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos", "object"] | |
| config.observation.actor.modalities.obs.rgb = [] | |
| # by default, basic N(0, 1) prior for both planner VAE and BCQ cVAE | |
| config.algo.value_planner.planner.vae.enabled = True | |
| config.algo.value_planner.planner.vae.prior.learn = False | |
| config.algo.value_planner.planner.vae.prior.is_conditioned = False | |
| config.algo.value_planner.value.action_sampler.vae.enabled = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.learn = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = False | |
| return config | |
| # mapping from test name to config modifier functions | |
| MODIFIERS = OrderedDict() | |
| def register_mod(test_name): | |
| def decorator(config_modifier): | |
| MODIFIERS[test_name] = config_modifier | |
| return decorator | |
| def iris_modifier_1(config): | |
| # no-op | |
| return config | |
| def iris_modifier_2(config): | |
| # learn parameters of Gaussian prior (obs-independent) | |
| config.algo.value_planner.planner.vae.enabled = True | |
| config.algo.value_planner.planner.vae.prior.learn = True | |
| config.algo.value_planner.planner.vae.prior.is_conditioned = False | |
| config.algo.value_planner.planner.vae.prior.use_gmm = False | |
| config.algo.value_planner.planner.vae.prior.use_categorical = False | |
| return config | |
| def iris_modifier_3(config): | |
| # learn parameters of Gaussian prior (obs-dependent) | |
| config.algo.value_planner.planner.vae.enabled = True | |
| config.algo.value_planner.planner.vae.prior.learn = True | |
| config.algo.value_planner.planner.vae.prior.is_conditioned = True | |
| config.algo.value_planner.planner.vae.prior.use_gmm = False | |
| config.algo.value_planner.planner.vae.prior.use_categorical = False | |
| return config | |
| def iris_modifier_4(config): | |
| # learn parameters of GMM prior (obs-independent, weights-fixed) | |
| config.algo.value_planner.planner.vae.enabled = True | |
| config.algo.value_planner.planner.vae.prior.learn = True | |
| config.algo.value_planner.planner.vae.prior.is_conditioned = False | |
| config.algo.value_planner.planner.vae.prior.use_gmm = True | |
| config.algo.value_planner.planner.vae.prior.gmm_learn_weights = False | |
| config.algo.value_planner.planner.vae.prior.use_categorical = False | |
| return config | |
| def iris_modifier_5(config): | |
| # learn parameters of GMM prior (obs-independent, weights-learned) | |
| config.algo.value_planner.planner.vae.enabled = True | |
| config.algo.value_planner.planner.vae.prior.learn = True | |
| config.algo.value_planner.planner.vae.prior.is_conditioned = False | |
| config.algo.value_planner.planner.vae.prior.use_gmm = True | |
| config.algo.value_planner.planner.vae.prior.gmm_learn_weights = True | |
| config.algo.value_planner.planner.vae.prior.use_categorical = False | |
| return config | |
| def iris_modifier_6(config): | |
| # learn parameters of GMM prior (obs-dependent, weights-fixed) | |
| config.algo.value_planner.planner.vae.enabled = True | |
| config.algo.value_planner.planner.vae.prior.learn = True | |
| config.algo.value_planner.planner.vae.prior.is_conditioned = True | |
| config.algo.value_planner.planner.vae.prior.use_gmm = True | |
| config.algo.value_planner.planner.vae.prior.gmm_learn_weights = False | |
| config.algo.value_planner.planner.vae.prior.use_categorical = False | |
| return config | |
| def iris_modifier_7(config): | |
| # learn parameters of GMM prior (obs-dependent, weights-learned) | |
| config.algo.value_planner.planner.vae.enabled = True | |
| config.algo.value_planner.planner.vae.prior.learn = True | |
| config.algo.value_planner.planner.vae.prior.is_conditioned = True | |
| config.algo.value_planner.planner.vae.prior.use_gmm = True | |
| config.algo.value_planner.planner.vae.prior.gmm_learn_weights = True | |
| config.algo.value_planner.planner.vae.prior.use_categorical = False | |
| return config | |
| def iris_modifier_8(config): | |
| # uniform categorical prior | |
| config.algo.value_planner.planner.vae.enabled = True | |
| config.algo.value_planner.planner.vae.prior.learn = False | |
| config.algo.value_planner.planner.vae.prior.is_conditioned = False | |
| config.algo.value_planner.planner.vae.prior.use_gmm = False | |
| config.algo.value_planner.planner.vae.prior.use_categorical = True | |
| return config | |
| def iris_modifier_9(config): | |
| # learn parameters of categorical prior (obs-independent) | |
| config.algo.value_planner.planner.vae.enabled = True | |
| config.algo.value_planner.planner.vae.prior.learn = True | |
| config.algo.value_planner.planner.vae.prior.is_conditioned = False | |
| config.algo.value_planner.planner.vae.prior.use_gmm = False | |
| config.algo.value_planner.planner.vae.prior.use_categorical = True | |
| return config | |
| def iris_modifier_10(config): | |
| # learn parameters of categorical prior (obs-dependent) | |
| config.algo.value_planner.planner.vae.enabled = True | |
| config.algo.value_planner.planner.vae.prior.learn = True | |
| config.algo.value_planner.planner.vae.prior.is_conditioned = True | |
| config.algo.value_planner.planner.vae.prior.use_gmm = False | |
| config.algo.value_planner.planner.vae.prior.use_categorical = True | |
| return config | |
| def iris_modifier_11(config): | |
| # bcq action sampler is GMM | |
| config.algo.value_planner.value.action_sampler.gmm.enabled = True | |
| config.algo.value_planner.value.action_sampler.vae.enabled = False | |
| return config | |
| def iris_modifier_12(config): | |
| # bcq value function is distributional | |
| config.algo.value_planner.value.critic.distributional.enabled = True | |
| config.algo.value_planner.value.critic.value_bounds = [-100., 100.] | |
| return config | |
| def iris_modifier_13(config): | |
| # learn parameters of Gaussian prior (obs-independent) | |
| config.algo.value_planner.value.action_sampler.vae.enabled = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.learn = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = False | |
| return config | |
| def iris_modifier_14(config): | |
| # learn parameters of Gaussian prior (obs-dependent) | |
| config.algo.value_planner.value.action_sampler.vae.enabled = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.learn = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = False | |
| return config | |
| def iris_modifier_15(config): | |
| # learn parameters of GMM prior (obs-independent, weights-fixed) | |
| config.algo.value_planner.value.action_sampler.vae.enabled = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.learn = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.gmm_learn_weights = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = False | |
| return config | |
| def iris_modifier_16(config): | |
| # learn parameters of GMM prior (obs-independent, weights-learned) | |
| config.algo.value_planner.value.action_sampler.vae.enabled = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.learn = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.gmm_learn_weights = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = False | |
| return config | |
| def iris_modifier_17(config): | |
| # learn parameters of GMM prior (obs-dependent, weights-fixed) | |
| config.algo.value_planner.value.action_sampler.vae.enabled = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.learn = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.gmm_learn_weights = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = False | |
| return config | |
| def iris_modifier_18(config): | |
| # learn parameters of GMM prior (obs-dependent, weights-learned) | |
| config.algo.value_planner.value.action_sampler.vae.enabled = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.learn = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.gmm_learn_weights = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = False | |
| return config | |
| def iris_modifier_19(config): | |
| # uniform categorical prior | |
| config.algo.value_planner.value.action_sampler.vae.enabled = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.learn = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = True | |
| return config | |
| def iris_modifier_20(config): | |
| # learn parameters of categorical prior (obs-independent) | |
| config.algo.value_planner.value.action_sampler.vae.enabled = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.learn = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = True | |
| return config | |
| def iris_modifier_21(config): | |
| # learn parameters of categorical prior (obs-dependent) | |
| config.algo.value_planner.value.action_sampler.vae.enabled = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.learn = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.is_conditioned = True | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_gmm = False | |
| config.algo.value_planner.value.action_sampler.vae.prior.use_categorical = True | |
| return config | |
| def test_iris(silence=True): | |
| for test_name in MODIFIERS: | |
| context = silence_stdout() if silence else dummy_context_mgr() | |
| with context: | |
| base_config = get_algo_base_config() | |
| res_str = TestUtils.test_run(base_config=base_config, config_modifier=MODIFIERS[test_name]) | |
| print("{}: {}".format(test_name, res_str)) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--verbose", | |
| action='store_true', | |
| help="don't suppress stdout during tests", | |
| ) | |
| args = parser.parse_args() | |
| test_iris(silence=(not args.verbose)) | |