File size: 2,460 Bytes
b620cf3
 
 
 
 
 
 
 
 
 
b79a585
b620cf3
b79a585
b620cf3
b79a585
b620cf3
b79a585
b620cf3
 
 
b79a585
b620cf3
b79a585
b620cf3
 
 
 
b79a585
b620cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import tensorflow as tf
import tensorflow.keras.layers as layers


class TopLayer(layers.Layer):
    def __init__(self, filters):
        super().__init__()
        self.filters = filters

        self.conv_1x1 = layers.Conv2D(self.filters, (1, 1), activation='relu', strides=1, padding="same",
                                      name="top_layer_1x1")
        self.conv_2x2 = layers.Conv2D(self.filters//3, (2, 2), activation='relu', strides=1, padding="same",
                                      name="top_layer_2x2")
        self.conv_4x4 = layers.Conv2D(self.filters//3, (4, 4), activation='relu', strides=1, padding="same",
                                      name="top_layer_4x4")
        self.conv_8x8 = layers.Conv2D(self.filters//3, (8, 8), activation='relu', strides=1, padding="same",
                                      name="top_layer_8x8")

        self.concat = layers.Concatenate(axis=-1)
        self.point_wise_conv = layers.Conv2D(self.filters, (1, 1), 1, activation=None, use_bias=False,
                                             padding='same', name="top_layer_point_wise")
        self.feat_fusion = layers.Conv2D(self.filters, (1, 1), 1, activation=None, use_bias=False,
                                         padding='same', name="top_layer_fusion")

        self.addition = layers.Add()
        self.gelu = layers.Activation('gelu')
        self.final_conv = layers.Conv2D(self.filters, (1, 1),  activation='relu', strides=1, padding="same",
                                        name="top_layer_out")

    def call(self, inputs, training=False):
        x = self.conv_1x1(inputs, training=training)

        feats_2x2 = self.conv_2x2(x, training=training)
        feats_4x4 = self.conv_4x4(x, training=training)
        feats_8x8 = self.conv_8x8(x, training=training)

        concatenated = self.concat([feats_2x2, feats_4x4, feats_8x8])
        concatenated = self.point_wise_conv(concatenated)

        concatenated = self.feat_fusion(concatenated)
        x = self.addition([inputs, concatenated])
        x = self.gelu(x)
        x = self.final_conv(x)
        return x


class Sampling(layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon