Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # -------------------------------------------------------- | |
| import argparse | |
| import json | |
| import os | |
| import time | |
| import traceback | |
| from typing import Optional | |
| import numpy as np | |
| from tqdm import tqdm | |
| from datasets.encode_openx_dataset import MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES, get_shard_inds, VAL_RATIO, \ | |
| process_dataset_step, DATA_FREQ_TABLE | |
| from datasets.extern.ego4d import ego4d_dataset_size, ego4d_dataset_generator | |
| from datasets.extern.egoexo4d import egoexo4d_dataset_size, egoexo4d_dataset_generator | |
| from datasets.extern.robomimic import robomimic_dataset_generator, robomimic_dataset_size | |
| from . import utils | |
| SCRIPT_DESCRIPTION=""" | |
| Similar to encode_openx_dataset.py except for non-OpenX datasets. | |
| Again, each split can be partitioned into multiple shards, | |
| which is useful for parallelized encoding across GPUs. | |
| Example usage: | |
| CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_extern_dataset --dataset_name egoexo4d --data_split train --num_shards 1000 --curr_shard_rank 400 | |
| Untested usage (SVD tokenizer): | |
| CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_extern_dataset --dataset_name robomimic --data_split val --no_quantization --encoder_type temporalvae --encoder_name_or_path 'stabilityai/stable-video-diffusion-img2vid' | |
| """.strip() | |
| DATASET_TO_GEN_AND_SIZE = { | |
| "ego4d": (ego4d_dataset_generator, ego4d_dataset_size), | |
| "egoexo4d": (egoexo4d_dataset_generator, egoexo4d_dataset_size), | |
| "robomimic": (robomimic_dataset_generator, robomimic_dataset_size), | |
| } | |
| def encode_dataset_split( | |
| extern_dataset_name: str, | |
| split: str, | |
| max_episodes: Optional[int], | |
| original_res: bool, | |
| no_quantization: bool, | |
| curr_shard_rank: int, | |
| num_shards: int, | |
| root_dir: str, | |
| encoder_type: str, | |
| encoder_name_or_path: str, | |
| dataset_postfix: str = "", | |
| no_encoding: bool = False, | |
| ): | |
| """ | |
| Encodes (e.g. tokenizes) dataset. | |
| The data written to disk can be used to load a `RawTokenDataset` (or the continuous version.) | |
| Args: | |
| extern_dataset_name: TODO | |
| split: expected to be either "train" or "val". TODO: decide how to split | |
| max_episodes: the maximum number of trajectories to include in the dataset. | |
| dataset_postfix: will be a suffix of the output dirname. | |
| image_encoder: string specifying the type of image encoder/tokenizer to use. | |
| original_res: if True, will maintain original resolution of the video rather than resizing it to 256x256. | |
| no_quantization: if True, will not perform quantization step in image encoder. | |
| """ | |
| extern_dataset_name = extern_dataset_name.strip() # never modified | |
| suffixed_dataset_name = extern_dataset_name # will modify later | |
| if original_res: | |
| suffixed_dataset_name = f"{suffixed_dataset_name}_originalres" | |
| if no_quantization: | |
| suffixed_dataset_name = f"{suffixed_dataset_name}_noquant" | |
| if no_encoding: | |
| suffixed_dataset_name = f"{suffixed_dataset_name}_noencoding" | |
| save_dirname = "_".join([suffixed_dataset_name, encoder_type, dataset_postfix, split]) | |
| dataset_path = os.path.join(root_dir, save_dirname) | |
| print("=" * 25) | |
| print(f"{dataset_path=}") | |
| utils.mkdir_if_missing(dataset_path) | |
| # Load data | |
| generator, size_func = DATASET_TO_GEN_AND_SIZE[extern_dataset_name] | |
| num_examples = size_func() | |
| if max_episodes is not None: | |
| num_examples = min(num_examples, max_episodes) # clip num_examples | |
| # We will only operate on a subset of the training examples, depending on: | |
| # 1) The split (train/val). Some examples are reserved for the other split. | |
| # 2) Sharding | |
| assert num_examples > MIN_VAL_EXAMPLES # non-positive number of train examples otherwise | |
| num_val_examples = np.clip(int(VAL_RATIO * num_examples), MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES) | |
| if split == "train": # first_ind inclusive, last_ind exclusive | |
| first_split_ind, last_split_ind = num_val_examples, num_examples | |
| elif split == "val": | |
| first_split_ind, last_split_ind = 0, num_val_examples | |
| else: | |
| raise NotImplementedError(f"{split=}") | |
| first_shard_ind, last_shard_ind = get_shard_inds(first_split_ind, last_split_ind, curr_shard_rank, num_shards) | |
| print(f"Total number of examples in {suffixed_dataset_name}: {num_examples}") | |
| print(f"Number of examples for {split=}, shard {curr_shard_rank} of {num_shards}: " | |
| f"{last_shard_ind - first_shard_ind}. {first_shard_ind=} {last_shard_ind=}") | |
| ##### Encode data ##### | |
| traj_lens = [] # only used to print statistics | |
| videos = [] # NOTE: videos/actions for the entire shard are stored in RAM until the end | |
| actions = [] | |
| segment_ids = [] | |
| # split based on some fixed batch sizes to reset RAM. | |
| max_batch_per_loading = 10 | |
| pbar = tqdm(range(first_shard_ind, last_shard_ind, max_batch_per_loading), position=0, leave=True) | |
| start_time = time.time() | |
| for start_idx in pbar: | |
| end_idx = min(start_idx + max_batch_per_loading, last_shard_ind) | |
| pbar.set_description(f"{suffixed_dataset_name} caching episodes: {start_idx}:{end_idx}") | |
| ds = generator(range(start_idx, end_idx)) | |
| for chunk_idx, episode in enumerate(tqdm(ds, position=1, leave=False)): | |
| segment_id = start_idx + chunk_idx | |
| try: | |
| # batchify the data and then process | |
| for step_ind, step_data in enumerate(episode["steps"]): | |
| dataset_step = process_dataset_step( | |
| step_data, | |
| encoder_type=encoder_type, | |
| encoder_name_or_path=encoder_name_or_path, | |
| keep_res=original_res, | |
| quantize=not no_quantization, | |
| no_encoding=no_encoding | |
| ) | |
| segment_ids.append(segment_id) | |
| videos.append(dataset_step["image"]) | |
| actions.append(dataset_step["action"]) | |
| traj_lens.append(step_ind + 1) # number of steps in this trajectory | |
| except: | |
| print("-" * 25) | |
| print(f"Add episode failed: {segment_id=}", traceback.format_exc(), suffixed_dataset_name) | |
| # 2 day timeout | |
| if time.time() - start_time > 86400 * 2: | |
| print(f"Writing dataset {suffixed_dataset_name} timed out") | |
| break | |
| if len(videos) == 0: | |
| print("Empty shard!") | |
| with open(f"{dataset_path}/error.json", "w") as f: | |
| json.dump({"status": "empty_shard"}, f) | |
| return | |
| if no_quantization: | |
| num_channels, height, width = videos[-1].shape[:3] # num_channels is not actually stored in metadata | |
| else: | |
| height, width = videos[-1].shape[:2] | |
| num_channels = None | |
| ##### Write videos, actions, segment_ids, and metadata ##### | |
| # align format to save segment_ids.bin, video.bin, actions/action.bin, metadata.json | |
| # save videos | |
| videos = np.stack(videos, axis=0) | |
| # fp = np.memmap(f'{dataset_path}/video.bin', dtype=video_dtype, mode='w+', shape=videos.shape) | |
| # fp[:] = videos[:] | |
| videos.tofile(f'{dataset_path}/video.bin') | |
| # save action | |
| utils.mkdir_if_missing(f'{dataset_path}/actions') | |
| actions = np.stack(actions, axis=0) | |
| # fp = np.memmap(f'{dataset_path}/actions/actions.bin', dtype=np.float32, mode='w+', shape=actions.shape) | |
| # fp[:] = actions[:] | |
| actions = actions.astype(np.float32) | |
| actions.tofile(f'{dataset_path}/actions/actions.bin') | |
| # save segment_ids | |
| segment_ids = np.array(segment_ids) | |
| # fp = np.memmap(f'{dataset_path}/segment_ids.bin', dtype=np.int32, mode='w+', shape=segment_ids.shape) | |
| # fp[:] = segment_ids[:] # map to trajectory index | |
| segment_ids = segment_ids.astype(np.int32) | |
| segment_ids.tofile(f'{dataset_path}/segment_ids.bin') | |
| # feature_mean = np.mean(videos) | |
| # feature_std = np.std((videos - feature_mean) / 1e9) * 1e9 | |
| # save metadata | |
| if encoder_type == "magvit": | |
| vocab_size = int(2 ** 18) | |
| elif encoder_type == "temporalvae": | |
| vocab_size = None | |
| else: | |
| raise NotImplementedError(f"{encoder_type=}") | |
| with open(f'{dataset_path}/metadata.json', 'w') as f: # Technically only need to save most of this data for shard 0 | |
| json.dump({ | |
| "token_dtype": str(np.dtype(videos.dtype)), | |
| "action_dim": actions[0].shape[-1], | |
| "s": 16, | |
| "h": height, | |
| "w": width, | |
| "vocab_size": vocab_size, | |
| "hz": DATA_FREQ_TABLE.get(extern_dataset_name, 1), # to be loaded from the data code | |
| "encoder_name_or_path": encoder_name_or_path, | |
| "encoder_type": encoder_type, | |
| "num_images": len(videos), | |
| "latent_channels": num_channels, | |
| "name": extern_dataset_name, | |
| # "feature_mean": feature_mean, | |
| # "feature_std": feature_std, | |
| }, f) | |
| print(f"{len(traj_lens)=} {np.mean(traj_lens)=} {np.sum(traj_lens)=}") | |
| print(f"Dataset creation time: {time.time() - start_time:.3f}") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description=SCRIPT_DESCRIPTION) | |
| parser.add_argument( | |
| "--dataset_name", type=str, required=True, choices=DATASET_TO_GEN_AND_SIZE.keys(), | |
| help="TODO" | |
| ) | |
| parser.add_argument( | |
| "--data_split", type=str, choices=["train", "val"], required=True, | |
| help="The split of the dataset to create." | |
| ) | |
| parser.add_argument( | |
| "--episode_cnt", type=int, | |
| help="If specified, will limit the maximum number of trajectories to encode." | |
| ) | |
| parser.add_argument( | |
| "--original_res", action='store_true', | |
| help="Maintain original resolution of the video rather than resizing it to 256x256." | |
| ) | |
| parser.add_argument( | |
| "--no_quantization", action='store_true', | |
| help="Skip quantization step in visual encoder." | |
| ) | |
| parser.add_argument( | |
| "--num_shards", type=int, default=1, | |
| help="The number of shards to partition the train/val dataset into." | |
| ) | |
| parser.add_argument( | |
| "--curr_shard_rank", type=int, default=0, | |
| help="The (0-indexed) shard number to encode." | |
| ) | |
| parser.add_argument( | |
| "--root_dir", type=str, default="data", | |
| help="The root directory to write all datasets to." | |
| ) | |
| parser.add_argument( | |
| "--encoder_type", type=str, default="magvit", choices=["magvit", "temporalvae"], | |
| help="Type of the image tokenizer." | |
| ) | |
| parser.add_argument( | |
| "--encoder_name_or_path", type=str, default="data/magvit2.ckpt", | |
| help="The path or name of the image encoder." | |
| ) | |
| parser.add_argument( | |
| "--no_encoding", action='store_true', | |
| help="Preserve the groundtruth raw images to compute metrics in validation." | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| utils.set_seed(233) | |
| dataset_postfix = f"shard{args.curr_shard_rank}_of_{args.num_shards}" | |
| if args.episode_cnt is not None: | |
| dataset_postfix = f"max{args.episode_cnt}_{dataset_postfix}" | |
| encode_dataset_split( | |
| extern_dataset_name=args.dataset_name, | |
| split=args.data_split, | |
| max_episodes=args.episode_cnt, | |
| dataset_postfix=dataset_postfix, | |
| original_res=args.original_res, | |
| no_quantization=args.no_quantization, | |
| num_shards=args.num_shards, | |
| curr_shard_rank=args.curr_shard_rank, | |
| root_dir=args.root_dir, | |
| encoder_type=args.encoder_type, | |
| encoder_name_or_path=args.encoder_name_or_path, | |
| no_encoding=args.no_encoding, | |
| ) | |