""" Custom Keras layers for BERT metagenome model. These layers must be registered as custom_objects when loading the model. """ import tensorflow as tf @tf.keras.utils.register_keras_serializable(package="deepG") class layer_pos_embedding(tf.keras.layers.Layer): """Token + Positional Embedding layer for BERT.""" def __init__(self, maxlen=1000, vocabulary_size=6, embed_dim=600, **kwargs): super().__init__(**kwargs) self.maxlen = int(maxlen) self.vocabulary_size = int(vocabulary_size) self.embed_dim = int(embed_dim) self.token_emb = tf.keras.layers.Embedding( input_dim=self.vocabulary_size, output_dim=self.embed_dim, name="token_emb", ) self.pos_emb = tf.keras.layers.Embedding( input_dim=self.maxlen, output_dim=self.embed_dim, name="pos_emb", ) def call(self, x): x = tf.cast(x, tf.int32) L = tf.shape(x)[1] positions = tf.range(start=0, limit=L, delta=1) positions = self.pos_emb(positions) tokens = self.token_emb(x) return tokens + positions def get_config(self): cfg = super().get_config() cfg.update( dict( maxlen=self.maxlen, vocabulary_size=self.vocabulary_size, embed_dim=self.embed_dim, ) ) return cfg @tf.keras.utils.register_keras_serializable(package="deepG") class layer_transformer_block(tf.keras.layers.Layer): """Transformer block with Multi-Head Attention and Feed-Forward Network.""" def __init__( self, num_heads=16, head_size=250, dropout_rate=0.0, ff_dim=2400.0, vocabulary_size=6, embed_dim=600, **kwargs ): super().__init__(**kwargs) self.num_heads = int(num_heads) self.head_size = int(head_size) self.dropout_rate = float(dropout_rate) self.ff_dim = int(ff_dim) self.vocabulary_size = int(vocabulary_size) self.embed_dim = int(embed_dim) self.mha = tf.keras.layers.MultiHeadAttention( num_heads=self.num_heads, key_dim=self.head_size, dropout=self.dropout_rate, name="mha", ) self.ffn1 = tf.keras.layers.Dense(self.ff_dim, activation=tf.nn.gelu, name="ffn1") self.ffn2 = tf.keras.layers.Dense(self.embed_dim, name="ffn2") self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="ln1") self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="ln2") self.drop1 = tf.keras.layers.Dropout(self.dropout_rate, name="drop1") self.drop2 = tf.keras.layers.Dropout(self.dropout_rate, name="drop2") def call(self, x, training=False): attn = self.mha(x, x, training=training) attn = self.drop1(attn, training=training) x = x + attn x = self.ln1(x) f = self.ffn2(self.ffn1(x)) f = self.drop2(f, training=training) x = x + f x = self.ln2(x) return x def get_config(self): cfg = super().get_config() cfg.update( dict( num_heads=self.num_heads, head_size=self.head_size, dropout_rate=self.dropout_rate, ff_dim=self.ff_dim, vocabulary_size=self.vocabulary_size, embed_dim=self.embed_dim, ) ) return cfg def get_custom_objects(): """Return dictionary of custom objects needed for model loading.""" return { "layer_pos_embedding": layer_pos_embedding, "layer_transformer_block": layer_transformer_block, }