Safetensors
English
llava
video-retrieval
text-to-video-search
multimodal-embedding
File size: 2,169 Bytes
7daf628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Loads tar files using webdataset."""
import os
import webdataset as wds
import decord
from torch.utils.data import DataLoader
import numpy as np
import einops


# Define a function to decode videos using Decord
def decode_video(video_bytes):
    # Save the video bytes to a temporary file and decode with decord
    vr = decord.VideoReader(video_bytes)
    frames = [vr[i].asnumpy() for i in range(0, len(vr), 5)]  # Frame skip example
    return frames


def convert_bytes_to_frames(video_bytes):
    vr = decord.VideoReader(video_bytes)
    frames = [vr[i].asnumpy() for i in range(0, len(vr), 5)]  # Frame skip example
    return frames



def decode_video(video_bytes):
    """
    Given video bytes, decode them into frames.
    """
    pass


if __name__ == "__main__":
    shard_folder = "/work/piyush/from_nfs2/datasets/SSv2/ssv2_shards/"
    shard_path = os.path.join(shard_folder, "shard-0000.tar")

    # Define your WebDataset path pattern (all shards)
    dataset_path = os.path.join(shard_folder, "shard-{0000..0002}.tar")

    ds = wds.WebDataset(dataset_path)
    sample = next(iter(ds))
    dl = DataLoader(ds, batch_size=16, num_workers=8)
    batch = next(iter(dl))

    # Create a WebDataset loader
    dataset = (
        wds.WebDataset(dataset_path)
        # .decode("rgb")  # Ensure that we decode the video bytes into RGB images
        .to_tuple("webm")  # Ensure that we get the video bytes and metadata
        # .map_tuple(decode_video)  # Apply your video decoding function
    )
    # dataloader = DataLoader(dataset, batch_size=16, num_workers=8)
    # batch = next(iter(dataloader))
    # print(batch[0].shape)  # (16, 32, 256, 256, 3) for example
    # H, W, C = 256, 256, 3
    H, W, C = (240, 427, 3)
    for (video,) in dataset:
        # Get file name


        # print(len(video[0]))
        print(type(video))
        video_array = np.frombuffer(video, dtype=np.uint8)
        reshaped_video = einops.rearrange(video_array, "(t h w c) -> t c h w", h=H, w=W, c=C)
        # print(len(video))
        # print(video[1])

        # np_video_bytes = np.frombuffer(video[0], np.uint8)
        break
    import ipdb; ipdb.set_trace()