| import tensorflow as tf |
| from tensorflow.keras import layers |
| from tensorflow.keras.saving import register_keras_serializable |
| from tensorflow.keras.optimizers.schedules import LearningRateSchedule |
| from tensorflow.keras import backend as K |
| import numpy as np |
|
|
| @register_keras_serializable() |
| class PositionalEncoding(layers.Layer): |
| def __init__(self, max_position=2048, **kwargs): |
| super().__init__(**kwargs) |
| self.max_position = max_position |
| self.pe = None |
| |
| def build(self, input_shape): |
| _, seq_length, d_model = input_shape |
| position = tf.range(seq_length, dtype=tf.float32)[:, tf.newaxis] |
| div_term = tf.exp( |
| tf.range(0, d_model, 2, dtype=tf.float32) * (-tf.math.log(10000.0) / d_model) |
| ) |
| pe = tf.zeros((seq_length, d_model)) |
| pe = tf.tensor_scatter_nd_update( |
| pe, |
| tf.stack([ |
| tf.repeat(tf.range(seq_length), tf.shape(div_term)), |
| tf.tile(tf.range(0, d_model, 2), [seq_length]) |
| ], axis=1), |
| tf.reshape(tf.sin(position * div_term), [-1]) |
| ) |
| pe = tf.tensor_scatter_nd_update( |
| pe, |
| tf.stack([ |
| tf.repeat(tf.range(seq_length), tf.shape(div_term)), |
| tf.tile(tf.range(1, d_model, 2), [seq_length]) |
| ], axis=1), |
| tf.reshape(tf.cos(position * div_term), [-1]) |
| ) |
| self.pe = tf.Variable( |
| initial_value=pe[tf.newaxis, :, :], |
| trainable=False, |
| name="positional_encoding", |
| dtype=tf.float32 |
| ) |
| |
| def call(self, x): |
| pe_cast = tf.cast(self.pe[:, :tf.shape(x)[1], :], dtype=x.dtype) |
| return x + 0.1 * pe_cast |
|
|
| def get_config(self): |
| config = super().get_config() |
| config.update({ |
| "max_position": self.max_position, |
| }) |
| return config |
|
|
| @register_keras_serializable() |
| class AdaptiveContextLayer(layers.Layer): |
| def __init__(self, context_percentage=0.2, **kwargs): |
| super().__init__(**kwargs) |
| self.context_percentage = context_percentage |
| |
| def call(self, inputs): |
| sequence_length = tf.shape(inputs)[1] |
| window_size = tf.cast(tf.math.ceil(tf.cast(sequence_length, tf.float32) * self.context_percentage), tf.int32) |
| return inputs[:, -window_size:, :] |
| |
| def get_config(self): |
| config = super().get_config() |
| config.update({ |
| "context_percentage": self.context_percentage |
| }) |
| return config |
|
|
| @register_keras_serializable() |
| class TransposeLayer(layers.Layer): |
| def __init__(self, **kwargs): |
| super(TransposeLayer, self).__init__(**kwargs) |
|
|
| def call(self, inputs): |
| return tf.transpose(inputs, perm=[0, 2, 1, 3]) |
|
|
| def compute_output_shape(self, input_shape): |
| return (input_shape[0], input_shape[2], input_shape[1], input_shape[3]) |
|
|
| def get_config(self): |
| config = super(TransposeLayer, self).get_config() |
| return config |
|
|
| @register_keras_serializable() |
| class ReshapeLayer(layers.Layer): |
| def __init__(self, **kwargs): |
| super(ReshapeLayer, self).__init__(**kwargs) |
|
|
| def call(self, inputs): |
| return tf.reshape(inputs, (tf.shape(inputs)[0], tf.shape(inputs)[1], -1)) |
|
|
| def compute_output_shape(self, input_shape): |
| if input_shape[0] is None: |
| batch_size = None |
| else: |
| batch_size = input_shape[0] |
| return (batch_size, input_shape[1], input_shape[2] * input_shape[3]) |
|
|
| def get_config(self): |
| config = super(ReshapeLayer, self).get_config() |
| return config |
|
|
| @register_keras_serializable() |
| class CustomOneCycleLR(LearningRateSchedule): |
| def __init__(self, max_lr, steps_per_epoch, epochs, pct_start=0.3, |
| anneal_strategy='cos', final_div_factor=25.0, **kwargs): |
| super().__init__(**kwargs) |
| self.max_lr = max_lr |
| self.steps_per_epoch = steps_per_epoch |
| self.epochs = epochs |
| self.pct_start = pct_start |
| self.anneal_strategy = anneal_strategy |
| self.final_div_factor = final_div_factor |
| |
| def __call__(self, step): |
| total_steps = self.steps_per_epoch * self.epochs |
| if step > total_steps: |
| return self.max_lr / self.final_div_factor |
| |
| pct = step / total_steps |
| if pct <= self.pct_start: |
| return self.max_lr * (pct / self.pct_start) |
| else: |
| pct = (pct - self.pct_start) / (1 - self.pct_start) |
| return self.max_lr * (1 - pct) / self.final_div_factor |
| |
| def get_config(self): |
| config = { |
| 'max_lr': self.max_lr, |
| 'steps_per_epoch': self.steps_per_epoch, |
| 'epochs': self.epochs, |
| 'pct_start': self.pct_start, |
| 'anneal_strategy': self.anneal_strategy, |
| 'final_div_factor': self.final_div_factor |
| } |
| return config |
|
|
| @register_keras_serializable() |
| class TemporalBlock(layers.Layer): |
| def __init__(self, in_channels, out_channels, kernel_size, dilation_rate, dropout=0.2, **kwargs): |
| super(TemporalBlock, self).__init__(**kwargs) |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.dilation_rate = dilation_rate |
| self.dropout = dropout |
|
|
| self.conv1 = layers.Conv1D( |
| filters=out_channels, |
| kernel_size=kernel_size, |
| dilation_rate=dilation_rate, |
| padding='causal', |
| kernel_initializer='he_normal' |
| ) |
| self.batch_norm1 = layers.BatchNormalization() |
| self.relu1 = layers.ReLU() |
| self.dropout1 = layers.Dropout(dropout) |
|
|
| self.conv2 = layers.Conv1D( |
| filters=out_channels, |
| kernel_size=kernel_size, |
| dilation_rate=dilation_rate, |
| padding='causal', |
| kernel_initializer='he_normal' |
| ) |
| self.batch_norm2 = layers.BatchNormalization() |
| self.relu2 = layers.ReLU() |
| self.dropout2 = layers.Dropout(dropout) |
|
|
| if in_channels != out_channels: |
| self.downsample = layers.Conv1D( |
| filters=out_channels, |
| kernel_size=1, |
| padding='same' |
| ) |
| else: |
| self.downsample = None |
|
|
| def call(self, x): |
| out = self.conv1(x) |
| out = self.batch_norm1(out) |
| out = self.relu1(out) |
| out = self.dropout1(out) |
|
|
| out = self.conv2(out) |
| out = self.batch_norm2(out) |
| out = self.relu2(out) |
| out = self.dropout2(out) |
|
|
| res = self.downsample(x) if self.downsample is not None else x |
| return self.relu2(out + res) |
|
|
| def get_config(self): |
| config = super(TemporalBlock, self).get_config() |
| config.update({ |
| "in_channels": self.in_channels, |
| "out_channels": self.out_channels, |
| "kernel_size": self.kernel_size, |
| "dilation_rate": self.dilation_rate, |
| "dropout": self.dropout |
| }) |
| return config |
|
|
| @register_keras_serializable() |
| class TemporalConvNet(layers.Layer): |
| def __init__(self, num_channels, kernel_size=2, dropout=0.2, **kwargs): |
| super(TemporalConvNet, self).__init__(**kwargs) |
| self.num_channels = num_channels |
| self.kernel_size = kernel_size |
| self.dropout = dropout |
| self.tcn_layers = [] |
|
|
| def build(self, input_shape): |
| in_channels = input_shape[-1] |
| for i, out_channels in enumerate(self.num_channels): |
| dilation_size = 2 ** i |
| tblock = TemporalBlock( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=self.kernel_size, |
| dilation_rate=dilation_size, |
| dropout=self.dropout |
| ) |
| self.tcn_layers.append(tblock) |
| in_channels = out_channels |
|
|
| def call(self, x): |
| for layer in self.tcn_layers: |
| x = layer(x) |
| return x |
|
|
| def get_config(self): |
| config = super(TemporalConvNet, self).get_config() |
| config.update({ |
| "num_channels": self.num_channels, |
| "kernel_size": self.kernel_size, |
| "dropout": self.dropout |
| }) |
| return config |
|
|
| @register_keras_serializable() |
| class CrossAttention(layers.Layer): |
| def __init__(self, num_heads, key_dim, **kwargs): |
| super(CrossAttention, self).__init__(**kwargs) |
| self.num_heads = num_heads |
| self.key_dim = key_dim |
| self.mha = None |
| self.layernorm = None |
| self.add = None |
|
|
| def build(self, input_shape): |
| self.mha = layers.MultiHeadAttention( |
| num_heads=self.num_heads, |
| key_dim=self.key_dim |
| ) |
| self.layernorm = layers.LayerNormalization(epsilon=1e-6) |
| self.add = layers.Add() |
| super(CrossAttention, self).build(input_shape) |
|
|
| def call(self, x, context): |
| attn_output = self.mha(x, context) |
| return self.add([x, self.layernorm(attn_output)]) |
|
|
| def get_config(self): |
| config = super(CrossAttention, self).get_config() |
| config.update({ |
| "num_heads": self.num_heads, |
| "key_dim": self.key_dim |
| }) |
| return config |
|
|
| @register_keras_serializable() |
| class CNNBlock(layers.Layer): |
| def __init__(self, filters, kernel_size, **kwargs): |
| super(CNNBlock, self).__init__(**kwargs) |
| self.filters = filters |
| self.kernel_size = kernel_size |
| self.conv1 = None |
| self.bn1 = None |
| self.conv2 = None |
| self.bn2 = None |
| self.relu = None |
| self.pool = None |
|
|
| def build(self, input_shape): |
| self.conv1 = layers.Conv2D(self.filters, self.kernel_size, padding='same') |
| self.bn1 = layers.BatchNormalization() |
| self.conv2 = layers.Conv2D(self.filters, self.kernel_size, padding='same') |
| self.bn2 = layers.BatchNormalization() |
| self.relu = layers.ReLU() |
| self.pool = layers.MaxPooling2D((2, 2)) |
| super(CNNBlock, self).build(input_shape) |
|
|
| def call(self, x): |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
| x = self.conv2(x) |
| x = self.bn2(x) |
| x = self.relu(x) |
| return self.pool(x) |
| |
| def get_config(self): |
| config = super(CNNBlock, self).get_config() |
| config.update({ |
| "filters": self.filters, |
| "kernel_size": self.kernel_size |
| }) |
| return config |
|
|
| @register_keras_serializable() |
| class F1Score(tf.keras.metrics.Metric): |
| def __init__(self, name='f1_score', **kwargs): |
| super().__init__(name=name, **kwargs) |
| self.precision = tf.keras.metrics.Precision() |
| self.recall = tf.keras.metrics.Recall() |
|
|
| def update_state(self, y_true, y_pred, sample_weight=None): |
| self.precision.update_state(y_true, y_pred, sample_weight) |
| self.recall.update_state(y_true, y_pred, sample_weight) |
|
|
| def result(self): |
| p = self.precision.result() |
| r = self.recall.result() |
| return 2 * ((p * r) / (p + r + tf.keras.backend.epsilon())) |
|
|
| def reset_state(self): |
| self.precision.reset_state() |
| self.recall.reset_state() |
|
|
| def get_config(self): |
| config = super(F1Score, self).get_config() |
| return config |
|
|
| def mean_axis1(x): |
| return K.mean(x, axis=1) |
|
|
| |
| custom_objects = { |
| 'CustomOneCycleLR': CustomOneCycleLR, |
| 'F1Score': F1Score, |
| 'mean_axis1': mean_axis1, |
| 'CNNBlock': CNNBlock, |
| 'CrossAttention': CrossAttention, |
| 'TemporalConvNet': TemporalConvNet, |
| 'TemporalBlock': TemporalBlock, |
| 'TransposeLayer': TransposeLayer, |
| 'ReshapeLayer': ReshapeLayer, |
| 'AdaptiveContextLayer': AdaptiveContextLayer, |
| 'PositionalEncoding': PositionalEncoding, |
| 'mean_axis1_lambda': tf.keras.layers.Lambda( |
| mean_axis1, |
| output_shape=lambda input_shape: (input_shape[0], input_shape[2], input_shape[3]) |
| ), |
| } |