| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import numpy as np |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torchvision.transforms as transforms |
| |
|
| | import src.datasets.utils.video.transforms as video_transforms |
| | import src.datasets.utils.video.volume_transforms as volume_transforms |
| |
|
| | from src.datasets.utils.video.randerase import RandomErasing |
| |
|
| | from src.models.utils.pos_embs import get_1d_sincos_pos_embed |
| | from src.masks.utils import apply_masks |
| |
|
| |
|
| | class FrameAggregation(nn.Module): |
| | """ |
| | Process each frame independently and concatenate all tokens |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model, |
| | max_frames=10000, |
| | use_pos_embed=False, |
| | attend_across_segments=False |
| | ): |
| | super().__init__() |
| | self.model = model |
| | self.embed_dim = embed_dim = model.embed_dim |
| | self.num_heads = model.num_heads |
| | self.attend_across_segments = attend_across_segments |
| | |
| | self.pos_embed = None |
| | if use_pos_embed: |
| | self.pos_embed = nn.Parameter( |
| | torch.zeros(1, max_frames, embed_dim), |
| | requires_grad=False) |
| | sincos = get_1d_sincos_pos_embed(embed_dim, max_frames) |
| | self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) |
| |
|
| | def forward(self, x, clip_indices=None): |
| |
|
| | |
| | |
| | num_views_per_clip = len(x[0]) |
| |
|
| | |
| | x = [torch.cat(xi, dim=0) for xi in x] |
| | |
| | x = torch.cat(x, dim=2) |
| | B, C, T, H, W = x.size() |
| |
|
| | |
| | x = x.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W) |
| |
|
| | outputs = self.model(x) |
| | _, N, D = outputs.size() |
| | outputs = outputs.reshape(B, T, N, D).flatten(1, 2) |
| |
|
| | |
| | B = B // num_views_per_clip |
| | all_outputs = [] |
| | for i in range(num_views_per_clip): |
| | o = outputs[i*B:(i+1)*B] |
| | |
| | if (self.pos_embed is not None) and (clip_indices is not None): |
| | pos_embed = self.pos_embed.repeat(B, 1, 1) |
| | pos_embed = apply_masks(pos_embed, clip_indices, concat=False) |
| | pos_embed = torch.cat(pos_embed, dim=1) |
| | pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) |
| | pos_embed = pos_embed.flatten(1, 2) |
| | o += pos_embed |
| | all_outputs += [o] |
| |
|
| | return all_outputs |
| |
|
| |
|
| | class ClipAggregation(nn.Module): |
| | """ |
| | Process each clip independently and concatenate all tokens |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model, |
| | tubelet_size=2, |
| | max_frames=10000, |
| | use_pos_embed=False, |
| | attend_across_segments=False |
| | ): |
| | super().__init__() |
| | self.model = model |
| | self.tubelet_size = tubelet_size |
| | self.embed_dim = embed_dim = model.embed_dim |
| | self.num_heads = model.num_heads |
| | self.attend_across_segments = attend_across_segments |
| | |
| | self.pos_embed = None |
| | if use_pos_embed: |
| | max_T = max_frames // tubelet_size |
| | self.pos_embed = nn.Parameter( |
| | torch.zeros(1, max_T, embed_dim), |
| | requires_grad=False) |
| | sincos = get_1d_sincos_pos_embed(embed_dim, max_T) |
| | self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) |
| |
|
| | def forward(self, x, clip_indices=None): |
| |
|
| | num_clips = len(x) |
| | num_views_per_clip = len(x[0]) |
| | B, C, T, H, W = x[0][0].size() |
| |
|
| | |
| | x = [torch.cat(xi, dim=0) for xi in x] |
| | x = torch.cat(x, dim=0) |
| | outputs = self.model(x) |
| | _, N, D = outputs.size() |
| |
|
| | T = T // self.tubelet_size |
| | N = N // T |
| |
|
| | |
| | eff_B = B * num_views_per_clip |
| | all_outputs = [[] for _ in range(num_views_per_clip)] |
| | for i in range(num_clips): |
| | o = outputs[i*eff_B:(i+1)*eff_B] |
| | for j in range(num_views_per_clip): |
| | all_outputs[j].append(o[j*B:(j+1)*B]) |
| |
|
| | if not self.attend_across_segments: |
| | return all_outputs |
| |
|
| | for i, outputs in enumerate(all_outputs): |
| |
|
| | |
| | outputs = [o.reshape(B, T, N, D) for o in outputs] |
| | outputs = torch.cat(outputs, dim=1).flatten(1, 2) |
| |
|
| | |
| | if (self.pos_embed is not None) and (clip_indices is not None): |
| | clip_indices = [c[:, ::self.tubelet_size] for c in clip_indices] |
| | pos_embed = self.pos_embed.repeat(B, 1, 1) |
| | pos_embed = apply_masks(pos_embed, clip_indices, concat=False) |
| | pos_embed = torch.cat(pos_embed, dim=1) |
| | pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) |
| | pos_embed = pos_embed.flatten(1, 2) |
| | outputs += pos_embed |
| |
|
| | all_outputs[i] = outputs |
| |
|
| | return all_outputs |
| |
|
| |
|
| | def make_transforms( |
| | training=True, |
| | random_horizontal_flip=True, |
| | random_resize_aspect_ratio=(3/4, 4/3), |
| | random_resize_scale=(0.3, 1.0), |
| | reprob=0.0, |
| | auto_augment=False, |
| | motion_shift=False, |
| | crop_size=224, |
| | num_views_per_clip=1, |
| | normalize=((0.485, 0.456, 0.406), |
| | (0.229, 0.224, 0.225)) |
| | ): |
| |
|
| | if not training and num_views_per_clip > 1: |
| | print('Making EvalVideoTransform, multi-view') |
| | _frames_augmentation = EvalVideoTransform( |
| | num_views_per_clip=num_views_per_clip, |
| | short_side_size=crop_size, |
| | normalize=normalize, |
| | ) |
| |
|
| | else: |
| | _frames_augmentation = VideoTransform( |
| | training=training, |
| | random_horizontal_flip=random_horizontal_flip, |
| | random_resize_aspect_ratio=random_resize_aspect_ratio, |
| | random_resize_scale=random_resize_scale, |
| | reprob=reprob, |
| | auto_augment=auto_augment, |
| | motion_shift=motion_shift, |
| | crop_size=crop_size, |
| | normalize=normalize, |
| | ) |
| | return _frames_augmentation |
| |
|
| |
|
| | class VideoTransform(object): |
| |
|
| | def __init__( |
| | self, |
| | training=True, |
| | random_horizontal_flip=True, |
| | random_resize_aspect_ratio=(3/4, 4/3), |
| | random_resize_scale=(0.3, 1.0), |
| | reprob=0.0, |
| | auto_augment=False, |
| | motion_shift=False, |
| | crop_size=224, |
| | normalize=((0.485, 0.456, 0.406), |
| | (0.229, 0.224, 0.225)) |
| | ): |
| |
|
| | self.training = training |
| |
|
| | short_side_size = int(crop_size * 256 / 224) |
| | self.eval_transform = video_transforms.Compose([ |
| | video_transforms.Resize(short_side_size, interpolation='bilinear'), |
| | video_transforms.CenterCrop(size=(crop_size, crop_size)), |
| | volume_transforms.ClipToTensor(), |
| | video_transforms.Normalize(mean=normalize[0], std=normalize[1]) |
| | ]) |
| |
|
| | self.random_horizontal_flip = random_horizontal_flip |
| | self.random_resize_aspect_ratio = random_resize_aspect_ratio |
| | self.random_resize_scale = random_resize_scale |
| | self.auto_augment = auto_augment |
| | self.motion_shift = motion_shift |
| | self.crop_size = crop_size |
| | self.normalize = torch.tensor(normalize) |
| |
|
| | self.autoaug_transform = video_transforms.create_random_augment( |
| | input_size=(crop_size, crop_size), |
| | auto_augment='rand-m7-n4-mstd0.5-inc1', |
| | interpolation='bicubic', |
| | ) |
| |
|
| | self.spatial_transform = video_transforms.random_resized_crop_with_shift \ |
| | if motion_shift else video_transforms.random_resized_crop |
| |
|
| | self.reprob = reprob |
| | self.erase_transform = RandomErasing( |
| | reprob, |
| | mode='pixel', |
| | max_count=1, |
| | num_splits=1, |
| | device='cpu', |
| | ) |
| |
|
| | def __call__(self, buffer): |
| |
|
| | if not self.training: |
| | return [self.eval_transform(buffer)] |
| |
|
| | buffer = [transforms.ToPILImage()(frame) for frame in buffer] |
| |
|
| | if self.auto_augment: |
| | buffer = self.autoaug_transform(buffer) |
| |
|
| | buffer = [transforms.ToTensor()(img) for img in buffer] |
| | buffer = torch.stack(buffer) |
| | buffer = buffer.permute(0, 2, 3, 1) |
| |
|
| | buffer = tensor_normalize(buffer, self.normalize[0], self.normalize[1]) |
| | buffer = buffer.permute(3, 0, 1, 2) |
| |
|
| | buffer = self.spatial_transform( |
| | images=buffer, |
| | target_height=self.crop_size, |
| | target_width=self.crop_size, |
| | scale=self.random_resize_scale, |
| | ratio=self.random_resize_aspect_ratio, |
| | ) |
| | if self.random_horizontal_flip: |
| | buffer, _ = video_transforms.horizontal_flip(0.5, buffer) |
| |
|
| | if self.reprob > 0: |
| | buffer = buffer.permute(1, 0, 2, 3) |
| | buffer = self.erase_transform(buffer) |
| | buffer = buffer.permute(1, 0, 2, 3) |
| |
|
| | return [buffer] |
| |
|
| |
|
| | class EvalVideoTransform(object): |
| |
|
| | def __init__( |
| | self, |
| | num_views_per_clip=1, |
| | short_side_size=224, |
| | normalize=((0.485, 0.456, 0.406), |
| | (0.229, 0.224, 0.225)) |
| | ): |
| | self.views_per_clip = num_views_per_clip |
| | self.short_side_size = short_side_size |
| | self.spatial_resize = video_transforms.Resize(short_side_size, interpolation='bilinear') |
| | self.to_tensor = video_transforms.Compose([ |
| | volume_transforms.ClipToTensor(), |
| | video_transforms.Normalize(mean=normalize[0], std=normalize[1]) |
| | ]) |
| |
|
| | def __call__(self, buffer): |
| |
|
| | |
| | buffer = np.array(self.spatial_resize(buffer)) |
| | T, H, W, C = buffer.shape |
| |
|
| | num_views = self.views_per_clip |
| | side_len = self.short_side_size |
| | spatial_step = (max(H, W) - side_len) // (num_views - 1) |
| |
|
| | all_views = [] |
| | for i in range(num_views): |
| | start = i*spatial_step |
| | if H > W: |
| | view = buffer[:, start:start+side_len, :, :] |
| | else: |
| | view = buffer[:, :, start:start+side_len, :] |
| | view = self.to_tensor(view) |
| | all_views.append(view) |
| |
|
| | return all_views |
| |
|
| |
|
| | def tensor_normalize(tensor, mean, std): |
| | """ |
| | Normalize a given tensor by subtracting the mean and dividing the std. |
| | Args: |
| | tensor (tensor): tensor to normalize. |
| | mean (tensor or list): mean value to subtract. |
| | std (tensor or list): std to divide. |
| | """ |
| | if tensor.dtype == torch.uint8: |
| | tensor = tensor.float() |
| | tensor = tensor / 255.0 |
| | if type(mean) == list: |
| | mean = torch.tensor(mean) |
| | if type(std) == list: |
| | std = torch.tensor(std) |
| | tensor = tensor - mean |
| | tensor = tensor / std |
| | return tensor |
| |
|