Spaces:
Sleeping
Sleeping
| """ | |
| Tests for a handful of scripts. Excludes stdout output by | |
| default (pass --verbose to see stdout output). | |
| """ | |
| import argparse | |
| import traceback | |
| import h5py | |
| import numpy as np | |
| import torch | |
| from collections import OrderedDict | |
| from termcolor import colored | |
| import robomimic | |
| import robomimic.utils.test_utils as TestUtils | |
| import robomimic.utils.torch_utils as TorchUtils | |
| from robomimic.config import Config | |
| from robomimic.utils.log_utils import silence_stdout | |
| from robomimic.utils.torch_utils import dummy_context_mgr | |
| from robomimic.scripts.train import train | |
| from robomimic.scripts.playback_dataset import playback_dataset | |
| from robomimic.scripts.run_trained_agent import run_trained_agent | |
| def get_checkpoint_to_test(): | |
| """ | |
| Run a quick training run to get a checkpoint. This function runs a basic bc-image | |
| training run. RGB modality is used for a harder test case for the run agent | |
| script, which will need to also try writing image observations to the rollout | |
| dataset. | |
| """ | |
| # prepare image training run | |
| config = TestUtils.get_base_config(algo_name="bc") | |
| def image_modifier(conf): | |
| # using high-dimensional images - don't load entire dataset into memory, and smaller batch size | |
| conf.train.hdf5_cache_mode = "low_dim" | |
| conf.train.num_data_workers = 0 | |
| conf.train.batch_size = 16 | |
| # replace object with rgb modality | |
| conf.observation.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"] | |
| conf.observation.modalities.obs.rgb = ["agentview_image"] | |
| # set up visual encoders | |
| conf.observation.encoder.rgb.core_class = "VisualCore" | |
| conf.observation.encoder.rgb.core_kwargs.feature_dimension = 64 | |
| conf.observation.encoder.rgb.core_kwargs.backbone_class = 'ResNet18Conv' # ResNet backbone for image observations (unused if no image observations) | |
| conf.observation.encoder.rgb.core_kwargs.backbone_kwargs.pretrained = False # kwargs for visual core | |
| conf.observation.encoder.rgb.core_kwargs.backbone_kwargs.input_coord_conv = False | |
| conf.observation.encoder.rgb.core_kwargs.pool_class = "SpatialSoftmax" # Alternate options are "SpatialMeanPool" or None (no pooling) | |
| conf.observation.encoder.rgb.core_kwargs.pool_kwargs.num_kp = 32 # Default arguments for "SpatialSoftmax" | |
| conf.observation.encoder.rgb.core_kwargs.pool_kwargs.learnable_temperature = False # Default arguments for "SpatialSoftmax" | |
| conf.observation.encoder.rgb.core_kwargs.pool_kwargs.temperature = 1.0 # Default arguments for "SpatialSoftmax" | |
| conf.observation.encoder.rgb.core_kwargs.pool_kwargs.noise_std = 0.0 | |
| # observation randomizer class - set to None to use no randomization, or 'CropRandomizer' to use crop randomization | |
| conf.observation.encoder.rgb.obs_randomizer_class = None | |
| return conf | |
| config = TestUtils.config_from_modifier(base_config=config, config_modifier=image_modifier) | |
| # run training | |
| device = TorchUtils.get_torch_device(try_to_use_cuda=True) | |
| train(config, device=device) | |
| # return checkpoint | |
| ckpt_path = TestUtils.checkpoint_path_from_test_run() | |
| return ckpt_path | |
| def test_playback_script(silence=True, use_actions=False, use_obs=False): | |
| context = silence_stdout() if silence else dummy_context_mgr() | |
| with context: | |
| try: | |
| # setup args and run script | |
| args = argparse.Namespace() | |
| args.dataset = TestUtils.example_dataset_path() | |
| args.filter_key = None | |
| args.n = 3 # playback 3 demonstrations | |
| args.use_actions = use_actions | |
| args.use_obs = use_obs | |
| args.render = False | |
| args.video_path = TestUtils.temp_video_path() # dump video | |
| args.video_skip = 5 | |
| if use_obs: | |
| # camera observation names | |
| args.render_image_names = ["agentview_image", "robot0_eye_in_hand_image"] | |
| else: | |
| # camera names | |
| args.render_image_names = ["agentview", "robot0_eye_in_hand"] | |
| args.first = False | |
| playback_dataset(args) | |
| # indicate success | |
| ret = colored("passed!", "green") | |
| except Exception as e: | |
| # indicate failure by returning error string | |
| ret = colored("failed with error:\n{}\n\n{}".format(e, traceback.format_exc()), "red") | |
| # delete output video | |
| TestUtils.maybe_remove_file(TestUtils.temp_video_path()) | |
| act_str = "-action_playback" if use_actions else "" | |
| obs_str = "-obs" if use_obs else "" | |
| test_name = "playback-script{}{}".format(act_str, obs_str) | |
| print("{}: {}".format(test_name, ret)) | |
| def test_run_agent_script(silence=True): | |
| context = silence_stdout() if silence else dummy_context_mgr() | |
| with context: | |
| try: | |
| # get a model checkpoint | |
| ckpt_path = get_checkpoint_to_test() | |
| # setup args and run script | |
| args = argparse.Namespace() | |
| args.agent = ckpt_path | |
| args.n_rollouts = 3 # 3 rollouts | |
| args.horizon = 10 # short rollouts - 10 steps | |
| args.env = None | |
| args.render = False | |
| args.video_path = TestUtils.temp_video_path() # dump video | |
| args.video_skip = 5 | |
| args.camera_names = ["agentview", "robot0_eye_in_hand"] | |
| args.dataset_path = TestUtils.temp_dataset_path() # dump dataset | |
| args.dataset_obs = True | |
| args.seed = 0 | |
| run_trained_agent(args) | |
| # simple sanity check for shape of image observations in rollout dataset | |
| f = h5py.File(TestUtils.temp_dataset_path(), "r") | |
| assert f["data/demo_1/obs/agentview_image"].shape == (10, 84, 84, 3) | |
| assert f["data/demo_1/obs/agentview_image"].dtype == np.uint8 | |
| f.close() | |
| # indicate success | |
| ret = colored("passed!", "green") | |
| except Exception as e: | |
| # indicate failure by returning error string | |
| ret = colored("failed with error:\n{}\n\n{}".format(e, traceback.format_exc()), "red") | |
| # delete trained model directory, output video, and output dataset | |
| TestUtils.maybe_remove_dir(TestUtils.temp_model_dir_path()) | |
| TestUtils.maybe_remove_file(TestUtils.temp_video_path()) | |
| TestUtils.maybe_remove_file(TestUtils.temp_dataset_path()) | |
| test_name = "run-agent-script" | |
| print("{}: {}".format(test_name, ret)) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--verbose", | |
| action='store_true', | |
| help="don't suppress stdout during tests", | |
| ) | |
| args = parser.parse_args() | |
| test_playback_script(silence=(not args.verbose), use_actions=False, use_obs=False) | |
| test_playback_script(silence=(not args.verbose), use_actions=True, use_obs=False) | |
| test_playback_script(silence=(not args.verbose), use_actions=False, use_obs=True) | |
| test_run_agent_script(silence=(not args.verbose)) | |