|
|
|
|
|
""" |
|
|
Train TBX5 classifier using both forward and reverse complement embeddings. |
|
|
This script combines embeddings from both strands to improve classification accuracy. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader, TensorDataset |
|
|
from sklearn.model_selection import train_test_split |
|
|
from sklearn.preprocessing import StandardScaler |
|
|
from sklearn.metrics import ( |
|
|
roc_auc_score, |
|
|
accuracy_score, |
|
|
precision_recall_fscore_support, |
|
|
confusion_matrix, |
|
|
) |
|
|
import json |
|
|
import pickle |
|
|
from tqdm import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'finetuning')) |
|
|
|
|
|
class TBX5ClassifierWithRC(nn.Module): |
|
|
""" |
|
|
3-layer feedforward neural network for TBX5 binding site classification |
|
|
using both forward and reverse complement embeddings. |
|
|
Architecture: |
|
|
- Input (8192 dimensions: 4096 forward + 4096 reverse complement) -> 2048 -> 512 -> 128 -> 1 (sigmoid) |
|
|
- ReLU activation, BatchNorm, Dropout(0.5) after each hidden layer |
|
|
""" |
|
|
|
|
|
def __init__(self, input_dim=8192, dropout_rate=0.5): |
|
|
super(TBX5ClassifierWithRC, self).__init__() |
|
|
|
|
|
self.fc1 = nn.Linear(input_dim, 2048) |
|
|
self.bn1 = nn.BatchNorm1d(2048) |
|
|
self.dropout1 = nn.Dropout(dropout_rate) |
|
|
|
|
|
self.fc2 = nn.Linear(2048, 512) |
|
|
self.bn2 = nn.BatchNorm1d(512) |
|
|
self.dropout2 = nn.Dropout(dropout_rate) |
|
|
|
|
|
self.fc3 = nn.Linear(512, 128) |
|
|
self.bn3 = nn.BatchNorm1d(128) |
|
|
self.dropout3 = nn.Dropout(dropout_rate) |
|
|
|
|
|
self.fc4 = nn.Linear(128, 1) |
|
|
|
|
|
self.relu = nn.ReLU() |
|
|
self.sigmoid = nn.Sigmoid() |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.fc1(x) |
|
|
x = self.relu(x) |
|
|
x = self.bn1(x) |
|
|
x = self.dropout1(x) |
|
|
|
|
|
|
|
|
x = self.fc2(x) |
|
|
x = self.relu(x) |
|
|
x = self.bn2(x) |
|
|
x = self.dropout2(x) |
|
|
|
|
|
|
|
|
x = self.fc3(x) |
|
|
x = self.relu(x) |
|
|
x = self.bn3(x) |
|
|
x = self.dropout3(x) |
|
|
|
|
|
|
|
|
x = self.fc4(x) |
|
|
x = self.sigmoid(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def load_tbx5_embeddings_with_rc_from_csv(embeddings_dir, rc_embeddings_dir, processed_data_dir): |
|
|
""" |
|
|
Load TBX5 embeddings using train/val/test splits from processed_data_new CSV files. |
|
|
|
|
|
Args: |
|
|
embeddings_dir: Directory containing forward embeddings |
|
|
rc_embeddings_dir: Directory containing reverse complement embeddings |
|
|
processed_data_dir: Directory containing train/val/test CSV files |
|
|
|
|
|
Returns: |
|
|
train/val/test data splits with combined embeddings |
|
|
""" |
|
|
print(f"Loading data using CSV splits from: {processed_data_dir}") |
|
|
print(f"Loading forward embeddings from: {embeddings_dir}") |
|
|
print(f"Loading reverse complement embeddings from: {rc_embeddings_dir}") |
|
|
|
|
|
|
|
|
train_df = pd.read_csv(os.path.join(processed_data_dir, 'train_tbx5_data_new.csv')) |
|
|
val_df = pd.read_csv(os.path.join(processed_data_dir, 'val_tbx5_data_new.csv')) |
|
|
test_df = pd.read_csv(os.path.join(processed_data_dir, 'test_tbx5_data_new.csv')) |
|
|
|
|
|
print(f"Train samples: {len(train_df)}") |
|
|
print(f"Val samples: {len(val_df)}") |
|
|
print(f"Test samples: {len(test_df)}") |
|
|
|
|
|
def load_embeddings_for_split(df, embeddings_dir, rc_embeddings_dir): |
|
|
"""Load embeddings for a specific split.""" |
|
|
all_embeddings = [] |
|
|
all_labels = [] |
|
|
all_starts = [] |
|
|
all_ends = [] |
|
|
all_tbx5_scores = [] |
|
|
all_chromosomes = [] |
|
|
|
|
|
total_samples = len(df) |
|
|
found_samples = 0 |
|
|
missing_files = 0 |
|
|
missing_samples = 0 |
|
|
|
|
|
|
|
|
loaded_chrom_data = {} |
|
|
|
|
|
|
|
|
for idx, row in df.iterrows(): |
|
|
chrom_num = row['chromosome'] |
|
|
chrom = f"chr{chrom_num}" |
|
|
start = row['start'] |
|
|
end = row['end'] |
|
|
label = row['label'] |
|
|
tbx5_score = row['tbx5_score'] |
|
|
|
|
|
|
|
|
if chrom not in loaded_chrom_data: |
|
|
forward_file = os.path.join(embeddings_dir, f"{chrom}_tbx5_embeddings_arrays.npz") |
|
|
rc_file = os.path.join(rc_embeddings_dir, f"{chrom}_tbx5_embeddings_rc_arrays.npz") |
|
|
|
|
|
if os.path.exists(forward_file) and os.path.exists(rc_file): |
|
|
print(f" Loading {chrom}...") |
|
|
forward_data = np.load(forward_file) |
|
|
rc_data = np.load(rc_file) |
|
|
|
|
|
loaded_chrom_data[chrom] = { |
|
|
'forward_embeddings': forward_data['embeddings'], |
|
|
'forward_starts': forward_data['starts'], |
|
|
'forward_ends': forward_data['ends'], |
|
|
'forward_tbx5_scores': forward_data['tbx5_scores'], |
|
|
'rc_embeddings': rc_data['embeddings'], |
|
|
'rc_starts': rc_data['starts'], |
|
|
'rc_ends': rc_data['ends'], |
|
|
'rc_tbx5_scores': rc_data['tbx5_scores'] |
|
|
} |
|
|
else: |
|
|
print(f" Warning: Missing embedding files for {chrom}") |
|
|
loaded_chrom_data[chrom] = None |
|
|
missing_files += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
if loaded_chrom_data[chrom] is None: |
|
|
missing_samples += 1 |
|
|
continue |
|
|
|
|
|
chrom_data = loaded_chrom_data[chrom] |
|
|
forward_starts = chrom_data['forward_starts'] |
|
|
forward_embeddings = chrom_data['forward_embeddings'] |
|
|
rc_embeddings = chrom_data['rc_embeddings'] |
|
|
|
|
|
|
|
|
mask = (forward_starts == start) |
|
|
if np.any(mask): |
|
|
|
|
|
emb_idx = np.where(mask)[0][0] |
|
|
|
|
|
|
|
|
forward_emb = forward_embeddings[emb_idx] |
|
|
rc_emb = rc_embeddings[emb_idx] |
|
|
|
|
|
|
|
|
combined_emb = np.concatenate([forward_emb, rc_emb]) |
|
|
|
|
|
all_embeddings.append(combined_emb) |
|
|
all_labels.append(label) |
|
|
all_starts.append(start) |
|
|
all_ends.append(end) |
|
|
all_tbx5_scores.append(tbx5_score) |
|
|
all_chromosomes.append(chrom) |
|
|
|
|
|
found_samples += 1 |
|
|
else: |
|
|
missing_samples += 1 |
|
|
|
|
|
continue |
|
|
|
|
|
print(f" Summary: {found_samples}/{total_samples} samples loaded") |
|
|
print(f" Missing files: {missing_files} samples") |
|
|
print(f" Missing embeddings: {missing_samples} samples") |
|
|
|
|
|
return ( |
|
|
np.array(all_embeddings), |
|
|
np.array(all_labels), |
|
|
np.array(all_starts), |
|
|
np.array(all_ends), |
|
|
np.array(all_tbx5_scores), |
|
|
all_chromosomes |
|
|
) |
|
|
|
|
|
|
|
|
print("Loading train data...") |
|
|
X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train = load_embeddings_for_split( |
|
|
train_df, embeddings_dir, rc_embeddings_dir |
|
|
) |
|
|
|
|
|
print("Loading validation data...") |
|
|
X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val = load_embeddings_for_split( |
|
|
val_df, embeddings_dir, rc_embeddings_dir |
|
|
) |
|
|
|
|
|
print("Loading test data...") |
|
|
X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test = load_embeddings_for_split( |
|
|
test_df, embeddings_dir, rc_embeddings_dir |
|
|
) |
|
|
|
|
|
print(f"\nLoaded data:") |
|
|
print(f"Train: {len(X_train)} samples") |
|
|
print(f"Val: {len(X_val)} samples") |
|
|
print(f"Test: {len(X_test)} samples") |
|
|
print(f"Embedding dimension: {X_train.shape[1]}") |
|
|
print(f"Train positive samples: {np.sum(y_train)}") |
|
|
print(f"Val positive samples: {np.sum(y_val)}") |
|
|
print(f"Test positive samples: {np.sum(y_test)}") |
|
|
|
|
|
|
|
|
if len(X_train) == 0: |
|
|
raise ValueError("No training data loaded! Check embedding files and CSV data.") |
|
|
if len(X_val) == 0: |
|
|
raise ValueError("No validation data loaded! Check embedding files and CSV data.") |
|
|
if len(X_test) == 0: |
|
|
raise ValueError("No test data loaded! Check embedding files and CSV data.") |
|
|
|
|
|
print(f"\nData quality check:") |
|
|
print(f"Train positive ratio: {np.mean(y_train):.3f}") |
|
|
print(f"Val positive ratio: {np.mean(y_val):.3f}") |
|
|
print(f"Test positive ratio: {np.mean(y_test):.3f}") |
|
|
|
|
|
metadata = { |
|
|
"total_samples": len(X_train) + len(X_val) + len(X_test), |
|
|
"embedding_dim": X_train.shape[1], |
|
|
"train_samples": len(X_train), |
|
|
"val_samples": len(X_val), |
|
|
"test_samples": len(X_test), |
|
|
"train_positive": int(np.sum(y_train)), |
|
|
"val_positive": int(np.sum(y_val)), |
|
|
"test_positive": int(np.sum(y_test)), |
|
|
"sequence_type": "forward_and_reverse_complement" |
|
|
} |
|
|
|
|
|
return ( |
|
|
X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train, |
|
|
X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val, |
|
|
X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test, |
|
|
metadata |
|
|
) |
|
|
|
|
|
def prepare_data_with_scaling(X_train, X_val, X_test, y_train, y_val, y_test): |
|
|
""" |
|
|
Scale the features for train/val/test splits. |
|
|
""" |
|
|
print("Scaling features...") |
|
|
|
|
|
|
|
|
scaler = StandardScaler() |
|
|
X_train_scaled = scaler.fit_transform(X_train) |
|
|
X_val_scaled = scaler.transform(X_val) |
|
|
X_test_scaled = scaler.transform(X_test) |
|
|
|
|
|
return X_train_scaled, X_val_scaled, X_test_scaled, scaler |
|
|
|
|
|
def train_model( |
|
|
model, |
|
|
train_loader, |
|
|
val_loader, |
|
|
test_loader, |
|
|
device, |
|
|
output_dir, |
|
|
num_epochs=500, |
|
|
learning_rate=1e-4, |
|
|
patience=100, |
|
|
lr_patience=20, |
|
|
min_lr=1e-6, |
|
|
gradient_clip=1.0, |
|
|
save_every=5, |
|
|
): |
|
|
""" |
|
|
Train the model with specified optimization settings. |
|
|
""" |
|
|
print(f"Training model with learning rate {learning_rate}") |
|
|
print(f"Early stopping patience: {patience}") |
|
|
print(f"Learning rate reduction patience: {lr_patience}") |
|
|
|
|
|
|
|
|
criterion = nn.BCELoss() |
|
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
|
|
optimizer, mode='min', factor=0.5, patience=lr_patience, min_lr=min_lr |
|
|
) |
|
|
|
|
|
|
|
|
train_losses = [] |
|
|
val_losses = [] |
|
|
val_aucs = [] |
|
|
test_results_by_epoch = {} |
|
|
best_val_auc = 0.0 |
|
|
best_epoch = 0 |
|
|
epochs_without_improvement = 0 |
|
|
|
|
|
print(f"Starting training for {num_epochs} epochs...") |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
|
|
|
model.train() |
|
|
train_loss = 0.0 |
|
|
train_correct = 0 |
|
|
train_total = 0 |
|
|
|
|
|
for batch_embeddings, batch_labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"): |
|
|
batch_embeddings = batch_embeddings.to(device) |
|
|
batch_labels = batch_labels.to(device).float() |
|
|
|
|
|
optimizer.zero_grad() |
|
|
outputs = model(batch_embeddings).squeeze() |
|
|
loss = criterion(outputs, batch_labels) |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
train_loss += loss.item() |
|
|
predicted = (outputs > 0.5).float() |
|
|
train_correct += (predicted == batch_labels).sum().item() |
|
|
train_total += batch_labels.size(0) |
|
|
|
|
|
train_loss /= len(train_loader) |
|
|
train_acc = train_correct / train_total |
|
|
|
|
|
|
|
|
model.eval() |
|
|
val_loss = 0.0 |
|
|
val_correct = 0 |
|
|
val_total = 0 |
|
|
val_predictions = [] |
|
|
val_labels = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch_embeddings, batch_labels in val_loader: |
|
|
batch_embeddings = batch_embeddings.to(device) |
|
|
batch_labels = batch_labels.to(device).float() |
|
|
|
|
|
outputs = model(batch_embeddings).squeeze() |
|
|
loss = criterion(outputs, batch_labels) |
|
|
|
|
|
val_loss += loss.item() |
|
|
predicted = (outputs > 0.5).float() |
|
|
val_correct += (predicted == batch_labels).sum().item() |
|
|
val_total += batch_labels.size(0) |
|
|
|
|
|
val_predictions.extend(outputs.cpu().numpy()) |
|
|
val_labels.extend(batch_labels.cpu().numpy()) |
|
|
|
|
|
val_loss /= len(val_loader) |
|
|
val_acc = val_correct / val_total |
|
|
val_auc = roc_auc_score(val_labels, val_predictions) |
|
|
|
|
|
|
|
|
scheduler.step(val_loss) |
|
|
|
|
|
|
|
|
train_losses.append(train_loss) |
|
|
val_losses.append(val_loss) |
|
|
val_aucs.append(val_auc) |
|
|
|
|
|
|
|
|
if val_auc > best_val_auc: |
|
|
best_val_auc = val_auc |
|
|
best_epoch = epoch |
|
|
epochs_without_improvement = 0 |
|
|
|
|
|
|
|
|
torch.save({ |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'epoch': epoch, |
|
|
'val_auc': val_auc, |
|
|
'val_loss': val_loss, |
|
|
'input_dim': model.fc1.in_features, |
|
|
}, os.path.join(output_dir, 'best_model.pth')) |
|
|
|
|
|
print(f"New best model saved! Val AUC: {val_auc:.4f}") |
|
|
else: |
|
|
epochs_without_improvement += 1 |
|
|
|
|
|
|
|
|
if (epoch + 1) % save_every == 0 or epoch == 0: |
|
|
|
|
|
epoch_model_path = os.path.join(output_dir, f"model_epoch_{epoch+1}.pth") |
|
|
torch.save({ |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'epoch': epoch + 1, |
|
|
'val_auc': val_auc, |
|
|
'val_loss': val_loss, |
|
|
'input_dim': model.fc1.in_features, |
|
|
}, epoch_model_path) |
|
|
|
|
|
|
|
|
test_results = evaluate_model_simple(model, test_loader, device) |
|
|
test_results_by_epoch[epoch + 1] = test_results |
|
|
|
|
|
print(f"Epoch {epoch+1:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " |
|
|
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}, " |
|
|
f"Test AUC: {test_results['auc']:.4f}") |
|
|
|
|
|
|
|
|
elif (epoch + 1) % 10 == 0: |
|
|
current_lr = optimizer.param_groups[0]['lr'] |
|
|
print(f"Epoch {epoch+1:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " |
|
|
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}, " |
|
|
f"LR: {current_lr:.2e}") |
|
|
|
|
|
|
|
|
if epochs_without_improvement >= patience: |
|
|
print(f"Early stopping at epoch {epoch+1} (no improvement for {patience} epochs)") |
|
|
break |
|
|
|
|
|
print(f"Training completed! Best validation AUC: {best_val_auc:.4f} at epoch {best_epoch+1}") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(os.path.join(output_dir, 'best_model.pth'), map_location=device, weights_only=False) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
test_predictions = [] |
|
|
test_labels = [] |
|
|
test_loss = 0.0 |
|
|
test_correct = 0 |
|
|
test_total = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch_embeddings, batch_labels in test_loader: |
|
|
batch_embeddings = batch_embeddings.to(device) |
|
|
batch_labels = batch_labels.to(device).float() |
|
|
|
|
|
outputs = model(batch_embeddings).squeeze() |
|
|
loss = criterion(outputs, batch_labels) |
|
|
|
|
|
test_loss += loss.item() |
|
|
predicted = (outputs > 0.5).float() |
|
|
test_correct += (predicted == batch_labels).sum().item() |
|
|
test_total += batch_labels.size(0) |
|
|
|
|
|
test_predictions.extend(outputs.cpu().numpy()) |
|
|
test_labels.extend(batch_labels.cpu().numpy()) |
|
|
|
|
|
test_loss /= len(test_loader) |
|
|
test_acc = test_correct / test_total |
|
|
test_auc = roc_auc_score(test_labels, test_predictions) |
|
|
|
|
|
|
|
|
precision, recall, f1, _ = precision_recall_fscore_support(test_labels, [1 if p > 0.5 else 0 for p in test_predictions], average='binary') |
|
|
cm = confusion_matrix(test_labels, [1 if p > 0.5 else 0 for p in test_predictions]) |
|
|
|
|
|
|
|
|
results = { |
|
|
'test_auc': float(test_auc), |
|
|
'test_accuracy': float(test_acc), |
|
|
'test_loss': float(test_loss), |
|
|
'test_precision': float(precision), |
|
|
'test_recall': float(recall), |
|
|
'test_f1': float(f1), |
|
|
'confusion_matrix': cm.tolist(), |
|
|
'best_val_auc': float(best_val_auc), |
|
|
'best_epoch': int(best_epoch + 1), |
|
|
'total_epochs': int(epoch + 1), |
|
|
'sequence_type': 'forward_and_reverse_complement', |
|
|
'predictions': [float(p) for p in test_predictions], |
|
|
'labels': [float(l) for l in test_labels] |
|
|
} |
|
|
|
|
|
with open(os.path.join(output_dir, 'test_results.json'), 'w') as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
|
|
|
history = { |
|
|
'train_losses': train_losses, |
|
|
'val_losses': val_losses, |
|
|
'val_aucs': val_aucs, |
|
|
'best_epoch': best_epoch + 1, |
|
|
'best_val_auc': best_val_auc |
|
|
} |
|
|
|
|
|
with open(os.path.join(output_dir, 'training_history.json'), 'w') as f: |
|
|
json.dump(history, f, indent=2) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(15, 5)) |
|
|
|
|
|
plt.subplot(1, 3, 1) |
|
|
plt.plot(train_losses, label='Train Loss') |
|
|
plt.plot(val_losses, label='Val Loss') |
|
|
plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Loss') |
|
|
plt.title('Training and Validation Loss') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
plt.subplot(1, 3, 2) |
|
|
plt.plot(val_aucs, label='Val AUC', color='green') |
|
|
plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('AUC') |
|
|
plt.title('Validation AUC') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
plt.subplot(1, 3, 3) |
|
|
plt.plot(range(len(train_losses)), train_losses, label='Train Loss') |
|
|
plt.plot(range(len(val_losses)), val_losses, label='Val Loss') |
|
|
plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Loss') |
|
|
plt.title('Loss Comparison') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(output_dir, 'training_history.png'), dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"\n=== Test Results ===") |
|
|
print(f"Test AUC: {test_auc:.4f}") |
|
|
print(f"Test Accuracy: {test_acc:.4f}") |
|
|
print(f"Test Precision: {precision:.4f}") |
|
|
print(f"Test Recall: {recall:.4f}") |
|
|
print(f"Test F1: {f1:.4f}") |
|
|
print(f"Confusion Matrix:\n{cm}") |
|
|
|
|
|
return results, test_results_by_epoch |
|
|
|
|
|
def evaluate_model_simple(model, test_loader, device): |
|
|
"""Simple evaluation that returns just basic metrics.""" |
|
|
model.eval() |
|
|
test_preds = [] |
|
|
test_labels = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch_X, batch_y in test_loader: |
|
|
batch_X = batch_X.to(device) |
|
|
outputs = model(batch_X).squeeze() |
|
|
test_preds.extend(outputs.cpu().numpy()) |
|
|
test_labels.extend(batch_y.numpy()) |
|
|
|
|
|
test_preds = np.array(test_preds) |
|
|
test_labels = np.array(test_labels) |
|
|
|
|
|
|
|
|
test_auc = roc_auc_score(test_labels, test_preds) |
|
|
test_preds_binary = (test_preds > 0.5).astype(int) |
|
|
test_acc = accuracy_score(test_labels, test_preds_binary) |
|
|
precision, recall, f1, _ = precision_recall_fscore_support( |
|
|
test_labels, test_preds_binary, average="binary" |
|
|
) |
|
|
|
|
|
return { |
|
|
"auc": test_auc, |
|
|
"accuracy": test_acc, |
|
|
"precision": precision, |
|
|
"recall": recall, |
|
|
"f1": f1, |
|
|
} |
|
|
|
|
|
def save_epoch_analysis(test_results_by_epoch, output_dir): |
|
|
"""Save analysis of results across epochs.""" |
|
|
epochs = sorted(test_results_by_epoch.keys()) |
|
|
|
|
|
|
|
|
summary_data = [] |
|
|
for epoch in epochs: |
|
|
results = test_results_by_epoch[epoch] |
|
|
summary_data.append( |
|
|
{ |
|
|
"epoch": epoch, |
|
|
"test_auc": results["auc"], |
|
|
"test_accuracy": results["accuracy"], |
|
|
"test_precision": results["precision"], |
|
|
"test_recall": results["recall"], |
|
|
"test_f1": results["f1"], |
|
|
} |
|
|
) |
|
|
|
|
|
df = pd.DataFrame(summary_data) |
|
|
|
|
|
|
|
|
csv_path = os.path.join(output_dir, "epoch_analysis.csv") |
|
|
df.to_csv(csv_path, index=False) |
|
|
|
|
|
|
|
|
json_path = os.path.join(output_dir, "epoch_analysis.json") |
|
|
with open(json_path, "w") as f: |
|
|
json.dump(test_results_by_epoch, f, indent=2) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
print("EPOCH-WISE TEST PERFORMANCE ANALYSIS") |
|
|
print("=" * 50) |
|
|
|
|
|
best_auc_epoch = df.loc[df["test_auc"].idxmax()] |
|
|
best_f1_epoch = df.loc[df["test_f1"].idxmax()] |
|
|
|
|
|
print( |
|
|
f"Best Test AUC: {best_auc_epoch['test_auc']:.4f} at Epoch {best_auc_epoch['epoch']}" |
|
|
) |
|
|
print( |
|
|
f"Best Test F1: {best_f1_epoch['test_f1']:.4f} at Epoch {best_f1_epoch['epoch']}" |
|
|
) |
|
|
print() |
|
|
print("Epoch-wise Performance:") |
|
|
print(df.to_string(index=False, float_format="%.4f")) |
|
|
|
|
|
|
|
|
if len(epochs) >= 2: |
|
|
auc_trend = df["test_auc"].iloc[-1] - df["test_auc"].iloc[0] |
|
|
if auc_trend < -0.01: |
|
|
print( |
|
|
f"\n⚠️ OVERFITTING DETECTED: Test AUC decreased by {abs(auc_trend):.4f} from epoch {epochs[0]} to {epochs[-1]}" |
|
|
) |
|
|
elif auc_trend > 0.01: |
|
|
print( |
|
|
f"\n✅ GOOD TRAINING: Test AUC improved by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}" |
|
|
) |
|
|
else: |
|
|
print( |
|
|
f"\n📊 STABLE TRAINING: Test AUC changed by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}" |
|
|
) |
|
|
|
|
|
return df |
|
|
|
|
|
def plot_training_history(train_losses, val_losses, val_aucs, output_dir): |
|
|
"""Plot training history.""" |
|
|
fig, axes = plt.subplots(1, 2, figsize=(12, 4)) |
|
|
|
|
|
|
|
|
axes[0].plot(train_losses, label="Train Loss") |
|
|
axes[0].plot(val_losses, label="Val Loss") |
|
|
axes[0].set_xlabel("Epoch") |
|
|
axes[0].set_ylabel("Loss") |
|
|
axes[0].set_title("Training and Validation Loss") |
|
|
axes[0].legend() |
|
|
axes[0].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
axes[1].plot(val_aucs, label="Val AUC", color="green") |
|
|
axes[1].set_xlabel("Epoch") |
|
|
axes[1].set_ylabel("AUC") |
|
|
axes[1].set_title("Validation AUC") |
|
|
axes[1].legend() |
|
|
axes[1].grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(output_dir, "training_history.png"), dpi=100) |
|
|
plt.close() |
|
|
|
|
|
def plot_confusion_matrix(cm, output_dir): |
|
|
"""Plot confusion matrix.""" |
|
|
plt.figure(figsize=(6, 5)) |
|
|
sns.heatmap( |
|
|
cm, |
|
|
annot=True, |
|
|
fmt="d", |
|
|
cmap="Blues", |
|
|
xticklabels=["Non-binding", "TBX5-binding"], |
|
|
yticklabels=["Non-binding", "TBX5-binding"], |
|
|
) |
|
|
plt.title("Confusion Matrix") |
|
|
plt.ylabel("True Label") |
|
|
plt.xlabel("Predicted Label") |
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=100) |
|
|
plt.close() |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Train TBX5 classifier with forward and reverse complement embeddings") |
|
|
parser.add_argument( |
|
|
"--embeddings-dir", |
|
|
type=str, |
|
|
default="tbx5_embeddings", |
|
|
help="Directory containing forward embeddings (default: tbx5_embeddings)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--rc-embeddings-dir", |
|
|
type=str, |
|
|
default="tbx5_embeddings_reverse_complement", |
|
|
help="Directory containing reverse complement embeddings (default: tbx5_embeddings_reverse_complement)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=str, |
|
|
default="result_with_rc", |
|
|
help="Output directory for results (default: result_with_rc)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch-size", |
|
|
type=int, |
|
|
default=32, |
|
|
help="Batch size for training (default: 32)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num-epochs", |
|
|
type=int, |
|
|
default=500, |
|
|
help="Number of training epochs (default: 500)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--learning-rate", |
|
|
type=float, |
|
|
default=1e-4, |
|
|
help="Learning rate (default: 1e-4)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--patience", |
|
|
type=int, |
|
|
default=100, |
|
|
help="Early stopping patience (default: 100)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dropout-rate", |
|
|
type=float, |
|
|
default=0.5, |
|
|
help="Dropout rate (default: 0.5)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--processed-data-dir", |
|
|
type=str, |
|
|
default="processed_data_new", |
|
|
help="Directory containing train/val/test CSV files (default: processed_data_new)", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
print("Loading combined embeddings using CSV splits...") |
|
|
(X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train, |
|
|
X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val, |
|
|
X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test, |
|
|
metadata) = load_tbx5_embeddings_with_rc_from_csv( |
|
|
args.embeddings_dir, args.rc_embeddings_dir, args.processed_data_dir |
|
|
) |
|
|
|
|
|
|
|
|
with open(os.path.join(args.output_dir, 'metadata.json'), 'w') as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
|
|
|
|
|
|
X_train_scaled, X_val_scaled, X_test_scaled, scaler = prepare_data_with_scaling( |
|
|
X_train, X_val, X_test, y_train, y_val, y_test |
|
|
) |
|
|
|
|
|
|
|
|
with open(os.path.join(args.output_dir, 'scaler.pkl'), 'wb') as f: |
|
|
pickle.dump(scaler, f) |
|
|
|
|
|
|
|
|
train_dataset = TensorDataset(torch.FloatTensor(X_train_scaled), torch.LongTensor(y_train)) |
|
|
val_dataset = TensorDataset(torch.FloatTensor(X_val_scaled), torch.LongTensor(y_val)) |
|
|
test_dataset = TensorDataset(torch.FloatTensor(X_test_scaled), torch.LongTensor(y_test)) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) |
|
|
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) |
|
|
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) |
|
|
|
|
|
|
|
|
input_dim = X_train_scaled.shape[1] |
|
|
print(f"Input dimension: {input_dim}") |
|
|
|
|
|
model = TBX5ClassifierWithRC(input_dim=input_dim, dropout_rate=args.dropout_rate).to(device) |
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f"Total parameters: {total_params:,}") |
|
|
print(f"Trainable parameters: {trainable_params:,}") |
|
|
|
|
|
|
|
|
results, test_results_by_epoch = train_model( |
|
|
model, train_loader, val_loader, test_loader, device, args.output_dir, |
|
|
num_epochs=args.num_epochs, |
|
|
learning_rate=args.learning_rate, |
|
|
patience=args.patience, |
|
|
) |
|
|
|
|
|
|
|
|
save_epoch_analysis(test_results_by_epoch, args.output_dir) |
|
|
|
|
|
|
|
|
plot_training_history(results.get('train_losses', []), results.get('val_losses', []), results.get('val_aucs', []), args.output_dir) |
|
|
plot_confusion_matrix(np.array(results['confusion_matrix']), args.output_dir) |
|
|
|
|
|
print(f"\nTraining completed! Results saved to {args.output_dir}") |
|
|
print(f"Best test AUC: {results['test_auc']:.4f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(epochs) >= 2: |
|
|
auc_trend = df["test_auc"].iloc[-1] - df["test_auc"].iloc[0] |
|
|
if auc_trend < -0.01: |
|
|
print( |
|
|
f"\n⚠️ OVERFITTING DETECTED: Test AUC decreased by {abs(auc_trend):.4f} from epoch {epochs[0]} to {epochs[-1]}" |
|
|
) |
|
|
elif auc_trend > 0.01: |
|
|
print( |
|
|
f"\n✅ GOOD TRAINING: Test AUC improved by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}" |
|
|
) |
|
|
else: |
|
|
print( |
|
|
f"\n📊 STABLE TRAINING: Test AUC changed by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}" |
|
|
) |
|
|
|
|
|
return df |
|
|
|
|
|
def plot_training_history(train_losses, val_losses, val_aucs, output_dir): |
|
|
"""Plot training history.""" |
|
|
fig, axes = plt.subplots(1, 2, figsize=(12, 4)) |
|
|
|
|
|
|
|
|
axes[0].plot(train_losses, label="Train Loss") |
|
|
axes[0].plot(val_losses, label="Val Loss") |
|
|
axes[0].set_xlabel("Epoch") |
|
|
axes[0].set_ylabel("Loss") |
|
|
axes[0].set_title("Training and Validation Loss") |
|
|
axes[0].legend() |
|
|
axes[0].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
axes[1].plot(val_aucs, label="Val AUC", color="green") |
|
|
axes[1].set_xlabel("Epoch") |
|
|
axes[1].set_ylabel("AUC") |
|
|
axes[1].set_title("Validation AUC") |
|
|
axes[1].legend() |
|
|
axes[1].grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(output_dir, "training_history.png"), dpi=100) |
|
|
plt.close() |
|
|
|
|
|
def plot_confusion_matrix(cm, output_dir): |
|
|
"""Plot confusion matrix.""" |
|
|
plt.figure(figsize=(6, 5)) |
|
|
sns.heatmap( |
|
|
cm, |
|
|
annot=True, |
|
|
fmt="d", |
|
|
cmap="Blues", |
|
|
xticklabels=["Non-binding", "TBX5-binding"], |
|
|
yticklabels=["Non-binding", "TBX5-binding"], |
|
|
) |
|
|
plt.title("Confusion Matrix") |
|
|
plt.ylabel("True Label") |
|
|
plt.xlabel("Predicted Label") |
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=100) |
|
|
plt.close() |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Train TBX5 classifier with forward and reverse complement embeddings") |
|
|
parser.add_argument( |
|
|
"--embeddings-dir", |
|
|
type=str, |
|
|
default="tbx5_embeddings", |
|
|
help="Directory containing forward embeddings (default: tbx5_embeddings)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--rc-embeddings-dir", |
|
|
type=str, |
|
|
default="tbx5_embeddings_reverse_complement", |
|
|
help="Directory containing reverse complement embeddings (default: tbx5_embeddings_reverse_complement)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=str, |
|
|
default="result_with_rc", |
|
|
help="Output directory for results (default: result_with_rc)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch-size", |
|
|
type=int, |
|
|
default=32, |
|
|
help="Batch size for training (default: 32)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num-epochs", |
|
|
type=int, |
|
|
default=500, |
|
|
help="Number of training epochs (default: 500)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--learning-rate", |
|
|
type=float, |
|
|
default=1e-4, |
|
|
help="Learning rate (default: 1e-4)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--patience", |
|
|
type=int, |
|
|
default=100, |
|
|
help="Early stopping patience (default: 100)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dropout-rate", |
|
|
type=float, |
|
|
default=0.5, |
|
|
help="Dropout rate (default: 0.5)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--processed-data-dir", |
|
|
type=str, |
|
|
default="processed_data_new", |
|
|
help="Directory containing train/val/test CSV files (default: processed_data_new)", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
print("Loading combined embeddings using CSV splits...") |
|
|
(X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train, |
|
|
X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val, |
|
|
X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test, |
|
|
metadata) = load_tbx5_embeddings_with_rc_from_csv( |
|
|
args.embeddings_dir, args.rc_embeddings_dir, args.processed_data_dir |
|
|
) |
|
|
|
|
|
|
|
|
with open(os.path.join(args.output_dir, 'metadata.json'), 'w') as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
|
|
|
|
|
|
X_train_scaled, X_val_scaled, X_test_scaled, scaler = prepare_data_with_scaling( |
|
|
X_train, X_val, X_test, y_train, y_val, y_test |
|
|
) |
|
|
|
|
|
|
|
|
with open(os.path.join(args.output_dir, 'scaler.pkl'), 'wb') as f: |
|
|
pickle.dump(scaler, f) |
|
|
|
|
|
|
|
|
train_dataset = TensorDataset(torch.FloatTensor(X_train_scaled), torch.LongTensor(y_train)) |
|
|
val_dataset = TensorDataset(torch.FloatTensor(X_val_scaled), torch.LongTensor(y_val)) |
|
|
test_dataset = TensorDataset(torch.FloatTensor(X_test_scaled), torch.LongTensor(y_test)) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) |
|
|
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) |
|
|
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) |
|
|
|
|
|
|
|
|
input_dim = X_train_scaled.shape[1] |
|
|
print(f"Input dimension: {input_dim}") |
|
|
|
|
|
model = TBX5ClassifierWithRC(input_dim=input_dim, dropout_rate=args.dropout_rate).to(device) |
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f"Total parameters: {total_params:,}") |
|
|
print(f"Trainable parameters: {trainable_params:,}") |
|
|
|
|
|
|
|
|
results, test_results_by_epoch = train_model( |
|
|
model, train_loader, val_loader, test_loader, device, args.output_dir, |
|
|
num_epochs=args.num_epochs, |
|
|
learning_rate=args.learning_rate, |
|
|
patience=args.patience, |
|
|
) |
|
|
|
|
|
|
|
|
save_epoch_analysis(test_results_by_epoch, args.output_dir) |
|
|
|
|
|
|
|
|
plot_training_history(results.get('train_losses', []), results.get('val_losses', []), results.get('val_aucs', []), args.output_dir) |
|
|
plot_confusion_matrix(np.array(results['confusion_matrix']), args.output_dir) |
|
|
|
|
|
print(f"\nTraining completed! Results saved to {args.output_dir}") |
|
|
print(f"Best test AUC: {results['test_auc']:.4f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|