| import tensorflow as tf | |
| import tensorflow.keras.layers as layers | |
| from tensorflow.keras import backend | |
| class LayerScale(layers.Layer): | |
| def __init__(self, init_values, projection_dim, **kwargs): | |
| super().__init__(**kwargs) | |
| self.init_values = init_values | |
| self.projection_dim = projection_dim | |
| def build(self, input_shape): | |
| self.gamma = tf.Variable(self.init_values * tf.ones((self.projection_dim,))) | |
| def call(self, x): | |
| return x * self.gamma | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update( | |
| { | |
| "init_values": self.init_values, | |
| "projection_dim": self.projection_dim, | |
| } | |
| ) | |
| return config | |
| class StochasticDepth(layers.Layer): | |
| def __init__(self, drop_path_rate, **kwargs): | |
| super().__init__(**kwargs) | |
| self.drop_path_rate = drop_path_rate | |
| def call(self, x, training=None): | |
| if training: | |
| keep_prob = 1 - self.drop_path_rate | |
| shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) | |
| random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) | |
| random_tensor = tf.floor(random_tensor) | |
| return (x / keep_prob) * random_tensor | |
| return x | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"drop_path_rate": self.drop_path_rate}) | |
| return config | |
| class ConvNeXtBlock(layers.Layer): | |
| def __init__(self, projection_dim, drop_path_rate=0.0, layer_scale_init_value=1e-6, name_prefix=None): | |
| super().__init__(name=name_prefix or f"prestem{backend.get_uid('prestem')}") | |
| self.depthwise_conv = layers.Conv2D( | |
| filters=projection_dim, kernel_size=7, padding="same", groups=projection_dim, | |
| name=self.name + "_depthwise_conv" | |
| ) | |
| self.pointwise_conv1 = layers.Dense(4 * projection_dim, name=self.name + "_pointwise_conv_1") | |
| self.act = layers.Activation("gelu", name=self.name + "_gelu") | |
| self.pointwise_conv2 = layers.Dense(projection_dim, name=self.name + "_pointwise_conv_2") | |
| self.layer_scale = LayerScale(layer_scale_init_value, projection_dim, name=self.name + "_layer_scale") \ | |
| if layer_scale_init_value is not None else None | |
| self.stochastic_depth = StochasticDepth(drop_path_rate, name=self.name + "_stochastic_depth") \ | |
| if drop_path_rate else layers.Activation("linear", name=self.name + "_identity") | |
| def call(self, inputs, training=False): | |
| x = self.depthwise_conv(inputs) | |
| x = self.pointwise_conv1(x) | |
| x = self.act(x) | |
| x = self.pointwise_conv2(x) | |
| if self.layer_scale: | |
| x = self.layer_scale(x) | |
| x = self.stochastic_depth(x, training=training) | |
| return inputs + x | |
| class ConvNeXtBlockTransposed(layers.Layer): | |
| def __init__(self, projection_dim, drop_path_rate=0.0, layer_scale_init_value=1e-6, name_prefix=None): | |
| super().__init__(name=name_prefix or f"poststem{backend.get_uid('poststem')}") | |
| self.projection_dim = projection_dim | |
| self.drop_path_rate = drop_path_rate | |
| self.layer_scale_init_value = layer_scale_init_value | |
| self.depthwise_conv_trans = layers.Conv2DTranspose( | |
| filters=projection_dim, kernel_size=7, padding="same", | |
| groups=projection_dim, name=self.name + "_depthwise_conv_trans" | |
| ) | |
| self.pointwise_conv1 = layers.Dense(4 * projection_dim, name=self.name + "_pointwise_conv_1") | |
| self.act = layers.Activation("gelu", name=self.name + "_gelu") | |
| self.pointwise_conv2 = layers.Dense(projection_dim, name=self.name + "_pointwise_conv_2") | |
| if layer_scale_init_value is not None: | |
| self.layer_scale = LayerScale(layer_scale_init_value, projection_dim, name=self.name + "_layer_scale") | |
| else: | |
| self.layer_scale = None | |
| if drop_path_rate: | |
| self.stochastic_depth = StochasticDepth(drop_path_rate, name=self.name + "_stochastic_depth") | |
| else: | |
| self.stochastic_depth = layers.Activation("linear", name=self.name + "_identity") | |
| def call(self, inputs, training=False): | |
| x = self.depthwise_conv_trans(inputs) | |
| x = self.pointwise_conv1(x) | |
| x = self.act(x) | |
| x = self.pointwise_conv2(x) | |
| if self.layer_scale: | |
| x = self.layer_scale(x) | |
| x = self.stochastic_depth(x, training=training) | |
| return inputs + x | |