File size: 9,037 Bytes
96da58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""
Utilities for testing algorithm implementations - used mainly by scripts in tests directory.
"""
import os
import json
import shutil
import traceback
from termcolor import colored

import numpy as np
import torch

import robomimic
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.torch_utils as TorchUtils
from robomimic.config import Config, config_factory
from robomimic.scripts.train import train


def maybe_remove_dir(dir_to_remove):
    """
    Remove directory if it exists.

    Args:
        dir_to_remove (str): path to directory to remove
    """
    if os.path.exists(dir_to_remove):
        shutil.rmtree(dir_to_remove)


def maybe_remove_file(file_to_remove):
    """
    Remove file if it exists.

    Args:
        file_to_remove (str): path to file to remove
    """
    if os.path.exists(file_to_remove):
        os.remove(file_to_remove)


def example_dataset_path():
    """
    Path to dataset to use for testing and example purposes. It should
    exist under the tests/assets directory, and will be downloaded 
    from a server if it does not exist.
    """
    dataset_folder = os.path.join(robomimic.__path__[0], "../tests/assets/")
    dataset_path = os.path.join(dataset_folder, "test_v141.hdf5")
    if not os.path.exists(dataset_path):
        print("\nWARNING: test hdf5 does not exist! Downloading from server...")
        os.makedirs(dataset_folder, exist_ok=True)
        FileUtils.download_url(
            url="http://downloads.cs.stanford.edu/downloads/rt_benchmark/test_v141.hdf5", 
            download_dir=dataset_folder,
        )
    return dataset_path


def example_momart_dataset_path():
    """
    Path to momart dataset to use for testing and example purposes. It should
    exist under the tests/assets directory, and will be downloaded
    from a server if it does not exist.
    """
    dataset_folder = os.path.join(robomimic.__path__[0], "../tests/assets/")
    dataset_path = os.path.join(dataset_folder, "test_momart.hdf5")
    if not os.path.exists(dataset_path):
        user_response = input("\nWARNING: momart test hdf5 does not exist! We will download sample dataset. "
                              "This will take 0.6GB space. Proceed? y/n\n")
        assert user_response.lower() in {"yes", "y"}, f"Did not receive confirmation. Aborting download."

        print("\nDownloading from server...")

        os.makedirs(dataset_folder, exist_ok=True)
        FileUtils.download_url(
            url="http://downloads.cs.stanford.edu/downloads/rt_mm/sample/test_momart.hdf5",
            download_dir=dataset_folder,
        )
    return dataset_path


def temp_model_dir_path():
    """
    Path to a temporary model directory to write to for testing and example purposes.
    """
    return os.path.join(robomimic.__path__[0], "../tests/tmp_model_dir")


def temp_dataset_path():
    """
    Defines default dataset path to write to for testing.
    """
    return os.path.join(robomimic.__path__[0], "../tests/", "tmp.hdf5")


def temp_video_path():
    """
    Defines default video path to write to for testing.
    """
    return os.path.join(robomimic.__path__[0], "../tests/", "tmp.mp4")


def get_base_config(algo_name):
    """
    Base config for testing algorithms.

    Args:
        algo_name (str): name of algorithm - loads the corresponding json
            from the config templates directory
    """

    # we will load and override defaults from template config
    base_config_path = os.path.join(robomimic.__path__[0], "exps/templates/{}.json".format(algo_name))
    with open(base_config_path, 'r') as f:
        config = Config(json.load(f))

    # small dataset with a handful of trajectories
    config.train.data = example_dataset_path()

    # temporary model dir
    model_dir = temp_model_dir_path()
    maybe_remove_dir(model_dir)
    config.train.output_dir = model_dir

    # train and validate for 3 gradient steps
    config.experiment.name = "test"
    config.experiment.validate = True
    config.experiment.epoch_every_n_steps = 3
    config.experiment.validation_epoch_every_n_steps = 3
    config.train.num_epochs = 1

    # default train and validation filter keys
    config.train.hdf5_filter_key = "train"
    config.train.hdf5_validation_filter_key = "valid"

    # ensure model saving, rollout, and offscreen video rendering are tested too
    config.experiment.save.enabled = True
    config.experiment.save.every_n_epochs = 1
    config.experiment.rollout.enabled = True
    config.experiment.rollout.rate = 1
    config.experiment.rollout.n = 1
    config.experiment.rollout.horizon = 10
    config.experiment.render_video = True

    # turn off logging to stdout, since that can interfere with testing code outputs
    config.experiment.logging.terminal_output_to_txt = False

    # test cuda (if available)
    config.train.cuda = True

    return config


def config_from_modifier(base_config, config_modifier):
    """
    Helper function to load a base config, modify it using
    the passed @config modifier function, and finalize it
    for training.

    Args:
        base_config (BaseConfig instance): starting config object that is
            loaded (to change algorithm config defaults), and then modified
            with @config_modifier

        config_modifier (function): function that takes a config object as
            input, and modifies it
    """

    # algo name to default config for this algorithm
    algo_name = base_config["algo_name"]
    config = config_factory(algo_name)

    # update config with the settings specified in the base config
    with config.unlocked():
        config.update(base_config)

        # modify the config and finalize it for training (no more modifications allowed)
        config = config_modifier(config)

    return config


def checkpoint_path_from_test_run():
    """
    Helper function that gets the path of a model checkpoint after a test training run is finished.
    """
    exp_dir = os.path.join(temp_model_dir_path(), "test")
    time_dir_names = [f.name for f in os.scandir(exp_dir) if f.is_dir()]
    assert len(time_dir_names) == 1
    path_to_models = os.path.join(exp_dir, time_dir_names[0], "models")
    epoch_name = [f.name for f in os.scandir(path_to_models) if f.name.startswith("model")][0]
    return os.path.join(path_to_models, epoch_name)


def test_eval_agent_from_checkpoint(ckpt_path, device):
    """
    Test loading a model from checkpoint and running a rollout with the 
    trained agent for a small number of steps.

    Args:
        ckpt_path (str): path to a checkpoint pth file

        device (torch.Device): torch device
    """

    # get policy and env from checkpoint
    policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_path=ckpt_path, device=device, verbose=True)
    env, _ = FileUtils.env_from_checkpoint(ckpt_dict=ckpt_dict, verbose=True)

    # run a test rollout
    ob_dict = env.reset()
    policy.start_episode()
    for _ in range(15):
        ac = policy(ob=ob_dict)
        ob_dict, r, done, _ = env.step(ac)


def test_run(base_config, config_modifier):
    """
    Takes a base_config and config_modifier (function that modifies a passed Config object)
    and runs training as a test. It also takes the trained checkpoint, tries to load the
    policy and environment from the checkpoint, and run an evaluation rollout. Returns
    a string that is colored green if the run finished successfully without any issues,
    and colored red if an error occurred. If an error occurs, the traceback is included
    in the string.

    Args:
        base_config (BaseConfig instance): starting config object that is
            loaded (to change algorithm config defaults), and then modified
            with @config_modifier

        config_modifier (function): function that takes a config object as
            input, and modifies it

    Returns:
        ret (str): a green "passed!" string, or a red "failed with error" string that contains
            the traceback
    """

    # disable some macros for testing
    Macros.RESULTS_SYNC_PATH = None
    Macros.USE_MAGLEV = False
    Macros.USE_NGC = False

    try:
        # get config
        config = config_from_modifier(base_config=base_config, config_modifier=config_modifier)

        # set torch device
        device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda)

        # run training
        train(config, device=device)

        # test evaluating a trained agent using saved checkpoint
        ckpt_path = checkpoint_path_from_test_run()
        test_eval_agent_from_checkpoint(ckpt_path, device=device)

        # indicate success
        ret = colored("passed!", "green")

    except Exception as e:
        # indicate failure by returning error string
        ret = colored("failed with error:\n{}\n\n{}".format(e, traceback.format_exc()), "red")

    # make sure model directory is cleaned up before returning from this function
    maybe_remove_dir(temp_model_dir_path())

    return ret