Spaces:
Runtime error
Runtime error
| """ | |
| Merge data shards generated from `encode_{extern,openx}_dataset.py` | |
| In addition to CLI args, `SHARD_DATA_FORMAT` must be changed depending on the dataset. | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import numpy as np | |
| from tqdm.auto import tqdm | |
| SHARD_DATA_FORMAT = "/private/home/xinleic/LR/HPT-Video-KZ/sharded_data/droid_magvit_shard{}_of_{}_train" | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--out_data_dir", type=str, required=True, | |
| help="Directory to save merged data, must not exist.") | |
| parser.add_argument("--num_shards", type=int, required=True, help="Number of shards the dataset was split into.") | |
| args = parser.parse_args() | |
| assert not os.path.exists(args.out_data_dir), "Will not overwrite existing directory." | |
| os.makedirs(os.path.join(args.out_data_dir, "actions"), exist_ok=True) | |
| num_frames = 0 | |
| valid_inds = [] | |
| for shard_ind in range(args.num_shards): | |
| shard_path = SHARD_DATA_FORMAT.format(shard_ind, args.num_shards) | |
| if os.path.isfile(os.path.join(shard_path, "metadata.json")): | |
| valid_inds.append(shard_ind) | |
| with open(os.path.join(shard_path, "metadata.json"), "r") as f: | |
| shard_metadata = json.load(f) | |
| num_frames += shard_metadata["num_images"] | |
| else: | |
| print(f"{shard_ind=} is invalid.") | |
| if num_frames == 0: | |
| print("No valid shards") | |
| exit(0) | |
| token_dtype = np.dtype(shard_metadata["token_dtype"]) | |
| if shard_metadata["quantized"]: | |
| frame_dims = (shard_metadata["h"], shard_metadata["w"]) | |
| else: | |
| frame_dims = (shard_metadata["latent_channels"], shard_metadata["h"], shard_metadata["w"]) | |
| action_dim = shard_metadata["action_dim"] | |
| videos = np.memmap( | |
| os.path.join(args.out_data_dir, "video.bin"), | |
| dtype=token_dtype, | |
| mode="write", | |
| shape=(num_frames, *frame_dims) | |
| ) | |
| actions = np.memmap( | |
| os.path.join(args.out_data_dir, "actions", "actions.bin"), | |
| dtype=np.float32, | |
| mode="write", | |
| shape=(num_frames, action_dim) | |
| ) | |
| segment_ids = np.memmap( | |
| os.path.join(args.out_data_dir, "segment_ids.bin"), | |
| dtype=np.int32, | |
| mode="write", | |
| shape=(num_frames,) | |
| ) | |
| prev_frame_ind = 0 | |
| prev_segment_id = 0 | |
| for shard_ind in tqdm(valid_inds): | |
| shard_path = SHARD_DATA_FORMAT.format(shard_ind, args.num_shards) | |
| with open(os.path.join(shard_path, "metadata.json"), "r") as f: | |
| shard_metadata = json.load(f) | |
| shard_num_frames = shard_metadata["num_images"] | |
| videos[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap( | |
| os.path.join(shard_path, "video.bin"), | |
| dtype=np.dtype(shard_metadata["token_dtype"]), | |
| mode="r", | |
| shape=(shard_num_frames, *frame_dims), | |
| ) | |
| actions[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap( | |
| os.path.join(shard_path, "actions", "actions.bin"), | |
| dtype=np.float32, | |
| mode="r", | |
| shape=(shard_num_frames, action_dim), | |
| ) | |
| segment_ids[prev_frame_ind: prev_frame_ind + shard_num_frames] = np.memmap( | |
| os.path.join(shard_path, "segment_ids.bin"), | |
| dtype=np.int32, | |
| mode="r", | |
| shape=(shard_num_frames,), | |
| ) + prev_segment_id | |
| prev_segment_id = segment_ids[prev_frame_ind + shard_num_frames - 1] + 1 | |
| prev_frame_ind += shard_num_frames | |
| assert prev_frame_ind == num_frames | |
| print("Finished") | |
| with (open(os.path.join(args.out_data_dir, "metadata.json"), "w") as f): | |
| merged_metadata = shard_metadata \ | |
| | vars(args) \ | |
| | {"num_images": num_frames, "input_path": SHARD_DATA_FORMAT.format(0, args.num_shards)} | |
| json.dump(merged_metadata, f) |