TIDE-II / model /tidev2_utils.py
pgatoula's picture
Minor corrections
b79a585
import tensorflow as tf
import tensorflow.keras.layers as layers
class TopLayer(layers.Layer):
def __init__(self, filters):
super().__init__()
self.filters = filters
self.conv_1x1 = layers.Conv2D(self.filters, (1, 1), activation='relu', strides=1, padding="same",
name="top_layer_1x1")
self.conv_2x2 = layers.Conv2D(self.filters//3, (2, 2), activation='relu', strides=1, padding="same",
name="top_layer_2x2")
self.conv_4x4 = layers.Conv2D(self.filters//3, (4, 4), activation='relu', strides=1, padding="same",
name="top_layer_4x4")
self.conv_8x8 = layers.Conv2D(self.filters//3, (8, 8), activation='relu', strides=1, padding="same",
name="top_layer_8x8")
self.concat = layers.Concatenate(axis=-1)
self.point_wise_conv = layers.Conv2D(self.filters, (1, 1), 1, activation=None, use_bias=False,
padding='same', name="top_layer_point_wise")
self.feat_fusion = layers.Conv2D(self.filters, (1, 1), 1, activation=None, use_bias=False,
padding='same', name="top_layer_fusion")
self.addition = layers.Add()
self.gelu = layers.Activation('gelu')
self.final_conv = layers.Conv2D(self.filters, (1, 1), activation='relu', strides=1, padding="same",
name="top_layer_out")
def call(self, inputs, training=False):
x = self.conv_1x1(inputs, training=training)
feats_2x2 = self.conv_2x2(x, training=training)
feats_4x4 = self.conv_4x4(x, training=training)
feats_8x8 = self.conv_8x8(x, training=training)
concatenated = self.concat([feats_2x2, feats_4x4, feats_8x8])
concatenated = self.point_wise_conv(concatenated)
concatenated = self.feat_fusion(concatenated)
x = self.addition([inputs, concatenated])
x = self.gelu(x)
x = self.final_conv(x)
return x
class Sampling(layers.Layer):
def __init__(self):
super().__init__()
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon