""" Custom Keras layers for CRISPR BERT model. These layers must be registered as custom_objects when loading the model. Based on code from Ziyu Mu's CRISPRArrayDetection repository. """ import tensorflow as tf @tf.keras.utils.register_keras_serializable(package="deepG") class layer_pos_embedding(tf.keras.layers.Layer): """Token + Positional Embedding layer for BERT.""" def __init__(self, maxlen=1000, vocabulary_size=6, embed_dim=600, **kwargs): super().__init__(**kwargs) self.maxlen = int(maxlen) self.vocabulary_size = int(vocabulary_size) self.embed_dim = int(embed_dim) self.token_emb = tf.keras.layers.Embedding( input_dim=self.vocabulary_size, output_dim=self.embed_dim, name="token_emb", ) self.pos_emb = tf.keras.layers.Embedding( input_dim=self.maxlen, output_dim=self.embed_dim, name="pos_emb", ) def call(self, x): x = tf.cast(x, tf.int32) L = tf.shape(x)[1] positions = tf.range(start=0, limit=L, delta=1) positions = self.pos_emb(positions) tokens = self.token_emb(x) return tokens + positions def get_config(self): cfg = super().get_config() cfg.update( dict( maxlen=self.maxlen, vocabulary_size=self.vocabulary_size, embed_dim=self.embed_dim, ) ) return cfg @tf.keras.utils.register_keras_serializable(package="deepG") class layer_transformer_block(tf.keras.layers.Layer): """Transformer block with Multi-Head Attention and Feed-Forward Network.""" def __init__( self, num_heads=16, head_size=250, dropout_rate=0.0, ff_dim=2400.0, vocabulary_size=6, embed_dim=600, **kwargs ): super().__init__(**kwargs) self.num_heads = int(num_heads) self.head_size = int(head_size) self.dropout_rate = float(dropout_rate) self.ff_dim = int(ff_dim) self.vocabulary_size = int(vocabulary_size) self.embed_dim = int(embed_dim) self.mha = tf.keras.layers.MultiHeadAttention( num_heads=self.num_heads, key_dim=self.head_size, dropout=self.dropout_rate, name="mha", ) self.ffn1 = tf.keras.layers.Dense(self.ff_dim, activation=tf.nn.gelu, name="ffn1") self.ffn2 = tf.keras.layers.Dense(self.embed_dim, name="ffn2") self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="ln1") self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="ln2") self.drop1 = tf.keras.layers.Dropout(self.dropout_rate, name="drop1") self.drop2 = tf.keras.layers.Dropout(self.dropout_rate, name="drop2") def build(self, input_shape): self.mha.build([input_shape, input_shape, input_shape]) self.ffn1.build(input_shape) self.ffn2.build((input_shape[0], input_shape[1], self.ff_dim)) self.ln1.build(input_shape) self.ln2.build(input_shape) super().build(input_shape) def call(self, x, training=False): attn = self.mha(x, x, training=training) attn = self.drop1(attn, training=training) x = x + attn x = self.ln1(x) f = self.ffn2(self.ffn1(x)) f = self.drop2(f, training=training) x = x + f x = self.ln2(x) return x def get_config(self): cfg = super().get_config() cfg.update( dict( num_heads=self.num_heads, head_size=self.head_size, dropout_rate=self.dropout_rate, ff_dim=self.ff_dim, vocabulary_size=self.vocabulary_size, embed_dim=self.embed_dim, ) ) return cfg @tf.keras.utils.register_keras_serializable(package="deepG") class BinaryFocalLoss(tf.keras.losses.Loss): """Binary Focal Loss for handling class imbalance.""" def __init__(self, alpha=0.25, gamma=2.0, name="binary_focal"): super().__init__(name=name) self.alpha = float(alpha) self.gamma = float(gamma) def call(self, y_true, y_pred): y_true = tf.cast(y_true, tf.float32) y_pred = tf.cast(y_pred, tf.float32) eps = tf.keras.backend.epsilon() y_pred = tf.clip_by_value(y_pred, eps, 1.0 - eps) ce = -(y_true * tf.math.log(y_pred) + (1.0 - y_true) * tf.math.log(1.0 - y_pred)) p_t = y_true * y_pred + (1.0 - y_true) * (1.0 - y_pred) alpha_t = y_true * self.alpha + (1.0 - y_true) * (1.0 - self.alpha) focal = alpha_t * tf.pow(1.0 - p_t, self.gamma) * ce return tf.reduce_mean(focal) def get_config(self): cfg = super().get_config() cfg.update({"alpha": self.alpha, "gamma": self.gamma}) return cfg def get_custom_objects(): """Return dictionary of custom objects needed for model loading.""" return { "layer_pos_embedding": layer_pos_embedding, "layer_transformer_block": layer_transformer_block, "BinaryFocalLoss": BinaryFocalLoss, }