TIDE-II / model /convnext_modules.py
pgatoula's picture
Initial commit
b620cf3
import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras import backend
class LayerScale(layers.Layer):
def __init__(self, init_values, projection_dim, **kwargs):
super().__init__(**kwargs)
self.init_values = init_values
self.projection_dim = projection_dim
def build(self, input_shape):
self.gamma = tf.Variable(self.init_values * tf.ones((self.projection_dim,)))
def call(self, x):
return x * self.gamma
def get_config(self):
config = super().get_config()
config.update(
{
"init_values": self.init_values,
"projection_dim": self.projection_dim,
}
)
return config
class StochasticDepth(layers.Layer):
def __init__(self, drop_path_rate, **kwargs):
super().__init__(**kwargs)
self.drop_path_rate = drop_path_rate
def call(self, x, training=None):
if training:
keep_prob = 1 - self.drop_path_rate
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
def get_config(self):
config = super().get_config()
config.update({"drop_path_rate": self.drop_path_rate})
return config
class ConvNeXtBlock(layers.Layer):
def __init__(self, projection_dim, drop_path_rate=0.0, layer_scale_init_value=1e-6, name_prefix=None):
super().__init__(name=name_prefix or f"prestem{backend.get_uid('prestem')}")
self.depthwise_conv = layers.Conv2D(
filters=projection_dim, kernel_size=7, padding="same", groups=projection_dim,
name=self.name + "_depthwise_conv"
)
self.pointwise_conv1 = layers.Dense(4 * projection_dim, name=self.name + "_pointwise_conv_1")
self.act = layers.Activation("gelu", name=self.name + "_gelu")
self.pointwise_conv2 = layers.Dense(projection_dim, name=self.name + "_pointwise_conv_2")
self.layer_scale = LayerScale(layer_scale_init_value, projection_dim, name=self.name + "_layer_scale") \
if layer_scale_init_value is not None else None
self.stochastic_depth = StochasticDepth(drop_path_rate, name=self.name + "_stochastic_depth") \
if drop_path_rate else layers.Activation("linear", name=self.name + "_identity")
def call(self, inputs, training=False):
x = self.depthwise_conv(inputs)
x = self.pointwise_conv1(x)
x = self.act(x)
x = self.pointwise_conv2(x)
if self.layer_scale:
x = self.layer_scale(x)
x = self.stochastic_depth(x, training=training)
return inputs + x
class ConvNeXtBlockTransposed(layers.Layer):
def __init__(self, projection_dim, drop_path_rate=0.0, layer_scale_init_value=1e-6, name_prefix=None):
super().__init__(name=name_prefix or f"poststem{backend.get_uid('poststem')}")
self.projection_dim = projection_dim
self.drop_path_rate = drop_path_rate
self.layer_scale_init_value = layer_scale_init_value
self.depthwise_conv_trans = layers.Conv2DTranspose(
filters=projection_dim, kernel_size=7, padding="same",
groups=projection_dim, name=self.name + "_depthwise_conv_trans"
)
self.pointwise_conv1 = layers.Dense(4 * projection_dim, name=self.name + "_pointwise_conv_1")
self.act = layers.Activation("gelu", name=self.name + "_gelu")
self.pointwise_conv2 = layers.Dense(projection_dim, name=self.name + "_pointwise_conv_2")
if layer_scale_init_value is not None:
self.layer_scale = LayerScale(layer_scale_init_value, projection_dim, name=self.name + "_layer_scale")
else:
self.layer_scale = None
if drop_path_rate:
self.stochastic_depth = StochasticDepth(drop_path_rate, name=self.name + "_stochastic_depth")
else:
self.stochastic_depth = layers.Activation("linear", name=self.name + "_identity")
def call(self, inputs, training=False):
x = self.depthwise_conv_trans(inputs)
x = self.pointwise_conv1(x)
x = self.act(x)
x = self.pointwise_conv2(x)
if self.layer_scale:
x = self.layer_scale(x)
x = self.stochastic_depth(x, training=training)
return inputs + x