Spaces:
Sleeping
Sleeping
File size: 2,583 Bytes
96da58e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
"""
Helper script to prepare datasets for diffusion policy training by (1) adding absolute actions and (2)
writing the absolute actions to action dictionaries.
"""
import os
import h5py
import argparse
import socket
import json
import numpy as np
import robomimic
import robomimic.macros as Macros
from robomimic.scripts.conversion.extract_action_dict import extract_action_dict
import mimicgen
from mimicgen.scripts.add_datagen_info import add_datagen_info
DATASETS = [
"/tmp/coffee/src_10.hdf5",
"/tmp/stack/src_10.hdf5",
]
def convert_actions_in_dataset(dataset_path, output_name=None, absolute_mg=False):
"""
Helper function to call the relevant scripts to get absolute action dicts for a given dataset.
"""
# first get absolute actions
args = argparse.Namespace()
args.dataset = dataset_path
args.n = None
args.absolute = True
args.absolute_mg = absolute_mg
new_ds_path = dataset_path
if output_name is not None:
args.output = os.path.join(os.path.dirname(dataset_path), output_name)
new_ds_path = args.output
else:
args.output = None
add_datagen_info(args)
# next convert actions to dict
args = argparse.Namespace()
args.dataset = new_ds_path
extract_action_dict(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--datasets",
type=str,
nargs='+',
default=None,
)
parser.add_argument(
"--output_name",
type=str,
default=None,
)
parser.add_argument(
"--absolute_mg",
action='store_true',
help="extract absolute actions using existing datagen info, and skip extraction of datagen info",
)
parser.add_argument(
"--slack",
action='store_true',
help="try to give slack notification after script finishes",
)
args = parser.parse_args()
datasets = args.datasets
if datasets is None:
datasets = DATASETS
for d in datasets:
dataset_path = os.path.expanduser(d)
convert_actions_in_dataset(dataset_path, output_name=args.output_name, absolute_mg=args.absolute_mg)
if args.slack and (Macros.SLACK_TOKEN is not None):
from robomimic.scripts.give_slack_notification import give_slack_notif
msg = "Completed the following action conversion run!\nHostname: {}\n".format(socket.gethostname())
datasets_json = json.dumps(dict(datasets=datasets), indent=4)
msg += "```{}```".format(datasets_json)
give_slack_notif(msg)
|