abdelac commited on
Commit
b3f12b5
·
1 Parent(s): 83b632c
Files changed (1) hide show
  1. custom_objects.py +228 -0
custom_objects.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ custom_objects.py - Fully Fixed & Compatible with TF 2.10+ / HF Spaces
3
+ """
4
+
5
+ import tensorflow as tf
6
+ from tensorflow.keras import layers
7
+
8
+ # ======================================================
9
+ # COMPATIBILITY IDENTITY LAYER
10
+ # ======================================================
11
+
12
+ # Fallback Identity for environments lacking tf.keras.layers.Identity
13
+ try:
14
+ Identity = layers.Identity
15
+ except AttributeError:
16
+ class Identity(layers.Layer):
17
+ def call(self, inputs):
18
+ return inputs
19
+
20
+ def compute_output_shape(self, input_shape):
21
+ return input_shape
22
+
23
+ def get_config(self):
24
+ return super().get_config()
25
+
26
+
27
+ # ======================================================
28
+ # VISION TRANSFORMER LAYERS
29
+ # ======================================================
30
+
31
+ class ClassToken(layers.Layer):
32
+ def __init__(self, name="class_token", **kwargs):
33
+ super().__init__(name=name, **kwargs)
34
+ self.supports_masking = True
35
+
36
+ def build(self, input_shape):
37
+ embed_dim = input_shape[-1]
38
+
39
+ self.cls = self.add_weight(
40
+ "cls_token",
41
+ shape=(1, 1, embed_dim),
42
+ initializer="zeros",
43
+ trainable=True
44
+ )
45
+ super().build(input_shape)
46
+
47
+ def call(self, x):
48
+ b = tf.shape(x)[0]
49
+ cls = tf.tile(self.cls, [b, 1, 1])
50
+ return tf.concat([cls, x], axis=1)
51
+
52
+
53
+ class PatchEmbeddings(layers.Layer):
54
+ def __init__(self, patch_size=16, embed_dim=768, **kwargs):
55
+ super().__init__(**kwargs)
56
+ self.patch_size = patch_size
57
+ self.embed_dim = embed_dim
58
+
59
+ def build(self, input_shape):
60
+ self.proj = layers.Conv2D(
61
+ filters=self.embed_dim,
62
+ kernel_size=self.patch_size,
63
+ strides=self.patch_size,
64
+ padding="valid"
65
+ )
66
+ super().build(input_shape)
67
+
68
+ def call(self, x):
69
+ x = self.proj(x)
70
+ B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
71
+ x = tf.reshape(x, [B, H * W, C])
72
+ return x
73
+
74
+
75
+ class AddPositionEmbs(layers.Layer):
76
+ def __init__(self, initializer="zeros", **kwargs):
77
+ super().__init__(**kwargs)
78
+ self.initializer = initializer
79
+
80
+ def build(self, input_shape):
81
+ seq_len, dim = input_shape[1], input_shape[2]
82
+
83
+ self.pe = self.add_weight(
84
+ "position_embeddings",
85
+ shape=(1, seq_len, dim),
86
+ initializer=self.initializer,
87
+ trainable=True
88
+ )
89
+ super().build(input_shape)
90
+
91
+ def call(self, x):
92
+ x_len = tf.shape(x)[1]
93
+ pe_len = tf.shape(self.pe)[1]
94
+ dim = tf.shape(self.pe)[2]
95
+
96
+ # If same length → normal addition
97
+ if x_len == pe_len:
98
+ return x + self.pe
99
+
100
+ # Resize positional embeddings correctly
101
+ pe = tf.reshape(self.pe, (1, pe_len, dim, 1)) # to NHWC
102
+ pe = tf.image.resize(pe, (x_len, dim)) # resize LENGTH only
103
+ pe = tf.reshape(pe, (1, x_len, dim)) # back to (1, L, D)
104
+
105
+ pe = tf.cast(pe, x.dtype)
106
+
107
+ return x + pe
108
+
109
+ class TransformerBlock(layers.Layer):
110
+ def __init__(self, num_heads=12, mlp_dim=3072, dropout_rate=0.1, **kwargs):
111
+ super().__init__(**kwargs)
112
+ self.num_heads = num_heads
113
+ self.mlp_dim = mlp_dim
114
+ self.dropout_rate = dropout_rate
115
+
116
+ def build(self, input_shape):
117
+ dim = input_shape[-1]
118
+
119
+ self.norm1 = layers.LayerNormalization(epsilon=1e-6)
120
+ self.att = layers.MultiHeadAttention(
121
+ num_heads=self.num_heads,
122
+ key_dim=dim // self.num_heads,
123
+ )
124
+ self.drop1 = layers.Dropout(self.dropout_rate)
125
+
126
+ self.norm2 = layers.LayerNormalization(epsilon=1e-6)
127
+ self.d1 = layers.Dense(self.mlp_dim, activation="gelu")
128
+ self.drop2 = layers.Dropout(self.dropout_rate)
129
+ self.d2 = layers.Dense(dim)
130
+ self.drop3 = layers.Dropout(self.dropout_rate)
131
+
132
+ super().build(input_shape)
133
+
134
+ def call(self, x, training=None):
135
+ h = self.norm1(x)
136
+ h = self.att(h, h)
137
+ h = self.drop1(h, training=training)
138
+ x = x + h
139
+
140
+ h = self.norm2(x)
141
+ h = self.d1(h)
142
+ h = self.drop2(h, training=training)
143
+ h = self.d2(h)
144
+ h = self.drop3(h, training=training)
145
+ return x + h
146
+
147
+
148
+ class ExtractToken(layers.Layer):
149
+ def call(self, x):
150
+ return x[:, 0]
151
+
152
+
153
+ class MlpBlock(layers.Layer):
154
+ def __init__(self, hidden_dim=3072, dropout=0.1, activation="gelu", **kwargs):
155
+ super().__init__(**kwargs)
156
+ self.hidden_dim = hidden_dim
157
+ self.dropout = dropout
158
+ self.activation = activation
159
+
160
+ def build(self, input_shape):
161
+ self.d1 = layers.Dense(self.hidden_dim)
162
+ self.d2 = layers.Dense(input_shape[-1])
163
+ self.drop1 = layers.Dropout(self.dropout)
164
+ self.drop2 = layers.Dropout(self.dropout)
165
+ super().build(input_shape)
166
+
167
+ def call(self, x, training=None):
168
+ h = self.d1(x)
169
+ h = tf.nn.gelu(h) if self.activation == "gelu" else tf.nn.relu(h)
170
+ h = self.drop1(h, training=training)
171
+ h = self.d2(h)
172
+ return self.drop2(h, training=training)
173
+
174
+
175
+ class SimpleMultiHeadAttention(layers.Layer):
176
+ def __init__(self, num_heads=8, key_dim=64, **kwargs):
177
+ super().__init__(**kwargs)
178
+ self.num_heads = num_heads
179
+ self.key_dim = key_dim
180
+
181
+ def build(self, input_shape):
182
+ self.mha = layers.MultiHeadAttention(
183
+ num_heads=self.num_heads,
184
+ key_dim=self.key_dim
185
+ )
186
+ super().build(input_shape)
187
+
188
+ def call(self, x):
189
+ return self.mha(x, x)
190
+
191
+
192
+ class FixedDropout(layers.Dropout):
193
+ pass
194
+
195
+
196
+ # ======================================================
197
+ # RETURN ALL CUSTOM OBJECTS
198
+ # ======================================================
199
+
200
+ def get_custom_objects():
201
+ return {
202
+ "Identity": Identity,
203
+ "ClassToken": ClassToken,
204
+ "PatchEmbeddings": PatchEmbeddings,
205
+ "AddPositionEmbs": AddPositionEmbs,
206
+ "TransformerBlock": TransformerBlock,
207
+ "ExtractToken": ExtractToken,
208
+ "MlpBlock": MlpBlock,
209
+ "SimpleMultiHeadAttention": SimpleMultiHeadAttention,
210
+ "FixedDropout": FixedDropout,
211
+
212
+ # Standard layers exposed for H5 compatibility
213
+ "MultiHeadAttention": layers.MultiHeadAttention,
214
+ "LayerNormalization": layers.LayerNormalization,
215
+ "Dropout": layers.Dropout,
216
+ "Dense": layers.Dense,
217
+ "Conv2D": layers.Conv2D,
218
+ "Flatten": layers.Flatten,
219
+ "Reshape": layers.Reshape,
220
+ "Activation": layers.Activation,
221
+
222
+ # Activations
223
+ "gelu": tf.nn.gelu,
224
+ "swish": tf.nn.swish,
225
+ "relu": tf.nn.relu,
226
+ "sigmoid": tf.nn.sigmoid,
227
+ "softmax": tf.nn.softmax,
228
+ }