Spaces:
Paused
Paused
| # Copyright 2022 Google LLC | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Dataset augmentation for frame interpolation.""" | |
| from typing import Callable, Dict, List | |
| import gin.tf | |
| import numpy as np | |
| import tensorflow as tf | |
| import tensorflow.math as tfm | |
| import tensorflow_addons.image as tfa_image | |
| _PI = 3.141592653589793 | |
| def _rotate_flow_vectors(flow: tf.Tensor, angle_rad: float) -> tf.Tensor: | |
| r"""Rotate the (u,v) vector of each pixel with angle in radians. | |
| Flow matrix system of coordinates. | |
| . . . . u (x) | |
| . | |
| . | |
| . v (-y) | |
| Rotation system of coordinates. | |
| . y | |
| . | |
| . | |
| . . . . x | |
| Args: | |
| flow: Flow map which has been image-rotated. | |
| angle_rad: The rotation angle in radians. | |
| Returns: | |
| A flow with the same map but each (u,v) vector rotated by angle_rad. | |
| """ | |
| u, v = tf.split(flow, 2, axis=-1) | |
| # rotu = u * cos(angle) - (-v) * sin(angle) | |
| rot_u = tfm.cos(angle_rad) * u + tfm.sin(angle_rad) * v | |
| # rotv = -(u * sin(theta) + (-v) * cos(theta)) | |
| rot_v = -tfm.sin(angle_rad) * u + tfm.cos(angle_rad) * v | |
| return tf.concat((rot_u, rot_v), axis=-1) | |
| def flow_rot90(flow: tf.Tensor, k: int) -> tf.Tensor: | |
| """Rotates a flow by a multiple of 90 degrees. | |
| Args: | |
| flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees. | |
| k: The multiplier factor. | |
| Returns: | |
| A flow image of the same shape as the input rotated by multiples of 90 | |
| degrees. | |
| """ | |
| angle_rad = tf.cast(k, dtype=tf.float32) * 90. * (_PI/180.) | |
| flow = tf.image.rot90(flow, k) | |
| return _rotate_flow_vectors(flow, angle_rad) | |
| def rotate_flow(flow: tf.Tensor, angle_rad: float) -> tf.Tensor: | |
| """Rotates a flow by a the provided angle in radians. | |
| Args: | |
| flow: The flow image shaped (H, W, 2) to rotate by multiples of 90 degrees. | |
| angle_rad: The angle to ratate the flow in radians. | |
| Returns: | |
| A flow image of the same shape as the input rotated by the provided angle in | |
| radians. | |
| """ | |
| flow = tfa_image.rotate( | |
| flow, | |
| angles=angle_rad, | |
| interpolation='bilinear', | |
| fill_mode='reflect') | |
| return _rotate_flow_vectors(flow, angle_rad) | |
| def flow_flip(flow: tf.Tensor) -> tf.Tensor: | |
| """Flips a flow left to right. | |
| Args: | |
| flow: The flow image shaped (H, W, 2) to flip left to right. | |
| Returns: | |
| A flow image of the same shape as the input flipped left to right. | |
| """ | |
| flow = tf.image.flip_left_right(tf.identity(flow)) | |
| flow_u, flow_v = tf.split(flow, 2, axis=-1) | |
| return tf.stack([-1 * flow_u, flow_v], axis=-1) | |
| def random_image_rot90(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: | |
| """Rotates a stack of images by a random multiples of 90 degrees. | |
| Args: | |
| images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the | |
| channel's axis. | |
| Returns: | |
| A tf.Tensor of the same rank as the `images` after random rotation by | |
| multiples of 90 degrees applied counter-clock wise. | |
| """ | |
| random_k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32) | |
| for key in images: | |
| images[key] = tf.image.rot90(images[key], k=random_k) | |
| return images | |
| def random_flip(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: | |
| """Flips a stack of images randomly. | |
| Args: | |
| images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the | |
| channel's axis. | |
| Returns: | |
| A tf.Tensor of the images after random left to right flip. | |
| """ | |
| prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32) | |
| prob = tf.cast(prob, tf.bool) | |
| def _identity(image): | |
| return image | |
| def _flip_left_right(image): | |
| return tf.image.flip_left_right(image) | |
| # pylint: disable=cell-var-from-loop | |
| for key in images: | |
| images[key] = tf.cond(prob, lambda: _flip_left_right(images[key]), | |
| lambda: _identity(images[key])) | |
| return images | |
| def random_reverse(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: | |
| """Reverses a stack of images randomly. | |
| Args: | |
| images: A dictionary of tf.Tensors, each shaped (H, W, num_channels), with | |
| each tensor being a stack of iamges along the last channel axis. | |
| Returns: | |
| A dictionary of tf.Tensors, each shaped the same as the input images dict. | |
| """ | |
| prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32) | |
| prob = tf.cast(prob, tf.bool) | |
| def _identity(images): | |
| return images | |
| def _reverse(images): | |
| images['x0'], images['x1'] = images['x1'], images['x0'] | |
| return images | |
| return tf.cond(prob, lambda: _reverse(images), lambda: _identity(images)) | |
| def random_rotate(images: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: | |
| """Rotates image randomly with [-45 to 45 degrees]. | |
| Args: | |
| images: A tf.Tensor shaped (H, W, num_channels) of images stacked along the | |
| channel's axis. | |
| Returns: | |
| A tf.Tensor of the images after random rotation with a bound of -72 to 72 | |
| degrees. | |
| """ | |
| prob = tf.random.uniform((), minval=0, maxval=2, dtype=tf.int32) | |
| prob = tf.cast(prob, tf.float32) | |
| random_angle = tf.random.uniform((), | |
| minval=-0.25 * np.pi, | |
| maxval=0.25 * np.pi, | |
| dtype=tf.float32) | |
| for key in images: | |
| images[key] = tfa_image.rotate( | |
| images[key], | |
| angles=random_angle * prob, | |
| interpolation='bilinear', | |
| fill_mode='constant') | |
| return images | |
| def data_augmentations( | |
| names: List[str]) -> Dict[str, Callable[..., tf.Tensor]]: | |
| """Creates the data augmentation functions. | |
| Args: | |
| names: The list of augmentation function names. | |
| Returns: | |
| A dictionary of Callables to the augmentation functions, keyed by their | |
| names. | |
| """ | |
| augmentations = dict() | |
| for name in names: | |
| if name == 'random_image_rot90': | |
| augmentations[name] = random_image_rot90 | |
| elif name == 'random_rotate': | |
| augmentations[name] = random_rotate | |
| elif name == 'random_flip': | |
| augmentations[name] = random_flip | |
| elif name == 'random_reverse': | |
| augmentations[name] = random_reverse | |
| else: | |
| raise AttributeError('Invalid augmentation function %s' % name) | |
| return augmentations | |