import tensorflow as tf class Augment(tf.keras.layers.Layer): def __init__(self, seed=42): super().__init__() # both use the same seed, so they'll make the same random changes. self.augment_inputs = tf.keras.layers.RandomRotation(factor=(-0.9, 0.9), fill_mode="constant", interpolation="bilinear", seed=seed, fill_value=0.0) self.augment_labels = tf.keras.layers.RandomRotation(factor=(-0.9, 0.9), fill_mode="constant", interpolation="bilinear", seed=seed, fill_value=0.0) self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=seed) self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=seed) def call(self, inputs, labels): inputs = self.augment_inputs(inputs) labels = self.augment_labels(labels) return inputs, labels def augment_flip(image, label, axis=0): if axis == 0: image = tf.image.random_flip_left_right(image, seed=42) label = tf.image.random_flip_left_right(label, seed=42) else: image = tf.image.random_flip_up_down(image, seed=42) label = tf.image.random_flip_up_down(label, seed=42) return image, label def augment_rot(image, label, kappa=1): image = tf.image.rot90(image, k=kappa) label = tf.image.rot90(label, k=kappa) return image, label def augment(images, labels, seed=42): print(type(images)) print(tf.shape(images)) images = tf.image.random_flip_left_right(images, seed=seed) labels = tf.image.random_flip_left_right(labels, seed=seed) images = tf.image.random_flip_up_down(images, seed=seed) labels = tf.image.random_flip_up_down(labels, seed=seed) images = tf.image.rot90(images, k=2) labels = tf.image.rot90(labels, k=2) # images = tf.image.random_crop(images, size = [1, IMG_SIZE[0], IMG_SIZE[1], 1], seed = seed) # labels = tf.image.random_crop(labels, size = [1, IMG_SIZE[0], IMG_SIZE[1], 1], seed = seed) return images, labels