File size: 2,169 Bytes
7daf628 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | """Loads tar files using webdataset."""
import os
import webdataset as wds
import decord
from torch.utils.data import DataLoader
import numpy as np
import einops
# Define a function to decode videos using Decord
def decode_video(video_bytes):
# Save the video bytes to a temporary file and decode with decord
vr = decord.VideoReader(video_bytes)
frames = [vr[i].asnumpy() for i in range(0, len(vr), 5)] # Frame skip example
return frames
def convert_bytes_to_frames(video_bytes):
vr = decord.VideoReader(video_bytes)
frames = [vr[i].asnumpy() for i in range(0, len(vr), 5)] # Frame skip example
return frames
def decode_video(video_bytes):
"""
Given video bytes, decode them into frames.
"""
pass
if __name__ == "__main__":
shard_folder = "/work/piyush/from_nfs2/datasets/SSv2/ssv2_shards/"
shard_path = os.path.join(shard_folder, "shard-0000.tar")
# Define your WebDataset path pattern (all shards)
dataset_path = os.path.join(shard_folder, "shard-{0000..0002}.tar")
ds = wds.WebDataset(dataset_path)
sample = next(iter(ds))
dl = DataLoader(ds, batch_size=16, num_workers=8)
batch = next(iter(dl))
# Create a WebDataset loader
dataset = (
wds.WebDataset(dataset_path)
# .decode("rgb") # Ensure that we decode the video bytes into RGB images
.to_tuple("webm") # Ensure that we get the video bytes and metadata
# .map_tuple(decode_video) # Apply your video decoding function
)
# dataloader = DataLoader(dataset, batch_size=16, num_workers=8)
# batch = next(iter(dataloader))
# print(batch[0].shape) # (16, 32, 256, 256, 3) for example
# H, W, C = 256, 256, 3
H, W, C = (240, 427, 3)
for (video,) in dataset:
# Get file name
# print(len(video[0]))
print(type(video))
video_array = np.frombuffer(video, dtype=np.uint8)
reshaped_video = einops.rearrange(video_array, "(t h w c) -> t c h w", h=H, w=W, c=C)
# print(len(video))
# print(video[1])
# np_video_bytes = np.frombuffer(video[0], np.uint8)
break
import ipdb; ipdb.set_trace() |