| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Implementation of data preprocessing ops. |
| |
| All preprocessing ops should return a data processing functors. A data |
| is represented as a dictionary of tensors, where field "image" is reserved |
| for 3D images (height x width x channels). The functors output dictionary with |
| field "image" being modified. Potentially, other fields can also be modified |
| or added. |
| """ |
| from typing import Optional, Tuple |
| import numpy as np |
|
|
| from scenic.dataset_lib.big_transfer.preprocessing import autoaugment |
| from scenic.dataset_lib.big_transfer.preprocessing import utils |
| from scenic.dataset_lib.big_transfer.registry import Registry |
| import tensorflow.compat.v1 as tf |
| import tensorflow.compat.v2 as tf2 |
|
|
| from tensorflow_addons import image as image_utils |
|
|
|
|
| @Registry.register("preprocess_ops.color_distort", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_color_distortion(): |
| """Applies random brigthness/saturation/hue/contrast transformations.""" |
|
|
| def _color_distortion(image): |
| image = tf.image.random_brightness(image, max_delta=128. / 255.) |
| image = tf.image.random_saturation(image, lower=0.1, upper=2.0) |
| image = tf.image.random_hue(image, max_delta=0.5) |
| image = tf.image.random_contrast(image, lower=0.1, upper=2.0) |
| return image |
|
|
| return _color_distortion |
|
|
|
|
| @Registry.register("preprocess_ops.random_brightness", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_random_brightness(max_delta=0.1): |
| """Applies random brigthness transformations.""" |
|
|
| |
| |
| def _random_brightness(image): |
| return tf.image.random_brightness(image, max_delta) |
|
|
| return _random_brightness |
|
|
|
|
| @Registry.register("preprocess_ops.random_saturation", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_random_saturation(lower=0.5, upper=2.0): |
| """Applies random saturation transformations.""" |
|
|
| |
| |
| def _random_saturation(image): |
| return tf.image.random_saturation(image, lower=lower, upper=upper) |
|
|
| return _random_saturation |
|
|
|
|
| @Registry.register("preprocess_ops.random_hue", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_random_hue(max_delta=0.1): |
| """Applies random hue transformations.""" |
|
|
| |
| |
| def _random_hue(image): |
| return tf.image.random_hue(image, max_delta=max_delta) |
|
|
| return _random_hue |
|
|
|
|
| @Registry.register("preprocess_ops.random_contrast", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_random_contrast(lower=0.5, upper=2.0): |
| """Applies random contrast transformations.""" |
|
|
| |
| |
| def _random_contrast(image): |
| return tf.image.random_contrast(image, lower=lower, upper=upper) |
|
|
| return _random_contrast |
|
|
|
|
| @Registry.register("preprocess_ops.decode", "function") |
| @utils.InKeyOutKey() |
| def get_decode(channels=3): |
| """Decode an encoded image string, see tf.io.decode_image.""" |
|
|
| def _decode(image): |
| |
| |
| |
| return tf.io.decode_jpeg(image, channels=channels) |
|
|
| return _decode |
|
|
|
|
| @Registry.register("preprocess_ops.decode_grayscale", "function") |
| @utils.InKeyOutKey() |
| def get_decode_grayscale(channels=1): |
| """Decode an encoded image string, see tf.io.decode_image.""" |
|
|
| def _decode_gray(image): |
| |
| |
| |
| return tf.io.decode_jpeg(image, channels=channels) |
|
|
| return _decode_gray |
|
|
|
|
| @Registry.register("preprocess_ops.pad", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_pad(pad_size): |
| """Pads an image. |
| |
| Args: |
| pad_size: either an integer u giving verticle and horizontal pad sizes u, or |
| a list or tuple [u, v] of integers where u and v are vertical and |
| horizontal pad sizes. |
| |
| Returns: |
| A function for padding an image. |
| |
| """ |
| pad_size = utils.maybe_repeat(pad_size, 2) |
|
|
| def _pad(image): |
| return tf.pad( |
| image, [[pad_size[0], pad_size[0]], [pad_size[1], pad_size[1]], [0, 0]]) |
|
|
| return _pad |
|
|
|
|
| @Registry.register("preprocess_ops.resize", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_resize(resize_size, method=tf2.image.ResizeMethod.BILINEAR, |
| antialias=False): |
| """Resizes image to a given size. |
| |
| Args: |
| resize_size: either an integer H, where H is both the new height and width |
| of the resized image, or a list or tuple [H, W] of integers, where H and W |
| are new image"s height and width respectively. |
| method: The type of interpolation to apply when resizing. |
| antialias: Whether to use an anti-aliasing filter when downsampling an |
| image. |
| |
| Returns: |
| A function for resizing an image. |
| |
| """ |
| resize_size = utils.maybe_repeat(resize_size, 2) |
|
|
| def _resize(image): |
| """Resizes image to a given size.""" |
| |
| |
| |
| |
| |
| dtype = image.dtype |
| image = tf2.image.resize( |
| images=image, size=resize_size, method=method, antialias=antialias) |
| return tf.cast(image, dtype) |
|
|
| return _resize |
|
|
|
|
| @Registry.register("preprocess_ops.resize_small", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_resize_small(smaller_size, method="area", antialias=True): |
| """Resizes the smaller side to `smaller_size` keeping aspect ratio. |
| |
| Args: |
| smaller_size: an integer, that represents a new size of the smaller side of |
| an input image. |
| method: the resize method. `area` is a meaningful, bwd-compat default. |
| antialias: See TF's image.resize method. |
| |
| Returns: |
| A function, that resizes an image and preserves its aspect ratio. |
| |
| """ |
|
|
| def _resize_small(image): |
| h, w = tf.shape(image)[0], tf.shape(image)[1] |
|
|
| |
| ratio = ( |
| tf.cast(smaller_size, tf.float32) / |
| tf.cast(tf.minimum(h, w), tf.float32)) |
| h = tf.cast(tf.round(tf.cast(h, tf.float32) * ratio), tf.int32) |
| w = tf.cast(tf.round(tf.cast(w, tf.float32) * ratio), tf.int32) |
|
|
| dtype = image.dtype |
| image = tf2.image.resize(image, (h, w), method, antialias) |
| return tf.cast(image, dtype) |
|
|
| return _resize_small |
|
|
|
|
| @Registry.register("preprocess_ops.inception_crop", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_inception_crop(resize_size=None, area_min=5, area_max=100, |
| resize_method=tf2.image.ResizeMethod.BILINEAR, |
| resize_antialias=False): |
| """Makes inception-style image crop. |
| |
| Inception-style crop is a random image crop (its size and aspect ratio are |
| random) that was used for training Inception models, see |
| https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf. |
| |
| Args: |
| resize_size: Resize image to [resize_size, resize_size] after crop. |
| area_min: minimal crop area. |
| area_max: maximal crop area. |
| resize_method: The type of interpolation to apply when resizing. Valid |
| values those accepted by tf.image.resize. |
| resize_antialias: Whether to use an anti-aliasing filter when downsampling |
| an image. |
| |
| Returns: |
| A function, that applies inception crop. |
| """ |
|
|
| def _inception_crop(image): |
| begin, size, _ = tf.image.sample_distorted_bounding_box( |
| tf.shape(image), |
| tf.zeros([0, 0, 4], tf.float32), |
| area_range=(area_min / 100, area_max / 100), |
| min_object_covered=0, |
| use_image_if_no_bounding_boxes=True) |
| crop = tf.slice(image, begin, size) |
| |
| |
| crop.set_shape([None, None, image.shape[-1]]) |
| if resize_size: |
| crop = get_resize( |
| [resize_size, resize_size], resize_method, resize_antialias)( |
| {"image": crop})["image"] |
| return crop |
|
|
| return _inception_crop |
|
|
|
|
| @Registry.register("preprocess_ops.decode_jpeg_and_inception_crop", "function") |
| @utils.InKeyOutKey() |
| def get_decode_jpeg_and_inception_crop( |
| resize_size=None, |
| area_min=5, |
| area_max=100, |
| aspect_ratio_range=None, |
| resize_method=tf2.image.ResizeMethod.BILINEAR, |
| resize_antialias=False): |
| """Decode jpeg string and make inception-style image crop. |
| |
| Inception-style crop is a random image crop (its size and aspect ratio are |
| random) that was used for training Inception models, see |
| https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf. |
| |
| Args: |
| resize_size: Resize image to [resize_size, resize_size] after crop. |
| area_min: minimal crop area. |
| area_max: maximal crop area. |
| aspect_ratio_range: An optional list of floats. Defaults to [0.75, 1.33]. |
| The cropped area of the image must have an aspect ratio = width / height |
| within this range. |
| resize_method: The type of interpolation to apply when resizing. Valid |
| values those accepted by tf.image.resize. |
| resize_antialias: Whether to use an anti-aliasing filter when downsampling |
| an image. |
| |
| Returns: |
| A function, that applies inception crop. |
| """ |
|
|
| def _inception_crop(image_data): |
| shape = tf.image.extract_jpeg_shape(image_data) |
| begin, size, _ = tf.image.sample_distorted_bounding_box( |
| shape, |
| tf.zeros([0, 0, 4], tf.float32), |
| area_range=(area_min / 100, area_max / 100), |
| min_object_covered=0, |
| aspect_ratio_range=aspect_ratio_range, |
| use_image_if_no_bounding_boxes=True) |
|
|
| |
| offset_y, offset_x, _ = tf.unstack(begin) |
| target_height, target_width, _ = tf.unstack(size) |
| crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) |
| image = tf.image.decode_and_crop_jpeg(image_data, crop_window, channels=3) |
|
|
| if resize_size: |
| image = get_resize( |
| [resize_size, resize_size], resize_method, resize_antialias)( |
| {"image": image})["image"] |
|
|
| return image |
|
|
| return _inception_crop |
|
|
|
|
| @Registry.register("preprocess_ops.decode_jpeg_and_center_crop", "function") |
| @utils.InKeyOutKey() |
| def get_decode_jpeg_and_center_crop(crop_size=None): |
| """Decode jpeg string and make a center image crop. |
| |
| Args: |
| crop_size: Crop image to [crop_size, crop_size]. |
| |
| Returns: |
| A function that applies center crop. |
| """ |
|
|
| crop_size = utils.maybe_repeat(crop_size, 2) |
|
|
| def _decode_and_center_crop(image_data): |
| shape = tf.image.extract_jpeg_shape(image_data) |
| target_height, target_width = crop_size |
|
|
| offset_y = (shape[0] - target_height) // 2 |
| offset_x = (shape[1] - target_width) // 2 |
|
|
| |
| crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) |
| image = tf.image.decode_and_crop_jpeg(image_data, crop_window, channels=3) |
| image.set_shape([target_height, target_width, 3]) |
| return image |
|
|
| return _decode_and_center_crop |
|
|
|
|
| @Registry.register("preprocess_ops.decode_jpeg_and_random_crop", "function") |
| @utils.InKeyOutKey() |
| def get_decode_jpeg_and_random_crop(crop_size=None): |
| """Decode jpeg string and make a center image crop. |
| |
| Args: |
| crop_size: Crop image to [crop_size, crop_size]. |
| |
| Returns: |
| A function that applies center crop. |
| """ |
|
|
| crop_size = utils.maybe_repeat(crop_size, 2) |
|
|
| def _decode_and_random_crop(image_data): |
| shape = tf.image.extract_jpeg_shape(image_data)[:2] |
| target_height, target_width = crop_size |
| limit = shape - crop_size + 1 |
| offset = tf.random.uniform([2], 0, tf.int32.max, dtype=tf.int32) % limit |
|
|
| |
| crop_window = tf.stack([offset[0], offset[1], target_height, target_width]) |
| image = tf.image.decode_and_crop_jpeg(image_data, crop_window, channels=3) |
| image.set_shape([target_height, target_width, 3]) |
| return image |
|
|
| return _decode_and_random_crop |
|
|
|
|
| @Registry.register("preprocess_ops.random_crop", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_random_crop(crop_size): |
| """Makes a random crop of a given size. |
| |
| Args: |
| crop_size: either an integer H, where H is both the height and width of the |
| random crop, or a list or tuple [H, W] of integers, where H and W are |
| height and width of the random crop respectively. |
| |
| Returns: |
| A function, that applies random crop. |
| """ |
| crop_size = utils.maybe_repeat(crop_size, 2) |
|
|
| def _crop(image): |
| return tf.random_crop(image, [crop_size[0], crop_size[1], image.shape[-1]]) |
|
|
| return _crop |
|
|
|
|
| @Registry.register("preprocess_ops.central_crop", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_central_crop(crop_size): |
| """Makes central crop of a given size. |
| |
| Args: |
| crop_size: either an integer H, where H is both the height and width of the |
| central crop, or a list or tuple [H, W] of integers, where H and W are |
| height and width of the central crop respectively. |
| |
| Returns: |
| A function, that applies central crop. |
| """ |
| crop_size = utils.maybe_repeat(crop_size, 2) |
|
|
| def _crop(image): |
| h, w = crop_size[0], crop_size[1] |
| dy = (tf.shape(image)[0] - h) // 2 |
| dx = (tf.shape(image)[1] - w) // 2 |
| return tf.image.crop_to_bounding_box(image, dy, dx, h, w) |
|
|
| return _crop |
|
|
|
|
| @Registry.register("preprocess_ops.central_crop_longer", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_central_crop_longer(): |
| """Center crop the longer side so that the image becomes a square. |
| |
| Args: |
| |
| Returns: |
| A function, that applies central crop. |
| """ |
|
|
| def _crop(image): |
| shape = tf.shape(image) |
| h, w = shape[0], shape[1] |
| crop_fn = tf.image.crop_to_bounding_box |
| return tf.cond( |
| h > w, |
| lambda: crop_fn(image, h // 2 - w // 2, 0, w, w), |
| lambda: crop_fn(image, 0, w // 2 - h // 2, h, h)) |
|
|
| return _crop |
|
|
|
|
| @Registry.register("preprocess_ops.flip_lr", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_random_flip_lr(): |
| """Flips an image horizontally with probability 50%.""" |
|
|
| def _random_flip_lr_pp(image): |
| return tf.image.random_flip_left_right(image) |
|
|
| return _random_flip_lr_pp |
|
|
|
|
| @Registry.register("preprocess_ops.flip_ud", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_random_flip_ud(): |
| """Flips an image vertically with probability 50%.""" |
|
|
| def _random_flip_ud_pp(image): |
| return tf.image.random_flip_up_down(image) |
|
|
| return _random_flip_ud_pp |
|
|
|
|
| @Registry.register("preprocess_ops.random_rotate", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_random_rotation(min_angle=0, max_angle=360): |
| """Randomly rotate an image.""" |
| if min_angle > max_angle: |
| raise ValueError("min_angle (%f) must be lower than max_angle (%f)" % |
| (min_angle, max_angle)) |
| |
| min_angle = np.radians(min_angle) |
| max_angle = np.radians(max_angle) |
|
|
| def _random_rotation(image): |
| """Rotation function.""" |
| num_dims = len(image.shape) |
| if num_dims in [3, 4]: |
| batch_size = tf.shape(image)[0] if num_dims == 4 else 1 |
| else: |
| raise ValueError("Tensor \"image\" should have 3 or 4 dimensions.") |
| random_angles = tf.random.uniform( |
| shape=(batch_size,), minval=min_angle, maxval=max_angle) |
| return image_utils.rotate(images=image, angles=random_angles) |
|
|
| return _random_rotation |
|
|
|
|
| @Registry.register("preprocess_ops.random_rotate90", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_random_rotation90(): |
| """Randomly rotate an image by multiples of 90 degrees.""" |
|
|
| def _random_rotation90(image): |
| """Rotation function.""" |
| num_rotations = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32) |
| return tf.image.rot90(image, k=num_rotations) |
|
|
| return _random_rotation90 |
|
|
|
|
| @Registry.register("preprocess_ops.rotate", "function") |
| def get_rotate(create_labels=None): |
| """Returns a function that does 90deg rotations and sets according labels. |
| |
| Args: |
| create_labels: create new labels to the default label field in the input |
| dictionary. It should be set to one of ['rotation', 'supervised', None]. |
| |
| Returns: |
| A function, that applies rotation preprocess. |
| """ |
|
|
| def _four_rots(img): |
| """Rotates an image four times, with 90 degrees between each rotation.""" |
| return tf.stack([ |
| img, |
| tf.transpose(tf.reverse_v2(img, [1]), [1, 0, 2]), |
| tf.reverse_v2(img, [0, 1]), |
| tf.reverse_v2(tf.transpose(img, [1, 0, 2]), [1]), |
| ]) |
|
|
| def _rotate_pp(data): |
| """Rotate preprocessing function applied on data dictionary input.""" |
| assert create_labels in [ |
| "rotation", "supervised", None |
| ], ("create_labels:{} must be one of ['rotation', 'supervised', None]." |
| .format(create_labels)) |
|
|
| |
| if create_labels == "rotation": |
| data["label"] = tf.constant([0, 1, 2, 3]) |
| |
| elif create_labels == "supervised": |
| if "label" in data: |
| data["label"] = tf.stack(tf.tile([data["label"]], [4])) |
| |
| data["image"] = _four_rots(data["image"]) |
| data["rot_label"] = tf.constant([0, 1, 2, 3]) |
|
|
| return data |
|
|
| return _rotate_pp |
|
|
|
|
| @Registry.register("preprocess_ops.value_range", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing(output_dtype=tf.float32) |
| def get_value_range(vmin=-1, vmax=1, in_min=0, in_max=255.0, clip_values=False): |
| """Transforms a [in_min,in_max] image to [vmin,vmax] range. |
| |
| Input ranges in_min/in_max can be equal-size lists to rescale the invidudal |
| channels independently. |
| |
| Args: |
| vmin: A scalar. Output max value. |
| vmax: A scalar. Output min value. |
| in_min: A scalar or a list of input min values to scale. If a list, the |
| length should match to the number of channels in the image. |
| in_max: A scalar or a list of input max values to scale. If a list, the |
| length should match to the number of channels in the image. |
| clip_values: Whether to clip the output values to the provided ranges. |
| |
| Returns: |
| A function to rescale the values. |
| """ |
|
|
| def _value_range(image): |
| """Scales values in given range.""" |
| in_min_t = tf.constant(in_min, tf.float32) |
| in_max_t = tf.constant(in_max, tf.float32) |
| image = tf.cast(image, tf.float32) |
| image = (image - in_min_t) / (in_max_t - in_min_t) |
| image = vmin + image * (vmax - vmin) |
| if clip_values: |
| image = tf.clip_by_value(image, vmin, vmax) |
| return image |
|
|
| return _value_range |
|
|
|
|
| @Registry.register("preprocess_ops.value_range_mc", "function") |
| def get_value_range_mc(vmin, vmax, *args): |
| """Independent multi-channel rescaling.""" |
| if len(args) % 2: |
| raise ValueError("Additional args must be list of even length giving " |
| "`in_max` and `in_min` concatenated") |
| num_channels = len(args) // 2 |
| in_min = args[:num_channels] |
| in_max = args[-num_channels:] |
|
|
| return get_value_range(vmin, vmax, in_min, in_max) |
|
|
|
|
| @Registry.register("preprocess_ops.delete_field", "function") |
| def get_delete_field(key): |
|
|
| def _delete_field(datum): |
| if key in datum: |
| del datum[key] |
| return datum |
|
|
| return _delete_field |
|
|
|
|
| @Registry.register("preprocess_ops.replicate", "function") |
| @utils.InKeyOutKey() |
| def get_replicate(num_replicas=2): |
| """Replicates an image `num_replicas` times along a new batch dimension.""" |
|
|
| def _replicate(image): |
| tiles = [num_replicas] + [1] * len(image.shape) |
| return tf.tile(image[None], tiles) |
|
|
| return _replicate |
|
|
|
|
| @Registry.register("preprocess_ops.standardize", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing(output_dtype=tf.float32) |
| def get_standardize(mean, std): |
| """Standardize an image with the given mean and standard deviation.""" |
|
|
| def _standardize(image): |
| image = tf.cast(image, dtype=tf.float32) |
| return (image - mean) / std |
|
|
| return _standardize |
|
|
|
|
| @Registry.register("preprocess_ops.select_channels", "function") |
| @utils.InKeyOutKey() |
| @utils.BatchedImagePreprocessing() |
| def get_select_channels(channels): |
| """Returns function to select specified channels.""" |
|
|
| def _select_channels(image): |
| """Returns a subset of available channels.""" |
| return tf.gather(image, channels, axis=-1) |
|
|
| return _select_channels |
|
|
|
|
| @Registry.register("preprocess_ops.extract_patches", "function") |
| @utils.InKeyOutKey() |
| def get_extract_patches(patch_size, stride): |
| """Extracts image patches. |
| |
| Args: |
| patch_size: patch size. |
| stride: patches stride. |
| |
| Returns: |
| A function for extracting patches. |
| """ |
|
|
| def _extract_patches(image): |
| """Extracts image patches.""" |
| h, w, c = image.get_shape().as_list() |
|
|
| count_h = h // stride |
| count_w = w // stride |
|
|
| |
| image = tf.extract_image_patches(image[None], |
| [1, patch_size, patch_size, 1], |
| [1, stride, stride, 1], |
| [1, 1, 1, 1], |
| padding="VALID") |
| |
|
|
| return tf.reshape(image, [count_h * count_w, patch_size, patch_size, c]) |
|
|
| return _extract_patches |
|
|
|
|
| @Registry.register("preprocess_ops.onehot", "function") |
| def get_onehot(depth, |
| key="labels", |
| key_result=None, |
| multi=True, |
| on=1.0, |
| off=0.0): |
| """One-hot encodes the input. |
| |
| Args: |
| depth: Length of the one-hot vector (how many classes). |
| key: Key of the data to be one-hot encoded. |
| key_result: Key under which to store the result (same as `key` if None). |
| multi: If there are multiple labels, whether to merge them into the same |
| "multi-hot" vector (True) or keep them as an extra dimension (False). |
| on: Value to fill in for the positive label (default: 1). |
| off: Value to fill in for negative labels (default: 0). |
| |
| Returns: |
| Data dictionary. |
| """ |
|
|
| def _onehot(data): |
| |
| |
| labels = data[key] |
| if labels.shape.rank > 0 and multi: |
| |
| |
| |
| x = tf.scatter_nd(labels[:, None], tf.ones(tf.shape(labels)[0]), (depth,)) |
| x = tf.clip_by_value(x, 0, 1) * (on - off) + off |
| else: |
| assert np.isclose(on + off * (depth - 1), 1), ( |
| "All on and off values must sum to 1") |
| x = tf.one_hot(labels, depth, on_value=on, off_value=off) |
| data[key_result or key] = x |
| return data |
|
|
| return _onehot |
|
|
|
|
| @Registry.register("preprocess_ops.keep", "function") |
| def get_keep(*keys): |
| """Keeps only the given keys.""" |
|
|
| def _keep(data): |
| return {k: v for k, v in data.items() if k in keys} |
|
|
| return _keep |
|
|
|
|
| @Registry.register("preprocess_ops.drop", "function") |
| def get_drop(*keys): |
| """Drops the given keys.""" |
|
|
| def _drop(data): |
| return {k: v for k, v in data.items() if k not in keys} |
|
|
| return _drop |
|
|
|
|
| @Registry.register("preprocess_ops.copy", "function") |
| def get_copy(inkey, outkey): |
| """Copies value of `inkey` into `outkey`.""" |
|
|
| def _copy(data): |
| data[outkey] = data[inkey] |
| return data |
|
|
| return _copy |
|
|
|
|
| @Registry.register("preprocess_ops.randaug", "function") |
| @utils.InKeyOutKey() |
| def get_randaug(num_layers: int = 2, magnitude: int = 10): |
| """Creates a function that applies RandAugment. |
| |
| RandAugment is from the paper https://arxiv.org/abs/1909.13719, |
| |
| Args: |
| num_layers: Integer, the number of augmentation transformations to apply |
| sequentially to an image. Represented as (N) in the paper. Usually best |
| values will be in the range [1, 3]. |
| magnitude: Integer, shared magnitude across all augmentation operations. |
| Represented as (M) in the paper. Usually best values are in the range [5, |
| 30]. |
| |
| Returns: |
| A function that applies RandAugment. |
| """ |
|
|
| def _randaug(image): |
| return autoaugment.distort_image_with_randaugment( |
| image=image, |
| num_layers=num_layers, |
| magnitude=magnitude, |
| ) |
|
|
| return _randaug |
|
|
|
|
| @Registry.register("preprocess_ops.patchify", "function") |
| @utils.InKeyOutKey() |
| def patchify(patch_size: Tuple[int, int], stride: Tuple[int, int]): |
| """Patchifies image. |
| |
| If image is of size (h, w, c), patchify it into (h//p*w//p, p*p*c) |
| |
| Args: |
| patch_size: Integer. |
| stride: Integer. |
| |
| Returns: |
| A function that applies RandAugment. |
| """ |
|
|
| def _extract_patches(image): |
| """Extracts image patches.""" |
| h, w, _ = image.get_shape().as_list() |
|
|
| count_h = h // stride[0] |
| count_w = w // stride[1] |
|
|
| |
| image = tf.extract_image_patches(image[None], |
| [1, patch_size[0], patch_size[1], 1], |
| [1, stride[0], stride[1], 1], |
| [1, 1, 1, 1], |
| padding="VALID") |
| |
| return tf.reshape(image, [count_h * count_w, -1]) |
|
|
| return _extract_patches |
|
|
|
|
|
|