File size: 10,149 Bytes
62a2f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# adapted from Deepstarr colab notebook: https://colab.research.google.com/drive/1Xgak40TuxWWLh5P5ARf0-4Xo0BcRn0Gd 

import argparse
import os
import sys
import time
import traceback
import sklearn
import json
import tensorflow as tf
import keras
import keras_nlp
import keras.layers as kl
from keras.layers import Conv1D, MaxPooling1D, AveragePooling1D
from keras_nlp.layers import SinePositionEncoding, TransformerEncoder
from keras.layers import BatchNormalization
from keras.models import Sequential, Model, load_model
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, History, ModelCheckpoint
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from collections import Counter
from itertools import product
from sklearn.metrics import mean_squared_error
from hyenamsta_model import HyenaMSTAPlus

startTime=time.time()
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

def parse_arguments():
    parser = argparse.ArgumentParser(description='DeepSTARR')
    parser.add_argument('--config', type=str, default='config/config-conv-117.json', help='Configuration file path (default: config/config-conv-117.json)')
    parser.add_argument('--indir', type=str, default='./DeepSTARR-Reimplementation-main/data/Sequences_activity_all.txt', help='Input data directory (default: ./DeepSTARR-Reimplementation-main/data/Sequences_activity_all.txt)')
    parser.add_argument('--out_dir', type=str, default='output', help='Output directory (default: output)')
    parser.add_argument('--label', type=str, default='hyenamsta_plus', help='Output label (default: hyenamsta_plus)')
    parser.add_argument('--model_type', type=str, default='hyenamsta_plus', help='Model type to use: "deepstarr" or "hyenamsta_plus" (default: hyenamsta_plus)')
    parser.add_argument('--num_motifs', type=int, default=48, help='Number of motifs for CA-MSTA (default: 48)')
    parser.add_argument('--motif_dim', type=int, default=96, help='Dimension of motif embeddings (default: 96)')
    parser.add_argument('--ca_msta_heads', type=int, default=8, help='Number of attention heads in CA-MSTA (default: 8)')
    parser.add_argument('--l2_reg', type=float, default=1e-6, help='L2 regularization strength (default: 1e-6)')
    return parser.parse_args()

def LoadConfig(config, args):
    with open(config, 'r') as file:
        params = json.load(file)
    
    # Add HyenaMSTA+ specific parameters
    params['model_type'] = args.model_type
    params['num_motifs'] = args.num_motifs
    params['motif_dim'] = args.motif_dim
    params['ca_msta_heads'] = args.ca_msta_heads
    params['l2_reg'] = args.l2_reg
    
    return params

def one_hot_encode(seq):
    nucleotide_dict = {'A': [1, 0, 0, 0],
                       'C': [0, 1, 0, 0],
                       'G': [0, 0, 1, 0],
                       'T': [0, 0, 0, 1],
                       'N': [0, 0, 0, 0]} 
    return np.array([nucleotide_dict[nuc] for nuc in seq])

def kmer_encode(sequence, k=3):
    sequence = sequence.upper()
    kmers = [sequence[i:i+k] for i in range(len(sequence) - k + 1)]
    kmer_counts = Counter(kmers)
    return {kmer: kmer_counts.get(kmer, 0) / len(kmers) for kmer in [''.join(p) for p in product('ACGT', repeat=k)]}

def kmer_features(seq, k=3):
    all_kmers = [''.join(p) for p in product('ACGT', repeat=k)]
    feature_matrix = []
    kmer_freqs = kmer_encode(seq, k)
    feature_vector = [kmer_freqs[kmer] for kmer in all_kmers]
    feature_matrix.append(feature_vector)
    return np.array(feature_matrix)

def prepare_input(data_set, params):
    if params['encode'] == 'one-hot':
        seq_matrix = np.array(data_set['Sequence'].apply(one_hot_encode).tolist())  # (number of sequences, length of sequences, nucleotides)
    elif params['encode'] == 'k-mer':
        seq_matrix = np.array(data_set['Sequence'].apply(kmer_features, k=3).tolist())  # (number of sequences, 1, 4^k)
    else:
        raise Exception ('wrong encoding method')

    Y_dev = data_set.Dev_log2_enrichment
    Y_hk = data_set.Hk_log2_enrichment
    Y = [Y_dev, Y_hk]

    return seq_matrix, Y

def DeepSTARR(params):
    if params['encode'] == 'one-hot':
        input = kl.Input(shape=(249, 4)) 
    elif params['encode'] == 'k-mer':
        input = kl.Input(shape=(1, 64)) 

    for i in range(params['convolution_layers']['n_layers']):
        x = kl.Conv1D(params['convolution_layers']['filters'][i],
                      kernel_size = params['convolution_layers']['kernel_sizes'][i],
                      padding = params['pad'],
                      name=str('Conv1D_'+str(i+1)))(input)
        x = kl.BatchNormalization()(x)
        x = kl.Activation('relu')(x)
        if params['encode'] == 'one-hot':
            x = kl.MaxPooling1D(2)(x)

        if params['dropout_conv'] == 'yes': x = kl.Dropout(params['dropout_prob'])(x)

    # optional attention layers
    for i in range(params['transformer_layers']['n_layers']):
        if i == 0:
            x = x + keras_nlp.layers.SinePositionEncoding()(x)
        x = TransformerEncoder(intermediate_dim = params['transformer_layers']['attn_key_dim'][i],
                                num_heads = params['transformer_layers']['attn_heads'][i],
                                dropout = params['dropout_prob'])(x)
    
    # After the convolutional layers, the output is flattened and passed through a series of fully connected/dense layers
    # Flattening converts a multi-dimensional input (from the convolutions) into a one-dimensional array (to be connected with the fully connected layers
    x = kl.Flatten()(x)
    
    # Fully connected layers
    # Each fully connected layer is followed by batch normalization, ReLU activation, and dropout
    for i in range(params['n_dense_layer']):
        x = kl.Dense(params['dense_neurons'+str(i+1)],
                     name=str('Dense_'+str(i+1)))(x)
        x = kl.BatchNormalization()(x)
        x = kl.Activation('relu')(x)
        x = kl.Dropout(params['dropout_prob'])(x)
    
    # Main model bottleneck
    bottleneck = x
    
    # heads per task (developmental and housekeeping enhancer activities)
    # The final output layer is a pair of dense layers, one for each task (developmental and housekeeping enhancer activities), each with a single neuron and a linear activation function
    tasks = ['Dev', 'Hk']
    outputs = []
    for task in tasks:
        outputs.append(kl.Dense(1, activation='linear', name=str('Dense_' + task))(bottleneck))
    
    # Build Keras model object
    model = Model([input], outputs)
    model.compile(Adam(learning_rate=params['lr']), # Adam optimizer
                  loss=['mse', 'mse'], # loss is Mean Squared Error (MSE)
                  loss_weights=[1, 1]) # in case we want to change the weights of each output. For now keep them with same weights

    return model, params

def train(selected_model, X_train, Y_train, X_valid, Y_valid, params):
    callbacks = [
        EarlyStopping(patience=params['early_stop'], monitor="val_loss", restore_best_weights=True),
        History()
    ]
    
    # Add learning rate scheduler if enabled
    if params.get('lr_schedule', False):
        def lr_scheduler(epoch, lr):
            if epoch < 20:  # Longer warm-up period
                return lr
            else:
                return lr * tf.math.exp(-0.03)  # Gentler decay
        
        callbacks.append(tf.keras.callbacks.LearningRateScheduler(lr_scheduler))
    
    my_history = selected_model.fit(
        X_train, Y_train,
        validation_data=(X_valid, Y_valid), 
        batch_size=params['batch_size'],
        epochs=params['epochs'],
        callbacks=callbacks
    )

    return selected_model, my_history

def summary_statistics(X, Y, set, task, main_model, main_params, out_dir):
    pred = main_model.predict(X, batch_size=main_params['batch_size']) # predict
    if task =="Dev":
        i=0
    if task =="Hk":
        i=1
    print(set + ' MSE ' + task + ' = ' + str("{0:0.2f}".format(mean_squared_error(Y, pred[i].squeeze()))))
    print(set + ' PCC ' + task + ' = ' + str("{0:0.2f}".format(stats.pearsonr(Y, pred[i].squeeze())[0])))
    print(set + ' SCC ' + task + ' = ' + str("{0:0.2f}".format(stats.spearmanr(Y, pred[i].squeeze())[0])))
    return str("{0:0.2f}".format(stats.pearsonr(Y, pred[i].squeeze())[0]))
 
def main(config, indir, out_dir, label, args):
    data = pd.read_table(indir)
    params = LoadConfig(config, args)

    X_train, Y_train = prepare_input(data[data['set'] == "Train"], params)
    X_valid, Y_valid = prepare_input(data[data['set'] == "Val"], params)
    X_test, Y_test = prepare_input(data[data['set'] == "Test"], params)

    # Select model based on model_type parameter
    if params['model_type'] == 'deepstarr':
        main_model, main_params = DeepSTARR(params)
        main_model.summary()
    else:  # hyenamsta_plus
        main_model, main_params = HyenaMSTAPlus(params)
        main_model.summary()
    main_model, my_history = train(main_model, X_train, Y_train, X_valid, Y_valid, main_params)

    endTime=time.time()
    seconds=endTime-startTime
    print("Total training time:",round(seconds/60,2),"minutes")

    dev_results = summary_statistics(X_test, Y_test[0], "test", "Dev", main_model, main_params, out_dir)
    hk_results = summary_statistics(X_test, Y_test[1], "test", "Hk", main_model, main_params, out_dir)

    result = {
        "AutoDNA": {
            "means": {
                "PCC(Dev)": dev_results,
                "PCC(Hk)": hk_results
            }
        }
    }
    
    with open(f"{out_dir}/final_info.json", "w") as file:
        json.dump(result, file, indent=4)

    main_model.save(out_dir + '/' + label + '.h5')

if __name__ == "__main__":
    try:
        args = parse_arguments()
        main(args.config, args.indir, args.out_dir, args.label, args)
    except Exception as e:
        print("Original error in subprocess:", flush=True)
        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
        traceback.print_exc(file=open(os.path.join(args.out_dir, "traceback.log"), "w"))
        raise