Spaces:
Sleeping
Sleeping
File size: 5,221 Bytes
52e5b45 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """
Custom Keras layers for CRISPR BERT model.
These layers must be registered as custom_objects when loading the model.
Based on code from Ziyu Mu's CRISPRArrayDetection repository.
"""
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package="deepG")
class layer_pos_embedding(tf.keras.layers.Layer):
"""Token + Positional Embedding layer for BERT."""
def __init__(self, maxlen=1000, vocabulary_size=6, embed_dim=600, **kwargs):
super().__init__(**kwargs)
self.maxlen = int(maxlen)
self.vocabulary_size = int(vocabulary_size)
self.embed_dim = int(embed_dim)
self.token_emb = tf.keras.layers.Embedding(
input_dim=self.vocabulary_size,
output_dim=self.embed_dim,
name="token_emb",
)
self.pos_emb = tf.keras.layers.Embedding(
input_dim=self.maxlen,
output_dim=self.embed_dim,
name="pos_emb",
)
def call(self, x):
x = tf.cast(x, tf.int32)
L = tf.shape(x)[1]
positions = tf.range(start=0, limit=L, delta=1)
positions = self.pos_emb(positions)
tokens = self.token_emb(x)
return tokens + positions
def get_config(self):
cfg = super().get_config()
cfg.update(
dict(
maxlen=self.maxlen,
vocabulary_size=self.vocabulary_size,
embed_dim=self.embed_dim,
)
)
return cfg
@tf.keras.utils.register_keras_serializable(package="deepG")
class layer_transformer_block(tf.keras.layers.Layer):
"""Transformer block with Multi-Head Attention and Feed-Forward Network."""
def __init__(
self,
num_heads=16,
head_size=250,
dropout_rate=0.0,
ff_dim=2400.0,
vocabulary_size=6,
embed_dim=600,
**kwargs
):
super().__init__(**kwargs)
self.num_heads = int(num_heads)
self.head_size = int(head_size)
self.dropout_rate = float(dropout_rate)
self.ff_dim = int(ff_dim)
self.vocabulary_size = int(vocabulary_size)
self.embed_dim = int(embed_dim)
self.mha = tf.keras.layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=self.head_size,
dropout=self.dropout_rate,
name="mha",
)
self.ffn1 = tf.keras.layers.Dense(self.ff_dim, activation=tf.nn.gelu, name="ffn1")
self.ffn2 = tf.keras.layers.Dense(self.embed_dim, name="ffn2")
self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="ln1")
self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="ln2")
self.drop1 = tf.keras.layers.Dropout(self.dropout_rate, name="drop1")
self.drop2 = tf.keras.layers.Dropout(self.dropout_rate, name="drop2")
def build(self, input_shape):
self.mha.build([input_shape, input_shape, input_shape])
self.ffn1.build(input_shape)
self.ffn2.build((input_shape[0], input_shape[1], self.ff_dim))
self.ln1.build(input_shape)
self.ln2.build(input_shape)
super().build(input_shape)
def call(self, x, training=False):
attn = self.mha(x, x, training=training)
attn = self.drop1(attn, training=training)
x = x + attn
x = self.ln1(x)
f = self.ffn2(self.ffn1(x))
f = self.drop2(f, training=training)
x = x + f
x = self.ln2(x)
return x
def get_config(self):
cfg = super().get_config()
cfg.update(
dict(
num_heads=self.num_heads,
head_size=self.head_size,
dropout_rate=self.dropout_rate,
ff_dim=self.ff_dim,
vocabulary_size=self.vocabulary_size,
embed_dim=self.embed_dim,
)
)
return cfg
@tf.keras.utils.register_keras_serializable(package="deepG")
class BinaryFocalLoss(tf.keras.losses.Loss):
"""Binary Focal Loss for handling class imbalance."""
def __init__(self, alpha=0.25, gamma=2.0, name="binary_focal"):
super().__init__(name=name)
self.alpha = float(alpha)
self.gamma = float(gamma)
def call(self, y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
eps = tf.keras.backend.epsilon()
y_pred = tf.clip_by_value(y_pred, eps, 1.0 - eps)
ce = -(y_true * tf.math.log(y_pred) + (1.0 - y_true) * tf.math.log(1.0 - y_pred))
p_t = y_true * y_pred + (1.0 - y_true) * (1.0 - y_pred)
alpha_t = y_true * self.alpha + (1.0 - y_true) * (1.0 - self.alpha)
focal = alpha_t * tf.pow(1.0 - p_t, self.gamma) * ce
return tf.reduce_mean(focal)
def get_config(self):
cfg = super().get_config()
cfg.update({"alpha": self.alpha, "gamma": self.gamma})
return cfg
def get_custom_objects():
"""Return dictionary of custom objects needed for model loading."""
return {
"layer_pos_embedding": layer_pos_embedding,
"layer_transformer_block": layer_transformer_block,
"BinaryFocalLoss": BinaryFocalLoss,
}
|