| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from munch import DefaultMunch |
| from common.data_augmentation import random_color, random_affine, random_erasing, random_misc |
|
|
|
|
| def data_augmentation(images, config=None, pixels_range=None, batch_info=None, current_res=None): |
| """ |
| This function is called every time a new batch of input images needs |
| to be augmented before it gets presented to the model to train. |
| It applies to the images all the data augmentation functions that are |
| specified in the `config` argument, which is a dictionary created from |
| the 'data_augmentation' section of the YAML configuration file. |
| |
| Inputs: |
| images: |
| Images to augment,a tensor with shape |
| [batch_size, width, height, channels]. |
| config: |
| Config dictionary created from the YAML file. |
| Contains the names and the arguments of the data augmentation |
| functions to apply to the input images. |
| batch_info: |
| Information passed by the data augmentation layer. |
| A Tensorflow 4D variable with the following elements: |
| - batch number since the beginning of the training |
| - training epoch number |
| - width of the images of the previous batch |
| - height of the images of the previous batch |
| """ |
|
|
| def _get_arg_values(used_args, default_args, function_name): |
| """ |
| This function generates the arguments to use with a data augmentation |
| function to be applied to the images, given the arguments used in |
| the `config` dictionary and the default arguments of the function. |
| """ |
| if used_args is None: |
| |
| used_args = DefaultMunch.fromDict({}) |
| if 'pixels_range' in used_args: |
| raise ValueError("\nThe `pixels_range` argument is managed by the Model Zoo and " |
| "should not be used.\nPlease update the 'data_augmentation' " |
| "section of your configuration file.") |
| args = DefaultMunch.fromDict(default_args) |
| if used_args is not None: |
| for k, v in used_args.items(): |
| if k in default_args: |
| args[k] = used_args[k] |
| else: |
| raise ValueError("\nFunction `{}`: unknown or unsupported argument `{}`\n" |
| "Please check the 'data_augmentation' section of your " |
| "configuration file.".format(function_name, k)) |
| return args |
|
|
|
|
| |
| config = DefaultMunch.fromDict(config) |
| for fn, args in config.items(): |
| if fn == 'random_contrast': |
| default = {'factor': None, 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_contrast( |
| images, |
| factor=args.factor, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_brightness': |
| default = {'factor': None, 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_brightness( |
| images, |
| factor=args.factor, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_gamma': |
| default = {'gamma': None, 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_gamma( |
| images, |
| gamma=args.gamma, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_hue': |
| default = {'delta': None, 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_hue( |
| images, |
| delta=args.delta, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_saturation': |
| default = {'delta': None, 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_saturation( |
| images, |
| delta=args.delta, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_value': |
| default = {'delta': None, 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_value( |
| images, |
| delta=args.delta, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_hsv': |
| default = {'hue_delta': None, 'saturation_delta': None, 'value_delta': None, 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_hsv( |
| images, |
| hue_delta=args.hue_delta, |
| saturation_delta=args.saturation_delta, |
| value_delta=args.value_delta, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_rgb_to_hsv': |
| default = {'change_rate': 0.25} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_rgb_to_hsv( |
| images, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_rgb_to_grayscale': |
| default = {'change_rate': 0.25} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_rgb_to_grayscale( |
| images, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
| |
| elif fn == 'random_sharpness': |
| default = {'factor': None, 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_sharpness( |
| images, |
| factor=args.factor, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_posterize': |
| default = {'bits': None, 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_posterize( |
| images, |
| bits=args.bits, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_invert': |
| default = {'change_rate': 0.25} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_invert( |
| images, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_solarize': |
| default = {'change_rate': 0.25} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_solarize( |
| images, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
| |
| elif fn == 'random_autocontrast': |
| default = {'cutoff': 10, 'change_rate': 0.25} |
| args = _get_arg_values(args, default, fn) |
| images = random_color.random_autocontrast( |
| images, |
| cutoff=args.cutoff, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_blur': |
| default = {'filter_size': None, 'padding': 'reflect', 'constant_values': 0, |
| 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_misc.random_blur( |
| images, |
| filter_size=args.filter_size, |
| padding=args.padding, |
| constant_values=args.constant_values, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_gaussian_noise': |
| default = {'stddev': None, 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_misc.random_gaussian_noise( |
| images, |
| stddev=args.stddev, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
| |
| elif fn == 'random_crop': |
| default = {'crop_center_x': (0.25, 0.75), |
| 'crop_center_y': (0.25, 0.75), |
| 'crop_width': (0.6, 0.9), |
| 'crop_height': (0.6, 0.9), |
| 'interpolation': 'bilinear', |
| 'change_rate': 0.9} |
| args = _get_arg_values(args, default, fn) |
| images = random_misc.random_crop( |
| images, |
| crop_center_x=args.crop_center_x, |
| crop_center_y=args.crop_center_y, |
| crop_width=args.crop_width, |
| crop_height=args.crop_height, |
| interpolation=args.interpolation, |
| change_rate=args.change_rate) |
| |
| elif fn == 'random_jpeg_quality': |
| default = {'jpeg_quality': None, 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_misc.random_jpeg_quality( |
| images, |
| jpeg_quality=args.jpeg_quality, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_flip': |
| default = {'mode': None, 'change_rate': 0.5} |
| args = _get_arg_values(args, default, fn) |
| images = random_affine.random_flip( |
| images, |
| mode=args.mode, |
| change_rate=args.change_rate) |
| |
| elif fn == 'random_translation': |
| default = {'width_factor': None, 'height_factor': None, |
| 'fill_mode': 'reflect', 'interpolation': 'bilinear', 'fill_value': 0.0, |
| 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_affine.random_translation(images, |
| width_factor=args.width_factor, |
| height_factor=args.height_factor, |
| fill_mode=args.fill_mode, |
| interpolation=args.interpolation, |
| fill_value=args.fill_value, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_rotation': |
| default = {'factor': None, |
| 'fill_mode': 'reflect', 'interpolation': 'bilinear', 'fill_value': 0.0, |
| 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_affine.random_rotation( |
| images, |
| factor=args.factor, |
| fill_mode=args.fill_mode, |
| interpolation=args.interpolation, |
| fill_value=args.fill_value, |
| change_rate=args.change_rate) |
|
|
| elif fn in ('random_shear', 'random_shear_x', 'random_shear_y'): |
| default = {'factor': None, |
| 'fill_mode': 'reflect', 'interpolation': 'bilinear', 'fill_value': 0.0, |
| 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| axis = fn[-1] if fn[-2:] in ('_x', '_y') else 'xy' |
| images = random_affine.random_shear( |
| images, |
| factor=args.factor, |
| axis=axis, |
| fill_mode=args.fill_mode, |
| interpolation=args.interpolation, |
| fill_value=args.fill_value, |
| change_rate=args.change_rate) |
|
|
| elif fn == 'random_zoom': |
| default = { |
| 'width_factor': None, 'height_factor': None, |
| 'fill_mode': 'reflect', 'interpolation': 'bilinear', 'fill_value': 0.0, |
| 'change_rate': 1.0} |
| args = _get_arg_values(args, default, fn) |
| images = random_affine.random_zoom( |
| images, |
| width_factor=args.width_factor, |
| height_factor=args.height_factor, |
| fill_mode=args.fill_mode, |
| interpolation=args.interpolation, |
| fill_value=args.fill_value, |
| change_rate=args.change_rate) |
| |
| elif fn == 'random_rectangle_erasing': |
| default = {'nrec': (0, 3), |
| 'area': (0.05, 0.2), |
| 'wh_ratio': (0.2, 1.5), |
| 'fill_method': 'random', |
| 'color': None, |
| 'change_rate': 1.0, |
| 'mode': 'image'} |
| args = _get_arg_values(args, default, fn) |
| images = random_erasing.random_rectangle_erasing( |
| images, |
| nrec=args.nrec, |
| area=args.area, |
| wh_ratio=args.wh_ratio, |
| fill_method=args.fill_method, |
| color=args.color, |
| pixels_range=pixels_range, |
| change_rate=args.change_rate, |
| mode=args.mode) |
|
|
| elif fn == 'random_periodic_resizing': |
| default = {'period': None, 'image_sizes': None, 'interpolation': 'bilinear'} |
| args = _get_arg_values(args, default, fn) |
| images = random_misc.random_periodic_resizing( |
| images, |
| interpolation=args.interpolation, |
| new_image_size=(current_res[1], current_res[0])) |
|
|
| else: |
| raise ValueError(f"\nUnknown or unsupported data augmentation function: `{fn}`\n" |
| "Please check the 'data_augmentation' section of your " |
| "configuration file.") |
|
|
| return images |
|
|
|
|
| def progressive_dataaug(images, config=None, pixels_range=None, batch_info=None): |
|
|
| """ |
| Loads the images from the imagenet dataset, pre-process them and return training, validation, and test tf.data.Datasets. |
| |
| Args: |
| images: |
| Images to augment, a tensor with shape |
| [batch_size, width, height, channels]. |
| config: |
| Config dictionary created from the YAML file. |
| Contains the names and the arguments of the data augmentation |
| functions to apply to the input images. Not used so far |
| pixels_range: |
| A tuple of 2 integers or floats, specifies the range of pixel |
| values in the input images and output images. Any range is |
| supported. It generally is either [0, 255], [0, 1] or [-1, 1]. |
| batch_info: |
| Information passed by the data augmentation layer. |
| A Tensorflow 4D variable with the following elements: |
| - batch number since the beginning of the training |
| - training epoch number |
| - width of the images of the previous batch |
| - height of the images of the previous batch |
| Returns: |
| Augmented images with variable data augmentation settings depending on epoch number |
| |
| """ |
|
|
| epoch = batch_info[1] |
| if epoch < 40: |
| images = random_affine.random_flip(images, mode="horizontal_and_vertical", change_rate=0.1) |
| elif epoch < 80: |
| images = random_affine.random_flip(images, mode="horizontal_and_vertical", change_rate=0.3) |
| elif epoch < 120: |
| images = random_affine.random_flip(images, mode="horizontal_and_vertical", change_rate=0.5) |
| elif epoch < 160: |
| images = random_affine.random_flip(images, mode="horizontal_and_vertical", change_rate=0.5) |
| images = random_affine.random_translation(images, width_factor=0.2, height_factor=0.2) |
| images = random_affine.random_zoom(images, width_factor=0.4) |
| else: |
| images = random_affine.random_flip(images, mode="horizontal_and_vertical", change_rate=0.5) |
| images = random_affine.random_translation(images, width_factor=0.2, height_factor=0.2) |
| images = random_affine.random_zoom(images, width_factor=0.4) |
| images = random_color.random_contrast(images, factor=0.7, pixels_range=pixels_range) |
| images = random_color.random_brightness(images, factor=0.5, pixels_range=pixels_range) |
| images = random_color.random_invert(images, change_rate=0.1) |
| return images |
|
|
| |