harari commited on
Commit
65a437b
·
verified ·
1 Parent(s): 731c404

Upload 3 files

Browse files
extract_tbx5_embeddings.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract embeddings for TBX5 motif data using Evo2 40B model.
4
+ - Extract embeddings from block 20 pre-normalization layer
5
+ - Use 8192bp window around motif site
6
+ - Average embeddings for 61bp sequences
7
+ - Create 4096 dimensional feature vector for each motif
8
+ """
9
+
10
+ import pandas as pd
11
+ import numpy as np
12
+ import torch
13
+ import gzip
14
+ from Bio import SeqIO
15
+ from Bio.Seq import Seq
16
+ from evo2 import Evo2
17
+ import pickle
18
+ from tqdm import tqdm
19
+ import os
20
+ import sys
21
+ import argparse
22
+
23
+ # Configure tqdm for better display in containers
24
+ tqdm.pandas()
25
+
26
+ # Configuration
27
+ WINDOW_SIZE = 256 # 8192bp window around motif site
28
+ LAYER_NAME = "blocks.26.mlp.l3" # Block 20 pre-normalization layer
29
+ SEQUENCE_LENGTH = 61 # Fixed sequence length for all motifs
30
+ BATCH_SIZE = 8 # Adjust based on GPU memory for 40B model
31
+
32
+ def load_fasta(fasta_path, chromosome):
33
+ """Load chromosome FASTA file."""
34
+ print(f"Loading chromosome {chromosome} FASTA file...")
35
+ with gzip.open(fasta_path, "rt") as handle:
36
+ for record in SeqIO.parse(handle, "fasta"):
37
+ seq = str(record.seq).upper()
38
+ print(f"Loaded chromosome {chromosome}, length: {len(seq):,} bp")
39
+ return seq
40
+ return None
41
+
42
+ def normalize_sequence_length(df):
43
+ """Normalize all sequences to 61bp length."""
44
+ print("Normalizing sequence lengths to 61bp...")
45
+
46
+ df_normalized = df.copy()
47
+
48
+ for idx, row in df_normalized.iterrows():
49
+ start = row['start']
50
+ end = row['end']
51
+ current_length = end - start + 1 # Both ends inclusive
52
+
53
+ if current_length != SEQUENCE_LENGTH:
54
+ if current_length < SEQUENCE_LENGTH:
55
+ # Extend sequence to 61bp
56
+ extension = SEQUENCE_LENGTH - current_length
57
+ new_start = max(0, start - extension // 2)
58
+ new_end = new_start + SEQUENCE_LENGTH - 1
59
+ else:
60
+ # Truncate sequence to 61bp (center the sequence)
61
+ excess = current_length - SEQUENCE_LENGTH
62
+ new_start = start + excess // 2
63
+ new_end = new_start + SEQUENCE_LENGTH - 1
64
+
65
+ df_normalized.at[idx, 'start'] = new_start
66
+ df_normalized.at[idx, 'end'] = new_end
67
+ df_normalized.at[idx, 'length'] = SEQUENCE_LENGTH
68
+
69
+ print(f"Normalized {len(df_normalized)} sequences to {SEQUENCE_LENGTH}bp")
70
+ return df_normalized
71
+
72
+ def get_sequence_window(chr_seq, start, end, window_size=WINDOW_SIZE):
73
+ """
74
+ Extract sequence window around motif site.
75
+
76
+ Args:
77
+ chr_seq: Full chromosome sequence
78
+ start: Start position of motif (1-based)
79
+ end: End position of motif (1-based)
80
+ window_size: Size of window around motif (default 8192bp)
81
+
82
+ Returns:
83
+ seq_window: Sequence window around motif
84
+ motif_pos: Position of motif in the window
85
+ """
86
+ # Convert to 0-based indexing
87
+ start_0 = start - 1
88
+ end_0 = end - 1
89
+
90
+ # Calculate center of motif
91
+ motif_center = (start_0 + end_0) // 2
92
+
93
+ # Calculate window boundaries
94
+ half_window = window_size // 2
95
+ window_start = max(0, motif_center - half_window)
96
+ window_end = min(len(chr_seq), motif_center + half_window)
97
+
98
+ # Extract sequence window
99
+ seq_window = chr_seq[window_start:window_end]
100
+
101
+ # Calculate motif position in window
102
+ motif_start_in_window = start_0 - window_start
103
+ motif_end_in_window = end_0 - window_start
104
+
105
+ return seq_window, motif_start_in_window, motif_end_in_window
106
+
107
+ def extract_embeddings_batch(model, sequences, layer_name=LAYER_NAME):
108
+ """
109
+ Extract embeddings for a batch of sequences.
110
+
111
+ Args:
112
+ model: Evo2 model
113
+ sequences: List of DNA sequences
114
+ layer_name: Name of layer to extract embeddings from
115
+
116
+ Returns:
117
+ embeddings: Averaged embeddings for each sequence
118
+ """
119
+ all_embeddings = []
120
+
121
+ for seq in sequences:
122
+ # Tokenize sequence
123
+ input_ids = (
124
+ torch.tensor(
125
+ model.tokenizer.tokenize(seq),
126
+ dtype=torch.int,
127
+ )
128
+ .unsqueeze(0)
129
+ .to("cuda:0")
130
+ )
131
+
132
+ # Get embeddings
133
+ with torch.no_grad():
134
+ _, embeddings = model(
135
+ input_ids, return_embeddings=True, layer_names=[layer_name]
136
+ )
137
+
138
+ # Average over sequence length dimension
139
+ # Shape: [batch_size, seq_len, hidden_dim] -> [batch_size, hidden_dim]
140
+ # Convert from BFloat16 to Float32 before converting to numpy
141
+ avg_embedding = embeddings[layer_name].mean(dim=1).float().cpu().numpy()
142
+ all_embeddings.append(avg_embedding)
143
+
144
+ return np.vstack(all_embeddings)
145
+
146
+ def process_motifs(model, chr_seq, motif_df, chromosome):
147
+ """
148
+ Process all motifs and extract embeddings.
149
+
150
+ Args:
151
+ model: Evo2 model
152
+ chr_seq: Chromosome sequence
153
+ motif_df: DataFrame with motif information
154
+ chromosome: Chromosome identifier
155
+
156
+ Returns:
157
+ embeddings_dict: Dictionary with motif indices as keys and embeddings as values
158
+ """
159
+ embeddings_dict = {}
160
+ failed_motifs = []
161
+
162
+ print(f"Processing {len(motif_df)} motifs on chromosome {chromosome}...")
163
+
164
+ for idx, row in tqdm(
165
+ motif_df.iterrows(),
166
+ total=len(motif_df),
167
+ desc=f"Chr{chromosome} embeddings",
168
+ ncols=120,
169
+ leave=True,
170
+ position=0,
171
+ mininterval=1.0,
172
+ maxinterval=10.0,
173
+ dynamic_ncols=True
174
+ ):
175
+ try:
176
+ # Get motif coordinates
177
+ start = int(row['start'])
178
+ end = int(row['end'])
179
+
180
+ # Print progress every 100 motifs
181
+ if idx % 100 == 0:
182
+ print(f"\nProcessing motif {idx+1}/{len(motif_df)} ({(idx+1)/len(motif_df)*100:.1f}%)")
183
+
184
+ # Extract sequence window
185
+ seq_window, motif_start, motif_end = get_sequence_window(
186
+ chr_seq, start, end
187
+ )
188
+
189
+ if seq_window is None:
190
+ failed_motifs.append(idx)
191
+ continue
192
+
193
+ # Extract motif sequence from window
194
+ motif_seq = seq_window[motif_start:motif_end+1]
195
+
196
+ # Verify motif length
197
+ if len(motif_seq) != SEQUENCE_LENGTH:
198
+ print(f"Warning: Motif length {len(motif_seq)} != {SEQUENCE_LENGTH} at position {start}-{end}")
199
+ failed_motifs.append(idx)
200
+ continue
201
+
202
+ # Extract embeddings for motif sequence
203
+ embeddings = extract_embeddings_batch(model, [motif_seq])
204
+
205
+ # Get single embedding (shape: [1, 4096])
206
+ motif_embedding = embeddings[0] # Shape: [4096]
207
+
208
+ embeddings_dict[idx] = {
209
+ "start": start,
210
+ "end": end,
211
+ "embedding": motif_embedding,
212
+ "tbx5_score": row.get("tbx5_score", 0),
213
+ "label": row.get("label", 0),
214
+ "chromosome": chromosome,
215
+ }
216
+
217
+ except Exception as e:
218
+ print(f"Error processing motif at index {idx}: {e}")
219
+ failed_motifs.append(idx)
220
+ continue
221
+
222
+ print(f"Successfully processed {len(embeddings_dict)} motifs")
223
+ if failed_motifs:
224
+ print(f"Failed to process {len(failed_motifs)} motifs: {failed_motifs[:10]}...")
225
+
226
+ return embeddings_dict
227
+
228
+ def save_embeddings(embeddings_dict, output_path, chromosome):
229
+ """Save embeddings to file."""
230
+ print(f"Saving embeddings to {output_path}")
231
+
232
+ # Convert to format suitable for saving
233
+ save_data = {
234
+ "embeddings": {},
235
+ "metadata": {
236
+ "chromosome": chromosome,
237
+ "window_size": WINDOW_SIZE,
238
+ "sequence_length": SEQUENCE_LENGTH,
239
+ "layer_name": LAYER_NAME,
240
+ "embedding_dim": 4096,
241
+ "num_motifs": len(embeddings_dict),
242
+ },
243
+ }
244
+
245
+ for idx, data in embeddings_dict.items():
246
+ save_data["embeddings"][idx] = data
247
+
248
+ # Save as pickle file
249
+ with open(output_path, "wb") as f:
250
+ pickle.dump(save_data, f)
251
+
252
+ # Also save as numpy arrays for easier loading
253
+ np_output = output_path.replace(".pkl", "_arrays.npz")
254
+
255
+ # Extract arrays
256
+ indices = []
257
+ starts = []
258
+ ends = []
259
+ embeddings = []
260
+ tbx5_scores = []
261
+ labels = []
262
+
263
+ for idx, data in embeddings_dict.items():
264
+ indices.append(idx)
265
+ starts.append(data["start"])
266
+ ends.append(data["end"])
267
+ embeddings.append(data["embedding"])
268
+ tbx5_scores.append(data["tbx5_score"])
269
+ labels.append(data["label"])
270
+
271
+ if len(embeddings) > 0:
272
+ np.savez_compressed(
273
+ np_output,
274
+ indices=np.array(indices),
275
+ starts=np.array(starts),
276
+ ends=np.array(ends),
277
+ embeddings=np.vstack(embeddings),
278
+ tbx5_scores=np.array(tbx5_scores),
279
+ labels=np.array(labels),
280
+ metadata=save_data["metadata"],
281
+ )
282
+ print(f"Saved numpy arrays to {np_output}")
283
+ else:
284
+ print("No embeddings to save in numpy format")
285
+
286
+ def main():
287
+ # Parse command line arguments
288
+ parser = argparse.ArgumentParser(
289
+ description="Extract embeddings for TBX5 motif data"
290
+ )
291
+ parser.add_argument(
292
+ "chromosome", type=str, help="Chromosome to process (e.g., 1, 2, X, Y)"
293
+ )
294
+ parser.add_argument(
295
+ "--fasta-dir",
296
+ type=str,
297
+ default="fasta",
298
+ help="Directory containing FASTA files (default: fasta)",
299
+ )
300
+ parser.add_argument(
301
+ "--csv-file",
302
+ type=str,
303
+ default="processed_data/all_tbx5_data.csv",
304
+ help="TBX5 CSV file (default: processed_data/all_tbx5_data.csv)",
305
+ )
306
+ parser.add_argument(
307
+ "--output-dir",
308
+ type=str,
309
+ default="tbx5_embeddings",
310
+ help="Output directory for embeddings (default: tbx5_embeddings)",
311
+ )
312
+ parser.add_argument(
313
+ "--model",
314
+ type=str,
315
+ default="evo2_40b",
316
+ help="Evo2 model to use (default: evo2_40b)",
317
+ )
318
+
319
+ args = parser.parse_args()
320
+ chromosome = args.chromosome
321
+
322
+ # Create output directory if it doesn't exist
323
+ os.makedirs(args.output_dir, exist_ok=True)
324
+
325
+ # File paths
326
+ fasta_path = os.path.join(
327
+ args.fasta_dir, f"Homo_sapiens.GRCh38.dna.chromosome.{chromosome}.fa.gz"
328
+ )
329
+ csv_path = args.csv_file
330
+ output_path = os.path.join(args.output_dir, f"chr{chromosome}_tbx5_embeddings.pkl")
331
+
332
+ # Check if files exist
333
+ if not os.path.exists(fasta_path):
334
+ print(f"Error: FASTA file not found at {fasta_path}")
335
+ return 1
336
+
337
+ if not os.path.exists(csv_path):
338
+ print(f"Error: CSV file not found at {csv_path}")
339
+ return 1
340
+
341
+ # Load chromosome sequence
342
+ chr_seq = load_fasta(fasta_path, chromosome)
343
+ if chr_seq is None:
344
+ print(f"Error: Failed to load chromosome {chromosome} sequence")
345
+ return 1
346
+
347
+ # Load TBX5 data
348
+ print(f"Loading TBX5 data for chromosome {chromosome}...")
349
+ motif_df = pd.read_csv(csv_path)
350
+
351
+ # Filter for specific chromosome
352
+ chr_motif_df = motif_df[motif_df['chromosome'] == chromosome].copy()
353
+
354
+ if len(chr_motif_df) == 0:
355
+ print(f"Warning: No chromosome {chromosome} motifs found in TBX5 data")
356
+ # Create empty output file to mark completion
357
+ save_data = {
358
+ "embeddings": {},
359
+ "metadata": {
360
+ "chromosome": chromosome,
361
+ "window_size": WINDOW_SIZE,
362
+ "sequence_length": SEQUENCE_LENGTH,
363
+ "layer_name": LAYER_NAME,
364
+ "embedding_dim": 4096,
365
+ "num_motifs": 0,
366
+ },
367
+ }
368
+ with open(output_path, "wb") as f:
369
+ pickle.dump(save_data, f)
370
+ print(f"Created empty embeddings file for chromosome {chromosome}")
371
+ return 0
372
+
373
+ print(f"Found {len(chr_motif_df)} motifs on chromosome {chromosome}")
374
+
375
+ # Normalize sequence lengths
376
+ chr_motif_df = normalize_sequence_length(chr_motif_df)
377
+
378
+ # Initialize model
379
+ print(f"Loading {args.model} model...")
380
+ model = Evo2(args.model)
381
+ model.model.eval() # Set to evaluation mode - access the actual model
382
+
383
+ # Process motifs and extract embeddings
384
+ embeddings_dict = process_motifs(model, chr_seq, chr_motif_df, chromosome)
385
+
386
+ # Save results
387
+ save_embeddings(embeddings_dict, output_path, chromosome)
388
+
389
+ print(f"Done processing chromosome {chromosome}!")
390
+
391
+ # Print summary statistics
392
+ print(f"\n=== Summary for Chromosome {chromosome} ===")
393
+ print(f"Total motifs processed: {len(embeddings_dict)}")
394
+ print(f"Embedding dimension: 4096")
395
+ print(f"Sequence length: {SEQUENCE_LENGTH}bp")
396
+ print(f"Window size: {WINDOW_SIZE}bp")
397
+ print(f"Output files:")
398
+ print(f" - {output_path}")
399
+ print(f" - {output_path.replace('.pkl', '_arrays.npz')}")
400
+
401
+ return 0
402
+
403
+ if __name__ == "__main__":
404
+ sys.exit(main())
extract_tbx5_embeddings_reverse_complement.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract reverse complement embeddings for TBX5 motif data using Evo2 40B model.
4
+ - Extract embeddings from block 20 pre-normalization layer
5
+ - Use 8192bp window around motif site
6
+ - Average embeddings for 61bp sequences (reverse complement)
7
+ - Create 4096 dimensional feature vector for each motif
8
+ """
9
+
10
+ import pandas as pd
11
+ import numpy as np
12
+ import torch
13
+ import gzip
14
+ from Bio import SeqIO
15
+ from Bio.Seq import Seq
16
+ from evo2 import Evo2
17
+ import pickle
18
+ from tqdm import tqdm
19
+ import os
20
+ import sys
21
+ import argparse
22
+
23
+ # Configure tqdm for better display in containers
24
+ tqdm.pandas()
25
+
26
+ # Configuration
27
+ WINDOW_SIZE = 8192 # 8192bp window around motif site
28
+ LAYER_NAME = "blocks.26.mlp.l3" # Block 20 pre-normalization layer
29
+ SEQUENCE_LENGTH = 61 # Fixed sequence length for all motifs
30
+ BATCH_SIZE = 8 # Adjust based on GPU memory for 40B model
31
+
32
+ def get_reverse_complement(sequence):
33
+ """Get reverse complement of DNA sequence."""
34
+ return str(Seq(sequence).reverse_complement())
35
+
36
+ def load_fasta(fasta_path, chromosome):
37
+ """Load chromosome FASTA file."""
38
+ print(f"Loading chromosome {chromosome} FASTA file...")
39
+ with gzip.open(fasta_path, "rt") as handle:
40
+ for record in SeqIO.parse(handle, "fasta"):
41
+ seq = str(record.seq).upper()
42
+ print(f"Loaded chromosome {chromosome}, length: {len(seq):,} bp")
43
+ return seq
44
+ return None
45
+
46
+ def normalize_sequence_length(df):
47
+ """Normalize all sequences to 61bp length."""
48
+ print("Normalizing sequence lengths to 61bp...")
49
+
50
+ df_normalized = df.copy()
51
+
52
+ for idx, row in df_normalized.iterrows():
53
+ start = row['start']
54
+ end = row['end']
55
+ current_length = end - start + 1 # Both ends inclusive
56
+
57
+ if current_length != SEQUENCE_LENGTH:
58
+ if current_length < SEQUENCE_LENGTH:
59
+ # Extend sequence to 61bp
60
+ extension = SEQUENCE_LENGTH - current_length
61
+ new_start = max(0, start - extension // 2)
62
+ new_end = new_start + SEQUENCE_LENGTH - 1
63
+ else:
64
+ # Truncate sequence to 61bp (center the sequence)
65
+ excess = current_length - SEQUENCE_LENGTH
66
+ new_start = start + excess // 2
67
+ new_end = new_start + SEQUENCE_LENGTH - 1
68
+
69
+ df_normalized.at[idx, 'start'] = new_start
70
+ df_normalized.at[idx, 'end'] = new_end
71
+ df_normalized.at[idx, 'length'] = SEQUENCE_LENGTH
72
+
73
+ print(f"Normalized {len(df_normalized)} sequences to {SEQUENCE_LENGTH}bp")
74
+ return df_normalized
75
+
76
+ def get_sequence_window(chr_seq, start, end, window_size=WINDOW_SIZE):
77
+ """
78
+ Extract sequence window around motif site.
79
+
80
+ Args:
81
+ chr_seq: Full chromosome sequence
82
+ start: Start position of motif (1-based)
83
+ end: End position of motif (1-based)
84
+ window_size: Size of window around motif (default 8192bp)
85
+
86
+ Returns:
87
+ seq_window: Sequence window around motif
88
+ motif_pos: Position of motif in the window
89
+ """
90
+ # Convert to 0-based indexing
91
+ start_0 = start - 1
92
+ end_0 = end - 1
93
+
94
+ # Calculate center of motif
95
+ motif_center = (start_0 + end_0) // 2
96
+
97
+ # Calculate window boundaries
98
+ half_window = window_size // 2
99
+ window_start = max(0, motif_center - half_window)
100
+ window_end = min(len(chr_seq), motif_center + half_window)
101
+
102
+ # Extract sequence window
103
+ seq_window = chr_seq[window_start:window_end]
104
+
105
+ # Calculate motif position in window
106
+ motif_start_in_window = start_0 - window_start
107
+ motif_end_in_window = end_0 - window_start
108
+
109
+ return seq_window, motif_start_in_window, motif_end_in_window
110
+
111
+ def extract_embeddings_batch(model, sequences, layer_name=LAYER_NAME):
112
+ """
113
+ Extract embeddings for a batch of sequences.
114
+
115
+ Args:
116
+ model: Evo2 model
117
+ sequences: List of DNA sequences
118
+ layer_name: Name of layer to extract embeddings from
119
+
120
+ Returns:
121
+ embeddings: Averaged embeddings for each sequence
122
+ """
123
+ all_embeddings = []
124
+
125
+ for seq in sequences:
126
+ # Tokenize sequence
127
+ input_ids = (
128
+ torch.tensor(
129
+ model.tokenizer.tokenize(seq),
130
+ dtype=torch.int,
131
+ )
132
+ .unsqueeze(0)
133
+ .to("cuda:0")
134
+ )
135
+
136
+ # Get embeddings
137
+ with torch.no_grad():
138
+ _, embeddings = model(
139
+ input_ids, return_embeddings=True, layer_names=[layer_name]
140
+ )
141
+
142
+ # Average over sequence length dimension
143
+ # Shape: [batch_size, seq_len, hidden_dim] -> [batch_size, hidden_dim]
144
+ # Convert from BFloat16 to Float32 before converting to numpy
145
+ avg_embedding = embeddings[layer_name].mean(dim=1).float().cpu().numpy()
146
+ all_embeddings.append(avg_embedding)
147
+
148
+ return np.vstack(all_embeddings)
149
+
150
+ def process_motifs(model, chr_seq, motif_df, chromosome):
151
+ """
152
+ Process all motifs and extract reverse complement embeddings.
153
+
154
+ Args:
155
+ model: Evo2 model
156
+ chr_seq: Chromosome sequence
157
+ motif_df: DataFrame with motif information
158
+ chromosome: Chromosome identifier
159
+
160
+ Returns:
161
+ embeddings_dict: Dictionary with motif indices as keys and embeddings as values
162
+ """
163
+ embeddings_dict = {}
164
+ failed_motifs = []
165
+
166
+ print(f"Processing {len(motif_df)} motifs on chromosome {chromosome} (reverse complement)...")
167
+
168
+ for idx, row in tqdm(
169
+ motif_df.iterrows(),
170
+ total=len(motif_df),
171
+ desc=f"Chr{chromosome} RC embeddings",
172
+ ncols=100,
173
+ leave=True,
174
+ position=0
175
+ ):
176
+ try:
177
+ # Get motif coordinates
178
+ start = int(row['start'])
179
+ end = int(row['end'])
180
+
181
+ # Extract sequence window
182
+ seq_window, motif_start, motif_end = get_sequence_window(
183
+ chr_seq, start, end
184
+ )
185
+
186
+ if seq_window is None:
187
+ failed_motifs.append(idx)
188
+ continue
189
+
190
+ # Extract motif sequence from window
191
+ motif_seq = seq_window[motif_start:motif_end+1]
192
+
193
+ # Verify motif length
194
+ if len(motif_seq) != SEQUENCE_LENGTH:
195
+ print(f"Warning: Motif length {len(motif_seq)} != {SEQUENCE_LENGTH} at position {start}-{end}")
196
+ failed_motifs.append(idx)
197
+ continue
198
+
199
+ # Get reverse complement of motif sequence
200
+ motif_seq_rc = get_reverse_complement(motif_seq)
201
+
202
+ # Extract embeddings for reverse complement sequence
203
+ embeddings = extract_embeddings_batch(model, [motif_seq_rc])
204
+
205
+ # Get single embedding (shape: [1, 4096])
206
+ motif_embedding = embeddings[0] # Shape: [4096]
207
+
208
+ embeddings_dict[idx] = {
209
+ "start": start,
210
+ "end": end,
211
+ "embedding": motif_embedding,
212
+ "tbx5_score": row.get("tbx5_score", 0),
213
+ "label": row.get("label", 0),
214
+ "chromosome": chromosome,
215
+ "sequence_type": "reverse_complement",
216
+ }
217
+
218
+ except Exception as e:
219
+ print(f"Error processing motif at index {idx}: {e}")
220
+ failed_motifs.append(idx)
221
+ continue
222
+
223
+ print(f"Successfully processed {len(embeddings_dict)} motifs (reverse complement)")
224
+ if failed_motifs:
225
+ print(f"Failed to process {len(failed_motifs)} motifs: {failed_motifs[:10]}...")
226
+
227
+ return embeddings_dict
228
+
229
+ def save_embeddings(embeddings_dict, output_path, chromosome):
230
+ """Save embeddings to file."""
231
+ print(f"Saving reverse complement embeddings to {output_path}")
232
+
233
+ # Convert to format suitable for saving
234
+ save_data = {
235
+ "embeddings": {},
236
+ "metadata": {
237
+ "chromosome": chromosome,
238
+ "window_size": WINDOW_SIZE,
239
+ "sequence_length": SEQUENCE_LENGTH,
240
+ "layer_name": LAYER_NAME,
241
+ "embedding_dim": 4096,
242
+ "num_motifs": len(embeddings_dict),
243
+ "sequence_type": "reverse_complement",
244
+ },
245
+ }
246
+
247
+ for idx, data in embeddings_dict.items():
248
+ save_data["embeddings"][idx] = data
249
+
250
+ # Save as pickle file
251
+ with open(output_path, "wb") as f:
252
+ pickle.dump(save_data, f)
253
+
254
+ # Also save as numpy arrays for easier loading
255
+ np_output = output_path.replace(".pkl", "_arrays.npz")
256
+
257
+ # Extract arrays
258
+ indices = []
259
+ starts = []
260
+ ends = []
261
+ embeddings = []
262
+ tbx5_scores = []
263
+ labels = []
264
+
265
+ for idx, data in embeddings_dict.items():
266
+ indices.append(idx)
267
+ starts.append(data["start"])
268
+ ends.append(data["end"])
269
+ embeddings.append(data["embedding"])
270
+ tbx5_scores.append(data["tbx5_score"])
271
+ labels.append(data["label"])
272
+
273
+ if len(embeddings) > 0:
274
+ np.savez_compressed(
275
+ np_output,
276
+ indices=np.array(indices),
277
+ starts=np.array(starts),
278
+ ends=np.array(ends),
279
+ embeddings=np.vstack(embeddings),
280
+ tbx5_scores=np.array(tbx5_scores),
281
+ labels=np.array(labels),
282
+ metadata=save_data["metadata"],
283
+ )
284
+ print(f"Saved numpy arrays to {np_output}")
285
+ else:
286
+ print("No embeddings to save in numpy format")
287
+
288
+ def main():
289
+ # Parse command line arguments
290
+ parser = argparse.ArgumentParser(
291
+ description="Extract reverse complement embeddings for TBX5 motif data"
292
+ )
293
+ parser.add_argument(
294
+ "chromosome", type=str, help="Chromosome to process (e.g., 1, 2, X, Y)"
295
+ )
296
+ parser.add_argument(
297
+ "--fasta-dir",
298
+ type=str,
299
+ default="fasta",
300
+ help="Directory containing FASTA files (default: fasta)",
301
+ )
302
+ parser.add_argument(
303
+ "--csv-file",
304
+ type=str,
305
+ default="processed_data/all_tbx5_data.csv",
306
+ help="TBX5 CSV file (default: processed_data/all_tbx5_data.csv)",
307
+ )
308
+ parser.add_argument(
309
+ "--output-dir",
310
+ type=str,
311
+ default="tbx5_embeddings_reverse_complement",
312
+ help="Output directory for reverse complement embeddings (default: tbx5_embeddings_reverse_complement)",
313
+ )
314
+ parser.add_argument(
315
+ "--model",
316
+ type=str,
317
+ default="evo2_40b",
318
+ help="Evo2 model to use (default: evo2_40b)",
319
+ )
320
+
321
+ args = parser.parse_args()
322
+ chromosome = args.chromosome
323
+
324
+ # Create output directory if it doesn't exist
325
+ os.makedirs(args.output_dir, exist_ok=True)
326
+
327
+ # File paths
328
+ fasta_path = os.path.join(
329
+ args.fasta_dir, f"Homo_sapiens.GRCh38.dna.chromosome.{chromosome}.fa.gz"
330
+ )
331
+ csv_path = args.csv_file
332
+ output_path = os.path.join(args.output_dir, f"chr{chromosome}_tbx5_embeddings_rc.pkl")
333
+
334
+ # Check if files exist
335
+ if not os.path.exists(fasta_path):
336
+ print(f"Error: FASTA file not found at {fasta_path}")
337
+ return 1
338
+
339
+ if not os.path.exists(csv_path):
340
+ print(f"Error: CSV file not found at {csv_path}")
341
+ return 1
342
+
343
+ # Load chromosome sequence
344
+ chr_seq = load_fasta(fasta_path, chromosome)
345
+ if chr_seq is None:
346
+ print(f"Error: Failed to load chromosome {chromosome} sequence")
347
+ return 1
348
+
349
+ # Load TBX5 data
350
+ print(f"Loading TBX5 data for chromosome {chromosome}...")
351
+ motif_df = pd.read_csv(csv_path)
352
+
353
+ # Filter for specific chromosome
354
+ chr_motif_df = motif_df[motif_df['chromosome'] == chromosome].copy()
355
+
356
+ if len(chr_motif_df) == 0:
357
+ print(f"Warning: No chromosome {chromosome} motifs found in TBX5 data")
358
+ # Create empty output file to mark completion
359
+ save_data = {
360
+ "embeddings": {},
361
+ "metadata": {
362
+ "chromosome": chromosome,
363
+ "window_size": WINDOW_SIZE,
364
+ "sequence_length": SEQUENCE_LENGTH,
365
+ "layer_name": LAYER_NAME,
366
+ "embedding_dim": 4096,
367
+ "num_motifs": 0,
368
+ "sequence_type": "reverse_complement",
369
+ },
370
+ }
371
+ with open(output_path, "wb") as f:
372
+ pickle.dump(save_data, f)
373
+ print(f"Created empty reverse complement embeddings file for chromosome {chromosome}")
374
+ return 0
375
+
376
+ print(f"Found {len(chr_motif_df)} motifs on chromosome {chromosome}")
377
+
378
+ # Normalize sequence lengths
379
+ chr_motif_df = normalize_sequence_length(chr_motif_df)
380
+
381
+ # Initialize model
382
+ print(f"Loading {args.model} model...")
383
+ model = Evo2(args.model)
384
+ model.model.eval() # Set to evaluation mode - access the actual model
385
+
386
+ # Process motifs and extract reverse complement embeddings
387
+ embeddings_dict = process_motifs(model, chr_seq, chr_motif_df, chromosome)
388
+
389
+ # Save results
390
+ save_embeddings(embeddings_dict, output_path, chromosome)
391
+
392
+ print(f"Done processing chromosome {chromosome} (reverse complement)!")
393
+
394
+ # Print summary statistics
395
+ print(f"\n=== Summary for Chromosome {chromosome} (Reverse Complement) ===")
396
+ print(f"Total motifs processed: {len(embeddings_dict)}")
397
+ print(f"Embedding dimension: 4096")
398
+ print(f"Sequence length: {SEQUENCE_LENGTH}bp")
399
+ print(f"Window size: {WINDOW_SIZE}bp")
400
+ print(f"Sequence type: Reverse complement")
401
+ print(f"Output files:")
402
+ print(f" - {output_path}")
403
+ print(f" - {output_path.replace('.pkl', '_arrays.npz')}")
404
+
405
+ return 0
406
+
407
+ if __name__ == "__main__":
408
+ sys.exit(main())
409
+
410
+
411
+
412
+
413
+
414
+
train_tbx5_classifier_with_rc.py ADDED
@@ -0,0 +1,1027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train TBX5 classifier using both forward and reverse complement embeddings.
4
+ This script combines embeddings from both strands to improve classification accuracy.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import argparse
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ from torch.utils.data import DataLoader, TensorDataset
16
+ from sklearn.model_selection import train_test_split
17
+ from sklearn.preprocessing import StandardScaler
18
+ from sklearn.metrics import (
19
+ roc_auc_score,
20
+ accuracy_score,
21
+ precision_recall_fscore_support,
22
+ confusion_matrix,
23
+ )
24
+ import json
25
+ import pickle
26
+ from tqdm import tqdm
27
+ import matplotlib.pyplot as plt
28
+ import seaborn as sns
29
+ from datetime import datetime
30
+
31
+ # Add the parent directory to the path to import from finetuning
32
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'finetuning'))
33
+
34
+ class TBX5ClassifierWithRC(nn.Module):
35
+ """
36
+ 3-layer feedforward neural network for TBX5 binding site classification
37
+ using both forward and reverse complement embeddings.
38
+ Architecture:
39
+ - Input (8192 dimensions: 4096 forward + 4096 reverse complement) -> 2048 -> 512 -> 128 -> 1 (sigmoid)
40
+ - ReLU activation, BatchNorm, Dropout(0.5) after each hidden layer
41
+ """
42
+
43
+ def __init__(self, input_dim=8192, dropout_rate=0.5):
44
+ super(TBX5ClassifierWithRC, self).__init__()
45
+
46
+ self.fc1 = nn.Linear(input_dim, 2048)
47
+ self.bn1 = nn.BatchNorm1d(2048)
48
+ self.dropout1 = nn.Dropout(dropout_rate)
49
+
50
+ self.fc2 = nn.Linear(2048, 512)
51
+ self.bn2 = nn.BatchNorm1d(512)
52
+ self.dropout2 = nn.Dropout(dropout_rate)
53
+
54
+ self.fc3 = nn.Linear(512, 128)
55
+ self.bn3 = nn.BatchNorm1d(128)
56
+ self.dropout3 = nn.Dropout(dropout_rate)
57
+
58
+ self.fc4 = nn.Linear(128, 1)
59
+
60
+ self.relu = nn.ReLU()
61
+ self.sigmoid = nn.Sigmoid()
62
+
63
+ def forward(self, x):
64
+ # Layer 1
65
+ x = self.fc1(x)
66
+ x = self.relu(x)
67
+ x = self.bn1(x)
68
+ x = self.dropout1(x)
69
+
70
+ # Layer 2
71
+ x = self.fc2(x)
72
+ x = self.relu(x)
73
+ x = self.bn2(x)
74
+ x = self.dropout2(x)
75
+
76
+ # Layer 3
77
+ x = self.fc3(x)
78
+ x = self.relu(x)
79
+ x = self.bn3(x)
80
+ x = self.dropout3(x)
81
+
82
+ # Output layer
83
+ x = self.fc4(x)
84
+ x = self.sigmoid(x)
85
+
86
+ return x
87
+
88
+ def load_tbx5_embeddings_with_rc_from_csv(embeddings_dir, rc_embeddings_dir, processed_data_dir):
89
+ """
90
+ Load TBX5 embeddings using train/val/test splits from processed_data_new CSV files.
91
+
92
+ Args:
93
+ embeddings_dir: Directory containing forward embeddings
94
+ rc_embeddings_dir: Directory containing reverse complement embeddings
95
+ processed_data_dir: Directory containing train/val/test CSV files
96
+
97
+ Returns:
98
+ train/val/test data splits with combined embeddings
99
+ """
100
+ print(f"Loading data using CSV splits from: {processed_data_dir}")
101
+ print(f"Loading forward embeddings from: {embeddings_dir}")
102
+ print(f"Loading reverse complement embeddings from: {rc_embeddings_dir}")
103
+
104
+ # Load CSV files
105
+ train_df = pd.read_csv(os.path.join(processed_data_dir, 'train_tbx5_data_new.csv'))
106
+ val_df = pd.read_csv(os.path.join(processed_data_dir, 'val_tbx5_data_new.csv'))
107
+ test_df = pd.read_csv(os.path.join(processed_data_dir, 'test_tbx5_data_new.csv'))
108
+
109
+ print(f"Train samples: {len(train_df)}")
110
+ print(f"Val samples: {len(val_df)}")
111
+ print(f"Test samples: {len(test_df)}")
112
+
113
+ def load_embeddings_for_split(df, embeddings_dir, rc_embeddings_dir):
114
+ """Load embeddings for a specific split."""
115
+ all_embeddings = []
116
+ all_labels = []
117
+ all_starts = []
118
+ all_ends = []
119
+ all_tbx5_scores = []
120
+ all_chromosomes = []
121
+
122
+ total_samples = len(df)
123
+ found_samples = 0
124
+ missing_files = 0
125
+ missing_samples = 0
126
+
127
+ # Keep track of loaded chromosome data to avoid reloading
128
+ loaded_chrom_data = {}
129
+
130
+ # Process samples in original order to maintain sequence
131
+ for idx, row in df.iterrows():
132
+ chrom_num = row['chromosome']
133
+ chrom = f"chr{chrom_num}"
134
+ start = row['start']
135
+ end = row['end']
136
+ label = row['label']
137
+ tbx5_score = row['tbx5_score']
138
+
139
+ # Load chromosome data if not already loaded
140
+ if chrom not in loaded_chrom_data:
141
+ forward_file = os.path.join(embeddings_dir, f"{chrom}_tbx5_embeddings_arrays.npz")
142
+ rc_file = os.path.join(rc_embeddings_dir, f"{chrom}_tbx5_embeddings_rc_arrays.npz")
143
+
144
+ if os.path.exists(forward_file) and os.path.exists(rc_file):
145
+ print(f" Loading {chrom}...")
146
+ forward_data = np.load(forward_file)
147
+ rc_data = np.load(rc_file)
148
+
149
+ loaded_chrom_data[chrom] = {
150
+ 'forward_embeddings': forward_data['embeddings'],
151
+ 'forward_starts': forward_data['starts'],
152
+ 'forward_ends': forward_data['ends'],
153
+ 'forward_tbx5_scores': forward_data['tbx5_scores'],
154
+ 'rc_embeddings': rc_data['embeddings'],
155
+ 'rc_starts': rc_data['starts'],
156
+ 'rc_ends': rc_data['ends'],
157
+ 'rc_tbx5_scores': rc_data['tbx5_scores']
158
+ }
159
+ else:
160
+ print(f" Warning: Missing embedding files for {chrom}")
161
+ loaded_chrom_data[chrom] = None
162
+ missing_files += 1
163
+ continue
164
+
165
+ # Skip if chromosome data not available
166
+ if loaded_chrom_data[chrom] is None:
167
+ missing_samples += 1
168
+ continue
169
+
170
+ chrom_data = loaded_chrom_data[chrom]
171
+ forward_starts = chrom_data['forward_starts']
172
+ forward_embeddings = chrom_data['forward_embeddings']
173
+ rc_embeddings = chrom_data['rc_embeddings']
174
+
175
+ # Find matching sample in embeddings (use chromosome and start only)
176
+ mask = (forward_starts == start)
177
+ if np.any(mask):
178
+ # If multiple matches, take the first one
179
+ emb_idx = np.where(mask)[0][0]
180
+
181
+ # Get embeddings
182
+ forward_emb = forward_embeddings[emb_idx]
183
+ rc_emb = rc_embeddings[emb_idx]
184
+
185
+ # Combine embeddings
186
+ combined_emb = np.concatenate([forward_emb, rc_emb])
187
+
188
+ all_embeddings.append(combined_emb)
189
+ all_labels.append(label)
190
+ all_starts.append(start)
191
+ all_ends.append(end)
192
+ all_tbx5_scores.append(tbx5_score)
193
+ all_chromosomes.append(chrom)
194
+
195
+ found_samples += 1
196
+ else:
197
+ missing_samples += 1
198
+ # Skip missing samples instead of adding zeros
199
+ continue
200
+
201
+ print(f" Summary: {found_samples}/{total_samples} samples loaded")
202
+ print(f" Missing files: {missing_files} samples")
203
+ print(f" Missing embeddings: {missing_samples} samples")
204
+
205
+ return (
206
+ np.array(all_embeddings),
207
+ np.array(all_labels),
208
+ np.array(all_starts),
209
+ np.array(all_ends),
210
+ np.array(all_tbx5_scores),
211
+ all_chromosomes
212
+ )
213
+
214
+ # Load data for each split
215
+ print("Loading train data...")
216
+ X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train = load_embeddings_for_split(
217
+ train_df, embeddings_dir, rc_embeddings_dir
218
+ )
219
+
220
+ print("Loading validation data...")
221
+ X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val = load_embeddings_for_split(
222
+ val_df, embeddings_dir, rc_embeddings_dir
223
+ )
224
+
225
+ print("Loading test data...")
226
+ X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test = load_embeddings_for_split(
227
+ test_df, embeddings_dir, rc_embeddings_dir
228
+ )
229
+
230
+ print(f"\nLoaded data:")
231
+ print(f"Train: {len(X_train)} samples")
232
+ print(f"Val: {len(X_val)} samples")
233
+ print(f"Test: {len(X_test)} samples")
234
+ print(f"Embedding dimension: {X_train.shape[1]}")
235
+ print(f"Train positive samples: {np.sum(y_train)}")
236
+ print(f"Val positive samples: {np.sum(y_val)}")
237
+ print(f"Test positive samples: {np.sum(y_test)}")
238
+
239
+ # Check if we have enough data
240
+ if len(X_train) == 0:
241
+ raise ValueError("No training data loaded! Check embedding files and CSV data.")
242
+ if len(X_val) == 0:
243
+ raise ValueError("No validation data loaded! Check embedding files and CSV data.")
244
+ if len(X_test) == 0:
245
+ raise ValueError("No test data loaded! Check embedding files and CSV data.")
246
+
247
+ print(f"\nData quality check:")
248
+ print(f"Train positive ratio: {np.mean(y_train):.3f}")
249
+ print(f"Val positive ratio: {np.mean(y_val):.3f}")
250
+ print(f"Test positive ratio: {np.mean(y_test):.3f}")
251
+
252
+ metadata = {
253
+ "total_samples": len(X_train) + len(X_val) + len(X_test),
254
+ "embedding_dim": X_train.shape[1],
255
+ "train_samples": len(X_train),
256
+ "val_samples": len(X_val),
257
+ "test_samples": len(X_test),
258
+ "train_positive": int(np.sum(y_train)),
259
+ "val_positive": int(np.sum(y_val)),
260
+ "test_positive": int(np.sum(y_test)),
261
+ "sequence_type": "forward_and_reverse_complement"
262
+ }
263
+
264
+ return (
265
+ X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train,
266
+ X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val,
267
+ X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test,
268
+ metadata
269
+ )
270
+
271
+ def prepare_data_with_scaling(X_train, X_val, X_test, y_train, y_val, y_test):
272
+ """
273
+ Scale the features for train/val/test splits.
274
+ """
275
+ print("Scaling features...")
276
+
277
+ # Scale features
278
+ scaler = StandardScaler()
279
+ X_train_scaled = scaler.fit_transform(X_train)
280
+ X_val_scaled = scaler.transform(X_val)
281
+ X_test_scaled = scaler.transform(X_test)
282
+
283
+ return X_train_scaled, X_val_scaled, X_test_scaled, scaler
284
+
285
+ def train_model(
286
+ model,
287
+ train_loader,
288
+ val_loader,
289
+ test_loader,
290
+ device,
291
+ output_dir,
292
+ num_epochs=500,
293
+ learning_rate=1e-4,
294
+ patience=100,
295
+ lr_patience=20,
296
+ min_lr=1e-6,
297
+ gradient_clip=1.0,
298
+ save_every=5,
299
+ ):
300
+ """
301
+ Train the model with specified optimization settings.
302
+ """
303
+ print(f"Training model with learning rate {learning_rate}")
304
+ print(f"Early stopping patience: {patience}")
305
+ print(f"Learning rate reduction patience: {lr_patience}")
306
+
307
+ # Loss and optimizer
308
+ criterion = nn.BCELoss()
309
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
310
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
311
+ optimizer, mode='min', factor=0.5, patience=lr_patience, min_lr=min_lr
312
+ )
313
+
314
+ # Training history
315
+ train_losses = []
316
+ val_losses = []
317
+ val_aucs = []
318
+ test_results_by_epoch = {} # Store test results for each saved epoch
319
+ best_val_auc = 0.0
320
+ best_epoch = 0
321
+ epochs_without_improvement = 0
322
+
323
+ print(f"Starting training for {num_epochs} epochs...")
324
+
325
+ for epoch in range(num_epochs):
326
+ # Training phase
327
+ model.train()
328
+ train_loss = 0.0
329
+ train_correct = 0
330
+ train_total = 0
331
+
332
+ for batch_embeddings, batch_labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
333
+ batch_embeddings = batch_embeddings.to(device)
334
+ batch_labels = batch_labels.to(device).float()
335
+
336
+ optimizer.zero_grad()
337
+ outputs = model(batch_embeddings).squeeze()
338
+ loss = criterion(outputs, batch_labels)
339
+ loss.backward()
340
+
341
+ # Gradient clipping
342
+ torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
343
+
344
+ optimizer.step()
345
+
346
+ train_loss += loss.item()
347
+ predicted = (outputs > 0.5).float()
348
+ train_correct += (predicted == batch_labels).sum().item()
349
+ train_total += batch_labels.size(0)
350
+
351
+ train_loss /= len(train_loader)
352
+ train_acc = train_correct / train_total
353
+
354
+ # Validation phase
355
+ model.eval()
356
+ val_loss = 0.0
357
+ val_correct = 0
358
+ val_total = 0
359
+ val_predictions = []
360
+ val_labels = []
361
+
362
+ with torch.no_grad():
363
+ for batch_embeddings, batch_labels in val_loader:
364
+ batch_embeddings = batch_embeddings.to(device)
365
+ batch_labels = batch_labels.to(device).float()
366
+
367
+ outputs = model(batch_embeddings).squeeze()
368
+ loss = criterion(outputs, batch_labels)
369
+
370
+ val_loss += loss.item()
371
+ predicted = (outputs > 0.5).float()
372
+ val_correct += (predicted == batch_labels).sum().item()
373
+ val_total += batch_labels.size(0)
374
+
375
+ val_predictions.extend(outputs.cpu().numpy())
376
+ val_labels.extend(batch_labels.cpu().numpy())
377
+
378
+ val_loss /= len(val_loader)
379
+ val_acc = val_correct / val_total
380
+ val_auc = roc_auc_score(val_labels, val_predictions)
381
+
382
+ # Update learning rate
383
+ scheduler.step(val_loss)
384
+
385
+ # Store history
386
+ train_losses.append(train_loss)
387
+ val_losses.append(val_loss)
388
+ val_aucs.append(val_auc)
389
+
390
+ # Check for improvement
391
+ if val_auc > best_val_auc:
392
+ best_val_auc = val_auc
393
+ best_epoch = epoch
394
+ epochs_without_improvement = 0
395
+
396
+ # Save best model
397
+ torch.save({
398
+ 'model_state_dict': model.state_dict(),
399
+ 'optimizer_state_dict': optimizer.state_dict(),
400
+ 'epoch': epoch,
401
+ 'val_auc': val_auc,
402
+ 'val_loss': val_loss,
403
+ 'input_dim': model.fc1.in_features,
404
+ }, os.path.join(output_dir, 'best_model.pth'))
405
+
406
+ print(f"New best model saved! Val AUC: {val_auc:.4f}")
407
+ else:
408
+ epochs_without_improvement += 1
409
+
410
+ # Save model and evaluate every N epochs
411
+ if (epoch + 1) % save_every == 0 or epoch == 0:
412
+ # Save model state
413
+ epoch_model_path = os.path.join(output_dir, f"model_epoch_{epoch+1}.pth")
414
+ torch.save({
415
+ 'model_state_dict': model.state_dict(),
416
+ 'optimizer_state_dict': optimizer.state_dict(),
417
+ 'epoch': epoch + 1,
418
+ 'val_auc': val_auc,
419
+ 'val_loss': val_loss,
420
+ 'input_dim': model.fc1.in_features,
421
+ }, epoch_model_path)
422
+
423
+ # Evaluate on test set
424
+ test_results = evaluate_model_simple(model, test_loader, device)
425
+ test_results_by_epoch[epoch + 1] = test_results
426
+
427
+ print(f"Epoch {epoch+1:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
428
+ f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}, "
429
+ f"Test AUC: {test_results['auc']:.4f}")
430
+
431
+ # Print progress for other epochs
432
+ elif (epoch + 1) % 10 == 0:
433
+ current_lr = optimizer.param_groups[0]['lr']
434
+ print(f"Epoch {epoch+1:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
435
+ f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}, "
436
+ f"LR: {current_lr:.2e}")
437
+
438
+ # Early stopping
439
+ if epochs_without_improvement >= patience:
440
+ print(f"Early stopping at epoch {epoch+1} (no improvement for {patience} epochs)")
441
+ break
442
+
443
+ print(f"Training completed! Best validation AUC: {best_val_auc:.4f} at epoch {best_epoch+1}")
444
+
445
+ # Load best model for testing
446
+ checkpoint = torch.load(os.path.join(output_dir, 'best_model.pth'), map_location=device, weights_only=False)
447
+ model.load_state_dict(checkpoint['model_state_dict'])
448
+
449
+ # Test evaluation
450
+ model.eval()
451
+ test_predictions = []
452
+ test_labels = []
453
+ test_loss = 0.0
454
+ test_correct = 0
455
+ test_total = 0
456
+
457
+ with torch.no_grad():
458
+ for batch_embeddings, batch_labels in test_loader:
459
+ batch_embeddings = batch_embeddings.to(device)
460
+ batch_labels = batch_labels.to(device).float()
461
+
462
+ outputs = model(batch_embeddings).squeeze()
463
+ loss = criterion(outputs, batch_labels)
464
+
465
+ test_loss += loss.item()
466
+ predicted = (outputs > 0.5).float()
467
+ test_correct += (predicted == batch_labels).sum().item()
468
+ test_total += batch_labels.size(0)
469
+
470
+ test_predictions.extend(outputs.cpu().numpy())
471
+ test_labels.extend(batch_labels.cpu().numpy())
472
+
473
+ test_loss /= len(test_loader)
474
+ test_acc = test_correct / test_total
475
+ test_auc = roc_auc_score(test_labels, test_predictions)
476
+
477
+ # Calculate additional metrics
478
+ precision, recall, f1, _ = precision_recall_fscore_support(test_labels, [1 if p > 0.5 else 0 for p in test_predictions], average='binary')
479
+ cm = confusion_matrix(test_labels, [1 if p > 0.5 else 0 for p in test_predictions])
480
+
481
+ # Save results
482
+ results = {
483
+ 'test_auc': float(test_auc),
484
+ 'test_accuracy': float(test_acc),
485
+ 'test_loss': float(test_loss),
486
+ 'test_precision': float(precision),
487
+ 'test_recall': float(recall),
488
+ 'test_f1': float(f1),
489
+ 'confusion_matrix': cm.tolist(),
490
+ 'best_val_auc': float(best_val_auc),
491
+ 'best_epoch': int(best_epoch + 1),
492
+ 'total_epochs': int(epoch + 1),
493
+ 'sequence_type': 'forward_and_reverse_complement',
494
+ 'predictions': [float(p) for p in test_predictions],
495
+ 'labels': [float(l) for l in test_labels]
496
+ }
497
+
498
+ with open(os.path.join(output_dir, 'test_results.json'), 'w') as f:
499
+ json.dump(results, f, indent=2)
500
+
501
+ # Save training history
502
+ history = {
503
+ 'train_losses': train_losses,
504
+ 'val_losses': val_losses,
505
+ 'val_aucs': val_aucs,
506
+ 'best_epoch': best_epoch + 1,
507
+ 'best_val_auc': best_val_auc
508
+ }
509
+
510
+ with open(os.path.join(output_dir, 'training_history.json'), 'w') as f:
511
+ json.dump(history, f, indent=2)
512
+
513
+ # Plot training history
514
+ plt.figure(figsize=(15, 5))
515
+
516
+ plt.subplot(1, 3, 1)
517
+ plt.plot(train_losses, label='Train Loss')
518
+ plt.plot(val_losses, label='Val Loss')
519
+ plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})')
520
+ plt.xlabel('Epoch')
521
+ plt.ylabel('Loss')
522
+ plt.title('Training and Validation Loss')
523
+ plt.legend()
524
+ plt.grid(True, alpha=0.3)
525
+
526
+ plt.subplot(1, 3, 2)
527
+ plt.plot(val_aucs, label='Val AUC', color='green')
528
+ plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})')
529
+ plt.xlabel('Epoch')
530
+ plt.ylabel('AUC')
531
+ plt.title('Validation AUC')
532
+ plt.legend()
533
+ plt.grid(True, alpha=0.3)
534
+
535
+ plt.subplot(1, 3, 3)
536
+ plt.plot(range(len(train_losses)), train_losses, label='Train Loss')
537
+ plt.plot(range(len(val_losses)), val_losses, label='Val Loss')
538
+ plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})')
539
+ plt.xlabel('Epoch')
540
+ plt.ylabel('Loss')
541
+ plt.title('Loss Comparison')
542
+ plt.legend()
543
+ plt.grid(True, alpha=0.3)
544
+
545
+ plt.tight_layout()
546
+ plt.savefig(os.path.join(output_dir, 'training_history.png'), dpi=300, bbox_inches='tight')
547
+ plt.close()
548
+
549
+ print(f"\n=== Test Results ===")
550
+ print(f"Test AUC: {test_auc:.4f}")
551
+ print(f"Test Accuracy: {test_acc:.4f}")
552
+ print(f"Test Precision: {precision:.4f}")
553
+ print(f"Test Recall: {recall:.4f}")
554
+ print(f"Test F1: {f1:.4f}")
555
+ print(f"Confusion Matrix:\n{cm}")
556
+
557
+ return results, test_results_by_epoch
558
+
559
+ def evaluate_model_simple(model, test_loader, device):
560
+ """Simple evaluation that returns just basic metrics."""
561
+ model.eval()
562
+ test_preds = []
563
+ test_labels = []
564
+
565
+ with torch.no_grad():
566
+ for batch_X, batch_y in test_loader:
567
+ batch_X = batch_X.to(device)
568
+ outputs = model(batch_X).squeeze()
569
+ test_preds.extend(outputs.cpu().numpy())
570
+ test_labels.extend(batch_y.numpy())
571
+
572
+ test_preds = np.array(test_preds)
573
+ test_labels = np.array(test_labels)
574
+
575
+ # Calculate basic metrics
576
+ test_auc = roc_auc_score(test_labels, test_preds)
577
+ test_preds_binary = (test_preds > 0.5).astype(int)
578
+ test_acc = accuracy_score(test_labels, test_preds_binary)
579
+ precision, recall, f1, _ = precision_recall_fscore_support(
580
+ test_labels, test_preds_binary, average="binary"
581
+ )
582
+
583
+ return {
584
+ "auc": test_auc,
585
+ "accuracy": test_acc,
586
+ "precision": precision,
587
+ "recall": recall,
588
+ "f1": f1,
589
+ }
590
+
591
+ def save_epoch_analysis(test_results_by_epoch, output_dir):
592
+ """Save analysis of results across epochs."""
593
+ epochs = sorted(test_results_by_epoch.keys())
594
+
595
+ # Create summary DataFrame
596
+ summary_data = []
597
+ for epoch in epochs:
598
+ results = test_results_by_epoch[epoch]
599
+ summary_data.append(
600
+ {
601
+ "epoch": epoch,
602
+ "test_auc": results["auc"],
603
+ "test_accuracy": results["accuracy"],
604
+ "test_precision": results["precision"],
605
+ "test_recall": results["recall"],
606
+ "test_f1": results["f1"],
607
+ }
608
+ )
609
+
610
+ df = pd.DataFrame(summary_data)
611
+
612
+ # Save to CSV
613
+ csv_path = os.path.join(output_dir, "epoch_analysis.csv")
614
+ df.to_csv(csv_path, index=False)
615
+
616
+ # Save to JSON
617
+ json_path = os.path.join(output_dir, "epoch_analysis.json")
618
+ with open(json_path, "w") as f:
619
+ json.dump(test_results_by_epoch, f, indent=2)
620
+
621
+ # Print summary
622
+ print("\n" + "=" * 50)
623
+ print("EPOCH-WISE TEST PERFORMANCE ANALYSIS")
624
+ print("=" * 50)
625
+
626
+ best_auc_epoch = df.loc[df["test_auc"].idxmax()]
627
+ best_f1_epoch = df.loc[df["test_f1"].idxmax()]
628
+
629
+ print(
630
+ f"Best Test AUC: {best_auc_epoch['test_auc']:.4f} at Epoch {best_auc_epoch['epoch']}"
631
+ )
632
+ print(
633
+ f"Best Test F1: {best_f1_epoch['test_f1']:.4f} at Epoch {best_f1_epoch['epoch']}"
634
+ )
635
+ print()
636
+ print("Epoch-wise Performance:")
637
+ print(df.to_string(index=False, float_format="%.4f"))
638
+
639
+ # Check for overfitting
640
+ if len(epochs) >= 2:
641
+ auc_trend = df["test_auc"].iloc[-1] - df["test_auc"].iloc[0]
642
+ if auc_trend < -0.01: # Significant decrease
643
+ print(
644
+ f"\n⚠️ OVERFITTING DETECTED: Test AUC decreased by {abs(auc_trend):.4f} from epoch {epochs[0]} to {epochs[-1]}"
645
+ )
646
+ elif auc_trend > 0.01:
647
+ print(
648
+ f"\n✅ GOOD TRAINING: Test AUC improved by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}"
649
+ )
650
+ else:
651
+ print(
652
+ f"\n📊 STABLE TRAINING: Test AUC changed by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}"
653
+ )
654
+
655
+ return df
656
+
657
+ def plot_training_history(train_losses, val_losses, val_aucs, output_dir):
658
+ """Plot training history."""
659
+ fig, axes = plt.subplots(1, 2, figsize=(12, 4))
660
+
661
+ # Loss plot
662
+ axes[0].plot(train_losses, label="Train Loss")
663
+ axes[0].plot(val_losses, label="Val Loss")
664
+ axes[0].set_xlabel("Epoch")
665
+ axes[0].set_ylabel("Loss")
666
+ axes[0].set_title("Training and Validation Loss")
667
+ axes[0].legend()
668
+ axes[0].grid(True, alpha=0.3)
669
+
670
+ # AUC plot
671
+ axes[1].plot(val_aucs, label="Val AUC", color="green")
672
+ axes[1].set_xlabel("Epoch")
673
+ axes[1].set_ylabel("AUC")
674
+ axes[1].set_title("Validation AUC")
675
+ axes[1].legend()
676
+ axes[1].grid(True, alpha=0.3)
677
+
678
+ plt.tight_layout()
679
+ plt.savefig(os.path.join(output_dir, "training_history.png"), dpi=100)
680
+ plt.close()
681
+
682
+ def plot_confusion_matrix(cm, output_dir):
683
+ """Plot confusion matrix."""
684
+ plt.figure(figsize=(6, 5))
685
+ sns.heatmap(
686
+ cm,
687
+ annot=True,
688
+ fmt="d",
689
+ cmap="Blues",
690
+ xticklabels=["Non-binding", "TBX5-binding"],
691
+ yticklabels=["Non-binding", "TBX5-binding"],
692
+ )
693
+ plt.title("Confusion Matrix")
694
+ plt.ylabel("True Label")
695
+ plt.xlabel("Predicted Label")
696
+ plt.tight_layout()
697
+ plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=100)
698
+ plt.close()
699
+
700
+ def main():
701
+ parser = argparse.ArgumentParser(description="Train TBX5 classifier with forward and reverse complement embeddings")
702
+ parser.add_argument(
703
+ "--embeddings-dir",
704
+ type=str,
705
+ default="tbx5_embeddings",
706
+ help="Directory containing forward embeddings (default: tbx5_embeddings)",
707
+ )
708
+ parser.add_argument(
709
+ "--rc-embeddings-dir",
710
+ type=str,
711
+ default="tbx5_embeddings_reverse_complement",
712
+ help="Directory containing reverse complement embeddings (default: tbx5_embeddings_reverse_complement)",
713
+ )
714
+ parser.add_argument(
715
+ "--output-dir",
716
+ type=str,
717
+ default="result_with_rc",
718
+ help="Output directory for results (default: result_with_rc)",
719
+ )
720
+ parser.add_argument(
721
+ "--batch-size",
722
+ type=int,
723
+ default=32,
724
+ help="Batch size for training (default: 32)",
725
+ )
726
+ parser.add_argument(
727
+ "--num-epochs",
728
+ type=int,
729
+ default=500,
730
+ help="Number of training epochs (default: 500)",
731
+ )
732
+ parser.add_argument(
733
+ "--learning-rate",
734
+ type=float,
735
+ default=1e-4,
736
+ help="Learning rate (default: 1e-4)",
737
+ )
738
+ parser.add_argument(
739
+ "--patience",
740
+ type=int,
741
+ default=100,
742
+ help="Early stopping patience (default: 100)",
743
+ )
744
+ parser.add_argument(
745
+ "--dropout-rate",
746
+ type=float,
747
+ default=0.5,
748
+ help="Dropout rate (default: 0.5)",
749
+ )
750
+ parser.add_argument(
751
+ "--processed-data-dir",
752
+ type=str,
753
+ default="processed_data_new",
754
+ help="Directory containing train/val/test CSV files (default: processed_data_new)",
755
+ )
756
+
757
+ args = parser.parse_args()
758
+
759
+ # Create output directory
760
+ os.makedirs(args.output_dir, exist_ok=True)
761
+
762
+ # Set device
763
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
764
+ print(f"Using device: {device}")
765
+
766
+ # Load embeddings using CSV splits
767
+ print("Loading combined embeddings using CSV splits...")
768
+ (X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train,
769
+ X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val,
770
+ X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test,
771
+ metadata) = load_tbx5_embeddings_with_rc_from_csv(
772
+ args.embeddings_dir, args.rc_embeddings_dir, args.processed_data_dir
773
+ )
774
+
775
+ # Save metadata
776
+ with open(os.path.join(args.output_dir, 'metadata.json'), 'w') as f:
777
+ json.dump(metadata, f, indent=2)
778
+
779
+ # Scale features
780
+ X_train_scaled, X_val_scaled, X_test_scaled, scaler = prepare_data_with_scaling(
781
+ X_train, X_val, X_test, y_train, y_val, y_test
782
+ )
783
+
784
+ # Save scaler
785
+ with open(os.path.join(args.output_dir, 'scaler.pkl'), 'wb') as f:
786
+ pickle.dump(scaler, f)
787
+
788
+ # Create data loaders
789
+ train_dataset = TensorDataset(torch.FloatTensor(X_train_scaled), torch.LongTensor(y_train))
790
+ val_dataset = TensorDataset(torch.FloatTensor(X_val_scaled), torch.LongTensor(y_val))
791
+ test_dataset = TensorDataset(torch.FloatTensor(X_test_scaled), torch.LongTensor(y_test))
792
+
793
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
794
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
795
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
796
+
797
+ # Initialize model
798
+ input_dim = X_train_scaled.shape[1]
799
+ print(f"Input dimension: {input_dim}")
800
+
801
+ model = TBX5ClassifierWithRC(input_dim=input_dim, dropout_rate=args.dropout_rate).to(device)
802
+
803
+ # Print model architecture
804
+ total_params = sum(p.numel() for p in model.parameters())
805
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
806
+ print(f"Total parameters: {total_params:,}")
807
+ print(f"Trainable parameters: {trainable_params:,}")
808
+
809
+ # Train model
810
+ results, test_results_by_epoch = train_model(
811
+ model, train_loader, val_loader, test_loader, device, args.output_dir,
812
+ num_epochs=args.num_epochs,
813
+ learning_rate=args.learning_rate,
814
+ patience=args.patience,
815
+ )
816
+
817
+ # Save epoch analysis
818
+ save_epoch_analysis(test_results_by_epoch, args.output_dir)
819
+
820
+ # Plot results
821
+ plot_training_history(results.get('train_losses', []), results.get('val_losses', []), results.get('val_aucs', []), args.output_dir)
822
+ plot_confusion_matrix(np.array(results['confusion_matrix']), args.output_dir)
823
+
824
+ print(f"\nTraining completed! Results saved to {args.output_dir}")
825
+ print(f"Best test AUC: {results['test_auc']:.4f}")
826
+
827
+ if __name__ == "__main__":
828
+ main()
829
+
830
+
831
+
832
+
833
+
834
+ # Check for overfitting
835
+ if len(epochs) >= 2:
836
+ auc_trend = df["test_auc"].iloc[-1] - df["test_auc"].iloc[0]
837
+ if auc_trend < -0.01: # Significant decrease
838
+ print(
839
+ f"\n⚠️ OVERFITTING DETECTED: Test AUC decreased by {abs(auc_trend):.4f} from epoch {epochs[0]} to {epochs[-1]}"
840
+ )
841
+ elif auc_trend > 0.01:
842
+ print(
843
+ f"\n✅ GOOD TRAINING: Test AUC improved by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}"
844
+ )
845
+ else:
846
+ print(
847
+ f"\n📊 STABLE TRAINING: Test AUC changed by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}"
848
+ )
849
+
850
+ return df
851
+
852
+ def plot_training_history(train_losses, val_losses, val_aucs, output_dir):
853
+ """Plot training history."""
854
+ fig, axes = plt.subplots(1, 2, figsize=(12, 4))
855
+
856
+ # Loss plot
857
+ axes[0].plot(train_losses, label="Train Loss")
858
+ axes[0].plot(val_losses, label="Val Loss")
859
+ axes[0].set_xlabel("Epoch")
860
+ axes[0].set_ylabel("Loss")
861
+ axes[0].set_title("Training and Validation Loss")
862
+ axes[0].legend()
863
+ axes[0].grid(True, alpha=0.3)
864
+
865
+ # AUC plot
866
+ axes[1].plot(val_aucs, label="Val AUC", color="green")
867
+ axes[1].set_xlabel("Epoch")
868
+ axes[1].set_ylabel("AUC")
869
+ axes[1].set_title("Validation AUC")
870
+ axes[1].legend()
871
+ axes[1].grid(True, alpha=0.3)
872
+
873
+ plt.tight_layout()
874
+ plt.savefig(os.path.join(output_dir, "training_history.png"), dpi=100)
875
+ plt.close()
876
+
877
+ def plot_confusion_matrix(cm, output_dir):
878
+ """Plot confusion matrix."""
879
+ plt.figure(figsize=(6, 5))
880
+ sns.heatmap(
881
+ cm,
882
+ annot=True,
883
+ fmt="d",
884
+ cmap="Blues",
885
+ xticklabels=["Non-binding", "TBX5-binding"],
886
+ yticklabels=["Non-binding", "TBX5-binding"],
887
+ )
888
+ plt.title("Confusion Matrix")
889
+ plt.ylabel("True Label")
890
+ plt.xlabel("Predicted Label")
891
+ plt.tight_layout()
892
+ plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=100)
893
+ plt.close()
894
+
895
+ def main():
896
+ parser = argparse.ArgumentParser(description="Train TBX5 classifier with forward and reverse complement embeddings")
897
+ parser.add_argument(
898
+ "--embeddings-dir",
899
+ type=str,
900
+ default="tbx5_embeddings",
901
+ help="Directory containing forward embeddings (default: tbx5_embeddings)",
902
+ )
903
+ parser.add_argument(
904
+ "--rc-embeddings-dir",
905
+ type=str,
906
+ default="tbx5_embeddings_reverse_complement",
907
+ help="Directory containing reverse complement embeddings (default: tbx5_embeddings_reverse_complement)",
908
+ )
909
+ parser.add_argument(
910
+ "--output-dir",
911
+ type=str,
912
+ default="result_with_rc",
913
+ help="Output directory for results (default: result_with_rc)",
914
+ )
915
+ parser.add_argument(
916
+ "--batch-size",
917
+ type=int,
918
+ default=32,
919
+ help="Batch size for training (default: 32)",
920
+ )
921
+ parser.add_argument(
922
+ "--num-epochs",
923
+ type=int,
924
+ default=500,
925
+ help="Number of training epochs (default: 500)",
926
+ )
927
+ parser.add_argument(
928
+ "--learning-rate",
929
+ type=float,
930
+ default=1e-4,
931
+ help="Learning rate (default: 1e-4)",
932
+ )
933
+ parser.add_argument(
934
+ "--patience",
935
+ type=int,
936
+ default=100,
937
+ help="Early stopping patience (default: 100)",
938
+ )
939
+ parser.add_argument(
940
+ "--dropout-rate",
941
+ type=float,
942
+ default=0.5,
943
+ help="Dropout rate (default: 0.5)",
944
+ )
945
+ parser.add_argument(
946
+ "--processed-data-dir",
947
+ type=str,
948
+ default="processed_data_new",
949
+ help="Directory containing train/val/test CSV files (default: processed_data_new)",
950
+ )
951
+
952
+ args = parser.parse_args()
953
+
954
+ # Create output directory
955
+ os.makedirs(args.output_dir, exist_ok=True)
956
+
957
+ # Set device
958
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
959
+ print(f"Using device: {device}")
960
+
961
+ # Load embeddings using CSV splits
962
+ print("Loading combined embeddings using CSV splits...")
963
+ (X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train,
964
+ X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val,
965
+ X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test,
966
+ metadata) = load_tbx5_embeddings_with_rc_from_csv(
967
+ args.embeddings_dir, args.rc_embeddings_dir, args.processed_data_dir
968
+ )
969
+
970
+ # Save metadata
971
+ with open(os.path.join(args.output_dir, 'metadata.json'), 'w') as f:
972
+ json.dump(metadata, f, indent=2)
973
+
974
+ # Scale features
975
+ X_train_scaled, X_val_scaled, X_test_scaled, scaler = prepare_data_with_scaling(
976
+ X_train, X_val, X_test, y_train, y_val, y_test
977
+ )
978
+
979
+ # Save scaler
980
+ with open(os.path.join(args.output_dir, 'scaler.pkl'), 'wb') as f:
981
+ pickle.dump(scaler, f)
982
+
983
+ # Create data loaders
984
+ train_dataset = TensorDataset(torch.FloatTensor(X_train_scaled), torch.LongTensor(y_train))
985
+ val_dataset = TensorDataset(torch.FloatTensor(X_val_scaled), torch.LongTensor(y_val))
986
+ test_dataset = TensorDataset(torch.FloatTensor(X_test_scaled), torch.LongTensor(y_test))
987
+
988
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
989
+ val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
990
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
991
+
992
+ # Initialize model
993
+ input_dim = X_train_scaled.shape[1]
994
+ print(f"Input dimension: {input_dim}")
995
+
996
+ model = TBX5ClassifierWithRC(input_dim=input_dim, dropout_rate=args.dropout_rate).to(device)
997
+
998
+ # Print model architecture
999
+ total_params = sum(p.numel() for p in model.parameters())
1000
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1001
+ print(f"Total parameters: {total_params:,}")
1002
+ print(f"Trainable parameters: {trainable_params:,}")
1003
+
1004
+ # Train model
1005
+ results, test_results_by_epoch = train_model(
1006
+ model, train_loader, val_loader, test_loader, device, args.output_dir,
1007
+ num_epochs=args.num_epochs,
1008
+ learning_rate=args.learning_rate,
1009
+ patience=args.patience,
1010
+ )
1011
+
1012
+ # Save epoch analysis
1013
+ save_epoch_analysis(test_results_by_epoch, args.output_dir)
1014
+
1015
+ # Plot results
1016
+ plot_training_history(results.get('train_losses', []), results.get('val_losses', []), results.get('val_aucs', []), args.output_dir)
1017
+ plot_confusion_matrix(np.array(results['confusion_matrix']), args.output_dir)
1018
+
1019
+ print(f"\nTraining completed! Results saved to {args.output_dir}")
1020
+ print(f"Best test AUC: {results['test_auc']:.4f}")
1021
+
1022
+ if __name__ == "__main__":
1023
+ main()
1024
+
1025
+
1026
+
1027
+