File size: 3,238 Bytes
91f8c72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import tensorflow as tf
from typing import Any, Tuple
import tensorflow_addons as tfda


class ResidualBlock(tf.keras.layers.Layer):

    def __init__(self, filter_num: int, filter_size: int, seed: Any = None, name=None, padding="default",
                 instance_normalization: bool = False):
        super(ResidualBlock, self).__init__(name=name)
        self.filter_num = filter_num
        self.filter_size = filter_size
        self.seed = seed
        self.padding_type = padding

        self.activation_1 = tf.keras.layers.Activation("linear", trainable=False)
        if padding == "default":
            self.conv_1 = tf.keras.layers.Conv2D(filters=self.filter_num, kernel_size=self.filter_size,
                                                 padding="same", trainable=True)
        elif padding == "reflect":
            self.pad_1 = ReflectionPadding2D(padding=(1, 1))
            self.conv_1 = tf.keras.layers.Conv2D(filters=self.filter_num, kernel_size=self.filter_size,
                                                 padding="valid", trainable=True)
        else:
            raise RuntimeError("Non valid padding type.")

        self.activation_2 = tf.keras.layers.Activation("relu")

        if instance_normalization:
            self.bn_1 = tfda.layers.InstanceNormalization(trainable=True)
            self.bn_2 = tfda.layers.InstanceNormalization(trainable=True)
        else:
            self.bn_1 = tf.keras.layers.BatchNormalization(trainable=True)
            self.bn_2 = tf.keras.layers.BatchNormalization(trainable=True)

        if padding == "default":
            self.conv_2 = tf.keras.layers.Conv2D(filters=self.filter_num, kernel_size=self.filter_size,
                                                 padding="same", trainable=True)

        elif padding == "reflect":
            self.pad_2 = ReflectionPadding2D(padding=(1, 1))
            self.conv_2 = tf.keras.layers.Conv2D(filters=self.filter_num, kernel_size=self.filter_size,
                                                 padding="valid", trainable=True)
        else:
            raise RuntimeError("Non valid padding type.")

        self.activation_3 = tf.keras.layers.Activation("relu")

    def call(self, inputs, *args, **kwargs):

        identity = self.activation_1(inputs)
        x = identity
        if self.padding_type == "reflect":
            x = self.pad_1(x)
        x = self.conv_1(x)
        x = self.activation_2(x)
        x = self.bn_1(x)
        if self.padding_type == "reflect":
            x = self.pad_2(x)
        x = self.conv_2(x)
        x = self.bn_2(x)
        residual = tf.keras.layers.Add()([x, identity])
        x = self.activation_3(residual)
        return x


class ReflectionPadding2D(tf.keras.layers.Layer):

    def __init__(self, padding: Tuple[int, int]):
        super(ReflectionPadding2D, self).__init__()
        self.pad_width, self.pad_height = padding

    def call(self, inputs, *args, **kwargs):
        padding_tensor = tf.constant([
            [0, 0],  # Batch
            [self.pad_height, self.pad_height],  # Height
            [self.pad_width, self.pad_width],  # Width
            [0, 0]  # Channels
        ])

        return tf.pad(inputs, padding_tensor, mode="REFLECT")