| | if __name__ == "__main__": |
| | import sys |
| | import os |
| | import pathlib |
| |
|
| | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) |
| | sys.path.append(ROOT_DIR) |
| |
|
| | import multiprocessing |
| | import os |
| | import shutil |
| | import click |
| | import pathlib |
| | import h5py |
| | from tqdm import tqdm |
| | import collections |
| | import pickle |
| | from diffusion_policy.common.robomimic_util import RobomimicAbsoluteActionConverter |
| |
|
| | def worker(x): |
| | path, idx, do_eval = x |
| | converter = RobomimicAbsoluteActionConverter(path) |
| | if do_eval: |
| | abs_actions, info = converter.convert_and_eval_idx(idx) |
| | else: |
| | abs_actions = converter.convert_idx(idx) |
| | info = dict() |
| | return abs_actions, info |
| |
|
| | @click.command() |
| | @click.option('-i', '--input', required=True, help='input hdf5 path') |
| | @click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist') |
| | @click.option('-e', '--eval_dir', default=None, help='directory to output evaluation metrics') |
| | @click.option('-n', '--num_workers', default=None, type=int) |
| | def main(input, output, eval_dir, num_workers): |
| | |
| | input = pathlib.Path(input).expanduser() |
| | assert input.is_file() |
| | output = pathlib.Path(output).expanduser() |
| | assert output.parent.is_dir() |
| | assert not output.is_dir() |
| |
|
| | do_eval = False |
| | if eval_dir is not None: |
| | eval_dir = pathlib.Path(eval_dir).expanduser() |
| | assert eval_dir.parent.exists() |
| | do_eval = True |
| | |
| | converter = RobomimicAbsoluteActionConverter(input) |
| |
|
| | |
| | with multiprocessing.Pool(num_workers) as pool: |
| | results = pool.map(worker, [(input, i, do_eval) for i in range(len(converter))]) |
| | |
| | |
| | print('Copying hdf5') |
| | shutil.copy(str(input), str(output)) |
| |
|
| | |
| | with h5py.File(output, 'r+') as out_file: |
| | for i in tqdm(range(len(converter)), desc="Writing to output"): |
| | abs_actions, info = results[i] |
| | demo = out_file[f'data/demo_{i}'] |
| | demo['actions'][:] = abs_actions |
| | |
| | |
| | if do_eval: |
| | eval_dir.mkdir(parents=False, exist_ok=True) |
| |
|
| | print("Writing error_stats.pkl") |
| | infos = [info for _, info in results] |
| | pickle.dump(infos, eval_dir.joinpath('error_stats.pkl').open('wb')) |
| |
|
| | print("Generating visualization") |
| | metrics = ['pos', 'rot'] |
| | metrics_dicts = dict() |
| | for m in metrics: |
| | metrics_dicts[m] = collections.defaultdict(list) |
| |
|
| | for i in range(len(infos)): |
| | info = infos[i] |
| | for k, v in info.items(): |
| | for m in metrics: |
| | metrics_dicts[m][k].append(v[m]) |
| |
|
| | from matplotlib import pyplot as plt |
| | plt.switch_backend('PDF') |
| |
|
| | fig, ax = plt.subplots(1, len(metrics)) |
| | for i in range(len(metrics)): |
| | axis = ax[i] |
| | data = metrics_dicts[metrics[i]] |
| | for key, value in data.items(): |
| | axis.plot(value, label=key) |
| | axis.legend() |
| | axis.set_title(metrics[i]) |
| | fig.set_size_inches(10,4) |
| | fig.savefig(str(eval_dir.joinpath('error_stats.pdf'))) |
| | fig.savefig(str(eval_dir.joinpath('error_stats.png'))) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|