Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # -------------------------------------------------------- | |
| import argparse | |
| import json | |
| import os | |
| import time | |
| import traceback | |
| from typing import Optional | |
| import math | |
| import numpy as np | |
| import tensorflow_datasets as tfds | |
| from tensorflow_datasets.core import DatasetBuilder | |
| from tqdm import tqdm | |
| from . import utils | |
| SCRIPT_DESCRIPTION=""" | |
| Converts an Open X-Embodiment dataset from GS to encoded/tokenized data on disk. | |
| This script only encodes one split (specified by `--data_split`) | |
| of a one OpenX dataset (specified by `--dataset_name`) at a time. | |
| Optionally, each split can be partitioned into multiple shards, | |
| which is useful for parallelized encoding across GPUs. | |
| Example usage: | |
| CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_openx_dataset --dataset_name bc_z --data_split train --episode_cnt 500 --num_shards 16 --curr_shard_rank 0 | |
| CUDA_VISIBLE_DEVICES=1 python -m datasets.encode_openx_dataset --dataset_name bc_z --data_split train --episode_cnt 500 --num_shards 16 --curr_shard_rank 1 | |
| set -e | |
| for ((i = 0; i < 64; i += 2)); do | |
| CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_openx_dataset --dataset_name bridge --data_split train --num_shards 64 --curr_shard_rank $i --root_dir sharded_data | |
| done | |
| set -e | |
| for ((i = 1; i < 64; i += 2)); do | |
| CUDA_VISIBLE_DEVICES=1 python -m datasets.encode_openx_dataset --dataset_name bridge --data_split train --num_shards 64 --curr_shard_rank $i --root_dir sharded_data | |
| done | |
| Example usage (SVD tokenizer): | |
| CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_openx_dataset --dataset_name language_table --data_split val --no_quantization --encoder_type temporalvae --encoder_name_or_path 'stabilityai/stable-video-diffusion-img2vid' | |
| """.strip() | |
| # The validation set is the first VAL_RATIO examples in the dataset, and clipped to [MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES] | |
| VAL_RATIO = 0.05 | |
| MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES = 20, 200 | |
| DATA_FREQ_TABLE = { | |
| "austin_sailor_dataset_converted_externally_to_rlds": 20, | |
| "stanford_hydra_dataset_converted_externally_to_rlds": 10, | |
| "austin_buds_dataset_converted_externally_to_rlds": 20, | |
| "austin_sirius_dataset_converted_externally_to_rlds": 20, | |
| "berkeley_mvp_converted_externally_to_rlds": 5, | |
| "berkeley_rpt_converted_externally_to_rlds": 30, | |
| "ucsd_kitchen_dataset_converted_externally_to_rlds": 2, | |
| "iamlab_cmu_pickup_insert_converted_externally_to_rlds": 20, | |
| "utaustin_mutex": 20, | |
| "imperialcollege_sawyer_wrist_cam": 10, | |
| "language_table": 2, # changed to match frequency | |
| "kuka": 2, # changed to match frequency | |
| "bc_z": 10, | |
| "robo_net": 1, | |
| "dlr_sara_pour_converted_externally_to_rlds": 10, | |
| "stanford_robocook_converted_externally_to_rlds": 5, | |
| "cmu_play_fusion": 5, | |
| "bridge": 5, | |
| "furniture_bench_dataset_converted_externally_to_rlds": 10, | |
| "ucsd_pick_and_place_dataset_converted_externally_to_rlds": 3, | |
| "usc_cloth_sim_converted_externally_to_rlds": 10, | |
| "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 20, | |
| "roboturk": 10, | |
| "kaist_nonprehensile_converted_externally_to_rlds": 10, | |
| "asu_table_top_converted_externally_to_rlds": 12, | |
| "utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10, | |
| "berkeley_cable_routing": 10, | |
| "droid": 15, | |
| "uiuc_d3field": 1, | |
| "robo_set": 5, | |
| "toto": 30, | |
| "nyu_door_opening_surprising_effectiveness": 3, | |
| "nyu_franka_play_dataset_converted_externally_to_rlds": 3, | |
| "mimic_play": 15, | |
| "maniskill_dataset_converted_externally_to_rlds": 20, | |
| "columbia_cairlab_pusht_real": 10, | |
| "conq_hose_manipulation": 30, | |
| "dlr_edan_shared_control_converted_externally_to_rlds": 5, | |
| "berkeley_gnm_sac_son": 10, | |
| "berkeley_autolab_ur5": 5, | |
| "aloha_mobile": 30, | |
| "1x_humanoid": 30, | |
| "epic_kitchen_originalres": 30, | |
| "epic_kitchen": 30, | |
| "exoego4d": 30, | |
| "ego4d": 1, # less than this. | |
| "robomimic": 6, # average length around 50 | |
| "metaworld": 6, | |
| "frodobot": 30, | |
| "fractal20220817_data": 3, | |
| # robomimic | |
| "robomimic": 6, # average length around 50 | |
| "robomimic_new": 6, # average length around 50 | |
| "robomimic_multitask_new": 6, # average length around 50 | |
| "robomimic_new_perturb": 6, # average length around 50 | |
| "robomimic_multitask_new_perturb": 6, # average length around 50 | |
| } | |
| def select_image(observation, verbose=False): | |
| """ | |
| Select a canonical frame as image observation. | |
| """ | |
| imgs = [] | |
| # does not need to prefer wrist camera | |
| for key in ["rgb", "image"]: | |
| for obs_key in observation: | |
| if key in obs_key and "depth" not in obs_key: | |
| image = observation[obs_key] | |
| if type(observation[obs_key]) is not np.ndarray: | |
| image = image.numpy() | |
| if verbose: | |
| print("selected image key:", obs_key) | |
| imgs.append(image) | |
| return imgs | |
| def process_dataset_step(step, encoder_type: str, encoder_name_or_path: str, | |
| keep_res=False, quantize=True, no_encoding=False): | |
| """ | |
| Map dataset-specific keys and values to a unified format. | |
| Args: | |
| step (dict): The step dictionary containing the dataset-specific information. | |
| encoder_type (str, optional): The image encoder to use. | |
| Returns: | |
| dict: The processed step dictionary with the mapped keys and values. | |
| """ | |
| step_dict = {} | |
| try: | |
| if "action" in step: | |
| step_dict["action"] = np.array(step["action"]) | |
| # handle action | |
| if type(step["action"]) is dict: | |
| step_dict["action"] = step_dict["action"].item() | |
| # outlier cases | |
| action = [] | |
| for k, v in sorted(step_dict["action"].items()): | |
| action.append(v.numpy().reshape(-1)) | |
| step_dict["action"] = np.concatenate(action) | |
| # handle image | |
| images = select_image(step["observation"]) | |
| # compute the embeddings. | |
| if no_encoding: | |
| step_dict["image"] = utils.resize_image(images[0]) | |
| elif quantize: | |
| step_dict["image"] = utils.get_quantized_image_embeddings( | |
| images[0], | |
| encoder_type=encoder_type, | |
| encoder_name_or_path=encoder_name_or_path, | |
| keep_res=keep_res, | |
| ) | |
| else: | |
| step_dict["image"] = utils.get_vae_image_embeddings( | |
| images[0], | |
| encoder_type=encoder_type, | |
| encoder_name_or_path=encoder_name_or_path, | |
| keep_res=keep_res, | |
| ) | |
| except Exception as e: | |
| print("--------------------------") | |
| print("process_dataset_step exception:", traceback.format_exc()) | |
| return step_dict | |
| def get_dataset_builder(gs_dataset_name) -> tuple[DatasetBuilder, int]: | |
| """ | |
| Returns the dataset builder and the total number of examples (for the train split). | |
| """ | |
| try: | |
| builder = tfds.builder_from_directory(builder_dir=f"gs://gresearch/robotics/{gs_dataset_name}/0.1.0/") | |
| except: | |
| try: | |
| builder = tfds.builder_from_directory(builder_dir=f"gs://gresearch/robotics/{gs_dataset_name}/1.0.0/") | |
| except: | |
| builder = tfds.builder_from_directory(builder_dir=f"gs://gresearch/robotics/{gs_dataset_name}/0.0.1/") | |
| info = builder.info | |
| num_examples = info.splits["train"].num_examples | |
| return builder, num_examples | |
| def get_shard_inds(first_split_ind: int, last_split_ind: int, curr_shard_rank: int, num_shards: int) -> tuple[int, int]: | |
| """ | |
| Given the indices of the first (inclusive) and last (exclusive) examples in the data split (i.e. entire train dataset or val dataset), | |
| returns the indices of the first (inclusive) and last (exclusive) examples for the current shard in this data split. | |
| """ | |
| split_num_examples = last_split_ind - first_split_ind | |
| shard_size_float = split_num_examples / num_shards # average number of examples per shard | |
| return ( | |
| first_split_ind + math.ceil(curr_shard_rank * shard_size_float), | |
| min(first_split_ind + math.ceil((curr_shard_rank + 1) * shard_size_float), last_split_ind) | |
| ) | |
| def encode_dataset_split( | |
| gs_dataset_name: str, | |
| split: str, | |
| max_episodes: Optional[int], | |
| original_res: bool, | |
| no_quantization: bool, | |
| curr_shard_rank: int, | |
| num_shards: int, | |
| root_dir: str, | |
| encoder_type: str, | |
| encoder_name_or_path: str, | |
| dataset_postfix: str = "", | |
| no_encoding: bool = False, | |
| ): | |
| """ | |
| Converts an Open X-Embodiment dataset from GS to encoded/tokenized data on disk. | |
| The data written to disk can be used to load a `RawTokenDataset` (or the continuous version.) | |
| Args: | |
| gs_dataset_name: the name of the dataset in Google Storage. | |
| Can be checked with gsutil ls -d gs://gresearch/robotics/*/ | |
| split: expected to be either "train" or "val". TODO: decide how to split | |
| max_episodes: the maximum number of trajectories to include in the dataset. | |
| dataset_postfix: will be a suffix of the output dirname. | |
| image_encoder: string specifying the type of image encoder/tokenizer to use. | |
| original_res: if True, will maintain original resolution of the video rather than resizing it to 256x256. | |
| no_quantization: if True, will not perform quantization step in image encoder. | |
| """ | |
| gs_dataset_name = gs_dataset_name.strip() # never modified | |
| suffixed_dataset_name = gs_dataset_name # will modify later | |
| if no_quantization: | |
| video_dtype = np.float16 | |
| elif no_encoding: | |
| video_dtype = np.uint8 | |
| else: | |
| video_dtype = np.uint32 | |
| if original_res: | |
| suffixed_dataset_name = f"{suffixed_dataset_name}_originalres" | |
| if no_quantization: | |
| suffixed_dataset_name = f"{suffixed_dataset_name}_noquant" | |
| if no_encoding: | |
| suffixed_dataset_name = f"{suffixed_dataset_name}_noencoding" | |
| save_dirname = "_".join([suffixed_dataset_name, encoder_type, dataset_postfix, split]) | |
| dataset_path = os.path.join(root_dir, save_dirname) | |
| print("=" * 25) | |
| print(f"{dataset_path=}") | |
| utils.mkdir_if_missing(dataset_path) | |
| # Load data | |
| builder, num_examples = get_dataset_builder(gs_dataset_name) | |
| if max_episodes is not None: | |
| num_examples = min(num_examples, max_episodes) # clip num_examples | |
| # We will only operate on a subset of the training examples, depending on: | |
| # 1) The split (train/val). Some examples are reserved for the other split. | |
| # 2) Sharding | |
| assert num_examples > MIN_VAL_EXAMPLES, f"{num_examples=} {MIN_VAL_EXAMPLES=}" # non-positive number of train examples otherwise | |
| num_val_examples = np.clip(int(VAL_RATIO * num_examples), MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES) | |
| if split == "train": # first_ind inclusive, last_ind exclusive | |
| first_split_ind, last_split_ind = num_val_examples, num_examples | |
| elif split == "val": | |
| first_split_ind, last_split_ind = 0, num_val_examples | |
| else: | |
| raise NotImplementedError(f"{split=}") | |
| first_shard_ind, last_shard_ind = get_shard_inds(first_split_ind, last_split_ind, curr_shard_rank, num_shards) | |
| print(f"Total number of examples in {suffixed_dataset_name}: {num_examples}") | |
| print(f"Number of examples for {split=}, shard {curr_shard_rank} of {num_shards}: " | |
| f"{last_shard_ind - first_shard_ind}. {first_shard_ind=} {last_shard_ind=}") | |
| ##### Encode data ##### | |
| traj_lens = [] # only used to print statistics | |
| videos = [] # NOTE: videos/actions for the entire shard are stored in RAM until the end | |
| actions = [] | |
| segment_ids = [] | |
| # split based on some fixed batch sizes to reset RAM. | |
| max_batch_per_loading = 10 | |
| pbar = tqdm(range(first_shard_ind, last_shard_ind, max_batch_per_loading), position=0, leave=True) | |
| start_time = time.time() | |
| for start_idx in pbar: | |
| end_idx = min(start_idx + max_batch_per_loading, last_shard_ind) | |
| pbar.set_description(f"{suffixed_dataset_name} caching episodes: {start_idx}:{end_idx}") | |
| ds = builder.as_dataset(split=f"train[{start_idx}:{end_idx}]") | |
| for chunk_idx, episode in enumerate(tqdm(ds, position=1, leave=False)): | |
| segment_id = start_idx + chunk_idx | |
| try: | |
| # batchify the data and then process | |
| for step_ind, step_data in enumerate(episode["steps"]): | |
| dataset_step = process_dataset_step( | |
| step_data, | |
| encoder_type=encoder_type, | |
| encoder_name_or_path=encoder_name_or_path, | |
| keep_res=original_res, | |
| quantize=not no_quantization, | |
| no_encoding=no_encoding | |
| ) | |
| segment_ids.append(segment_id) | |
| videos.append(dataset_step["image"]) | |
| actions.append(dataset_step["action"]) | |
| traj_lens.append(step_ind + 1) # number of steps in this trajectory | |
| except: | |
| print("-" * 25) | |
| print(f"Add episode failed: {segment_id=}", traceback.format_exc(), suffixed_dataset_name) | |
| # 2 day timeout | |
| if time.time() - start_time > 86400 * 2: | |
| print(f"Writing dataset {suffixed_dataset_name} timed out") | |
| break | |
| if no_quantization: | |
| num_channels, height, width = videos[-1].shape[:3] | |
| else: | |
| height, width = videos[-1].shape[:2] | |
| num_channels = None | |
| ##### Write videos, actions, segment_ids, and metadata ##### | |
| # align format to save segment_ids.bin, video.bin, actions/action.bin, metadata.json | |
| # save videos | |
| videos = np.stack(videos, axis=0) | |
| fp = np.memmap(f'{dataset_path}/video.bin', dtype=video_dtype, mode='w+', shape=videos.shape) | |
| fp[:] = videos[:] | |
| # save action | |
| utils.mkdir_if_missing(f'{dataset_path}/actions') | |
| actions = np.stack(actions, axis=0) | |
| fp = np.memmap(f'{dataset_path}/actions/actions.bin', dtype=np.float32, mode='w+', shape=actions.shape) | |
| fp[:] = actions[:] | |
| # save segment_ids | |
| segment_ids = np.array(segment_ids) | |
| fp = np.memmap(f'{dataset_path}/segment_ids.bin', dtype=np.int32, mode='w+', shape=segment_ids.shape) | |
| fp[:] = segment_ids[:] # map to trajectory index | |
| # feature_mean = float(np.mean(videos)) | |
| # feature_std = float(np.std((videos - feature_mean) / 1e9)) * 1e9 | |
| # save metadata | |
| if encoder_type == "magvit": | |
| vocab_size = int(2 ** 18) | |
| elif encoder_type == "temporalvae": | |
| vocab_size = None | |
| else: | |
| raise NotImplementedError(f"{encoder_type=}") | |
| with open(f'{dataset_path}/metadata.json', 'w') as f: # Technically only need to save most of this data for shard 0 | |
| json.dump({ | |
| "token_dtype": str(np.dtype(videos.dtype)), | |
| "action_dim": actions[0].shape[-1], | |
| "s": 16, | |
| "h": height, | |
| "w": width, | |
| "vocab_size": vocab_size, | |
| "hz": DATA_FREQ_TABLE.get(gs_dataset_name, 1), # to be loaded from the data code TODO: remove default? | |
| "encoder_name_or_path": encoder_name_or_path, | |
| "encoder_type": encoder_type, | |
| "num_images": len(videos), | |
| "name": gs_dataset_name, | |
| "latent_channels": num_channels, | |
| "quantized": not args.no_quantization, | |
| # "feature_mean": feature_mean, | |
| # "feature_std": feature_std, | |
| }, f) | |
| print(f"{len(traj_lens)=} {np.mean(traj_lens)=} {np.sum(traj_lens)=}") | |
| print(f"Dataset creation time: {time.time() - start_time:.3f}") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description=SCRIPT_DESCRIPTION) | |
| parser.add_argument( | |
| "--dataset_name", type=str, required=True, | |
| help="The name of the Open X-Embodiment dataset on Google Storage. " | |
| "Can be checked with gsutil ls -d gs://gresearch/robotics/*/. " | |
| ) | |
| parser.add_argument( | |
| "--data_split", type=str, choices=["train", "val"], required=True, | |
| help="The split of the dataset to create." | |
| ) | |
| parser.add_argument( | |
| "--episode_cnt", type=int, | |
| help="If specified, will limit the maximum number of trajectories to encode." | |
| ) | |
| parser.add_argument( | |
| "--original_res", action='store_true', | |
| help="Maintain original resolution of the video rather than resizing it to 256x256." | |
| ) | |
| parser.add_argument( | |
| "--no_quantization", action='store_true', | |
| help="Skip quantization step in visual encoder." | |
| ) | |
| parser.add_argument( | |
| "--num_shards", type=int, default=1, | |
| help="The number of shards to partition the train/val dataset into." | |
| ) | |
| parser.add_argument( | |
| "--curr_shard_rank", type=int, default=0, | |
| help="The (0-indexed) shard number to encode." | |
| ) | |
| parser.add_argument( | |
| "--root_dir", type=str, default="data", | |
| help="The root directory to write all datasets to." | |
| ) | |
| parser.add_argument( | |
| "--encoder_type", type=str, default="magvit", choices=["magvit", "temporalvae"], | |
| help="Type of the image tokenizer." | |
| ) | |
| parser.add_argument( | |
| "--encoder_name_or_path", type=str, default="data/magvit2.ckpt", | |
| help="The path or name of the image encoder." | |
| ) | |
| parser.add_argument( | |
| "--no_encoding", action='store_true', | |
| help="Preserve the groundtruth raw images to compute metrics in validation." | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| utils.set_seed(233) | |
| dataset_postfix = f"shard{args.curr_shard_rank}_of_{args.num_shards}" if args.num_shards > 1 else "" | |
| if args.episode_cnt is not None: | |
| dataset_postfix = f"max{args.episode_cnt}_{dataset_postfix}" if dataset_postfix else f"max{args.episode_cnt}" | |
| encode_dataset_split( | |
| gs_dataset_name=args.dataset_name, | |
| split=args.data_split, | |
| max_episodes=args.episode_cnt, | |
| dataset_postfix=dataset_postfix, | |
| original_res=args.original_res, | |
| no_quantization=args.no_quantization, | |
| num_shards=args.num_shards, | |
| curr_shard_rank=args.curr_shard_rank, | |
| root_dir=args.root_dir, | |
| encoder_type=args.encoder_type, | |
| encoder_name_or_path=args.encoder_name_or_path, | |
| no_encoding=args.no_encoding, | |
| ) | |