| |
| |
|
|
| """ |
| 3Augment implementation |
| Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino) |
| and timm DA(https://github.com/rwightman/pytorch-image-models) |
| """ |
| import torch |
| from torchvision import transforms |
|
|
| from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor |
|
|
| import numpy as np |
| from torchvision import datasets, transforms |
| import random |
|
|
|
|
|
|
| from PIL import ImageFilter, ImageOps |
| import torchvision.transforms.functional as TF |
|
|
|
|
| class GaussianBlur(object): |
| """ |
| Apply Gaussian Blur to the PIL image. |
| """ |
| def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): |
| self.prob = p |
| self.radius_min = radius_min |
| self.radius_max = radius_max |
|
|
| def __call__(self, img): |
| do_it = random.random() <= self.prob |
| if not do_it: |
| return img |
|
|
| img = img.filter( |
| ImageFilter.GaussianBlur( |
| radius=random.uniform(self.radius_min, self.radius_max) |
| ) |
| ) |
| return img |
|
|
| class Solarization(object): |
| """ |
| Apply Solarization to the PIL image. |
| """ |
| def __init__(self, p=0.2): |
| self.p = p |
|
|
| def __call__(self, img): |
| if random.random() < self.p: |
| return ImageOps.solarize(img) |
| else: |
| return img |
|
|
| class gray_scale(object): |
| """ |
| Apply Solarization to the PIL image. |
| """ |
| def __init__(self, p=0.2): |
| self.p = p |
| self.transf = transforms.Grayscale(3) |
| |
| def __call__(self, img): |
| if random.random() < self.p: |
| return self.transf(img) |
| else: |
| return img |
| |
| |
| |
| class horizontal_flip(object): |
| """ |
| Apply Solarization to the PIL image. |
| """ |
| def __init__(self, p=0.2,activate_pred=False): |
| self.p = p |
| self.transf = transforms.RandomHorizontalFlip(p=1.0) |
| |
| def __call__(self, img): |
| if random.random() < self.p: |
| return self.transf(img) |
| else: |
| return img |
| |
| |
| |
| def new_data_aug_generator(args = None): |
| img_size = args.input_size |
| remove_random_resized_crop = args.src |
| mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |
| primary_tfl = [] |
| scale=(0.08, 1.0) |
| interpolation='bicubic' |
| if remove_random_resized_crop: |
| primary_tfl = [ |
| transforms.Resize(img_size, interpolation=3), |
| transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'), |
| transforms.RandomHorizontalFlip() |
| ] |
| else: |
| primary_tfl = [ |
| RandomResizedCropAndInterpolation( |
| img_size, scale=scale, interpolation=interpolation), |
| transforms.RandomHorizontalFlip() |
| ] |
|
|
| |
| secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0), |
| Solarization(p=1.0), |
| GaussianBlur(p=1.0)])] |
| |
| if args.color_jitter is not None and not args.color_jitter==0: |
| secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) |
| final_tfl = [ |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=torch.tensor(mean), |
| std=torch.tensor(std)) |
| ] |
| return transforms.Compose(primary_tfl+secondary_tfl+final_tfl) |
|
|