File size: 2,140 Bytes
64ea7b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

import tensorflow as tf

class Augment(tf.keras.layers.Layer):
    def __init__(self, seed=42):
        super().__init__()
        # both use the same seed, so they'll make the same random changes.
        self.augment_inputs = tf.keras.layers.RandomRotation(factor=(-0.9, 0.9), fill_mode="constant",
                                                             interpolation="bilinear", seed=seed, fill_value=0.0)
        self.augment_labels = tf.keras.layers.RandomRotation(factor=(-0.9, 0.9), fill_mode="constant",
                                                             interpolation="bilinear", seed=seed, fill_value=0.0)
        self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=seed)
        self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=seed)

    def call(self, inputs, labels):
        inputs = self.augment_inputs(inputs)
        labels = self.augment_labels(labels)
        return inputs, labels


def augment_flip(image, label, axis=0):
    if axis == 0:

        image = tf.image.random_flip_left_right(image, seed=42)
        label = tf.image.random_flip_left_right(label, seed=42)

    else:

        image = tf.image.random_flip_up_down(image, seed=42)
        label = tf.image.random_flip_up_down(label, seed=42)

    return image, label


def augment_rot(image, label, kappa=1):
    image = tf.image.rot90(image, k=kappa)
    label = tf.image.rot90(label, k=kappa)

    return image, label


def augment(images, labels, seed=42):
    print(type(images))
    print(tf.shape(images))

    images = tf.image.random_flip_left_right(images, seed=seed)
    labels = tf.image.random_flip_left_right(labels, seed=seed)
    images = tf.image.random_flip_up_down(images, seed=seed)
    labels = tf.image.random_flip_up_down(labels, seed=seed)
    images = tf.image.rot90(images, k=2)
    labels = tf.image.rot90(labels, k=2)
    # images = tf.image.random_crop(images, size = [1, IMG_SIZE[0], IMG_SIZE[1], 1], seed = seed)
    # labels = tf.image.random_crop(labels, size = [1, IMG_SIZE[0], IMG_SIZE[1], 1], seed = seed)

    return images, labels