|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| Transforms and data augmentation for both image + bbox.
|
| """
|
|
|
| import logging
|
|
|
| import random
|
| from typing import Iterable
|
|
|
| import torch
|
| import torchvision.transforms as T
|
| import torchvision.transforms.functional as F
|
| import torchvision.transforms.v2.functional as Fv2
|
| from PIL import Image as PILImage
|
|
|
| from torchvision.transforms import InterpolationMode
|
|
|
| from training.utils.data_utils import VideoDatapoint
|
|
|
|
|
| def hflip(datapoint, index):
|
|
|
| datapoint.frames[index].data = F.hflip(datapoint.frames[index].data)
|
| for obj in datapoint.frames[index].objects:
|
| if obj.segment is not None:
|
| obj.segment = F.hflip(obj.segment)
|
|
|
| return datapoint
|
|
|
|
|
| def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
| w, h = image_size
|
| if max_size is not None:
|
| min_original_size = float(min((w, h)))
|
| max_original_size = float(max((w, h)))
|
| if max_original_size / min_original_size * size > max_size:
|
| size = max_size * min_original_size / max_original_size
|
|
|
| if (w <= h and w == size) or (h <= w and h == size):
|
| return (h, w)
|
|
|
| if w < h:
|
| ow = int(round(size))
|
| oh = int(round(size * h / w))
|
| else:
|
| oh = int(round(size))
|
| ow = int(round(size * w / h))
|
|
|
| return (oh, ow)
|
|
|
|
|
| def resize(datapoint, index, size, max_size=None, square=False, v2=False):
|
|
|
|
|
| def get_size(image_size, size, max_size=None):
|
| if isinstance(size, (list, tuple)):
|
| return size[::-1]
|
| else:
|
| return get_size_with_aspect_ratio(image_size, size, max_size)
|
|
|
| if square:
|
| size = size, size
|
| else:
|
| cur_size = (
|
| datapoint.frames[index].data.size()[-2:][::-1]
|
| if v2
|
| else datapoint.frames[index].data.size
|
| )
|
| size = get_size(cur_size, size, max_size)
|
|
|
| old_size = (
|
| datapoint.frames[index].data.size()[-2:][::-1]
|
| if v2
|
| else datapoint.frames[index].data.size
|
| )
|
| if v2:
|
| datapoint.frames[index].data = Fv2.resize(
|
| datapoint.frames[index].data, size, antialias=True
|
| )
|
| else:
|
| datapoint.frames[index].data = F.resize(datapoint.frames[index].data, size)
|
|
|
| new_size = (
|
| datapoint.frames[index].data.size()[-2:][::-1]
|
| if v2
|
| else datapoint.frames[index].data.size
|
| )
|
|
|
| for obj in datapoint.frames[index].objects:
|
| if obj.segment is not None:
|
| obj.segment = F.resize(obj.segment[None, None], size).squeeze()
|
|
|
| h, w = size
|
| datapoint.frames[index].size = (h, w)
|
| return datapoint
|
|
|
|
|
| def pad(datapoint, index, padding, v2=False):
|
| old_h, old_w = datapoint.frames[index].size
|
| h, w = old_h, old_w
|
| if len(padding) == 2:
|
|
|
| datapoint.frames[index].data = F.pad(
|
| datapoint.frames[index].data, (0, 0, padding[0], padding[1])
|
| )
|
| h += padding[1]
|
| w += padding[0]
|
| else:
|
|
|
| datapoint.frames[index].data = F.pad(
|
| datapoint.frames[index].data,
|
| (padding[0], padding[1], padding[2], padding[3]),
|
| )
|
| h += padding[1] + padding[3]
|
| w += padding[0] + padding[2]
|
|
|
| datapoint.frames[index].size = (h, w)
|
|
|
| for obj in datapoint.frames[index].objects:
|
| if obj.segment is not None:
|
| if v2:
|
| if len(padding) == 2:
|
| obj.segment = Fv2.pad(obj.segment, (0, 0, padding[0], padding[1]))
|
| else:
|
| obj.segment = Fv2.pad(obj.segment, tuple(padding))
|
| else:
|
| if len(padding) == 2:
|
| obj.segment = F.pad(obj.segment, (0, 0, padding[0], padding[1]))
|
| else:
|
| obj.segment = F.pad(obj.segment, tuple(padding))
|
| return datapoint
|
|
|
|
|
| class RandomHorizontalFlip:
|
| def __init__(self, consistent_transform, p=0.5):
|
| self.p = p
|
| self.consistent_transform = consistent_transform
|
|
|
| def __call__(self, datapoint, **kwargs):
|
| if self.consistent_transform:
|
| if random.random() < self.p:
|
| for i in range(len(datapoint.frames)):
|
| datapoint = hflip(datapoint, i)
|
| return datapoint
|
| for i in range(len(datapoint.frames)):
|
| if random.random() < self.p:
|
| datapoint = hflip(datapoint, i)
|
| return datapoint
|
|
|
|
|
| class RandomResizeAPI:
|
| def __init__(
|
| self, sizes, consistent_transform, max_size=None, square=False, v2=False
|
| ):
|
| if isinstance(sizes, int):
|
| sizes = (sizes,)
|
| assert isinstance(sizes, Iterable)
|
| self.sizes = list(sizes)
|
| self.max_size = max_size
|
| self.square = square
|
| self.consistent_transform = consistent_transform
|
| self.v2 = v2
|
|
|
| def __call__(self, datapoint, **kwargs):
|
| if self.consistent_transform:
|
| size = random.choice(self.sizes)
|
| for i in range(len(datapoint.frames)):
|
| datapoint = resize(
|
| datapoint, i, size, self.max_size, square=self.square, v2=self.v2
|
| )
|
| return datapoint
|
| for i in range(len(datapoint.frames)):
|
| size = random.choice(self.sizes)
|
| datapoint = resize(
|
| datapoint, i, size, self.max_size, square=self.square, v2=self.v2
|
| )
|
| return datapoint
|
|
|
|
|
| class ToTensorAPI:
|
| def __init__(self, v2=False):
|
| self.v2 = v2
|
|
|
| def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
| for img in datapoint.frames:
|
| if self.v2:
|
| img.data = Fv2.to_image_tensor(img.data)
|
| else:
|
| img.data = F.to_tensor(img.data)
|
| return datapoint
|
|
|
|
|
| class NormalizeAPI:
|
| def __init__(self, mean, std, v2=False):
|
| self.mean = mean
|
| self.std = std
|
| self.v2 = v2
|
|
|
| def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
| for img in datapoint.frames:
|
| if self.v2:
|
| img.data = Fv2.convert_image_dtype(img.data, torch.float32)
|
| img.data = Fv2.normalize(img.data, mean=self.mean, std=self.std)
|
| else:
|
| img.data = F.normalize(img.data, mean=self.mean, std=self.std)
|
|
|
| return datapoint
|
|
|
|
|
| class ComposeAPI:
|
| def __init__(self, transforms):
|
| self.transforms = transforms
|
|
|
| def __call__(self, datapoint, **kwargs):
|
| for t in self.transforms:
|
| datapoint = t(datapoint, **kwargs)
|
| return datapoint
|
|
|
| def __repr__(self):
|
| format_string = self.__class__.__name__ + "("
|
| for t in self.transforms:
|
| format_string += "\n"
|
| format_string += " {0}".format(t)
|
| format_string += "\n)"
|
| return format_string
|
|
|
|
|
| class RandomGrayscale:
|
| def __init__(self, consistent_transform, p=0.5):
|
| self.p = p
|
| self.consistent_transform = consistent_transform
|
| self.Grayscale = T.Grayscale(num_output_channels=3)
|
|
|
| def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
| if self.consistent_transform:
|
| if random.random() < self.p:
|
| for img in datapoint.frames:
|
| img.data = self.Grayscale(img.data)
|
| return datapoint
|
| for img in datapoint.frames:
|
| if random.random() < self.p:
|
| img.data = self.Grayscale(img.data)
|
| return datapoint
|
|
|
|
|
| class ColorJitter:
|
| def __init__(self, consistent_transform, brightness, contrast, saturation, hue):
|
| self.consistent_transform = consistent_transform
|
| self.brightness = (
|
| brightness
|
| if isinstance(brightness, list)
|
| else [max(0, 1 - brightness), 1 + brightness]
|
| )
|
| self.contrast = (
|
| contrast
|
| if isinstance(contrast, list)
|
| else [max(0, 1 - contrast), 1 + contrast]
|
| )
|
| self.saturation = (
|
| saturation
|
| if isinstance(saturation, list)
|
| else [max(0, 1 - saturation), 1 + saturation]
|
| )
|
| self.hue = hue if isinstance(hue, list) or hue is None else ([-hue, hue])
|
|
|
| def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
| if self.consistent_transform:
|
|
|
| (
|
| fn_idx,
|
| brightness_factor,
|
| contrast_factor,
|
| saturation_factor,
|
| hue_factor,
|
| ) = T.ColorJitter.get_params(
|
| self.brightness, self.contrast, self.saturation, self.hue
|
| )
|
| for img in datapoint.frames:
|
| if not self.consistent_transform:
|
| (
|
| fn_idx,
|
| brightness_factor,
|
| contrast_factor,
|
| saturation_factor,
|
| hue_factor,
|
| ) = T.ColorJitter.get_params(
|
| self.brightness, self.contrast, self.saturation, self.hue
|
| )
|
| for fn_id in fn_idx:
|
| if fn_id == 0 and brightness_factor is not None:
|
| img.data = F.adjust_brightness(img.data, brightness_factor)
|
| elif fn_id == 1 and contrast_factor is not None:
|
| img.data = F.adjust_contrast(img.data, contrast_factor)
|
| elif fn_id == 2 and saturation_factor is not None:
|
| img.data = F.adjust_saturation(img.data, saturation_factor)
|
| elif fn_id == 3 and hue_factor is not None:
|
| img.data = F.adjust_hue(img.data, hue_factor)
|
| return datapoint
|
|
|
|
|
| class RandomAffine:
|
| def __init__(
|
| self,
|
| degrees,
|
| consistent_transform,
|
| scale=None,
|
| translate=None,
|
| shear=None,
|
| image_mean=(123, 116, 103),
|
| log_warning=True,
|
| num_tentatives=1,
|
| image_interpolation="bicubic",
|
| ):
|
| """
|
| The mask is required for this transform.
|
| if consistent_transform if True, then the same random affine is applied to all frames and masks.
|
| """
|
| self.degrees = degrees if isinstance(degrees, list) else ([-degrees, degrees])
|
| self.scale = scale
|
| self.shear = (
|
| shear if isinstance(shear, list) else ([-shear, shear] if shear else None)
|
| )
|
| self.translate = translate
|
| self.fill_img = image_mean
|
| self.consistent_transform = consistent_transform
|
| self.log_warning = log_warning
|
| self.num_tentatives = num_tentatives
|
|
|
| if image_interpolation == "bicubic":
|
| self.image_interpolation = InterpolationMode.BICUBIC
|
| elif image_interpolation == "bilinear":
|
| self.image_interpolation = InterpolationMode.BILINEAR
|
| else:
|
| raise NotImplementedError
|
|
|
| def __call__(self, datapoint: VideoDatapoint, **kwargs):
|
| for _tentative in range(self.num_tentatives):
|
| res = self.transform_datapoint(datapoint)
|
| if res is not None:
|
| return res
|
|
|
| if self.log_warning:
|
| logging.warning(
|
| f"Skip RandomAffine for zero-area mask in first frame after {self.num_tentatives} tentatives"
|
| )
|
| return datapoint
|
|
|
| def transform_datapoint(self, datapoint: VideoDatapoint):
|
| _, height, width = F.get_dimensions(datapoint.frames[0].data)
|
| img_size = [width, height]
|
|
|
| if self.consistent_transform:
|
|
|
| affine_params = T.RandomAffine.get_params(
|
| degrees=self.degrees,
|
| translate=self.translate,
|
| scale_ranges=self.scale,
|
| shears=self.shear,
|
| img_size=img_size,
|
| )
|
|
|
| for img_idx, img in enumerate(datapoint.frames):
|
| this_masks = [
|
| obj.segment.unsqueeze(0) if obj.segment is not None else None
|
| for obj in img.objects
|
| ]
|
| if not self.consistent_transform:
|
|
|
| affine_params = T.RandomAffine.get_params(
|
| degrees=self.degrees,
|
| translate=self.translate,
|
| scale_ranges=self.scale,
|
| shears=self.shear,
|
| img_size=img_size,
|
| )
|
|
|
| transformed_bboxes, transformed_masks = [], []
|
| for i in range(len(img.objects)):
|
| if this_masks[i] is None:
|
| transformed_masks.append(None)
|
|
|
| transformed_bboxes.append(torch.tensor([[0, 0, 1, 1]]))
|
| else:
|
| transformed_mask = F.affine(
|
| this_masks[i],
|
| *affine_params,
|
| interpolation=InterpolationMode.NEAREST,
|
| fill=0.0,
|
| )
|
| if img_idx == 0 and transformed_mask.max() == 0:
|
|
|
|
|
| return None
|
| transformed_masks.append(transformed_mask.squeeze())
|
|
|
| for i in range(len(img.objects)):
|
| img.objects[i].segment = transformed_masks[i]
|
|
|
| img.data = F.affine(
|
| img.data,
|
| *affine_params,
|
| interpolation=self.image_interpolation,
|
| fill=self.fill_img,
|
| )
|
| return datapoint
|
|
|
|
|
| def random_mosaic_frame(
|
| datapoint,
|
| index,
|
| grid_h,
|
| grid_w,
|
| target_grid_y,
|
| target_grid_x,
|
| should_hflip,
|
| ):
|
|
|
| image_data = datapoint.frames[index].data
|
| is_pil = isinstance(image_data, PILImage.Image)
|
| if is_pil:
|
| H_im = image_data.height
|
| W_im = image_data.width
|
| image_data_output = PILImage.new("RGB", (W_im, H_im))
|
| else:
|
| H_im = image_data.size(-2)
|
| W_im = image_data.size(-1)
|
| image_data_output = torch.zeros_like(image_data)
|
|
|
| downsize_cache = {}
|
| for grid_y in range(grid_h):
|
| for grid_x in range(grid_w):
|
| y_offset_b = grid_y * H_im // grid_h
|
| x_offset_b = grid_x * W_im // grid_w
|
| y_offset_e = (grid_y + 1) * H_im // grid_h
|
| x_offset_e = (grid_x + 1) * W_im // grid_w
|
| H_im_downsize = y_offset_e - y_offset_b
|
| W_im_downsize = x_offset_e - x_offset_b
|
|
|
| if (H_im_downsize, W_im_downsize) in downsize_cache:
|
| image_data_downsize = downsize_cache[(H_im_downsize, W_im_downsize)]
|
| else:
|
| image_data_downsize = F.resize(
|
| image_data,
|
| size=(H_im_downsize, W_im_downsize),
|
| interpolation=InterpolationMode.BILINEAR,
|
| antialias=True,
|
| )
|
| downsize_cache[(H_im_downsize, W_im_downsize)] = image_data_downsize
|
| if should_hflip[grid_y, grid_x].item():
|
| image_data_downsize = F.hflip(image_data_downsize)
|
|
|
| if is_pil:
|
| image_data_output.paste(image_data_downsize, (x_offset_b, y_offset_b))
|
| else:
|
| image_data_output[:, y_offset_b:y_offset_e, x_offset_b:x_offset_e] = (
|
| image_data_downsize
|
| )
|
|
|
| datapoint.frames[index].data = image_data_output
|
|
|
|
|
| for obj in datapoint.frames[index].objects:
|
| if obj.segment is None:
|
| continue
|
| assert obj.segment.shape == (H_im, W_im) and obj.segment.dtype == torch.uint8
|
| segment_output = torch.zeros_like(obj.segment)
|
|
|
| target_y_offset_b = target_grid_y * H_im // grid_h
|
| target_x_offset_b = target_grid_x * W_im // grid_w
|
| target_y_offset_e = (target_grid_y + 1) * H_im // grid_h
|
| target_x_offset_e = (target_grid_x + 1) * W_im // grid_w
|
| target_H_im_downsize = target_y_offset_e - target_y_offset_b
|
| target_W_im_downsize = target_x_offset_e - target_x_offset_b
|
|
|
| segment_downsize = F.resize(
|
| obj.segment[None, None],
|
| size=(target_H_im_downsize, target_W_im_downsize),
|
| interpolation=InterpolationMode.BILINEAR,
|
| antialias=True,
|
| )[0, 0]
|
| if should_hflip[target_grid_y, target_grid_x].item():
|
| segment_downsize = F.hflip(segment_downsize[None, None])[0, 0]
|
|
|
| segment_output[
|
| target_y_offset_b:target_y_offset_e, target_x_offset_b:target_x_offset_e
|
| ] = segment_downsize
|
| obj.segment = segment_output
|
|
|
| return datapoint
|
|
|
|
|
| class RandomMosaicVideoAPI:
|
| def __init__(self, prob=0.15, grid_h=2, grid_w=2, use_random_hflip=False):
|
| self.prob = prob
|
| self.grid_h = grid_h
|
| self.grid_w = grid_w
|
| self.use_random_hflip = use_random_hflip
|
|
|
| def __call__(self, datapoint, **kwargs):
|
| if random.random() > self.prob:
|
| return datapoint
|
|
|
|
|
| target_grid_y = random.randint(0, self.grid_h - 1)
|
| target_grid_x = random.randint(0, self.grid_w - 1)
|
|
|
| if self.use_random_hflip:
|
| should_hflip = torch.rand(self.grid_h, self.grid_w) < 0.5
|
| else:
|
| should_hflip = torch.zeros(self.grid_h, self.grid_w, dtype=torch.bool)
|
| for i in range(len(datapoint.frames)):
|
| datapoint = random_mosaic_frame(
|
| datapoint,
|
| i,
|
| grid_h=self.grid_h,
|
| grid_w=self.grid_w,
|
| target_grid_y=target_grid_y,
|
| target_grid_x=target_grid_x,
|
| should_hflip=should_hflip,
|
| )
|
|
|
| return datapoint
|
|
|