File size: 7,169 Bytes
96da58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Tests for a handful of scripts. Excludes stdout output by 
default (pass --verbose to see stdout output).
"""
import argparse
import traceback
import h5py
import numpy as np
import torch
from collections import OrderedDict
from termcolor import colored

import robomimic
import robomimic.utils.test_utils as TestUtils
import robomimic.utils.torch_utils as TorchUtils
from robomimic.config import Config
from robomimic.utils.log_utils import silence_stdout
from robomimic.utils.torch_utils import dummy_context_mgr
from robomimic.scripts.train import train
from robomimic.scripts.playback_dataset import playback_dataset
from robomimic.scripts.run_trained_agent import run_trained_agent


def get_checkpoint_to_test():
    """
    Run a quick training run to get a checkpoint. This function runs a basic bc-image
    training run. RGB modality is used for a harder test case for the run agent
    script, which will need to also try writing image observations to the rollout
    dataset.
    """

    # prepare image training run
    config = TestUtils.get_base_config(algo_name="bc")

    def image_modifier(conf):
        # using high-dimensional images - don't load entire dataset into memory, and smaller batch size
        conf.train.hdf5_cache_mode = "low_dim"
        conf.train.num_data_workers = 0
        conf.train.batch_size = 16

        # replace object with rgb modality
        conf.observation.modalities.obs.low_dim = ["robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"]
        conf.observation.modalities.obs.rgb = ["agentview_image"]

        # set up visual encoders
        conf.observation.encoder.rgb.core_class = "VisualCore"
        conf.observation.encoder.rgb.core_kwargs.feature_dimension = 64
        conf.observation.encoder.rgb.core_kwargs.backbone_class = 'ResNet18Conv'                         # ResNet backbone for image observations (unused if no image observations)
        conf.observation.encoder.rgb.core_kwargs.backbone_kwargs.pretrained = False                # kwargs for visual core
        conf.observation.encoder.rgb.core_kwargs.backbone_kwargs.input_coord_conv = False
        conf.observation.encoder.rgb.core_kwargs.pool_class = "SpatialSoftmax"                # Alternate options are "SpatialMeanPool" or None (no pooling)
        conf.observation.encoder.rgb.core_kwargs.pool_kwargs.num_kp = 32                      # Default arguments for "SpatialSoftmax"
        conf.observation.encoder.rgb.core_kwargs.pool_kwargs.learnable_temperature = False    # Default arguments for "SpatialSoftmax"
        conf.observation.encoder.rgb.core_kwargs.pool_kwargs.temperature = 1.0                # Default arguments for "SpatialSoftmax"
        conf.observation.encoder.rgb.core_kwargs.pool_kwargs.noise_std = 0.0

        # observation randomizer class - set to None to use no randomization, or 'CropRandomizer' to use crop randomization
        conf.observation.encoder.rgb.obs_randomizer_class = None

        return conf

    config = TestUtils.config_from_modifier(base_config=config, config_modifier=image_modifier)

    # run training
    device = TorchUtils.get_torch_device(try_to_use_cuda=True)
    train(config, device=device)

    # return checkpoint
    ckpt_path = TestUtils.checkpoint_path_from_test_run()
    return ckpt_path


def test_playback_script(silence=True, use_actions=False, use_obs=False):
    context = silence_stdout() if silence else dummy_context_mgr()
    with context:

        try:
            # setup args and run script
            args = argparse.Namespace()
            args.dataset = TestUtils.example_dataset_path()
            args.filter_key = None
            args.n = 3 # playback 3 demonstrations
            args.use_actions = use_actions
            args.use_obs = use_obs
            args.render = False
            args.video_path = TestUtils.temp_video_path() # dump video
            args.video_skip = 5
            if use_obs:
                # camera observation names
                args.render_image_names = ["agentview_image", "robot0_eye_in_hand_image"]
            else:
                # camera names
                args.render_image_names = ["agentview", "robot0_eye_in_hand"]
            args.first = False
            playback_dataset(args)

            # indicate success
            ret = colored("passed!", "green")

        except Exception as e:
            # indicate failure by returning error string
            ret = colored("failed with error:\n{}\n\n{}".format(e, traceback.format_exc()), "red")

        # delete output video
        TestUtils.maybe_remove_file(TestUtils.temp_video_path())

    act_str = "-action_playback" if use_actions else ""
    obs_str = "-obs" if use_obs else ""
    test_name = "playback-script{}{}".format(act_str, obs_str)
    print("{}: {}".format(test_name, ret))


def test_run_agent_script(silence=True):
    context = silence_stdout() if silence else dummy_context_mgr()
    with context:

        try:
            # get a model checkpoint
            ckpt_path = get_checkpoint_to_test()

            # setup args and run script
            args = argparse.Namespace()
            args.agent = ckpt_path
            args.n_rollouts = 3 # 3 rollouts
            args.horizon = 10 # short rollouts - 10 steps
            args.env = None
            args.render = False
            args.video_path = TestUtils.temp_video_path() # dump video
            args.video_skip = 5
            args.camera_names = ["agentview", "robot0_eye_in_hand"]
            args.dataset_path = TestUtils.temp_dataset_path() # dump dataset
            args.dataset_obs = True
            args.seed = 0
            run_trained_agent(args)

            # simple sanity check for shape of image observations in rollout dataset
            f = h5py.File(TestUtils.temp_dataset_path(), "r")
            assert f["data/demo_1/obs/agentview_image"].shape == (10, 84, 84, 3)
            assert f["data/demo_1/obs/agentview_image"].dtype == np.uint8
            f.close()

            # indicate success
            ret = colored("passed!", "green")

        except Exception as e:
            # indicate failure by returning error string
            ret = colored("failed with error:\n{}\n\n{}".format(e, traceback.format_exc()), "red")

        # delete trained model directory, output video, and output dataset
        TestUtils.maybe_remove_dir(TestUtils.temp_model_dir_path())
        TestUtils.maybe_remove_file(TestUtils.temp_video_path())
        TestUtils.maybe_remove_file(TestUtils.temp_dataset_path())

    test_name = "run-agent-script"
    print("{}: {}".format(test_name, ret))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--verbose",
        action='store_true',
        help="don't suppress stdout during tests",
    )
    args = parser.parse_args()

    test_playback_script(silence=(not args.verbose), use_actions=False, use_obs=False)
    test_playback_script(silence=(not args.verbose), use_actions=True, use_obs=False)
    test_playback_script(silence=(not args.verbose), use_actions=False, use_obs=True)
    test_run_agent_script(silence=(not args.verbose))