File size: 4,491 Bytes
b620cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import tensorflow as tf
import tensorflow.keras.layers as layers

from tensorflow.keras import backend


class LayerScale(layers.Layer):
    def __init__(self, init_values, projection_dim, **kwargs):
        super().__init__(**kwargs)
        self.init_values = init_values
        self.projection_dim = projection_dim

    def build(self, input_shape):
        self.gamma = tf.Variable(self.init_values * tf.ones((self.projection_dim,)))

    def call(self, x):
        return x * self.gamma

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "init_values": self.init_values,
                "projection_dim": self.projection_dim,
            }
        )
        return config


class StochasticDepth(layers.Layer):
    def __init__(self, drop_path_rate, **kwargs):
        super().__init__(**kwargs)
        self.drop_path_rate = drop_path_rate

    def call(self, x, training=None):
        if training:
            keep_prob = 1 - self.drop_path_rate
            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
            random_tensor = tf.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x

    def get_config(self):
        config = super().get_config()
        config.update({"drop_path_rate": self.drop_path_rate})
        return config


class ConvNeXtBlock(layers.Layer):
    def __init__(self, projection_dim, drop_path_rate=0.0, layer_scale_init_value=1e-6, name_prefix=None):
        super().__init__(name=name_prefix or f"prestem{backend.get_uid('prestem')}")
        self.depthwise_conv = layers.Conv2D(
            filters=projection_dim, kernel_size=7, padding="same", groups=projection_dim,
            name=self.name + "_depthwise_conv"
        )
        self.pointwise_conv1 = layers.Dense(4 * projection_dim, name=self.name + "_pointwise_conv_1")
        self.act = layers.Activation("gelu", name=self.name + "_gelu")
        self.pointwise_conv2 = layers.Dense(projection_dim, name=self.name + "_pointwise_conv_2")
        self.layer_scale = LayerScale(layer_scale_init_value, projection_dim, name=self.name + "_layer_scale") \
            if layer_scale_init_value is not None else None
        self.stochastic_depth = StochasticDepth(drop_path_rate, name=self.name + "_stochastic_depth") \
            if drop_path_rate else layers.Activation("linear", name=self.name + "_identity")

    def call(self, inputs, training=False):
        x = self.depthwise_conv(inputs)
        x = self.pointwise_conv1(x)
        x = self.act(x)
        x = self.pointwise_conv2(x)
        if self.layer_scale:
            x = self.layer_scale(x)
        x = self.stochastic_depth(x, training=training)
        return inputs + x


class ConvNeXtBlockTransposed(layers.Layer):
    def __init__(self, projection_dim, drop_path_rate=0.0, layer_scale_init_value=1e-6, name_prefix=None):
        super().__init__(name=name_prefix or f"poststem{backend.get_uid('poststem')}")
        self.projection_dim = projection_dim
        self.drop_path_rate = drop_path_rate
        self.layer_scale_init_value = layer_scale_init_value

        self.depthwise_conv_trans = layers.Conv2DTranspose(
            filters=projection_dim, kernel_size=7, padding="same",
            groups=projection_dim, name=self.name + "_depthwise_conv_trans"
        )
        self.pointwise_conv1 = layers.Dense(4 * projection_dim, name=self.name + "_pointwise_conv_1")
        self.act = layers.Activation("gelu", name=self.name + "_gelu")
        self.pointwise_conv2 = layers.Dense(projection_dim, name=self.name + "_pointwise_conv_2")

        if layer_scale_init_value is not None:
            self.layer_scale = LayerScale(layer_scale_init_value, projection_dim, name=self.name + "_layer_scale")
        else:
            self.layer_scale = None

        if drop_path_rate:
            self.stochastic_depth = StochasticDepth(drop_path_rate, name=self.name + "_stochastic_depth")
        else:
            self.stochastic_depth = layers.Activation("linear", name=self.name + "_identity")

    def call(self, inputs, training=False):
        x = self.depthwise_conv_trans(inputs)
        x = self.pointwise_conv1(x)
        x = self.act(x)
        x = self.pointwise_conv2(x)
        if self.layer_scale:
            x = self.layer_scale(x)
        x = self.stochastic_depth(x, training=training)
        return inputs + x