Spaces:
Sleeping
Sleeping
| __version__ = "0.3.1" | |
| # stores released dataset links and rollout horizons in global dictionary. | |
| # Structure is given below for each type of dataset: | |
| # robosuite / real | |
| # { | |
| # task: | |
| # dataset_type: | |
| # hdf5_type: | |
| # url: link | |
| # horizon: value | |
| # ... | |
| # ... | |
| # ... | |
| # } | |
| DATASET_REGISTRY = {} | |
| # momart | |
| # { | |
| # task: | |
| # dataset_type: | |
| # url: link | |
| # size: value | |
| # ... | |
| # ... | |
| # } | |
| MOMART_DATASET_REGISTRY = {} | |
| def register_dataset_link(task, dataset_type, hdf5_type, link, horizon): | |
| """ | |
| Helper function to register dataset link in global dictionary. | |
| Also takes a @horizon parameter - this corresponds to the evaluation | |
| rollout horizon that should be used during training. | |
| Args: | |
| task (str): name of task for this dataset | |
| dataset_type (str): type of dataset (usually identifies the dataset source) | |
| hdf5_type (str): type of hdf5 - usually one of "raw", "low_dim", or "image", | |
| to identify the kind of observations in the dataset | |
| link (str): download link for the dataset | |
| horizon (int): evaluation rollout horizon that should be used with this dataset | |
| """ | |
| if task not in DATASET_REGISTRY: | |
| DATASET_REGISTRY[task] = {} | |
| if dataset_type not in DATASET_REGISTRY[task]: | |
| DATASET_REGISTRY[task][dataset_type] = {} | |
| DATASET_REGISTRY[task][dataset_type][hdf5_type] = dict(url=link, horizon=horizon) | |
| def register_all_links(): | |
| """ | |
| Record all dataset links in this function. | |
| """ | |
| # all proficient human datasets | |
| ph_tasks = ["lift", "can", "square", "transport", "tool_hang", "lift_real", "can_real", "tool_hang_real"] | |
| ph_horizons = [400, 400, 400, 700, 700, 1000, 1000, 1000] | |
| for task, horizon in zip(ph_tasks, ph_horizons): | |
| register_dataset_link(task=task, dataset_type="ph", hdf5_type="raw", horizon=horizon, | |
| link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/ph/demo{}.hdf5".format( | |
| task, "" if "real" in task else "_v141" | |
| ) | |
| ) | |
| # real world datasets only have demo.hdf5 files which already contain all observation modalities | |
| # while sim datasets store raw low-dim mujoco states in the demo.hdf5 | |
| if "real" not in task: | |
| register_dataset_link(task=task, dataset_type="ph", hdf5_type="low_dim", horizon=horizon, | |
| link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/ph/low_dim_v141.hdf5".format(task)) | |
| register_dataset_link(task=task, dataset_type="ph", hdf5_type="image", horizon=horizon, | |
| link=None) | |
| # all multi human datasets | |
| mh_tasks = ["lift", "can", "square", "transport"] | |
| mh_horizons = [500, 500, 500, 1100] | |
| for task, horizon in zip(mh_tasks, mh_horizons): | |
| register_dataset_link(task=task, dataset_type="mh", hdf5_type="raw", horizon=horizon, | |
| link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mh/demo_v141.hdf5".format(task)) | |
| register_dataset_link(task=task, dataset_type="mh", hdf5_type="low_dim", horizon=horizon, | |
| link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mh/low_dim_v141.hdf5".format(task)) | |
| register_dataset_link(task=task, dataset_type="mh", hdf5_type="image", horizon=horizon, | |
| link=None) | |
| # all machine generated datasets | |
| for task, horizon in zip(["lift", "can"], [400, 400]): | |
| register_dataset_link(task=task, dataset_type="mg", hdf5_type="raw", horizon=horizon, | |
| link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mg/demo_v141.hdf5".format(task)) | |
| register_dataset_link(task=task, dataset_type="mg", hdf5_type="low_dim_sparse", horizon=horizon, | |
| link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mg/low_dim_sparse_v141.hdf5".format(task)) | |
| register_dataset_link(task=task, dataset_type="mg", hdf5_type="image_sparse", horizon=horizon, | |
| link=None) | |
| register_dataset_link(task=task, dataset_type="mg", hdf5_type="low_dim_dense", horizon=horizon, | |
| link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/{}/mg/low_dim_dense_v141.hdf5".format(task)) | |
| register_dataset_link(task=task, dataset_type="mg", hdf5_type="image_dense", horizon=horizon, | |
| link=None) | |
| # can-paired dataset | |
| register_dataset_link(task="can", dataset_type="paired", hdf5_type="raw", horizon=400, | |
| link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/can/paired/demo_v141.hdf5") | |
| register_dataset_link(task="can", dataset_type="paired", hdf5_type="low_dim", horizon=400, | |
| link="http://downloads.cs.stanford.edu/downloads/rt_benchmark/can/paired/low_dim_v141.hdf5") | |
| register_dataset_link(task="can", dataset_type="paired", hdf5_type="image", horizon=400, | |
| link=None) | |
| def register_momart_dataset_link(task, dataset_type, link, dataset_size): | |
| """ | |
| Helper function to register dataset link in global dictionary. | |
| Also takes a @horizon parameter - this corresponds to the evaluation | |
| rollout horizon that should be used during training. | |
| Args: | |
| task (str): name of task for this dataset | |
| dataset_type (str): type of dataset (usually identifies the dataset source) | |
| link (str): download link for the dataset | |
| dataset_size (float): size of the dataset, in GB | |
| """ | |
| if task not in MOMART_DATASET_REGISTRY: | |
| MOMART_DATASET_REGISTRY[task] = {} | |
| if dataset_type not in MOMART_DATASET_REGISTRY[task]: | |
| MOMART_DATASET_REGISTRY[task][dataset_type] = {} | |
| MOMART_DATASET_REGISTRY[task][dataset_type] = dict(url=link, size=dataset_size) | |
| def register_all_momart_links(): | |
| """ | |
| Record all dataset links in this function. | |
| """ | |
| # all tasks, mapped to their [exp, sub, gen, sam] sizes | |
| momart_tasks = { | |
| "table_setup_from_dishwasher": [14, 14, 3.3, 0.6], | |
| "table_setup_from_dresser": [16, 17, 3.1, 0.7], | |
| "table_cleanup_to_dishwasher": [23, 36, 5.3, 1.1], | |
| "table_cleanup_to_sink": [17, 28, 2.9, 0.8], | |
| "unload_dishwasher": [21, 27, 5.4, 1.0], | |
| } | |
| momart_dataset_types = [ | |
| "expert", | |
| "suboptimal", | |
| "generalize", | |
| "sample", | |
| ] | |
| # Iterate over all combos and register the link | |
| for task, dataset_sizes in momart_tasks.items(): | |
| for dataset_type, dataset_size in zip(momart_dataset_types, dataset_sizes): | |
| register_momart_dataset_link( | |
| task=task, | |
| dataset_type=dataset_type, | |
| link=f"http://downloads.cs.stanford.edu/downloads/rt_mm/{dataset_type}/{task}_{dataset_type}.hdf5", | |
| dataset_size=dataset_size, | |
| ) | |
| register_all_links() | |
| register_all_momart_links() | |