| import easing_functions as ef |
| import random |
| import torch |
| from torchvision import transforms |
| from torchvision.transforms import functional as F |
|
|
|
|
| class MotionAugmentation: |
| def __init__(self, |
| size, |
| prob_fgr_affine, |
| prob_bgr_affine, |
| prob_noise, |
| prob_color_jitter, |
| prob_grayscale, |
| prob_sharpness, |
| prob_blur, |
| prob_hflip, |
| prob_pause, |
| static_affine=True, |
| aspect_ratio_range=(0.9, 1.1)): |
| self.size = size |
| self.prob_fgr_affine = prob_fgr_affine |
| self.prob_bgr_affine = prob_bgr_affine |
| self.prob_noise = prob_noise |
| self.prob_color_jitter = prob_color_jitter |
| self.prob_grayscale = prob_grayscale |
| self.prob_sharpness = prob_sharpness |
| self.prob_blur = prob_blur |
| self.prob_hflip = prob_hflip |
| self.prob_pause = prob_pause |
| self.static_affine = static_affine |
| self.aspect_ratio_range = aspect_ratio_range |
| |
| def __call__(self, fgrs, phas, bgrs): |
| |
| if random.random() < self.prob_fgr_affine: |
| fgrs, phas = self._motion_affine(fgrs, phas) |
|
|
| |
| if random.random() < self.prob_bgr_affine / 2: |
| bgrs = self._motion_affine(bgrs) |
| if random.random() < self.prob_bgr_affine / 2: |
| fgrs, phas, bgrs = self._motion_affine(fgrs, phas, bgrs) |
| |
| |
| if self.static_affine: |
| fgrs, phas = self._static_affine(fgrs, phas, scale_ranges=(0.5, 1)) |
| bgrs = self._static_affine(bgrs, scale_ranges=(1, 1.5)) |
| |
| |
| fgrs = torch.stack([F.to_tensor(fgr) for fgr in fgrs]) |
| phas = torch.stack([F.to_tensor(pha) for pha in phas]) |
| bgrs = torch.stack([F.to_tensor(bgr) for bgr in bgrs]) |
| |
| |
| params = transforms.RandomResizedCrop.get_params(fgrs, scale=(1, 1), ratio=self.aspect_ratio_range) |
| fgrs = F.resized_crop(fgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) |
| phas = F.resized_crop(phas, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) |
| params = transforms.RandomResizedCrop.get_params(bgrs, scale=(1, 1), ratio=self.aspect_ratio_range) |
| bgrs = F.resized_crop(bgrs, *params, self.size, interpolation=F.InterpolationMode.BILINEAR) |
|
|
| |
| if random.random() < self.prob_hflip: |
| fgrs = F.hflip(fgrs) |
| phas = F.hflip(phas) |
| if random.random() < self.prob_hflip: |
| bgrs = F.hflip(bgrs) |
|
|
| |
| if random.random() < self.prob_noise: |
| fgrs, bgrs = self._motion_noise(fgrs, bgrs) |
| |
| |
| if random.random() < self.prob_color_jitter: |
| fgrs = self._motion_color_jitter(fgrs) |
| if random.random() < self.prob_color_jitter: |
| bgrs = self._motion_color_jitter(bgrs) |
| |
| |
| if random.random() < self.prob_grayscale: |
| fgrs = F.rgb_to_grayscale(fgrs, num_output_channels=3).contiguous() |
| bgrs = F.rgb_to_grayscale(bgrs, num_output_channels=3).contiguous() |
| |
| |
| if random.random() < self.prob_sharpness: |
| sharpness = random.random() * 8 |
| fgrs = F.adjust_sharpness(fgrs, sharpness) |
| phas = F.adjust_sharpness(phas, sharpness) |
| bgrs = F.adjust_sharpness(bgrs, sharpness) |
| |
| |
| if random.random() < self.prob_blur / 3: |
| fgrs, phas = self._motion_blur(fgrs, phas) |
| if random.random() < self.prob_blur / 3: |
| bgrs = self._motion_blur(bgrs) |
| if random.random() < self.prob_blur / 3: |
| fgrs, phas, bgrs = self._motion_blur(fgrs, phas, bgrs) |
|
|
| |
| if random.random() < self.prob_pause: |
| fgrs, phas, bgrs = self._motion_pause(fgrs, phas, bgrs) |
| |
| return fgrs, phas, bgrs |
| |
| def _static_affine(self, *imgs, scale_ranges): |
| params = transforms.RandomAffine.get_params( |
| degrees=(-10, 10), translate=(0.1, 0.1), scale_ranges=scale_ranges, |
| shears=(-5, 5), img_size=imgs[0][0].size) |
| imgs = [[F.affine(t, *params, F.InterpolationMode.BILINEAR) for t in img] for img in imgs] |
| return imgs if len(imgs) > 1 else imgs[0] |
| |
| def _motion_affine(self, *imgs): |
| config = dict(degrees=(-10, 10), translate=(0.1, 0.1), |
| scale_ranges=(0.9, 1.1), shears=(-5, 5), img_size=imgs[0][0].size) |
| angleA, (transXA, transYA), scaleA, (shearXA, shearYA) = transforms.RandomAffine.get_params(**config) |
| angleB, (transXB, transYB), scaleB, (shearXB, shearYB) = transforms.RandomAffine.get_params(**config) |
| |
| T = len(imgs[0]) |
| easing = random_easing_fn() |
| for t in range(T): |
| percentage = easing(t / (T - 1)) |
| angle = lerp(angleA, angleB, percentage) |
| transX = lerp(transXA, transXB, percentage) |
| transY = lerp(transYA, transYB, percentage) |
| scale = lerp(scaleA, scaleB, percentage) |
| shearX = lerp(shearXA, shearXB, percentage) |
| shearY = lerp(shearYA, shearYB, percentage) |
| for img in imgs: |
| img[t] = F.affine(img[t], angle, (transX, transY), scale, (shearX, shearY), F.InterpolationMode.BILINEAR) |
| return imgs if len(imgs) > 1 else imgs[0] |
| |
| def _motion_noise(self, *imgs): |
| grain_size = random.random() * 3 + 1 |
| monochrome = random.random() < 0.5 |
| for img in imgs: |
| T, C, H, W = img.shape |
| noise = torch.randn((T, 1 if monochrome else C, round(H / grain_size), round(W / grain_size))) |
| noise.mul_(random.random() * 0.2 / grain_size) |
| if grain_size != 1: |
| noise = F.resize(noise, (H, W)) |
| img.add_(noise).clamp_(0, 1) |
| return imgs if len(imgs) > 1 else imgs[0] |
| |
| def _motion_color_jitter(self, *imgs): |
| brightnessA, brightnessB, contrastA, contrastB, saturationA, saturationB, hueA, hueB \ |
| = torch.randn(8).mul(0.1).tolist() |
| strength = random.random() * 0.2 |
| easing = random_easing_fn() |
| T = len(imgs[0]) |
| for t in range(T): |
| percentage = easing(t / (T - 1)) * strength |
| for img in imgs: |
| img[t] = F.adjust_brightness(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1)) |
| img[t] = F.adjust_contrast(img[t], max(1 + lerp(contrastA, contrastB, percentage), 0.1)) |
| img[t] = F.adjust_saturation(img[t], max(1 + lerp(brightnessA, brightnessB, percentage), 0.1)) |
| img[t] = F.adjust_hue(img[t], min(0.5, max(-0.5, lerp(hueA, hueB, percentage) * 0.1))) |
| return imgs if len(imgs) > 1 else imgs[0] |
| |
| def _motion_blur(self, *imgs): |
| blurA = random.random() * 10 |
| blurB = random.random() * 10 |
|
|
| T = len(imgs[0]) |
| easing = random_easing_fn() |
| for t in range(T): |
| percentage = easing(t / (T - 1)) |
| blur = max(lerp(blurA, blurB, percentage), 0) |
| if blur != 0: |
| kernel_size = int(blur * 2) |
| if kernel_size % 2 == 0: |
| kernel_size += 1 |
| for img in imgs: |
| img[t] = F.gaussian_blur(img[t], kernel_size, sigma=blur) |
| |
| return imgs if len(imgs) > 1 else imgs[0] |
| |
| def _motion_pause(self, *imgs): |
| T = len(imgs[0]) |
| pause_frame = random.choice(range(T - 1)) |
| pause_length = random.choice(range(T - pause_frame)) |
| for img in imgs: |
| img[pause_frame + 1 : pause_frame + pause_length] = img[pause_frame] |
| return imgs if len(imgs) > 1 else imgs[0] |
| |
|
|
| def lerp(a, b, percentage): |
| return a * (1 - percentage) + b * percentage |
|
|
|
|
| def random_easing_fn(): |
| if random.random() < 0.2: |
| return ef.LinearInOut() |
| else: |
| return random.choice([ |
| ef.BackEaseIn, |
| ef.BackEaseOut, |
| ef.BackEaseInOut, |
| ef.BounceEaseIn, |
| ef.BounceEaseOut, |
| ef.BounceEaseInOut, |
| ef.CircularEaseIn, |
| ef.CircularEaseOut, |
| ef.CircularEaseInOut, |
| ef.CubicEaseIn, |
| ef.CubicEaseOut, |
| ef.CubicEaseInOut, |
| ef.ExponentialEaseIn, |
| ef.ExponentialEaseOut, |
| ef.ExponentialEaseInOut, |
| ef.ElasticEaseIn, |
| ef.ElasticEaseOut, |
| ef.ElasticEaseInOut, |
| ef.QuadEaseIn, |
| ef.QuadEaseOut, |
| ef.QuadEaseInOut, |
| ef.QuarticEaseIn, |
| ef.QuarticEaseOut, |
| ef.QuarticEaseInOut, |
| ef.QuinticEaseIn, |
| ef.QuinticEaseOut, |
| ef.QuinticEaseInOut, |
| ef.SineEaseIn, |
| ef.SineEaseOut, |
| ef.SineEaseInOut, |
| Step, |
| ])() |
|
|
| class Step: |
| def __call__(self, value): |
| return 0 if value < 0.5 else 1 |
|
|
|
|
| |
|
|
|
|
| class TrainFrameSampler: |
| def __init__(self, speed=[0.5, 1, 2, 3, 4, 5]): |
| self.speed = speed |
| |
| def __call__(self, seq_length): |
| frames = list(range(seq_length)) |
| |
| |
| speed = random.choice(self.speed) |
| frames = [int(f * speed) for f in frames] |
| |
| |
| shift = random.choice(range(seq_length)) |
| frames = [f + shift for f in frames] |
| |
| |
| if random.random() < 0.5: |
| frames = frames[::-1] |
|
|
| return frames |
| |
| class ValidFrameSampler: |
| def __call__(self, seq_length): |
| return range(seq_length) |
|
|