| | """ AutoAugment, RandAugment, and AugMix for PyTorch |
| | |
| | This code implements the searched ImageNet policies with various tweaks and improvements and |
| | does not include any of the search code. |
| | |
| | AA and RA Implementation adapted from: |
| | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py |
| | |
| | AugMix adapted from: |
| | https://github.com/google-research/augmix |
| | |
| | Papers: |
| | AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 |
| | Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 |
| | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 |
| | AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 |
| | |
| | Hacked together by / Copyright 2020 Ross Wightman |
| | """ |
| | import random |
| | import math |
| | import re |
| | from PIL import Image, ImageOps, ImageEnhance, ImageChops |
| | import PIL |
| | import numpy as np |
| |
|
| |
|
| | _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) |
| |
|
| | _FILL = (128, 128, 128) |
| |
|
| | |
| | |
| | _MAX_LEVEL = 10. |
| |
|
| | _HPARAMS_DEFAULT = dict( |
| | translate_const=250, |
| | img_mean=_FILL, |
| | ) |
| |
|
| | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) |
| |
|
| |
|
| | def _interpolation(kwargs): |
| | interpolation = kwargs.pop('resample', Image.BILINEAR) |
| | if isinstance(interpolation, (list, tuple)): |
| | return random.choice(interpolation) |
| | else: |
| | return interpolation |
| |
|
| |
|
| | def _check_args_tf(kwargs): |
| | if 'fillcolor' in kwargs and _PIL_VER < (5, 0): |
| | kwargs.pop('fillcolor') |
| | kwargs['resample'] = _interpolation(kwargs) |
| |
|
| |
|
| | def shear_x(img, factor, **kwargs): |
| | _check_args_tf(kwargs) |
| | return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) |
| |
|
| |
|
| | def shear_y(img, factor, **kwargs): |
| | _check_args_tf(kwargs) |
| | return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) |
| |
|
| |
|
| | def translate_x_rel(img, pct, **kwargs): |
| | pixels = pct * img.size[0] |
| | _check_args_tf(kwargs) |
| | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) |
| |
|
| |
|
| | def translate_y_rel(img, pct, **kwargs): |
| | pixels = pct * img.size[1] |
| | _check_args_tf(kwargs) |
| | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) |
| |
|
| |
|
| | def translate_x_abs(img, pixels, **kwargs): |
| | _check_args_tf(kwargs) |
| | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) |
| |
|
| |
|
| | def translate_y_abs(img, pixels, **kwargs): |
| | _check_args_tf(kwargs) |
| | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) |
| |
|
| |
|
| | def rotate(img, degrees, **kwargs): |
| | _check_args_tf(kwargs) |
| | if _PIL_VER >= (5, 2): |
| | return img.rotate(degrees, **kwargs) |
| | elif _PIL_VER >= (5, 0): |
| | w, h = img.size |
| | post_trans = (0, 0) |
| | rotn_center = (w / 2.0, h / 2.0) |
| | angle = -math.radians(degrees) |
| | matrix = [ |
| | round(math.cos(angle), 15), |
| | round(math.sin(angle), 15), |
| | 0.0, |
| | round(-math.sin(angle), 15), |
| | round(math.cos(angle), 15), |
| | 0.0, |
| | ] |
| |
|
| | def transform(x, y, matrix): |
| | (a, b, c, d, e, f) = matrix |
| | return a * x + b * y + c, d * x + e * y + f |
| |
|
| | matrix[2], matrix[5] = transform( |
| | -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix |
| | ) |
| | matrix[2] += rotn_center[0] |
| | matrix[5] += rotn_center[1] |
| | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) |
| | else: |
| | return img.rotate(degrees, resample=kwargs['resample']) |
| |
|
| |
|
| | def auto_contrast(img, **__): |
| | return ImageOps.autocontrast(img) |
| |
|
| |
|
| | def invert(img, **__): |
| | return ImageOps.invert(img) |
| |
|
| |
|
| | def equalize(img, **__): |
| | return ImageOps.equalize(img) |
| |
|
| |
|
| | def solarize(img, thresh, **__): |
| | return ImageOps.solarize(img, thresh) |
| |
|
| |
|
| | def solarize_add(img, add, thresh=128, **__): |
| | lut = [] |
| | for i in range(256): |
| | if i < thresh: |
| | lut.append(min(255, i + add)) |
| | else: |
| | lut.append(i) |
| | if img.mode in ("L", "RGB"): |
| | if img.mode == "RGB" and len(lut) == 256: |
| | lut = lut + lut + lut |
| | return img.point(lut) |
| | else: |
| | return img |
| |
|
| |
|
| | def posterize(img, bits_to_keep, **__): |
| | if bits_to_keep >= 8: |
| | return img |
| | return ImageOps.posterize(img, bits_to_keep) |
| |
|
| |
|
| | def contrast(img, factor, **__): |
| | return ImageEnhance.Contrast(img).enhance(factor) |
| |
|
| |
|
| | def color(img, factor, **__): |
| | return ImageEnhance.Color(img).enhance(factor) |
| |
|
| |
|
| | def brightness(img, factor, **__): |
| | return ImageEnhance.Brightness(img).enhance(factor) |
| |
|
| |
|
| | def sharpness(img, factor, **__): |
| | return ImageEnhance.Sharpness(img).enhance(factor) |
| |
|
| |
|
| | def _randomly_negate(v): |
| | """With 50% prob, negate the value""" |
| | return -v if random.random() > 0.5 else v |
| |
|
| |
|
| | def _rotate_level_to_arg(level, _hparams): |
| | |
| | level = (level / _MAX_LEVEL) * 30. |
| | level = _randomly_negate(level) |
| | return level, |
| |
|
| |
|
| | def _enhance_level_to_arg(level, _hparams): |
| | |
| | return (level / _MAX_LEVEL) * 1.8 + 0.1, |
| |
|
| |
|
| | def _enhance_increasing_level_to_arg(level, _hparams): |
| | |
| | |
| | level = (level / _MAX_LEVEL) * .9 |
| | level = 1.0 + _randomly_negate(level) |
| | return level, |
| |
|
| |
|
| | def _shear_level_to_arg(level, _hparams): |
| | |
| | level = (level / _MAX_LEVEL) * 0.3 |
| | level = _randomly_negate(level) |
| | return level, |
| |
|
| |
|
| | def _translate_abs_level_to_arg(level, hparams): |
| | translate_const = hparams['translate_const'] |
| | level = (level / _MAX_LEVEL) * float(translate_const) |
| | level = _randomly_negate(level) |
| | return level, |
| |
|
| |
|
| | def _translate_rel_level_to_arg(level, hparams): |
| | |
| | translate_pct = hparams.get('translate_pct', 0.45) |
| | level = (level / _MAX_LEVEL) * translate_pct |
| | level = _randomly_negate(level) |
| | return level, |
| |
|
| |
|
| | def _posterize_level_to_arg(level, _hparams): |
| | |
| | |
| | |
| | return int((level / _MAX_LEVEL) * 4), |
| |
|
| |
|
| | def _posterize_increasing_level_to_arg(level, hparams): |
| | |
| | |
| | |
| | return 4 - _posterize_level_to_arg(level, hparams)[0], |
| |
|
| |
|
| | def _posterize_original_level_to_arg(level, _hparams): |
| | |
| | |
| | |
| | return int((level / _MAX_LEVEL) * 4) + 4, |
| |
|
| |
|
| | def _solarize_level_to_arg(level, _hparams): |
| | |
| | |
| | return int((level / _MAX_LEVEL) * 256), |
| |
|
| |
|
| | def _solarize_increasing_level_to_arg(level, _hparams): |
| | |
| | |
| | return 256 - _solarize_level_to_arg(level, _hparams)[0], |
| |
|
| |
|
| | def _solarize_add_level_to_arg(level, _hparams): |
| | |
| | return int((level / _MAX_LEVEL) * 110), |
| |
|
| |
|
| | LEVEL_TO_ARG = { |
| | 'AutoContrast': None, |
| | 'Equalize': None, |
| | 'Invert': None, |
| | 'Rotate': _rotate_level_to_arg, |
| | |
| | 'Posterize': _posterize_level_to_arg, |
| | 'PosterizeIncreasing': _posterize_increasing_level_to_arg, |
| | 'PosterizeOriginal': _posterize_original_level_to_arg, |
| | 'Solarize': _solarize_level_to_arg, |
| | 'SolarizeIncreasing': _solarize_increasing_level_to_arg, |
| | 'SolarizeAdd': _solarize_add_level_to_arg, |
| | 'Color': _enhance_level_to_arg, |
| | 'ColorIncreasing': _enhance_increasing_level_to_arg, |
| | 'Contrast': _enhance_level_to_arg, |
| | 'ContrastIncreasing': _enhance_increasing_level_to_arg, |
| | 'Brightness': _enhance_level_to_arg, |
| | 'BrightnessIncreasing': _enhance_increasing_level_to_arg, |
| | 'Sharpness': _enhance_level_to_arg, |
| | 'SharpnessIncreasing': _enhance_increasing_level_to_arg, |
| | 'ShearX': _shear_level_to_arg, |
| | 'ShearY': _shear_level_to_arg, |
| | 'TranslateX': _translate_abs_level_to_arg, |
| | 'TranslateY': _translate_abs_level_to_arg, |
| | 'TranslateXRel': _translate_rel_level_to_arg, |
| | 'TranslateYRel': _translate_rel_level_to_arg, |
| | } |
| |
|
| |
|
| | NAME_TO_OP = { |
| | 'AutoContrast': auto_contrast, |
| | 'Equalize': equalize, |
| | 'Invert': invert, |
| | 'Rotate': rotate, |
| | 'Posterize': posterize, |
| | 'PosterizeIncreasing': posterize, |
| | 'PosterizeOriginal': posterize, |
| | 'Solarize': solarize, |
| | 'SolarizeIncreasing': solarize, |
| | 'SolarizeAdd': solarize_add, |
| | 'Color': color, |
| | 'ColorIncreasing': color, |
| | 'Contrast': contrast, |
| | 'ContrastIncreasing': contrast, |
| | 'Brightness': brightness, |
| | 'BrightnessIncreasing': brightness, |
| | 'Sharpness': sharpness, |
| | 'SharpnessIncreasing': sharpness, |
| | 'ShearX': shear_x, |
| | 'ShearY': shear_y, |
| | 'TranslateX': translate_x_abs, |
| | 'TranslateY': translate_y_abs, |
| | 'TranslateXRel': translate_x_rel, |
| | 'TranslateYRel': translate_y_rel, |
| | } |
| |
|
| |
|
| | class AugmentOp: |
| |
|
| | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): |
| | hparams = hparams or _HPARAMS_DEFAULT |
| | self.aug_fn = NAME_TO_OP[name] |
| | self.level_fn = LEVEL_TO_ARG[name] |
| | self.prob = prob |
| | self.magnitude = magnitude |
| | self.hparams = hparams.copy() |
| | self.kwargs = dict( |
| | fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, |
| | resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | self.magnitude_std = self.hparams.get('magnitude_std', 0) |
| |
|
| | def __call__(self, img): |
| | if self.prob < 1.0 and random.random() > self.prob: |
| | return img |
| | magnitude = self.magnitude |
| | if self.magnitude_std: |
| | if self.magnitude_std == float('inf'): |
| | magnitude = random.uniform(0, magnitude) |
| | elif self.magnitude_std > 0: |
| | magnitude = random.gauss(magnitude, self.magnitude_std) |
| | magnitude = min(_MAX_LEVEL, max(0, magnitude)) |
| | level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() |
| | return self.aug_fn(img, *level_args, **self.kwargs) |
| |
|
| |
|
| | def auto_augment_policy_v0(hparams): |
| | |
| | policy = [ |
| | [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], |
| | [('Color', 0.4, 9), ('Equalize', 0.6, 3)], |
| | [('Color', 0.4, 1), ('Rotate', 0.6, 8)], |
| | [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], |
| | [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], |
| | [('Color', 0.2, 0), ('Equalize', 0.8, 8)], |
| | [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], |
| | [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], |
| | [('Color', 0.6, 1), ('Equalize', 1.0, 2)], |
| | [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], |
| | [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], |
| | [('Color', 0.4, 7), ('Equalize', 0.6, 0)], |
| | [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], |
| | [('Solarize', 0.6, 8), ('Color', 0.6, 9)], |
| | [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], |
| | [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], |
| | [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], |
| | [('ShearY', 0.8, 0), ('Color', 0.6, 4)], |
| | [('Color', 1.0, 0), ('Rotate', 0.6, 2)], |
| | [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], |
| | [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], |
| | [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], |
| | [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], |
| | [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], |
| | [('Color', 0.8, 6), ('Rotate', 0.4, 5)], |
| | ] |
| | pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] |
| | return pc |
| |
|
| |
|
| | def auto_augment_policy_v0r(hparams): |
| | |
| | |
| | policy = [ |
| | [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], |
| | [('Color', 0.4, 9), ('Equalize', 0.6, 3)], |
| | [('Color', 0.4, 1), ('Rotate', 0.6, 8)], |
| | [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], |
| | [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], |
| | [('Color', 0.2, 0), ('Equalize', 0.8, 8)], |
| | [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], |
| | [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], |
| | [('Color', 0.6, 1), ('Equalize', 1.0, 2)], |
| | [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], |
| | [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], |
| | [('Color', 0.4, 7), ('Equalize', 0.6, 0)], |
| | [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)], |
| | [('Solarize', 0.6, 8), ('Color', 0.6, 9)], |
| | [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], |
| | [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], |
| | [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], |
| | [('ShearY', 0.8, 0), ('Color', 0.6, 4)], |
| | [('Color', 1.0, 0), ('Rotate', 0.6, 2)], |
| | [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], |
| | [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], |
| | [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], |
| | [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)], |
| | [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], |
| | [('Color', 0.8, 6), ('Rotate', 0.4, 5)], |
| | ] |
| | pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] |
| | return pc |
| |
|
| |
|
| | def auto_augment_policy_original(hparams): |
| | |
| | policy = [ |
| | [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)], |
| | [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], |
| | [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], |
| | [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)], |
| | [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], |
| | [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], |
| | [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], |
| | [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)], |
| | [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], |
| | [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)], |
| | [('Rotate', 0.8, 8), ('Color', 0.4, 0)], |
| | [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], |
| | [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], |
| | [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], |
| | [('Color', 0.6, 4), ('Contrast', 1.0, 8)], |
| | [('Rotate', 0.8, 8), ('Color', 1.0, 2)], |
| | [('Color', 0.8, 8), ('Solarize', 0.8, 7)], |
| | [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], |
| | [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], |
| | [('Color', 0.4, 0), ('Equalize', 0.6, 3)], |
| | [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], |
| | [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], |
| | [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], |
| | [('Color', 0.6, 4), ('Contrast', 1.0, 8)], |
| | [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], |
| | ] |
| | pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] |
| | return pc |
| |
|
| |
|
| | def auto_augment_policy_originalr(hparams): |
| | |
| | policy = [ |
| | [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)], |
| | [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], |
| | [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], |
| | [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)], |
| | [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], |
| | [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], |
| | [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], |
| | [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)], |
| | [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], |
| | [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)], |
| | [('Rotate', 0.8, 8), ('Color', 0.4, 0)], |
| | [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], |
| | [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], |
| | [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], |
| | [('Color', 0.6, 4), ('Contrast', 1.0, 8)], |
| | [('Rotate', 0.8, 8), ('Color', 1.0, 2)], |
| | [('Color', 0.8, 8), ('Solarize', 0.8, 7)], |
| | [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], |
| | [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], |
| | [('Color', 0.4, 0), ('Equalize', 0.6, 3)], |
| | [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], |
| | [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], |
| | [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], |
| | [('Color', 0.6, 4), ('Contrast', 1.0, 8)], |
| | [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], |
| | ] |
| | pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] |
| | return pc |
| |
|
| |
|
| | def auto_augment_policy(name='v0', hparams=None): |
| | hparams = hparams or _HPARAMS_DEFAULT |
| | if name == 'original': |
| | return auto_augment_policy_original(hparams) |
| | elif name == 'originalr': |
| | return auto_augment_policy_originalr(hparams) |
| | elif name == 'v0': |
| | return auto_augment_policy_v0(hparams) |
| | elif name == 'v0r': |
| | return auto_augment_policy_v0r(hparams) |
| | else: |
| | assert False, 'Unknown AA policy (%s)' % name |
| |
|
| |
|
| | class AutoAugment: |
| |
|
| | def __init__(self, policy): |
| | self.policy = policy |
| |
|
| | def __call__(self, img): |
| | sub_policy = random.choice(self.policy) |
| | for op in sub_policy: |
| | img = op(img) |
| | return img |
| |
|
| |
|
| | def auto_augment_transform(config_str, hparams): |
| | """ |
| | Create a AutoAugment transform |
| | |
| | :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by |
| | dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). |
| | The remaining sections, not order sepecific determine |
| | 'mstd' - float std deviation of magnitude noise applied |
| | Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 |
| | |
| | :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme |
| | |
| | :return: A PyTorch compatible Transform |
| | """ |
| | config = config_str.split('-') |
| | policy_name = config[0] |
| | config = config[1:] |
| | for c in config: |
| | cs = re.split(r'(\d.*)', c) |
| | if len(cs) < 2: |
| | continue |
| | key, val = cs[:2] |
| | if key == 'mstd': |
| | |
| | hparams.setdefault('magnitude_std', float(val)) |
| | else: |
| | assert False, 'Unknown AutoAugment config section' |
| | aa_policy = auto_augment_policy(policy_name, hparams=hparams) |
| | return AutoAugment(aa_policy) |
| |
|
| |
|
| | _RAND_TRANSFORMS = [ |
| | 'AutoContrast', |
| | 'Equalize', |
| | 'Invert', |
| | 'Rotate', |
| | 'Posterize', |
| | 'Solarize', |
| | 'SolarizeAdd', |
| | 'Color', |
| | 'Contrast', |
| | 'Brightness', |
| | 'Sharpness', |
| | 'ShearX', |
| | 'ShearY', |
| | 'TranslateXRel', |
| | 'TranslateYRel', |
| | |
| | ] |
| |
|
| |
|
| | _RAND_INCREASING_TRANSFORMS = [ |
| | 'AutoContrast', |
| | 'Equalize', |
| | 'Invert', |
| | 'Rotate', |
| | 'PosterizeIncreasing', |
| | 'SolarizeIncreasing', |
| | 'SolarizeAdd', |
| | 'ColorIncreasing', |
| | 'ContrastIncreasing', |
| | 'BrightnessIncreasing', |
| | 'SharpnessIncreasing', |
| | 'ShearX', |
| | 'ShearY', |
| | 'TranslateXRel', |
| | 'TranslateYRel', |
| | |
| | ] |
| |
|
| |
|
| |
|
| | |
| | |
| | _RAND_CHOICE_WEIGHTS_0 = { |
| | 'Rotate': 0.3, |
| | 'ShearX': 0.2, |
| | 'ShearY': 0.2, |
| | 'TranslateXRel': 0.1, |
| | 'TranslateYRel': 0.1, |
| | 'Color': .025, |
| | 'Sharpness': 0.025, |
| | 'AutoContrast': 0.025, |
| | 'Solarize': .005, |
| | 'SolarizeAdd': .005, |
| | 'Contrast': .005, |
| | 'Brightness': .005, |
| | 'Equalize': .005, |
| | 'Posterize': 0, |
| | 'Invert': 0, |
| | } |
| |
|
| |
|
| | def _select_rand_weights(weight_idx=0, transforms=None): |
| | transforms = transforms or _RAND_TRANSFORMS |
| | assert weight_idx == 0 |
| | rand_weights = _RAND_CHOICE_WEIGHTS_0 |
| | probs = [rand_weights[k] for k in transforms] |
| | probs /= np.sum(probs) |
| | return probs |
| |
|
| |
|
| | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): |
| | hparams = hparams or _HPARAMS_DEFAULT |
| | transforms = transforms or _RAND_TRANSFORMS |
| | return [AugmentOp( |
| | name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] |
| |
|
| |
|
| | class RandAugment: |
| | def __init__(self, ops, num_layers=2, choice_weights=None): |
| | self.ops = ops |
| | self.num_layers = num_layers |
| | self.choice_weights = choice_weights |
| |
|
| | def __call__(self, img): |
| | |
| | ops = np.random.choice( |
| | self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) |
| | for op in ops: |
| | img = op(img) |
| | return img |
| |
|
| |
|
| | def rand_augment_transform(config_str, hparams): |
| | """ |
| | Create a RandAugment transform |
| | |
| | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by |
| | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining |
| | sections, not order sepecific determine |
| | 'm' - integer magnitude of rand augment |
| | 'n' - integer num layers (number of transform ops selected per image) |
| | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) |
| | 'mstd' - float std deviation of magnitude noise applied |
| | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) |
| | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 |
| | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 |
| | |
| | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme |
| | |
| | :return: A PyTorch compatible Transform |
| | """ |
| | magnitude = _MAX_LEVEL |
| | num_layers = 2 |
| | weight_idx = None |
| | transforms = _RAND_TRANSFORMS |
| | config = config_str.split('-') |
| | assert config[0] == 'rand' |
| | config = config[1:] |
| | for c in config: |
| | cs = re.split(r'(\d.*)', c) |
| | if len(cs) < 2: |
| | continue |
| | key, val = cs[:2] |
| | if key == 'mstd': |
| | |
| | hparams.setdefault('magnitude_std', float(val)) |
| | elif key == 'inc': |
| | if bool(val): |
| | transforms = _RAND_INCREASING_TRANSFORMS |
| | elif key == 'm': |
| | magnitude = int(val) |
| | elif key == 'n': |
| | num_layers = int(val) |
| | elif key == 'w': |
| | weight_idx = int(val) |
| | else: |
| | assert False, 'Unknown RandAugment config section' |
| | ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms) |
| | choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) |
| | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) |
| |
|
| |
|
| | _AUGMIX_TRANSFORMS = [ |
| | 'AutoContrast', |
| | 'ColorIncreasing', |
| | 'ContrastIncreasing', |
| | 'BrightnessIncreasing', |
| | 'SharpnessIncreasing', |
| | 'Equalize', |
| | 'Rotate', |
| | 'PosterizeIncreasing', |
| | 'SolarizeIncreasing', |
| | 'ShearX', |
| | 'ShearY', |
| | 'TranslateXRel', |
| | 'TranslateYRel', |
| | ] |
| |
|
| |
|
| | def augmix_ops(magnitude=10, hparams=None, transforms=None): |
| | hparams = hparams or _HPARAMS_DEFAULT |
| | transforms = transforms or _AUGMIX_TRANSFORMS |
| | return [AugmentOp( |
| | name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms] |
| |
|
| |
|
| | class AugMixAugment: |
| | """ AugMix Transform |
| | Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py |
| | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - |
| | https://arxiv.org/abs/1912.02781 |
| | """ |
| | def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False): |
| | self.ops = ops |
| | self.alpha = alpha |
| | self.width = width |
| | self.depth = depth |
| | self.blended = blended |
| |
|
| | def _calc_blended_weights(self, ws, m): |
| | ws = ws * m |
| | cump = 1. |
| | rws = [] |
| | for w in ws[::-1]: |
| | alpha = w / cump |
| | cump *= (1 - alpha) |
| | rws.append(alpha) |
| | return np.array(rws[::-1], dtype=np.float32) |
| |
|
| | def _apply_blended(self, img, mixing_weights, m): |
| | |
| | |
| | |
| | |
| | img_orig = img.copy() |
| | ws = self._calc_blended_weights(mixing_weights, m) |
| | for w in ws: |
| | depth = self.depth if self.depth > 0 else np.random.randint(1, 4) |
| | ops = np.random.choice(self.ops, depth, replace=True) |
| | img_aug = img_orig |
| | for op in ops: |
| | img_aug = op(img_aug) |
| | img = Image.blend(img, img_aug, w) |
| | return img |
| |
|
| | def _apply_basic(self, img, mixing_weights, m): |
| | |
| | |
| | |
| | img_shape = img.size[0], img.size[1], len(img.getbands()) |
| | mixed = np.zeros(img_shape, dtype=np.float32) |
| | for mw in mixing_weights: |
| | depth = self.depth if self.depth > 0 else np.random.randint(1, 4) |
| | ops = np.random.choice(self.ops, depth, replace=True) |
| | img_aug = img |
| | for op in ops: |
| | img_aug = op(img_aug) |
| | mixed += mw * np.asarray(img_aug, dtype=np.float32) |
| | np.clip(mixed, 0, 255., out=mixed) |
| | mixed = Image.fromarray(mixed.astype(np.uint8)) |
| | return Image.blend(img, mixed, m) |
| |
|
| | def __call__(self, img): |
| | mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width)) |
| | m = np.float32(np.random.beta(self.alpha, self.alpha)) |
| | if self.blended: |
| | mixed = self._apply_blended(img, mixing_weights, m) |
| | else: |
| | mixed = self._apply_basic(img, mixing_weights, m) |
| | return mixed |
| |
|
| |
|
| | def augment_and_mix_transform(config_str, hparams): |
| | """ Create AugMix PyTorch transform |
| | |
| | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by |
| | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining |
| | sections, not order sepecific determine |
| | 'm' - integer magnitude (severity) of augmentation mix (default: 3) |
| | 'w' - integer width of augmentation chain (default: 3) |
| | 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) |
| | 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) |
| | 'mstd' - float std deviation of magnitude noise applied (default: 0) |
| | Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 |
| | |
| | :param hparams: Other hparams (kwargs) for the Augmentation transforms |
| | |
| | :return: A PyTorch compatible Transform |
| | """ |
| | magnitude = 3 |
| | width = 3 |
| | depth = -1 |
| | alpha = 1. |
| | blended = False |
| | hparams['magnitude_std'] = float('inf') |
| | config = config_str.split('-') |
| | assert config[0] == 'augmix' |
| | config = config[1:] |
| | for c in config: |
| | cs = re.split(r'(\d.*)', c) |
| | if len(cs) < 2: |
| | continue |
| | key, val = cs[:2] |
| | if key == 'mstd': |
| | |
| | hparams.setdefault('magnitude_std', float(val)) |
| | elif key == 'm': |
| | magnitude = int(val) |
| | elif key == 'w': |
| | width = int(val) |
| | elif key == 'd': |
| | depth = int(val) |
| | elif key == 'a': |
| | alpha = float(val) |
| | elif key == 'b': |
| | blended = bool(val) |
| | else: |
| | assert False, 'Unknown AugMix config section' |
| | ops = augmix_ops(magnitude=magnitude, hparams=hparams) |
| | return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended) |
| |
|