# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # import numpy as np import torch import torch.nn as nn import torchvision.transforms as transforms import src.datasets.utils.video.transforms as video_transforms import src.datasets.utils.video.volume_transforms as volume_transforms from src.datasets.utils.video.randerase import RandomErasing from src.models.utils.pos_embs import get_1d_sincos_pos_embed from src.masks.utils import apply_masks class FrameAggregation(nn.Module): """ Process each frame independently and concatenate all tokens """ def __init__( self, model, max_frames=10000, use_pos_embed=False, attend_across_segments=False ): super().__init__() self.model = model self.embed_dim = embed_dim = model.embed_dim self.num_heads = model.num_heads self.attend_across_segments = attend_across_segments # 1D-temporal pos-embedding self.pos_embed = None if use_pos_embed: self.pos_embed = nn.Parameter( torch.zeros(1, max_frames, embed_dim), requires_grad=False) sincos = get_1d_sincos_pos_embed(embed_dim, max_frames) self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) def forward(self, x, clip_indices=None): # TODO: implement attend_across_segments=False # num_clips = len(x) num_views_per_clip = len(x[0]) # Concatenate views along batch dimension x = [torch.cat(xi, dim=0) for xi in x] # Concatenate clips along temporal dimension x = torch.cat(x, dim=2) B, C, T, H, W = x.size() # Put each frame along the batch dimension x = x.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W) outputs = self.model(x) _, N, D = outputs.size() outputs = outputs.reshape(B, T, N, D).flatten(1, 2) # Separate views into list B = B // num_views_per_clip all_outputs = [] for i in range(num_views_per_clip): o = outputs[i*B:(i+1)*B] # Compute positional embedding if (self.pos_embed is not None) and (clip_indices is not None): pos_embed = self.pos_embed.repeat(B, 1, 1) # [B, F, D] pos_embed = apply_masks(pos_embed, clip_indices, concat=False) # list(Tensor([B, T, D])) pos_embed = torch.cat(pos_embed, dim=1) # concatenate along temporal dimension pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) # [B, T*num_clips, N, D] pos_embed = pos_embed.flatten(1, 2) o += pos_embed all_outputs += [o] return all_outputs class ClipAggregation(nn.Module): """ Process each clip independently and concatenate all tokens """ def __init__( self, model, tubelet_size=2, max_frames=10000, use_pos_embed=False, attend_across_segments=False ): super().__init__() self.model = model self.tubelet_size = tubelet_size self.embed_dim = embed_dim = model.embed_dim self.num_heads = model.num_heads self.attend_across_segments = attend_across_segments # 1D-temporal pos-embedding self.pos_embed = None if use_pos_embed: max_T = max_frames // tubelet_size self.pos_embed = nn.Parameter( torch.zeros(1, max_T, embed_dim), requires_grad=False) sincos = get_1d_sincos_pos_embed(embed_dim, max_T) self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) def forward(self, x, clip_indices=None): num_clips = len(x) num_views_per_clip = len(x[0]) B, C, T, H, W = x[0][0].size() # Concatenate all spatial and temporal views along batch dimension x = [torch.cat(xi, dim=0) for xi in x] x = torch.cat(x, dim=0) outputs = self.model(x) _, N, D = outputs.size() T = T // self.tubelet_size # Num temporal tokens N = N // T # Num spatial tokens # Unroll outputs into a 2D array [spatial_views x temporal_views] eff_B = B * num_views_per_clip all_outputs = [[] for _ in range(num_views_per_clip)] for i in range(num_clips): o = outputs[i*eff_B:(i+1)*eff_B] for j in range(num_views_per_clip): all_outputs[j].append(o[j*B:(j+1)*B]) if not self.attend_across_segments: return all_outputs for i, outputs in enumerate(all_outputs): # Concatenate along temporal dimension outputs = [o.reshape(B, T, N, D) for o in outputs] outputs = torch.cat(outputs, dim=1).flatten(1, 2) # Compute positional embedding if (self.pos_embed is not None) and (clip_indices is not None): clip_indices = [c[:, ::self.tubelet_size] for c in clip_indices] pos_embed = self.pos_embed.repeat(B, 1, 1) # [B, F, D] pos_embed = apply_masks(pos_embed, clip_indices, concat=False) # list(Tensor([B, T, D])) pos_embed = torch.cat(pos_embed, dim=1) # concatenate along temporal dimension pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) # [B, T*num_clips, N, D] pos_embed = pos_embed.flatten(1, 2) outputs += pos_embed all_outputs[i] = outputs return all_outputs def make_transforms( training=True, random_horizontal_flip=True, random_resize_aspect_ratio=(3/4, 4/3), random_resize_scale=(0.3, 1.0), reprob=0.0, auto_augment=False, motion_shift=False, crop_size=224, num_views_per_clip=1, normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ): if not training and num_views_per_clip > 1: print('Making EvalVideoTransform, multi-view') _frames_augmentation = EvalVideoTransform( num_views_per_clip=num_views_per_clip, short_side_size=crop_size, normalize=normalize, ) else: _frames_augmentation = VideoTransform( training=training, random_horizontal_flip=random_horizontal_flip, random_resize_aspect_ratio=random_resize_aspect_ratio, random_resize_scale=random_resize_scale, reprob=reprob, auto_augment=auto_augment, motion_shift=motion_shift, crop_size=crop_size, normalize=normalize, ) return _frames_augmentation class VideoTransform(object): def __init__( self, training=True, random_horizontal_flip=True, random_resize_aspect_ratio=(3/4, 4/3), random_resize_scale=(0.3, 1.0), reprob=0.0, auto_augment=False, motion_shift=False, crop_size=224, normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ): self.training = training short_side_size = int(crop_size * 256 / 224) self.eval_transform = video_transforms.Compose([ video_transforms.Resize(short_side_size, interpolation='bilinear'), video_transforms.CenterCrop(size=(crop_size, crop_size)), volume_transforms.ClipToTensor(), video_transforms.Normalize(mean=normalize[0], std=normalize[1]) ]) self.random_horizontal_flip = random_horizontal_flip self.random_resize_aspect_ratio = random_resize_aspect_ratio self.random_resize_scale = random_resize_scale self.auto_augment = auto_augment self.motion_shift = motion_shift self.crop_size = crop_size self.normalize = torch.tensor(normalize) self.autoaug_transform = video_transforms.create_random_augment( input_size=(crop_size, crop_size), auto_augment='rand-m7-n4-mstd0.5-inc1', interpolation='bicubic', ) self.spatial_transform = video_transforms.random_resized_crop_with_shift \ if motion_shift else video_transforms.random_resized_crop self.reprob = reprob self.erase_transform = RandomErasing( reprob, mode='pixel', max_count=1, num_splits=1, device='cpu', ) def __call__(self, buffer): if not self.training: return [self.eval_transform(buffer)] buffer = [transforms.ToPILImage()(frame) for frame in buffer] if self.auto_augment: buffer = self.autoaug_transform(buffer) buffer = [transforms.ToTensor()(img) for img in buffer] buffer = torch.stack(buffer) # T C H W buffer = buffer.permute(0, 2, 3, 1) # T H W C buffer = tensor_normalize(buffer, self.normalize[0], self.normalize[1]) buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W buffer = self.spatial_transform( images=buffer, target_height=self.crop_size, target_width=self.crop_size, scale=self.random_resize_scale, ratio=self.random_resize_aspect_ratio, ) if self.random_horizontal_flip: buffer, _ = video_transforms.horizontal_flip(0.5, buffer) if self.reprob > 0: buffer = buffer.permute(1, 0, 2, 3) buffer = self.erase_transform(buffer) buffer = buffer.permute(1, 0, 2, 3) return [buffer] class EvalVideoTransform(object): def __init__( self, num_views_per_clip=1, short_side_size=224, normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ): self.views_per_clip = num_views_per_clip self.short_side_size = short_side_size self.spatial_resize = video_transforms.Resize(short_side_size, interpolation='bilinear') self.to_tensor = video_transforms.Compose([ volume_transforms.ClipToTensor(), video_transforms.Normalize(mean=normalize[0], std=normalize[1]) ]) def __call__(self, buffer): # Sample several spatial views of each clip buffer = np.array(self.spatial_resize(buffer)) T, H, W, C = buffer.shape num_views = self.views_per_clip side_len = self.short_side_size spatial_step = (max(H, W) - side_len) // (num_views - 1) all_views = [] for i in range(num_views): start = i*spatial_step if H > W: view = buffer[:, start:start+side_len, :, :] else: view = buffer[:, :, start:start+side_len, :] view = self.to_tensor(view) all_views.append(view) return all_views def tensor_normalize(tensor, mean, std): """ Normalize a given tensor by subtracting the mean and dividing the std. Args: tensor (tensor): tensor to normalize. mean (tensor or list): mean value to subtract. std (tensor or list): std to divide. """ if tensor.dtype == torch.uint8: tensor = tensor.float() tensor = tensor / 255.0 if type(mean) == list: mean = torch.tensor(mean) if type(std) == list: std = torch.tensor(std) tensor = tensor - mean tensor = tensor / std return tensor