| """Loads tar files using webdataset.""" |
| import os |
| import webdataset as wds |
| import decord |
| from torch.utils.data import DataLoader |
| import numpy as np |
| import einops |
|
|
|
|
| |
| def decode_video(video_bytes): |
| |
| vr = decord.VideoReader(video_bytes) |
| frames = [vr[i].asnumpy() for i in range(0, len(vr), 5)] |
| return frames |
|
|
|
|
| def convert_bytes_to_frames(video_bytes): |
| vr = decord.VideoReader(video_bytes) |
| frames = [vr[i].asnumpy() for i in range(0, len(vr), 5)] |
| return frames |
|
|
|
|
|
|
| def decode_video(video_bytes): |
| """ |
| Given video bytes, decode them into frames. |
| """ |
| pass |
|
|
|
|
| if __name__ == "__main__": |
| shard_folder = "/work/piyush/from_nfs2/datasets/SSv2/ssv2_shards/" |
| shard_path = os.path.join(shard_folder, "shard-0000.tar") |
|
|
| |
| dataset_path = os.path.join(shard_folder, "shard-{0000..0002}.tar") |
|
|
| ds = wds.WebDataset(dataset_path) |
| sample = next(iter(ds)) |
| dl = DataLoader(ds, batch_size=16, num_workers=8) |
| batch = next(iter(dl)) |
|
|
| |
| dataset = ( |
| wds.WebDataset(dataset_path) |
| |
| .to_tuple("webm") |
| |
| ) |
| |
| |
| |
| |
| H, W, C = (240, 427, 3) |
| for (video,) in dataset: |
| |
|
|
|
|
| |
| print(type(video)) |
| video_array = np.frombuffer(video, dtype=np.uint8) |
| reshaped_video = einops.rearrange(video_array, "(t h w c) -> t c h w", h=H, w=W, c=C) |
| |
| |
|
|
| |
| break |
| import ipdb; ipdb.set_trace() |