File size: 5,826 Bytes
068b511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import json 
import torch 
from typing import Any, Dict, List, Optional, Tuple
from torch.utils.data import DataLoader, Dataset 
import torchvision.transforms as TT  
from torchvision import transforms
from torchvision.transforms.functional import center_crop, resize 
from torchvision.transforms import InterpolationMode 
import random 
try:
    import decord
except ImportError:
    raise ImportError(
        "The `decord` package is required for loading the video dataset. Install with `pip install decord`"
    )

decord.bridge.set_bridge("torch")

class ImageVideoDataset(Dataset): 
    def __init__(
        self,
        data_root,
        tokenizer,
        max_sequence_length: int = 226,
        height: int = 480,
        width: int = 720,
        video_reshape_mode: str = "center",
        fps: int = 8,
        stripe: int = 2,
        max_num_frames: int = 49,
        skip_frames_start: int = 0,
        skip_frames_end: int = 0,
        random_flip: Optional[float] = None,
    ) -> None:
        super().__init__() 

        with open(data_root, 'r') as f: 
            self.data_list = json.load(f)
        
        self.tokenizer = tokenizer 
        self.max_sequence_length = max_sequence_length 
        self.height = height
        self.width = width
        self.video_reshape_mode = video_reshape_mode
        self.fps = fps
        self.max_num_frames = max_num_frames
        self.skip_frames_start = skip_frames_start
        self.skip_frames_end = skip_frames_end 
        self.stripe = stripe 
        self.video_transforms = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(random_flip) if random_flip else transforms.Lambda(lambda x: x),
                transforms.Lambda(lambda x: x / 255.0),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
            ]
        )

    
    def __len__(self):
        return len(self.data_list) 
    
    def _resize_for_rectangle_crop(self, arr):
        image_size = self.height, self.width
        reshape_mode = self.video_reshape_mode
        if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
            arr = resize(
                arr,
                size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
                interpolation=InterpolationMode.BICUBIC,
            )
        else:
            arr = resize(
                arr,
                size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
                interpolation=InterpolationMode.BICUBIC,
            )

        h, w = arr.shape[2], arr.shape[3]
        arr = arr.squeeze(0)

        delta_h = h - image_size[0]
        delta_w = w - image_size[1]

        if reshape_mode == "random" or reshape_mode == "none":
            top = np.random.randint(0, delta_h + 1)
            left = np.random.randint(0, delta_w + 1)
        elif reshape_mode == "center":
            top, left = delta_h // 2, delta_w // 2
        else:
            raise NotImplementedError
        arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
        return arr
    
    def __getitem__(self, index): 
        while True: 
            try:
                video_reader = decord.VideoReader(self.data_list[index]['file_path'], width=self.width, height=self.height)
                video_num_frames = len(video_reader) 
                # print(video_num_frames, video_reader.get_avg_fps()) 
                if self.stripe * self.max_num_frames > video_num_frames: 
                    stripe = 1
                else:
                    stripe = self.stripe 

                random_range = video_num_frames - stripe * self.max_num_frames - 1
                random_range = max(1, random_range)
                start_frame = random.randint(1, random_range) if random_range > 0 else 1
                
                indices = list(range(start_frame, start_frame + stripe * self.max_num_frames, stripe)) # (end_frame - start_frame) // self.max_num_frames))
                frames = video_reader.get_batch(indices)

                # Ensure that we don't go over the limit
                frames = frames[: self.max_num_frames]
                selected_num_frames = frames.shape[0]

                # Choose first (4k + 1) frames as this is how many is required by the VAE
                remainder = (3 + (selected_num_frames % 4)) % 4
                if remainder != 0:
                    frames = frames[:-remainder]
                selected_num_frames = frames.shape[0]

                assert (selected_num_frames - 1) % 4 == 0 
                if selected_num_frames == self.max_num_frames: 
                    break 
                else:
                    index = (index + 1) % len(self.data_list) 
                    continue 
            
            except Exception as e:
                index = (index + 1) % len(self.data_list) 
                print(video_num_frames, start_frame, indices)
                print(
                    "Error encounter during audio feature extraction: ", e, 
                )
                continue

        # Training transforms
        # frames = (frames - 127.5) / 127.5
        frames = frames.permute(0, 3, 1, 2).contiguous()  # [F, C, H, W]
        frames = self._resize_for_rectangle_crop(frames) 
        frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) 

        text_inputs = self.tokenizer(
            [self.data_list[index]['text']],
            padding="max_length",
            max_length=self.max_sequence_length,
            truncation=True,
            add_special_tokens=True,
            return_tensors="pt",
        )
        text_input_ids = text_inputs.input_ids[0]

        return frames.contiguous(), text_input_ids