Spaces:
Sleeping
Sleeping
File size: 7,169 Bytes
96da58e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
"""
Tests for a handful of scripts. Excludes stdout output by
default (pass --verbose to see stdout output).
"""
import argparse
import traceback
import h5py
import numpy as np
import torch
from collections import OrderedDict
from termcolor import colored
import robomimic
import robomimic.utils.test_utils as TestUtils
import robomimic.utils.torch_utils as TorchUtils
from robomimic.config import Config
from robomimic.utils.log_utils import silence_stdout
from robomimic.utils.torch_utils import dummy_context_mgr
from robomimic.scripts.train import train
from robomimic.scripts.playback_dataset import playback_dataset
from robomimic.scripts.run_trained_agent import run_trained_agent
def get_checkpoint_to_test():
"""
Run a quick training run to get a checkpoint. This function runs a basic bc-image
training run. RGB modality is used for a harder test case for the run agent
script, which will need to also try writing image observations to the rollout
dataset.
"""
# prepare image training run
config = TestUtils.get_base_config(algo_name="bc")
def image_modifier(conf):
# using high-dimensional images - don't load entire dataset into memory, and smaller batch size
conf.train.hdf5_cache_mode = "low_dim"
conf.train.num_data_workers = 0
conf.train.batch_size = 16
# replace object with rgb modality
conf.observation.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"]
conf.observation.modalities.obs.rgb = ["agentview_image"]
# set up visual encoders
conf.observation.encoder.rgb.core_class = "VisualCore"
conf.observation.encoder.rgb.core_kwargs.feature_dimension = 64
conf.observation.encoder.rgb.core_kwargs.backbone_class = 'ResNet18Conv' # ResNet backbone for image observations (unused if no image observations)
conf.observation.encoder.rgb.core_kwargs.backbone_kwargs.pretrained = False # kwargs for visual core
conf.observation.encoder.rgb.core_kwargs.backbone_kwargs.input_coord_conv = False
conf.observation.encoder.rgb.core_kwargs.pool_class = "SpatialSoftmax" # Alternate options are "SpatialMeanPool" or None (no pooling)
conf.observation.encoder.rgb.core_kwargs.pool_kwargs.num_kp = 32 # Default arguments for "SpatialSoftmax"
conf.observation.encoder.rgb.core_kwargs.pool_kwargs.learnable_temperature = False # Default arguments for "SpatialSoftmax"
conf.observation.encoder.rgb.core_kwargs.pool_kwargs.temperature = 1.0 # Default arguments for "SpatialSoftmax"
conf.observation.encoder.rgb.core_kwargs.pool_kwargs.noise_std = 0.0
# observation randomizer class - set to None to use no randomization, or 'CropRandomizer' to use crop randomization
conf.observation.encoder.rgb.obs_randomizer_class = None
return conf
config = TestUtils.config_from_modifier(base_config=config, config_modifier=image_modifier)
# run training
device = TorchUtils.get_torch_device(try_to_use_cuda=True)
train(config, device=device)
# return checkpoint
ckpt_path = TestUtils.checkpoint_path_from_test_run()
return ckpt_path
def test_playback_script(silence=True, use_actions=False, use_obs=False):
context = silence_stdout() if silence else dummy_context_mgr()
with context:
try:
# setup args and run script
args = argparse.Namespace()
args.dataset = TestUtils.example_dataset_path()
args.filter_key = None
args.n = 3 # playback 3 demonstrations
args.use_actions = use_actions
args.use_obs = use_obs
args.render = False
args.video_path = TestUtils.temp_video_path() # dump video
args.video_skip = 5
if use_obs:
# camera observation names
args.render_image_names = ["agentview_image", "robot0_eye_in_hand_image"]
else:
# camera names
args.render_image_names = ["agentview", "robot0_eye_in_hand"]
args.first = False
playback_dataset(args)
# indicate success
ret = colored("passed!", "green")
except Exception as e:
# indicate failure by returning error string
ret = colored("failed with error:\n{}\n\n{}".format(e, traceback.format_exc()), "red")
# delete output video
TestUtils.maybe_remove_file(TestUtils.temp_video_path())
act_str = "-action_playback" if use_actions else ""
obs_str = "-obs" if use_obs else ""
test_name = "playback-script{}{}".format(act_str, obs_str)
print("{}: {}".format(test_name, ret))
def test_run_agent_script(silence=True):
context = silence_stdout() if silence else dummy_context_mgr()
with context:
try:
# get a model checkpoint
ckpt_path = get_checkpoint_to_test()
# setup args and run script
args = argparse.Namespace()
args.agent = ckpt_path
args.n_rollouts = 3 # 3 rollouts
args.horizon = 10 # short rollouts - 10 steps
args.env = None
args.render = False
args.video_path = TestUtils.temp_video_path() # dump video
args.video_skip = 5
args.camera_names = ["agentview", "robot0_eye_in_hand"]
args.dataset_path = TestUtils.temp_dataset_path() # dump dataset
args.dataset_obs = True
args.seed = 0
run_trained_agent(args)
# simple sanity check for shape of image observations in rollout dataset
f = h5py.File(TestUtils.temp_dataset_path(), "r")
assert f["data/demo_1/obs/agentview_image"].shape == (10, 84, 84, 3)
assert f["data/demo_1/obs/agentview_image"].dtype == np.uint8
f.close()
# indicate success
ret = colored("passed!", "green")
except Exception as e:
# indicate failure by returning error string
ret = colored("failed with error:\n{}\n\n{}".format(e, traceback.format_exc()), "red")
# delete trained model directory, output video, and output dataset
TestUtils.maybe_remove_dir(TestUtils.temp_model_dir_path())
TestUtils.maybe_remove_file(TestUtils.temp_video_path())
TestUtils.maybe_remove_file(TestUtils.temp_dataset_path())
test_name = "run-agent-script"
print("{}: {}".format(test_name, ret))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--verbose",
action='store_true',
help="don't suppress stdout during tests",
)
args = parser.parse_args()
test_playback_script(silence=(not args.verbose), use_actions=False, use_obs=False)
test_playback_script(silence=(not args.verbose), use_actions=True, use_obs=False)
test_playback_script(silence=(not args.verbose), use_actions=False, use_obs=True)
test_run_agent_script(silence=(not args.verbose))
|