Spaces:
Sleeping
Sleeping
| """ | |
| Custom Keras layers for CRISPR BERT model. | |
| These layers must be registered as custom_objects when loading the model. | |
| Based on code from Ziyu Mu's CRISPRArrayDetection repository. | |
| """ | |
| import tensorflow as tf | |
| class layer_pos_embedding(tf.keras.layers.Layer): | |
| """Token + Positional Embedding layer for BERT.""" | |
| def __init__(self, maxlen=1000, vocabulary_size=6, embed_dim=600, **kwargs): | |
| super().__init__(**kwargs) | |
| self.maxlen = int(maxlen) | |
| self.vocabulary_size = int(vocabulary_size) | |
| self.embed_dim = int(embed_dim) | |
| self.token_emb = tf.keras.layers.Embedding( | |
| input_dim=self.vocabulary_size, | |
| output_dim=self.embed_dim, | |
| name="token_emb", | |
| ) | |
| self.pos_emb = tf.keras.layers.Embedding( | |
| input_dim=self.maxlen, | |
| output_dim=self.embed_dim, | |
| name="pos_emb", | |
| ) | |
| def call(self, x): | |
| x = tf.cast(x, tf.int32) | |
| L = tf.shape(x)[1] | |
| positions = tf.range(start=0, limit=L, delta=1) | |
| positions = self.pos_emb(positions) | |
| tokens = self.token_emb(x) | |
| return tokens + positions | |
| def get_config(self): | |
| cfg = super().get_config() | |
| cfg.update( | |
| dict( | |
| maxlen=self.maxlen, | |
| vocabulary_size=self.vocabulary_size, | |
| embed_dim=self.embed_dim, | |
| ) | |
| ) | |
| return cfg | |
| class layer_transformer_block(tf.keras.layers.Layer): | |
| """Transformer block with Multi-Head Attention and Feed-Forward Network.""" | |
| def __init__( | |
| self, | |
| num_heads=16, | |
| head_size=250, | |
| dropout_rate=0.0, | |
| ff_dim=2400.0, | |
| vocabulary_size=6, | |
| embed_dim=600, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.num_heads = int(num_heads) | |
| self.head_size = int(head_size) | |
| self.dropout_rate = float(dropout_rate) | |
| self.ff_dim = int(ff_dim) | |
| self.vocabulary_size = int(vocabulary_size) | |
| self.embed_dim = int(embed_dim) | |
| self.mha = tf.keras.layers.MultiHeadAttention( | |
| num_heads=self.num_heads, | |
| key_dim=self.head_size, | |
| dropout=self.dropout_rate, | |
| name="mha", | |
| ) | |
| self.ffn1 = tf.keras.layers.Dense(self.ff_dim, activation=tf.nn.gelu, name="ffn1") | |
| self.ffn2 = tf.keras.layers.Dense(self.embed_dim, name="ffn2") | |
| self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="ln1") | |
| self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="ln2") | |
| self.drop1 = tf.keras.layers.Dropout(self.dropout_rate, name="drop1") | |
| self.drop2 = tf.keras.layers.Dropout(self.dropout_rate, name="drop2") | |
| def build(self, input_shape): | |
| self.mha.build([input_shape, input_shape, input_shape]) | |
| self.ffn1.build(input_shape) | |
| self.ffn2.build((input_shape[0], input_shape[1], self.ff_dim)) | |
| self.ln1.build(input_shape) | |
| self.ln2.build(input_shape) | |
| super().build(input_shape) | |
| def call(self, x, training=False): | |
| attn = self.mha(x, x, training=training) | |
| attn = self.drop1(attn, training=training) | |
| x = x + attn | |
| x = self.ln1(x) | |
| f = self.ffn2(self.ffn1(x)) | |
| f = self.drop2(f, training=training) | |
| x = x + f | |
| x = self.ln2(x) | |
| return x | |
| def get_config(self): | |
| cfg = super().get_config() | |
| cfg.update( | |
| dict( | |
| num_heads=self.num_heads, | |
| head_size=self.head_size, | |
| dropout_rate=self.dropout_rate, | |
| ff_dim=self.ff_dim, | |
| vocabulary_size=self.vocabulary_size, | |
| embed_dim=self.embed_dim, | |
| ) | |
| ) | |
| return cfg | |
| class BinaryFocalLoss(tf.keras.losses.Loss): | |
| """Binary Focal Loss for handling class imbalance.""" | |
| def __init__(self, alpha=0.25, gamma=2.0, name="binary_focal"): | |
| super().__init__(name=name) | |
| self.alpha = float(alpha) | |
| self.gamma = float(gamma) | |
| def call(self, y_true, y_pred): | |
| y_true = tf.cast(y_true, tf.float32) | |
| y_pred = tf.cast(y_pred, tf.float32) | |
| eps = tf.keras.backend.epsilon() | |
| y_pred = tf.clip_by_value(y_pred, eps, 1.0 - eps) | |
| ce = -(y_true * tf.math.log(y_pred) + (1.0 - y_true) * tf.math.log(1.0 - y_pred)) | |
| p_t = y_true * y_pred + (1.0 - y_true) * (1.0 - y_pred) | |
| alpha_t = y_true * self.alpha + (1.0 - y_true) * (1.0 - self.alpha) | |
| focal = alpha_t * tf.pow(1.0 - p_t, self.gamma) * ce | |
| return tf.reduce_mean(focal) | |
| def get_config(self): | |
| cfg = super().get_config() | |
| cfg.update({"alpha": self.alpha, "gamma": self.gamma}) | |
| return cfg | |
| def get_custom_objects(): | |
| """Return dictionary of custom objects needed for model loading.""" | |
| return { | |
| "layer_pos_embedding": layer_pos_embedding, | |
| "layer_transformer_block": layer_transformer_block, | |
| "BinaryFocalLoss": BinaryFocalLoss, | |
| } | |