|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import random |
|
|
from collections import Counter |
|
|
from typing import List, Optional |
|
|
|
|
|
import decord |
|
|
import torch |
|
|
import torch.utils |
|
|
import torch.utils.data |
|
|
from torch.nn import functional as F |
|
|
from torch.utils.data import Sampler |
|
|
|
|
|
IMG_EXTENSIONS = [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG"] |
|
|
|
|
|
|
|
|
def is_image_file(filename): |
|
|
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) |
|
|
|
|
|
|
|
|
class DecordInit(object): |
|
|
"""Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" |
|
|
|
|
|
def __init__(self, num_threads=1): |
|
|
self.num_threads = num_threads |
|
|
self.ctx = decord.cpu(0) |
|
|
|
|
|
def __call__(self, filename): |
|
|
"""Perform the Decord initialization. |
|
|
Args: |
|
|
results (dict): The resulting dict to be modified and passed |
|
|
to the next transform in pipeline. |
|
|
""" |
|
|
reader = decord.VideoReader(filename, |
|
|
ctx=self.ctx, |
|
|
num_threads=self.num_threads) |
|
|
return reader |
|
|
|
|
|
def __repr__(self): |
|
|
repr_str = (f"{self.__class__.__name__}(" |
|
|
f"sr={self.sr}," |
|
|
f"num_threads={self.num_threads})") |
|
|
return repr_str |
|
|
|
|
|
|
|
|
def pad_to_multiple(number, ds_stride): |
|
|
remainder = number % ds_stride |
|
|
if remainder == 0: |
|
|
return number |
|
|
else: |
|
|
padding = ds_stride - remainder |
|
|
return number + padding |
|
|
|
|
|
|
|
|
|
|
|
class Collate: |
|
|
|
|
|
def __init__(self, args): |
|
|
self.batch_size = args.train_batch_size |
|
|
self.group_frame = args.group_frame |
|
|
self.group_resolution = args.group_resolution |
|
|
|
|
|
self.max_height = args.max_height |
|
|
self.max_width = args.max_width |
|
|
self.ae_stride = args.ae_stride |
|
|
|
|
|
self.ae_stride_t = args.ae_stride_t |
|
|
self.ae_stride_thw = (self.ae_stride_t, self.ae_stride, self.ae_stride) |
|
|
|
|
|
self.patch_size = args.patch_size |
|
|
self.patch_size_t = args.patch_size_t |
|
|
|
|
|
self.num_frames = args.num_frames |
|
|
self.use_image_num = args.use_image_num |
|
|
self.max_thw = (self.num_frames, self.max_height, self.max_width) |
|
|
|
|
|
def package(self, batch): |
|
|
batch_tubes = [i["pixel_values"] for i in batch] |
|
|
input_ids = [i["input_ids"] for i in batch] |
|
|
cond_mask = [i["cond_mask"] for i in batch] |
|
|
return batch_tubes, input_ids, cond_mask |
|
|
|
|
|
def __call__(self, batch): |
|
|
batch_tubes, input_ids, cond_mask = self.package(batch) |
|
|
|
|
|
ds_stride = self.ae_stride * self.patch_size |
|
|
t_ds_stride = self.ae_stride_t * self.patch_size_t |
|
|
|
|
|
pad_batch_tubes, attention_mask, input_ids, cond_mask = self.process( |
|
|
batch_tubes, |
|
|
input_ids, |
|
|
cond_mask, |
|
|
t_ds_stride, |
|
|
ds_stride, |
|
|
self.max_thw, |
|
|
self.ae_stride_thw, |
|
|
) |
|
|
assert not torch.any( |
|
|
torch.isnan(pad_batch_tubes)), "after pad_batch_tubes" |
|
|
return pad_batch_tubes, attention_mask, input_ids, cond_mask |
|
|
|
|
|
def process( |
|
|
self, |
|
|
batch_tubes, |
|
|
input_ids, |
|
|
cond_mask, |
|
|
t_ds_stride, |
|
|
ds_stride, |
|
|
max_thw, |
|
|
ae_stride_thw, |
|
|
): |
|
|
|
|
|
batch_input_size = [i.shape |
|
|
for i in batch_tubes] |
|
|
assert len(batch_input_size) == self.batch_size |
|
|
if self.group_frame or self.group_resolution or self.batch_size == 1: |
|
|
len_each_batch = batch_input_size |
|
|
idx_length_dict = dict( |
|
|
[*zip(list(range(self.batch_size)), len_each_batch)]) |
|
|
count_dict = Counter(len_each_batch) |
|
|
if len(count_dict) != 1: |
|
|
sorted_by_value = sorted(count_dict.items(), |
|
|
key=lambda item: item[1]) |
|
|
pick_length = sorted_by_value[-1][0] |
|
|
candidate_batch = [ |
|
|
idx for idx, length in idx_length_dict.items() |
|
|
if length == pick_length |
|
|
] |
|
|
random_select_batch = [ |
|
|
random.choice(candidate_batch) |
|
|
for _ in range(len(len_each_batch) - len(candidate_batch)) |
|
|
] |
|
|
print( |
|
|
batch_input_size, |
|
|
idx_length_dict, |
|
|
count_dict, |
|
|
sorted_by_value, |
|
|
pick_length, |
|
|
candidate_batch, |
|
|
random_select_batch, |
|
|
) |
|
|
pick_idx = candidate_batch + random_select_batch |
|
|
|
|
|
batch_tubes = [batch_tubes[i] for i in pick_idx] |
|
|
batch_input_size = [i.shape for i in batch_tubes |
|
|
] |
|
|
input_ids = [input_ids[i] for i in pick_idx] |
|
|
cond_mask = [cond_mask[i] for i in pick_idx] |
|
|
|
|
|
for i in range(1, self.batch_size): |
|
|
assert batch_input_size[0] == batch_input_size[i] |
|
|
max_t = max([i[1] for i in batch_input_size]) |
|
|
max_h = max([i[2] for i in batch_input_size]) |
|
|
max_w = max([i[3] for i in batch_input_size]) |
|
|
else: |
|
|
max_t, max_h, max_w = max_thw |
|
|
pad_max_t, pad_max_h, pad_max_w = ( |
|
|
pad_to_multiple(max_t - 1 + self.ae_stride_t, t_ds_stride), |
|
|
pad_to_multiple(max_h, ds_stride), |
|
|
pad_to_multiple(max_w, ds_stride), |
|
|
) |
|
|
pad_max_t = pad_max_t + 1 - self.ae_stride_t |
|
|
each_pad_t_h_w = [[ |
|
|
pad_max_t - i.shape[1], pad_max_h - i.shape[2], |
|
|
pad_max_w - i.shape[3] |
|
|
] for i in batch_tubes] |
|
|
pad_batch_tubes = [ |
|
|
F.pad(im, (0, pad_w, 0, pad_h, 0, pad_t), value=0) |
|
|
for (pad_t, pad_h, pad_w), im in zip(each_pad_t_h_w, batch_tubes) |
|
|
] |
|
|
pad_batch_tubes = torch.stack(pad_batch_tubes, dim=0) |
|
|
|
|
|
max_tube_size = [pad_max_t, pad_max_h, pad_max_w] |
|
|
max_latent_size = [ |
|
|
((max_tube_size[0] - 1) // ae_stride_thw[0] + 1), |
|
|
max_tube_size[1] // ae_stride_thw[1], |
|
|
max_tube_size[2] // ae_stride_thw[2], |
|
|
] |
|
|
valid_latent_size = [[ |
|
|
int(math.ceil((i[1] - 1) / ae_stride_thw[0])) + 1, |
|
|
int(math.ceil(i[2] / ae_stride_thw[1])), |
|
|
int(math.ceil(i[3] / ae_stride_thw[2])), |
|
|
] for i in batch_input_size] |
|
|
attention_mask = [ |
|
|
F.pad( |
|
|
torch.ones(i, dtype=pad_batch_tubes.dtype), |
|
|
( |
|
|
0, |
|
|
max_latent_size[2] - i[2], |
|
|
0, |
|
|
max_latent_size[1] - i[1], |
|
|
0, |
|
|
max_latent_size[0] - i[0], |
|
|
), |
|
|
value=0, |
|
|
) for i in valid_latent_size |
|
|
] |
|
|
attention_mask = torch.stack(attention_mask) |
|
|
if self.batch_size == 1 or self.group_frame or self.group_resolution: |
|
|
assert torch.all(attention_mask.bool()) |
|
|
|
|
|
input_ids = torch.stack(input_ids) |
|
|
cond_mask = torch.stack(cond_mask) |
|
|
|
|
|
return pad_batch_tubes, attention_mask, input_ids, cond_mask |
|
|
|
|
|
|
|
|
def split_to_even_chunks(indices, lengths, num_chunks, batch_size): |
|
|
""" |
|
|
Split a list of indices into `chunks` chunks of roughly equal lengths. |
|
|
""" |
|
|
|
|
|
if len(indices) % num_chunks != 0: |
|
|
chunks = [indices[i::num_chunks] for i in range(num_chunks)] |
|
|
else: |
|
|
num_indices_per_chunk = len(indices) // num_chunks |
|
|
|
|
|
chunks = [[] for _ in range(num_chunks)] |
|
|
chunks_lengths = [0 for _ in range(num_chunks)] |
|
|
for index in indices: |
|
|
shortest_chunk = chunks_lengths.index(min(chunks_lengths)) |
|
|
chunks[shortest_chunk].append(index) |
|
|
chunks_lengths[shortest_chunk] += lengths[index] |
|
|
if len(chunks[shortest_chunk]) == num_indices_per_chunk: |
|
|
chunks_lengths[shortest_chunk] = float("inf") |
|
|
|
|
|
|
|
|
pad_chunks = [] |
|
|
for idx, chunk in enumerate(chunks): |
|
|
if batch_size != len(chunk): |
|
|
assert batch_size > len(chunk) |
|
|
if len(chunk) != 0: |
|
|
chunk = chunk + [ |
|
|
random.choice(chunk) |
|
|
for _ in range(batch_size - len(chunk)) |
|
|
] |
|
|
else: |
|
|
chunk = random.choice(pad_chunks) |
|
|
print(chunks[idx], "->", chunk) |
|
|
pad_chunks.append(chunk) |
|
|
return pad_chunks |
|
|
|
|
|
|
|
|
def group_frame_fun(indices, lengths): |
|
|
|
|
|
indices.sort(key=lambda i: lengths[i], reverse=True) |
|
|
return indices |
|
|
|
|
|
|
|
|
def megabatch_frame_alignment(megabatches, lengths): |
|
|
aligned_magabatches = [] |
|
|
for _, megabatch in enumerate(megabatches): |
|
|
assert len(megabatch) != 0 |
|
|
len_each_megabatch = [lengths[i] for i in megabatch] |
|
|
idx_length_dict = dict([*zip(megabatch, len_each_megabatch)]) |
|
|
count_dict = Counter(len_each_megabatch) |
|
|
|
|
|
|
|
|
if len(count_dict) != 1: |
|
|
sorted_by_value = sorted(count_dict.items(), |
|
|
key=lambda item: item[1]) |
|
|
pick_length = sorted_by_value[-1][0] |
|
|
candidate_batch = [ |
|
|
idx for idx, length in idx_length_dict.items() |
|
|
if length == pick_length |
|
|
] |
|
|
random_select_batch = [ |
|
|
random.choice(candidate_batch) |
|
|
for i in range(len(idx_length_dict) - len(candidate_batch)) |
|
|
] |
|
|
aligned_magabatch = candidate_batch + random_select_batch |
|
|
aligned_magabatches.append(aligned_magabatch) |
|
|
|
|
|
else: |
|
|
aligned_magabatches.append(megabatch) |
|
|
|
|
|
return aligned_magabatches |
|
|
|
|
|
|
|
|
def get_length_grouped_indices( |
|
|
lengths, |
|
|
batch_size, |
|
|
world_size, |
|
|
generator=None, |
|
|
group_frame=False, |
|
|
group_resolution=False, |
|
|
seed=42, |
|
|
): |
|
|
|
|
|
if generator is None: |
|
|
generator = torch.Generator().manual_seed( |
|
|
seed) |
|
|
|
|
|
indices = torch.randperm(len(lengths), generator=generator).tolist() |
|
|
|
|
|
|
|
|
indices = group_frame_fun(indices, lengths) |
|
|
|
|
|
|
|
|
megabatch_size = world_size * batch_size |
|
|
megabatches = [ |
|
|
indices[i:i + megabatch_size] |
|
|
for i in range(0, len(lengths), megabatch_size) |
|
|
] |
|
|
|
|
|
|
|
|
megabatches = megabatch_frame_alignment(megabatches, lengths) |
|
|
|
|
|
|
|
|
megabatches = [ |
|
|
split_to_even_chunks(megabatch, lengths, world_size, batch_size) |
|
|
for megabatch in megabatches |
|
|
] |
|
|
|
|
|
|
|
|
indices = torch.randperm(len(megabatches), generator=generator).tolist() |
|
|
shuffled_megabatches = [megabatches[i] for i in indices] |
|
|
|
|
|
|
|
|
return [ |
|
|
i for megabatch in shuffled_megabatches for batch in megabatch |
|
|
for i in batch |
|
|
] |
|
|
|
|
|
|
|
|
class LengthGroupedSampler(Sampler): |
|
|
r""" |
|
|
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while |
|
|
keeping a bit of randomness. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
batch_size: int, |
|
|
rank: int, |
|
|
world_size: int, |
|
|
lengths: Optional[List[int]] = None, |
|
|
group_frame=False, |
|
|
group_resolution=False, |
|
|
generator=None, |
|
|
): |
|
|
if lengths is None: |
|
|
raise ValueError("Lengths must be provided.") |
|
|
|
|
|
self.batch_size = batch_size |
|
|
self.rank = rank |
|
|
self.world_size = world_size |
|
|
self.lengths = lengths |
|
|
self.group_frame = group_frame |
|
|
self.group_resolution = group_resolution |
|
|
self.generator = generator |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.lengths) |
|
|
|
|
|
def __iter__(self): |
|
|
indices = get_length_grouped_indices( |
|
|
self.lengths, |
|
|
self.batch_size, |
|
|
self.world_size, |
|
|
group_frame=self.group_frame, |
|
|
group_resolution=self.group_resolution, |
|
|
generator=self.generator, |
|
|
) |
|
|
|
|
|
def distributed_sampler(lst, rank, batch_size, world_size): |
|
|
result = [] |
|
|
index = rank * batch_size |
|
|
while index < len(lst): |
|
|
result.extend(lst[index:index + batch_size]) |
|
|
index += batch_size * world_size |
|
|
return result |
|
|
|
|
|
indices = distributed_sampler(indices, self.rank, self.batch_size, |
|
|
self.world_size) |
|
|
return iter(indices) |
|
|
|