| import tensorflow as tf | |
| import tensorflow.keras.layers as layers | |
| class TopLayer(layers.Layer): | |
| def __init__(self, filters): | |
| super().__init__() | |
| self.filters = filters | |
| self.conv_1x1 = layers.Conv2D(self.filters, (1, 1), activation='relu', strides=1, padding="same", | |
| name="top_layer_1x1") | |
| self.conv_2x2 = layers.Conv2D(self.filters//3, (2, 2), activation='relu', strides=1, padding="same", | |
| name="top_layer_2x2") | |
| self.conv_4x4 = layers.Conv2D(self.filters//3, (4, 4), activation='relu', strides=1, padding="same", | |
| name="top_layer_4x4") | |
| self.conv_8x8 = layers.Conv2D(self.filters//3, (8, 8), activation='relu', strides=1, padding="same", | |
| name="top_layer_8x8") | |
| self.concat = layers.Concatenate(axis=-1) | |
| self.point_wise_conv = layers.Conv2D(self.filters, (1, 1), 1, activation=None, use_bias=False, | |
| padding='same', name="top_layer_point_wise") | |
| self.feat_fusion = layers.Conv2D(self.filters, (1, 1), 1, activation=None, use_bias=False, | |
| padding='same', name="top_layer_fusion") | |
| self.addition = layers.Add() | |
| self.gelu = layers.Activation('gelu') | |
| self.final_conv = layers.Conv2D(self.filters, (1, 1), activation='relu', strides=1, padding="same", | |
| name="top_layer_out") | |
| def call(self, inputs, training=False): | |
| x = self.conv_1x1(inputs, training=training) | |
| feats_2x2 = self.conv_2x2(x, training=training) | |
| feats_4x4 = self.conv_4x4(x, training=training) | |
| feats_8x8 = self.conv_8x8(x, training=training) | |
| concatenated = self.concat([feats_2x2, feats_4x4, feats_8x8]) | |
| concatenated = self.point_wise_conv(concatenated) | |
| concatenated = self.feat_fusion(concatenated) | |
| x = self.addition([inputs, concatenated]) | |
| x = self.gelu(x) | |
| x = self.final_conv(x) | |
| return x | |
| class Sampling(layers.Layer): | |
| def __init__(self): | |
| super().__init__() | |
| def call(self, inputs): | |
| z_mean, z_log_var = inputs | |
| batch = tf.shape(z_mean)[0] | |
| dim = tf.shape(z_mean)[1] | |
| epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) | |
| return z_mean + tf.exp(0.5 * z_log_var) * epsilon | |