Dianor commited on
Commit
6c096c8
·
verified ·
1 Parent(s): 012f49b

Create aitestrl_layers.py

Browse files
Files changed (1) hide show
  1. aitestrl_layers.py +360 -0
aitestrl_layers.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import layers
3
+ from tensorflow.keras.saving import register_keras_serializable
4
+ from tensorflow.keras.optimizers.schedules import LearningRateSchedule
5
+ from tensorflow.keras import backend as K
6
+ import numpy as np
7
+
8
+ @register_keras_serializable()
9
+ class PositionalEncoding(layers.Layer):
10
+ def __init__(self, max_position=2048, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.max_position = max_position
13
+ self.pe = None
14
+
15
+ def build(self, input_shape):
16
+ _, seq_length, d_model = input_shape
17
+ position = tf.range(seq_length, dtype=tf.float32)[:, tf.newaxis]
18
+ div_term = tf.exp(
19
+ tf.range(0, d_model, 2, dtype=tf.float32) * (-tf.math.log(10000.0) / d_model)
20
+ )
21
+ pe = tf.zeros((seq_length, d_model))
22
+ pe = tf.tensor_scatter_nd_update(
23
+ pe,
24
+ tf.stack([
25
+ tf.repeat(tf.range(seq_length), tf.shape(div_term)),
26
+ tf.tile(tf.range(0, d_model, 2), [seq_length])
27
+ ], axis=1),
28
+ tf.reshape(tf.sin(position * div_term), [-1])
29
+ )
30
+ pe = tf.tensor_scatter_nd_update(
31
+ pe,
32
+ tf.stack([
33
+ tf.repeat(tf.range(seq_length), tf.shape(div_term)),
34
+ tf.tile(tf.range(1, d_model, 2), [seq_length])
35
+ ], axis=1),
36
+ tf.reshape(tf.cos(position * div_term), [-1])
37
+ )
38
+ self.pe = tf.Variable(
39
+ initial_value=pe[tf.newaxis, :, :],
40
+ trainable=False,
41
+ name="positional_encoding",
42
+ dtype=tf.float32
43
+ )
44
+
45
+ def call(self, x):
46
+ pe_cast = tf.cast(self.pe[:, :tf.shape(x)[1], :], dtype=x.dtype)
47
+ return x + 0.1 * pe_cast
48
+
49
+ def get_config(self):
50
+ config = super().get_config()
51
+ config.update({
52
+ "max_position": self.max_position,
53
+ })
54
+ return config
55
+
56
+ @register_keras_serializable()
57
+ class AdaptiveContextLayer(layers.Layer):
58
+ def __init__(self, context_percentage=0.2, **kwargs):
59
+ super().__init__(**kwargs)
60
+ self.context_percentage = context_percentage
61
+
62
+ def call(self, inputs):
63
+ sequence_length = tf.shape(inputs)[1]
64
+ window_size = tf.cast(tf.math.ceil(tf.cast(sequence_length, tf.float32) * self.context_percentage), tf.int32)
65
+ return inputs[:, -window_size:, :]
66
+
67
+ def get_config(self):
68
+ config = super().get_config()
69
+ config.update({
70
+ "context_percentage": self.context_percentage
71
+ })
72
+ return config
73
+
74
+ @register_keras_serializable()
75
+ class TransposeLayer(layers.Layer):
76
+ def __init__(self, **kwargs):
77
+ super(TransposeLayer, self).__init__(**kwargs)
78
+
79
+ def call(self, inputs):
80
+ return tf.transpose(inputs, perm=[0, 2, 1, 3])
81
+
82
+ def compute_output_shape(self, input_shape):
83
+ return (input_shape[0], input_shape[2], input_shape[1], input_shape[3])
84
+
85
+ def get_config(self):
86
+ config = super(TransposeLayer, self).get_config()
87
+ return config
88
+
89
+ @register_keras_serializable()
90
+ class ReshapeLayer(layers.Layer):
91
+ def __init__(self, **kwargs):
92
+ super(ReshapeLayer, self).__init__(**kwargs)
93
+
94
+ def call(self, inputs):
95
+ return tf.reshape(inputs, (tf.shape(inputs)[0], tf.shape(inputs)[1], -1))
96
+
97
+ def compute_output_shape(self, input_shape):
98
+ if input_shape[0] is None:
99
+ batch_size = None
100
+ else:
101
+ batch_size = input_shape[0]
102
+ return (batch_size, input_shape[1], input_shape[2] * input_shape[3])
103
+
104
+ def get_config(self):
105
+ config = super(ReshapeLayer, self).get_config()
106
+ return config
107
+
108
+ @register_keras_serializable()
109
+ class CustomOneCycleLR(LearningRateSchedule):
110
+ def __init__(self, max_lr, steps_per_epoch, epochs, pct_start=0.3,
111
+ anneal_strategy='cos', final_div_factor=25.0, **kwargs):
112
+ super().__init__(**kwargs)
113
+ self.max_lr = max_lr
114
+ self.steps_per_epoch = steps_per_epoch
115
+ self.epochs = epochs
116
+ self.pct_start = pct_start
117
+ self.anneal_strategy = anneal_strategy
118
+ self.final_div_factor = final_div_factor
119
+
120
+ def __call__(self, step):
121
+ total_steps = self.steps_per_epoch * self.epochs
122
+ if step > total_steps:
123
+ return self.max_lr / self.final_div_factor
124
+
125
+ pct = step / total_steps
126
+ if pct <= self.pct_start:
127
+ return self.max_lr * (pct / self.pct_start)
128
+ else:
129
+ pct = (pct - self.pct_start) / (1 - self.pct_start)
130
+ return self.max_lr * (1 - pct) / self.final_div_factor
131
+
132
+ def get_config(self):
133
+ config = {
134
+ 'max_lr': self.max_lr,
135
+ 'steps_per_epoch': self.steps_per_epoch,
136
+ 'epochs': self.epochs,
137
+ 'pct_start': self.pct_start,
138
+ 'anneal_strategy': self.anneal_strategy,
139
+ 'final_div_factor': self.final_div_factor
140
+ }
141
+ return config
142
+
143
+ @register_keras_serializable()
144
+ class TemporalBlock(layers.Layer):
145
+ def __init__(self, in_channels, out_channels, kernel_size, dilation_rate, dropout=0.2, **kwargs):
146
+ super(TemporalBlock, self).__init__(**kwargs)
147
+ self.in_channels = in_channels
148
+ self.out_channels = out_channels
149
+ self.kernel_size = kernel_size
150
+ self.dilation_rate = dilation_rate
151
+ self.dropout = dropout
152
+
153
+ self.conv1 = layers.Conv1D(
154
+ filters=out_channels,
155
+ kernel_size=kernel_size,
156
+ dilation_rate=dilation_rate,
157
+ padding='causal',
158
+ kernel_initializer='he_normal'
159
+ )
160
+ self.batch_norm1 = layers.BatchNormalization()
161
+ self.relu1 = layers.ReLU()
162
+ self.dropout1 = layers.Dropout(dropout)
163
+
164
+ self.conv2 = layers.Conv1D(
165
+ filters=out_channels,
166
+ kernel_size=kernel_size,
167
+ dilation_rate=dilation_rate,
168
+ padding='causal',
169
+ kernel_initializer='he_normal'
170
+ )
171
+ self.batch_norm2 = layers.BatchNormalization()
172
+ self.relu2 = layers.ReLU()
173
+ self.dropout2 = layers.Dropout(dropout)
174
+
175
+ if in_channels != out_channels:
176
+ self.downsample = layers.Conv1D(
177
+ filters=out_channels,
178
+ kernel_size=1,
179
+ padding='same'
180
+ )
181
+ else:
182
+ self.downsample = None
183
+
184
+ def call(self, x):
185
+ out = self.conv1(x)
186
+ out = self.batch_norm1(out)
187
+ out = self.relu1(out)
188
+ out = self.dropout1(out)
189
+
190
+ out = self.conv2(out)
191
+ out = self.batch_norm2(out)
192
+ out = self.relu2(out)
193
+ out = self.dropout2(out)
194
+
195
+ res = self.downsample(x) if self.downsample is not None else x
196
+ return self.relu2(out + res)
197
+
198
+ def get_config(self):
199
+ config = super(TemporalBlock, self).get_config()
200
+ config.update({
201
+ "in_channels": self.in_channels,
202
+ "out_channels": self.out_channels,
203
+ "kernel_size": self.kernel_size,
204
+ "dilation_rate": self.dilation_rate,
205
+ "dropout": self.dropout
206
+ })
207
+ return config
208
+
209
+ @register_keras_serializable()
210
+ class TemporalConvNet(layers.Layer):
211
+ def __init__(self, num_channels, kernel_size=2, dropout=0.2, **kwargs):
212
+ super(TemporalConvNet, self).__init__(**kwargs)
213
+ self.num_channels = num_channels
214
+ self.kernel_size = kernel_size
215
+ self.dropout = dropout
216
+ self.tcn_layers = []
217
+
218
+ def build(self, input_shape):
219
+ in_channels = input_shape[-1]
220
+ for i, out_channels in enumerate(self.num_channels):
221
+ dilation_size = 2 ** i
222
+ tblock = TemporalBlock(
223
+ in_channels=in_channels,
224
+ out_channels=out_channels,
225
+ kernel_size=self.kernel_size,
226
+ dilation_rate=dilation_size,
227
+ dropout=self.dropout
228
+ )
229
+ self.tcn_layers.append(tblock)
230
+ in_channels = out_channels
231
+
232
+ def call(self, x):
233
+ for layer in self.tcn_layers:
234
+ x = layer(x)
235
+ return x
236
+
237
+ def get_config(self):
238
+ config = super(TemporalConvNet, self).get_config()
239
+ config.update({
240
+ "num_channels": self.num_channels,
241
+ "kernel_size": self.kernel_size,
242
+ "dropout": self.dropout
243
+ })
244
+ return config
245
+
246
+ @register_keras_serializable()
247
+ class CrossAttention(layers.Layer):
248
+ def __init__(self, num_heads, key_dim, **kwargs):
249
+ super(CrossAttention, self).__init__(**kwargs)
250
+ self.num_heads = num_heads
251
+ self.key_dim = key_dim
252
+ self.mha = None
253
+ self.layernorm = None
254
+ self.add = None
255
+
256
+ def build(self, input_shape):
257
+ self.mha = layers.MultiHeadAttention(
258
+ num_heads=self.num_heads,
259
+ key_dim=self.key_dim
260
+ )
261
+ self.layernorm = layers.LayerNormalization(epsilon=1e-6)
262
+ self.add = layers.Add()
263
+ super(CrossAttention, self).build(input_shape)
264
+
265
+ def call(self, x, context):
266
+ attn_output = self.mha(x, context)
267
+ return self.add([x, self.layernorm(attn_output)])
268
+
269
+ def get_config(self):
270
+ config = super(CrossAttention, self).get_config()
271
+ config.update({
272
+ "num_heads": self.num_heads,
273
+ "key_dim": self.key_dim
274
+ })
275
+ return config
276
+
277
+ @register_keras_serializable()
278
+ class CNNBlock(layers.Layer):
279
+ def __init__(self, filters, kernel_size, **kwargs):
280
+ super(CNNBlock, self).__init__(**kwargs)
281
+ self.filters = filters
282
+ self.kernel_size = kernel_size
283
+ self.conv1 = None
284
+ self.bn1 = None
285
+ self.conv2 = None
286
+ self.bn2 = None
287
+ self.relu = None
288
+ self.pool = None
289
+
290
+ def build(self, input_shape):
291
+ self.conv1 = layers.Conv2D(self.filters, self.kernel_size, padding='same')
292
+ self.bn1 = layers.BatchNormalization()
293
+ self.conv2 = layers.Conv2D(self.filters, self.kernel_size, padding='same')
294
+ self.bn2 = layers.BatchNormalization()
295
+ self.relu = layers.ReLU()
296
+ self.pool = layers.MaxPooling2D((2, 2))
297
+ super(CNNBlock, self).build(input_shape)
298
+
299
+ def call(self, x):
300
+ x = self.conv1(x)
301
+ x = self.bn1(x)
302
+ x = self.relu(x)
303
+ x = self.conv2(x)
304
+ x = self.bn2(x)
305
+ x = self.relu(x)
306
+ return self.pool(x)
307
+
308
+ def get_config(self):
309
+ config = super(CNNBlock, self).get_config()
310
+ config.update({
311
+ "filters": self.filters,
312
+ "kernel_size": self.kernel_size
313
+ })
314
+ return config
315
+
316
+ @register_keras_serializable()
317
+ class F1Score(tf.keras.metrics.Metric):
318
+ def __init__(self, name='f1_score', **kwargs):
319
+ super().__init__(name=name, **kwargs)
320
+ self.precision = tf.keras.metrics.Precision()
321
+ self.recall = tf.keras.metrics.Recall()
322
+
323
+ def update_state(self, y_true, y_pred, sample_weight=None):
324
+ self.precision.update_state(y_true, y_pred, sample_weight)
325
+ self.recall.update_state(y_true, y_pred, sample_weight)
326
+
327
+ def result(self):
328
+ p = self.precision.result()
329
+ r = self.recall.result()
330
+ return 2 * ((p * r) / (p + r + tf.keras.backend.epsilon()))
331
+
332
+ def reset_state(self):
333
+ self.precision.reset_state()
334
+ self.recall.reset_state()
335
+
336
+ def get_config(self):
337
+ config = super(F1Score, self).get_config()
338
+ return config
339
+
340
+ def mean_axis1(x):
341
+ return K.mean(x, axis=1)
342
+
343
+ # Словарь с custom objects для загрузки модели
344
+ custom_objects = {
345
+ 'CustomOneCycleLR': CustomOneCycleLR,
346
+ 'F1Score': F1Score,
347
+ 'mean_axis1': mean_axis1,
348
+ 'CNNBlock': CNNBlock,
349
+ 'CrossAttention': CrossAttention,
350
+ 'TemporalConvNet': TemporalConvNet,
351
+ 'TemporalBlock': TemporalBlock,
352
+ 'TransposeLayer': TransposeLayer,
353
+ 'ReshapeLayer': ReshapeLayer,
354
+ 'AdaptiveContextLayer': AdaptiveContextLayer,
355
+ 'PositionalEncoding': PositionalEncoding,
356
+ 'mean_axis1_lambda': tf.keras.layers.Lambda(
357
+ mean_axis1,
358
+ output_shape=lambda input_shape: (input_shape[0], input_shape[2], input_shape[3])
359
+ ),
360
+ }