Spaces:
Sleeping
Sleeping
File size: 6,829 Bytes
96da58e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
__version__ = "0.3.1"
# stores released dataset links and rollout horizons in global dictionary.
# Structure is given below for each type of dataset:
# robosuite / real
# {
# task:
# dataset_type:
# hdf5_type:
# url: link
# horizon: value
# ...
# ...
# ...
# }
DATASET_REGISTRY = {}
# momart
# {
# task:
# dataset_type:
# url: link
# size: value
# ...
# ...
# }
MOMART_DATASET_REGISTRY = {}
def register_dataset_link(task, dataset_type, hdf5_type, link, horizon):
"""
Helper function to register dataset link in global dictionary.
Also takes a @horizon parameter - this corresponds to the evaluation
rollout horizon that should be used during training.
Args:
task (str): name of task for this dataset
dataset_type (str): type of dataset (usually identifies the dataset source)
hdf5_type (str): type of hdf5 - usually one of "raw", "low_dim", or "image",
to identify the kind of observations in the dataset
link (str): download link for the dataset
horizon (int): evaluation rollout horizon that should be used with this dataset
"""
if task not in DATASET_REGISTRY:
DATASET_REGISTRY[task] = {}
if dataset_type not in DATASET_REGISTRY[task]:
DATASET_REGISTRY[task][dataset_type] = {}
DATASET_REGISTRY[task][dataset_type][hdf5_type] = dict(url=link, horizon=horizon)
def register_all_links():
"""
Record all dataset links in this function.
"""
# all proficient human datasets
ph_tasks = ["lift", "can", "square", "transport", "tool_hang", "lift_real", "can_real", "tool_hang_real"]
ph_horizons = [400, 400, 400, 700, 700, 1000, 1000, 1000]
for task, horizon in zip(ph_tasks, ph_horizons):
register_dataset_link(task=task, dataset_type="ph", hdf5_type="raw", horizon=horizon,
link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/ph/demo{}.hdf5".format(
task, "" if "real" in task else "_v141"
)
)
# real world datasets only have demo.hdf5 files which already contain all observation modalities
# while sim datasets store raw low-dim mujoco states in the demo.hdf5
if "real" not in task:
register_dataset_link(task=task, dataset_type="ph", hdf5_type="low_dim", horizon=horizon,
link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/ph/low_dim_v141.hdf5".format(task))
register_dataset_link(task=task, dataset_type="ph", hdf5_type="image", horizon=horizon,
link=None)
# all multi human datasets
mh_tasks = ["lift", "can", "square", "transport"]
mh_horizons = [500, 500, 500, 1100]
for task, horizon in zip(mh_tasks, mh_horizons):
register_dataset_link(task=task, dataset_type="mh", hdf5_type="raw", horizon=horizon,
link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mh/demo_v141.hdf5".format(task))
register_dataset_link(task=task, dataset_type="mh", hdf5_type="low_dim", horizon=horizon,
link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mh/low_dim_v141.hdf5".format(task))
register_dataset_link(task=task, dataset_type="mh", hdf5_type="image", horizon=horizon,
link=None)
# all machine generated datasets
for task, horizon in zip(["lift", "can"], [400, 400]):
register_dataset_link(task=task, dataset_type="mg", hdf5_type="raw", horizon=horizon,
link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mg/demo_v141.hdf5".format(task))
register_dataset_link(task=task, dataset_type="mg", hdf5_type="low_dim_sparse", horizon=horizon,
link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mg/low_dim_sparse_v141.hdf5".format(task))
register_dataset_link(task=task, dataset_type="mg", hdf5_type="image_sparse", horizon=horizon,
link=None)
register_dataset_link(task=task, dataset_type="mg", hdf5_type="low_dim_dense", horizon=horizon,
link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mg/low_dim_dense_v141.hdf5".format(task))
register_dataset_link(task=task, dataset_type="mg", hdf5_type="image_dense", horizon=horizon,
link=None)
# can-paired dataset
register_dataset_link(task="can", dataset_type="paired", hdf5_type="raw", horizon=400,
link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/can/paired/demo_v141.hdf5")
register_dataset_link(task="can", dataset_type="paired", hdf5_type="low_dim", horizon=400,
link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/can/paired/low_dim_v141.hdf5")
register_dataset_link(task="can", dataset_type="paired", hdf5_type="image", horizon=400,
link=None)
def register_momart_dataset_link(task, dataset_type, link, dataset_size):
"""
Helper function to register dataset link in global dictionary.
Also takes a @horizon parameter - this corresponds to the evaluation
rollout horizon that should be used during training.
Args:
task (str): name of task for this dataset
dataset_type (str): type of dataset (usually identifies the dataset source)
link (str): download link for the dataset
dataset_size (float): size of the dataset, in GB
"""
if task not in MOMART_DATASET_REGISTRY:
MOMART_DATASET_REGISTRY[task] = {}
if dataset_type not in MOMART_DATASET_REGISTRY[task]:
MOMART_DATASET_REGISTRY[task][dataset_type] = {}
MOMART_DATASET_REGISTRY[task][dataset_type] = dict(url=link, size=dataset_size)
def register_all_momart_links():
"""
Record all dataset links in this function.
"""
# all tasks, mapped to their [exp, sub, gen, sam] sizes
momart_tasks = {
"table_setup_from_dishwasher": [14, 14, 3.3, 0.6],
"table_setup_from_dresser": [16, 17, 3.1, 0.7],
"table_cleanup_to_dishwasher": [23, 36, 5.3, 1.1],
"table_cleanup_to_sink": [17, 28, 2.9, 0.8],
"unload_dishwasher": [21, 27, 5.4, 1.0],
}
momart_dataset_types = [
"expert",
"suboptimal",
"generalize",
"sample",
]
# Iterate over all combos and register the link
for task, dataset_sizes in momart_tasks.items():
for dataset_type, dataset_size in zip(momart_dataset_types, dataset_sizes):
register_momart_dataset_link(
task=task,
dataset_type=dataset_type,
link=f"http://downloads.cs.stanford.edu/downloads/rt_mm/{dataset_type}/{task}_{dataset_type}.hdf5",
dataset_size=dataset_size,
)
register_all_links()
register_all_momart_links()
|