| import tensorflow as tf |
| import keras |
| import keras.layers as kl |
| from keras_nlp.layers import SinePositionEncoding, TransformerEncoder |
|
|
| class EnhancedHyenaPlusLayer(kl.Layer): |
| """ |
| Enhanced Hyena+DNA layer with multi-scale feature extraction, residual connections, |
| explicit dimension alignment, and layer normalization for improved gradient flow and stability. |
| """ |
| def __init__(self, filters, kernel_size, output_dim, use_residual=True, dilation_rate=1, |
| kernel_regularizer=None, **kwargs): |
| super(EnhancedHyenaPlusLayer, self).__init__(**kwargs) |
| self.filters = filters |
| self.kernel_size = kernel_size |
| self.output_dim = output_dim |
| self.use_residual = use_residual |
| self.dilation_rate = dilation_rate |
| self.kernel_regularizer = kernel_regularizer |
| |
| |
| self.conv = kl.Conv1D(filters, kernel_size, padding='same', |
| kernel_regularizer=kernel_regularizer) |
| |
| |
| self.dilated_conv = kl.Conv1D(filters // 2, kernel_size, |
| padding='same', |
| dilation_rate=dilation_rate, |
| kernel_regularizer=kernel_regularizer) |
| |
| |
| self.local_conv = kl.Conv1D(filters // 2, 3, padding='same', |
| kernel_regularizer=kernel_regularizer) |
| |
| |
| self.batch_norm = kl.BatchNormalization() |
| self.activation = kl.Activation('relu') |
| |
| |
| self.fusion = kl.Dense(filters, kernel_regularizer=kernel_regularizer) |
| |
| |
| self.projection = kl.Dense(output_dim, kernel_regularizer=kernel_regularizer) |
| |
| |
| self.layer_norm = kl.LayerNormalization() |
| |
| |
| self.input_projection = None |
| if use_residual: |
| self.input_projection = kl.Dense(output_dim, kernel_regularizer=kernel_regularizer) |
| |
| def call(self, inputs, training=None): |
| |
| residual = inputs |
| |
| |
| x_main = self.conv(inputs) |
| |
| |
| x_dilated = self.dilated_conv(inputs) |
| |
| |
| x_local = self.local_conv(inputs) |
| |
| |
| x_multi = tf.concat([x_dilated, x_local], axis=-1) |
| |
| |
| x = self.fusion(x_multi) + x_main |
| |
| x = self.batch_norm(x, training=training) |
| x = self.activation(x) |
| |
| |
| x = self.projection(x) |
| |
| |
| if self.use_residual: |
| |
| residual = self.input_projection(residual) |
| x = x + residual |
| |
| |
| x = self.layer_norm(x) |
| |
| return x |
| |
| def get_config(self): |
| config = super(EnhancedHyenaPlusLayer, self).get_config() |
| config.update({ |
| 'filters': self.filters, |
| 'kernel_size': self.kernel_size, |
| 'output_dim': self.output_dim, |
| 'use_residual': self.use_residual, |
| 'dilation_rate': self.dilation_rate, |
| 'kernel_regularizer': self.kernel_regularizer |
| }) |
| return config |
|
|
| class HybridContextAwareMSTA(kl.Layer): |
| """ |
| Hybrid Context-Aware Motif-Specific Transformer Attention (HCA-MSTA) module |
| with enhanced biological interpretability and selective motif attention. |
| Combines the strengths of previous approaches with improved positional encoding. |
| """ |
| def __init__(self, num_motifs, motif_dim, num_heads=4, dropout_rate=0.1, |
| kernel_regularizer=None, activity_regularizer=None, **kwargs): |
| super(HybridContextAwareMSTA, self).__init__(**kwargs) |
| self.num_motifs = num_motifs |
| self.motif_dim = motif_dim |
| self.num_heads = num_heads |
| self.dropout_rate = dropout_rate |
| self.kernel_regularizer = kernel_regularizer |
| self.activity_regularizer = activity_regularizer |
| |
| |
| self.motif_embeddings = self.add_weight( |
| shape=(num_motifs, motif_dim), |
| initializer='glorot_uniform', |
| regularizer=activity_regularizer, |
| trainable=True, |
| name='motif_embeddings' |
| ) |
| |
| |
| self.motif_position_encoding = self.add_weight( |
| shape=(num_motifs, motif_dim), |
| initializer='glorot_uniform', |
| trainable=True, |
| name='motif_position_encoding' |
| ) |
| |
| |
| self.motif_importance = self.add_weight( |
| shape=(num_motifs, 1), |
| initializer='ones', |
| regularizer=activity_regularizer, |
| trainable=True, |
| name='motif_importance' |
| ) |
| |
| |
| self.query_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer) |
| self.key_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer) |
| self.value_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer) |
| |
| |
| self.attention = kl.MultiHeadAttention( |
| num_heads=num_heads, |
| key_dim=motif_dim // num_heads, |
| dropout=dropout_rate |
| ) |
| |
| |
| self.gate_dense = kl.Dense(motif_dim, activation='sigmoid', |
| kernel_regularizer=kernel_regularizer) |
| |
| |
| self.output_dense = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer) |
| self.dropout = kl.Dropout(dropout_rate) |
| self.layer_norm = kl.LayerNormalization() |
| |
| |
| self.ffn_dense1 = kl.Dense(motif_dim * 2, activation='relu', |
| kernel_regularizer=kernel_regularizer) |
| self.ffn_dense2 = kl.Dense(motif_dim, kernel_regularizer=kernel_regularizer) |
| self.ffn_layer_norm = kl.LayerNormalization() |
| self.ffn_dropout = kl.Dropout(dropout_rate) |
| |
| def positional_masking(self, sequence_embeddings, motif_embeddings): |
| """ |
| Generate hybrid positional masking based on sequence and motif relevance |
| with improved biological context awareness and motif importance weighting. |
| Combines inverse distance and Gaussian approaches for better biological relevance. |
| """ |
| |
| similarity = tf.matmul(sequence_embeddings, tf.transpose(motif_embeddings, [0, 2, 1])) |
| |
| |
| scaled_similarity = similarity / tf.sqrt(tf.cast(self.motif_dim, tf.float32)) |
| |
| |
| attention_weights = tf.nn.softmax(scaled_similarity, axis=-1) |
| |
| |
| seq_length = tf.shape(sequence_embeddings)[1] |
| motif_length = tf.shape(motif_embeddings)[1] |
| |
| |
| position_indices = tf.range(seq_length)[:, tf.newaxis] - tf.range(motif_length)[tf.newaxis, :] |
| position_indices_float = tf.cast(position_indices, tf.float32) |
| |
| |
| inverse_weights = 1.0 / (1.0 + tf.abs(position_indices_float)) |
| |
| |
| gaussian_weights = tf.exp(-0.5 * tf.square(position_indices_float / 8.0)) |
| |
| |
| |
| position_weights = 0.5 * inverse_weights + 0.5 * gaussian_weights |
| position_weights = tf.expand_dims(position_weights, 0) |
| |
| |
| motif_weights = tf.nn.softmax(self.motif_importance * 1.5, axis=0) |
| motif_weights = tf.expand_dims(tf.expand_dims(motif_weights, 0), 1) |
| |
| |
| combined_weights = attention_weights * position_weights * tf.squeeze(motif_weights, -1) |
| |
| return combined_weights |
| |
| def call(self, inputs, training=None): |
| |
| batch_size = tf.shape(inputs)[0] |
| |
| |
| motifs = tf.tile(tf.expand_dims(self.motif_embeddings, 0), [batch_size, 1, 1]) |
| pos_encoding = tf.tile(tf.expand_dims(self.motif_position_encoding, 0), [batch_size, 1, 1]) |
| |
| |
| motifs_with_pos = motifs + pos_encoding |
| |
| |
| query = self.query_dense(inputs) |
| |
| |
| key = self.key_dense(motifs_with_pos) |
| value = self.value_dense(motifs_with_pos) |
| |
| |
| pos_mask = self.positional_masking(query, motifs_with_pos) |
| |
| |
| attention_output = self.attention( |
| query=query, |
| key=key, |
| value=value, |
| attention_mask=pos_mask, |
| training=training |
| ) |
| |
| |
| gate = self.gate_dense(inputs) |
| gated_attention = gate * attention_output |
| |
| |
| output = self.output_dense(gated_attention) |
| output = self.dropout(output, training=training) |
| output = self.layer_norm(output + inputs) |
| |
| |
| ffn_output = self.ffn_dense1(output) |
| ffn_output = self.ffn_dense2(ffn_output) |
| ffn_output = self.ffn_dropout(ffn_output, training=training) |
| final_output = self.ffn_layer_norm(output + ffn_output) |
| |
| return final_output |
| |
| def get_config(self): |
| config = super(HybridContextAwareMSTA, self).get_config() |
| config.update({ |
| 'num_motifs': self.num_motifs, |
| 'motif_dim': self.motif_dim, |
| 'num_heads': self.num_heads, |
| 'dropout_rate': self.dropout_rate, |
| 'kernel_regularizer': self.kernel_regularizer, |
| 'activity_regularizer': self.activity_regularizer |
| }) |
| return config |
|
|
| def HyenaMSTAPlus(params): |
| """ |
| Enhanced HyenaMSTA+ model for enhancer activity prediction with multi-scale feature |
| extraction, hybrid attention mechanism, and improved biological context modeling. |
| """ |
| if params['encode'] == 'one-hot': |
| input_layer = kl.Input(shape=(249, 4)) |
| elif params['encode'] == 'k-mer': |
| input_layer = kl.Input(shape=(1, 64)) |
| |
| |
| l2_reg = params.get('l2_reg', 1e-6) |
| kernel_regularizer = tf.keras.regularizers.l2(l2_reg) |
| activity_regularizer = tf.keras.regularizers.l1(l2_reg/20) |
| |
| |
| x = input_layer |
| hyena_layers = [] |
| |
| |
| num_motifs = params.get('num_motifs', 48) |
| motif_dim = params.get('motif_dim', 96) |
| |
| |
| for i in range(params['convolution_layers']['n_layers']): |
| |
| dilation_rate = 2**min(i, 2) |
| |
| hyena_layer = EnhancedHyenaPlusLayer( |
| filters=params['convolution_layers']['filters'][i], |
| kernel_size=params['convolution_layers']['kernel_sizes'][i], |
| output_dim=motif_dim, |
| dilation_rate=dilation_rate, |
| kernel_regularizer=kernel_regularizer, |
| name=f'EnhancedHyenaPlus_{i+1}' |
| ) |
| x = hyena_layer(x) |
| hyena_layers.append(x) |
| |
| if params['encode'] == 'one-hot': |
| x = kl.MaxPooling1D(2)(x) |
| |
| if params['dropout_conv'] == 'yes': |
| x = kl.Dropout(params['dropout_prob'])(x) |
| |
| |
| ca_msta = HybridContextAwareMSTA( |
| num_motifs=num_motifs, |
| motif_dim=motif_dim, |
| num_heads=params.get('ca_msta_heads', 8), |
| dropout_rate=params['dropout_prob'], |
| kernel_regularizer=kernel_regularizer, |
| activity_regularizer=activity_regularizer |
| ) |
| |
| x = ca_msta(x) |
| |
| |
| x = kl.Flatten()(x) |
| |
| |
| for i in range(params['n_dense_layer']): |
| x = kl.Dense(params['dense_neurons'+str(i+1)], |
| name=str('Dense_'+str(i+1)))(x) |
| x = kl.BatchNormalization()(x) |
| x = kl.Activation('relu')(x) |
| x = kl.Dropout(params['dropout_prob'])(x) |
| |
| |
| bottleneck = x |
| |
| |
| tasks = ['Dev', 'Hk'] |
| outputs = [] |
| for task in tasks: |
| outputs.append(kl.Dense(1, activation='linear', name=str('Dense_' + task))(bottleneck)) |
| |
| |
| model = keras.models.Model([input_layer], outputs) |
| model.compile( |
| keras.optimizers.Adam(learning_rate=params['lr']), |
| loss=['mse', 'mse'], |
| loss_weights=[1, 1] |
| ) |
| |
| return model, params |
|
|