File size: 4,491 Bytes
b620cf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras import backend
class LayerScale(layers.Layer):
def __init__(self, init_values, projection_dim, **kwargs):
super().__init__(**kwargs)
self.init_values = init_values
self.projection_dim = projection_dim
def build(self, input_shape):
self.gamma = tf.Variable(self.init_values * tf.ones((self.projection_dim,)))
def call(self, x):
return x * self.gamma
def get_config(self):
config = super().get_config()
config.update(
{
"init_values": self.init_values,
"projection_dim": self.projection_dim,
}
)
return config
class StochasticDepth(layers.Layer):
def __init__(self, drop_path_rate, **kwargs):
super().__init__(**kwargs)
self.drop_path_rate = drop_path_rate
def call(self, x, training=None):
if training:
keep_prob = 1 - self.drop_path_rate
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
def get_config(self):
config = super().get_config()
config.update({"drop_path_rate": self.drop_path_rate})
return config
class ConvNeXtBlock(layers.Layer):
def __init__(self, projection_dim, drop_path_rate=0.0, layer_scale_init_value=1e-6, name_prefix=None):
super().__init__(name=name_prefix or f"prestem{backend.get_uid('prestem')}")
self.depthwise_conv = layers.Conv2D(
filters=projection_dim, kernel_size=7, padding="same", groups=projection_dim,
name=self.name + "_depthwise_conv"
)
self.pointwise_conv1 = layers.Dense(4 * projection_dim, name=self.name + "_pointwise_conv_1")
self.act = layers.Activation("gelu", name=self.name + "_gelu")
self.pointwise_conv2 = layers.Dense(projection_dim, name=self.name + "_pointwise_conv_2")
self.layer_scale = LayerScale(layer_scale_init_value, projection_dim, name=self.name + "_layer_scale") \
if layer_scale_init_value is not None else None
self.stochastic_depth = StochasticDepth(drop_path_rate, name=self.name + "_stochastic_depth") \
if drop_path_rate else layers.Activation("linear", name=self.name + "_identity")
def call(self, inputs, training=False):
x = self.depthwise_conv(inputs)
x = self.pointwise_conv1(x)
x = self.act(x)
x = self.pointwise_conv2(x)
if self.layer_scale:
x = self.layer_scale(x)
x = self.stochastic_depth(x, training=training)
return inputs + x
class ConvNeXtBlockTransposed(layers.Layer):
def __init__(self, projection_dim, drop_path_rate=0.0, layer_scale_init_value=1e-6, name_prefix=None):
super().__init__(name=name_prefix or f"poststem{backend.get_uid('poststem')}")
self.projection_dim = projection_dim
self.drop_path_rate = drop_path_rate
self.layer_scale_init_value = layer_scale_init_value
self.depthwise_conv_trans = layers.Conv2DTranspose(
filters=projection_dim, kernel_size=7, padding="same",
groups=projection_dim, name=self.name + "_depthwise_conv_trans"
)
self.pointwise_conv1 = layers.Dense(4 * projection_dim, name=self.name + "_pointwise_conv_1")
self.act = layers.Activation("gelu", name=self.name + "_gelu")
self.pointwise_conv2 = layers.Dense(projection_dim, name=self.name + "_pointwise_conv_2")
if layer_scale_init_value is not None:
self.layer_scale = LayerScale(layer_scale_init_value, projection_dim, name=self.name + "_layer_scale")
else:
self.layer_scale = None
if drop_path_rate:
self.stochastic_depth = StochasticDepth(drop_path_rate, name=self.name + "_stochastic_depth")
else:
self.stochastic_depth = layers.Activation("linear", name=self.name + "_identity")
def call(self, inputs, training=False):
x = self.depthwise_conv_trans(inputs)
x = self.pointwise_conv1(x)
x = self.act(x)
x = self.pointwise_conv2(x)
if self.layer_scale:
x = self.layer_scale(x)
x = self.stochastic_depth(x, training=training)
return inputs + x
|