| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Base augmentations operators.""" |
|
|
| import numpy as np |
| from PIL import Image, ImageOps, ImageEnhance |
|
|
| |
| IMAGE_SIZE = 32 |
| import torch |
| from torchvision import transforms |
|
|
|
|
| def int_parameter(level, maxval): |
| """Helper function to scale `val` between 0 and maxval . |
| |
| Args: |
| level: Level of the operation that will be between [0, `PARAMETER_MAX`]. |
| maxval: Maximum value that the operation can have. This will be scaled to |
| level/PARAMETER_MAX. |
| |
| Returns: |
| An int that results from scaling `maxval` according to `level`. |
| """ |
| return int(level * maxval / 10) |
|
|
|
|
| def float_parameter(level, maxval): |
| """Helper function to scale `val` between 0 and maxval. |
| |
| Args: |
| level: Level of the operation that will be between [0, `PARAMETER_MAX`]. |
| maxval: Maximum value that the operation can have. This will be scaled to |
| level/PARAMETER_MAX. |
| |
| Returns: |
| A float that results from scaling `maxval` according to `level`. |
| """ |
| return float(level) * maxval / 10. |
|
|
|
|
| def sample_level(n): |
| return np.random.uniform(low=0.1, high=n) |
|
|
|
|
| def autocontrast(pil_img, _): |
| return ImageOps.autocontrast(pil_img) |
|
|
|
|
| def equalize(pil_img, _): |
| return ImageOps.equalize(pil_img) |
|
|
|
|
| def posterize(pil_img, level): |
| level = int_parameter(sample_level(level), 4) |
| return ImageOps.posterize(pil_img, 4 - level) |
|
|
|
|
| def rotate(pil_img, level): |
| degrees = int_parameter(sample_level(level), 30) |
| if np.random.uniform() > 0.5: |
| degrees = -degrees |
| return pil_img.rotate(degrees, resample=Image.BILINEAR) |
|
|
|
|
| def solarize(pil_img, level): |
| level = int_parameter(sample_level(level), 256) |
| return ImageOps.solarize(pil_img, 256 - level) |
|
|
|
|
| def shear_x(pil_img, level): |
| level = float_parameter(sample_level(level), 0.3) |
| if np.random.uniform() > 0.5: |
| level = -level |
| return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), |
| Image.AFFINE, (1, level, 0, 0, 1, 0), |
| resample=Image.BILINEAR) |
|
|
|
|
| def shear_y(pil_img, level): |
| level = float_parameter(sample_level(level), 0.3) |
| if np.random.uniform() > 0.5: |
| level = -level |
| return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), |
| Image.AFFINE, (1, 0, 0, level, 1, 0), |
| resample=Image.BILINEAR) |
|
|
|
|
| def translate_x(pil_img, level): |
| level = int_parameter(sample_level(level), IMAGE_SIZE / 3) |
| if np.random.random() > 0.5: |
| level = -level |
| return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), |
| Image.AFFINE, (1, 0, level, 0, 1, 0), |
| resample=Image.BILINEAR) |
|
|
|
|
| def translate_y(pil_img, level): |
| level = int_parameter(sample_level(level), IMAGE_SIZE / 3) |
| if np.random.random() > 0.5: |
| level = -level |
| return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), |
| Image.AFFINE, (1, 0, 0, 0, 1, level), |
| resample=Image.BILINEAR) |
|
|
|
|
| |
| def color(pil_img, level): |
| level = float_parameter(sample_level(level), 1.8) + 0.1 |
| return ImageEnhance.Color(pil_img).enhance(level) |
|
|
|
|
| |
| def contrast(pil_img, level): |
| level = float_parameter(sample_level(level), 1.8) + 0.1 |
| return ImageEnhance.Contrast(pil_img).enhance(level) |
|
|
|
|
| |
| def brightness(pil_img, level): |
| level = float_parameter(sample_level(level), 1.8) + 0.1 |
| return ImageEnhance.Brightness(pil_img).enhance(level) |
|
|
|
|
| |
| def sharpness(pil_img, level): |
| level = float_parameter(sample_level(level), 1.8) + 0.1 |
| return ImageEnhance.Sharpness(pil_img).enhance(level) |
|
|
| def random_resized_crop(pil_img, level): |
| return transforms.RandomResizedCrop(32)(pil_img) |
|
|
| def random_flip(pil_img, level): |
| return transforms.RandomHorizontalFlip(p=0.5)(pil_img) |
|
|
| def grayscale(pil_img, level): |
| return transforms.Grayscale(num_output_channels=3)(pil_img) |
|
|
| augmentations = [ |
| autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, |
| translate_x, translate_y, grayscale |
| ] |
|
|
| augmentations_all = [ |
| autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, |
| translate_x, translate_y, color, contrast, brightness, sharpness, grayscale |
| ] |
|
|
| def aug_cifar(image, preprocess, mixture_width=3, mixture_depth=-1, aug_severity=3): |
| """Perform AugMix augmentations and compute mixture. |
| |
| Args: |
| image: PIL.Image input image |
| preprocess: Preprocessing function which should return a torch tensor. |
| |
| Returns: |
| mixed: Augmented and mixed image. |
| """ |
| aug_list = augmentations_all |
| |
| |
|
|
| ws = np.float32(np.random.dirichlet([1] * mixture_width)) |
| m = np.float32(np.random.beta(1, 1)) |
|
|
| mix = torch.zeros_like(preprocess(image)) |
| for i in range(mixture_width): |
| image_aug = image.copy() |
| depth = mixture_depth if mixture_depth > 0 else np.random.randint( |
| 1, 4) |
| for _ in range(depth): |
| op = np.random.choice(aug_list) |
| image_aug = op(image_aug, aug_severity) |
| |
| mix += ws[i] * preprocess(image_aug) |
|
|
| |
| return mix |