motif_classifier / train_tbx5_classifier_with_rc.py
harari's picture
Upload 3 files
65a437b verified
#!/usr/bin/env python3
"""
Train TBX5 classifier using both forward and reverse complement embeddings.
This script combines embeddings from both strands to improve classification accuracy.
"""
import os
import sys
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
roc_auc_score,
accuracy_score,
precision_recall_fscore_support,
confusion_matrix,
)
import json
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
# Add the parent directory to the path to import from finetuning
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'finetuning'))
class TBX5ClassifierWithRC(nn.Module):
"""
3-layer feedforward neural network for TBX5 binding site classification
using both forward and reverse complement embeddings.
Architecture:
- Input (8192 dimensions: 4096 forward + 4096 reverse complement) -> 2048 -> 512 -> 128 -> 1 (sigmoid)
- ReLU activation, BatchNorm, Dropout(0.5) after each hidden layer
"""
def __init__(self, input_dim=8192, dropout_rate=0.5):
super(TBX5ClassifierWithRC, self).__init__()
self.fc1 = nn.Linear(input_dim, 2048)
self.bn1 = nn.BatchNorm1d(2048)
self.dropout1 = nn.Dropout(dropout_rate)
self.fc2 = nn.Linear(2048, 512)
self.bn2 = nn.BatchNorm1d(512)
self.dropout2 = nn.Dropout(dropout_rate)
self.fc3 = nn.Linear(512, 128)
self.bn3 = nn.BatchNorm1d(128)
self.dropout3 = nn.Dropout(dropout_rate)
self.fc4 = nn.Linear(128, 1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# Layer 1
x = self.fc1(x)
x = self.relu(x)
x = self.bn1(x)
x = self.dropout1(x)
# Layer 2
x = self.fc2(x)
x = self.relu(x)
x = self.bn2(x)
x = self.dropout2(x)
# Layer 3
x = self.fc3(x)
x = self.relu(x)
x = self.bn3(x)
x = self.dropout3(x)
# Output layer
x = self.fc4(x)
x = self.sigmoid(x)
return x
def load_tbx5_embeddings_with_rc_from_csv(embeddings_dir, rc_embeddings_dir, processed_data_dir):
"""
Load TBX5 embeddings using train/val/test splits from processed_data_new CSV files.
Args:
embeddings_dir: Directory containing forward embeddings
rc_embeddings_dir: Directory containing reverse complement embeddings
processed_data_dir: Directory containing train/val/test CSV files
Returns:
train/val/test data splits with combined embeddings
"""
print(f"Loading data using CSV splits from: {processed_data_dir}")
print(f"Loading forward embeddings from: {embeddings_dir}")
print(f"Loading reverse complement embeddings from: {rc_embeddings_dir}")
# Load CSV files
train_df = pd.read_csv(os.path.join(processed_data_dir, 'train_tbx5_data_new.csv'))
val_df = pd.read_csv(os.path.join(processed_data_dir, 'val_tbx5_data_new.csv'))
test_df = pd.read_csv(os.path.join(processed_data_dir, 'test_tbx5_data_new.csv'))
print(f"Train samples: {len(train_df)}")
print(f"Val samples: {len(val_df)}")
print(f"Test samples: {len(test_df)}")
def load_embeddings_for_split(df, embeddings_dir, rc_embeddings_dir):
"""Load embeddings for a specific split."""
all_embeddings = []
all_labels = []
all_starts = []
all_ends = []
all_tbx5_scores = []
all_chromosomes = []
total_samples = len(df)
found_samples = 0
missing_files = 0
missing_samples = 0
# Keep track of loaded chromosome data to avoid reloading
loaded_chrom_data = {}
# Process samples in original order to maintain sequence
for idx, row in df.iterrows():
chrom_num = row['chromosome']
chrom = f"chr{chrom_num}"
start = row['start']
end = row['end']
label = row['label']
tbx5_score = row['tbx5_score']
# Load chromosome data if not already loaded
if chrom not in loaded_chrom_data:
forward_file = os.path.join(embeddings_dir, f"{chrom}_tbx5_embeddings_arrays.npz")
rc_file = os.path.join(rc_embeddings_dir, f"{chrom}_tbx5_embeddings_rc_arrays.npz")
if os.path.exists(forward_file) and os.path.exists(rc_file):
print(f" Loading {chrom}...")
forward_data = np.load(forward_file)
rc_data = np.load(rc_file)
loaded_chrom_data[chrom] = {
'forward_embeddings': forward_data['embeddings'],
'forward_starts': forward_data['starts'],
'forward_ends': forward_data['ends'],
'forward_tbx5_scores': forward_data['tbx5_scores'],
'rc_embeddings': rc_data['embeddings'],
'rc_starts': rc_data['starts'],
'rc_ends': rc_data['ends'],
'rc_tbx5_scores': rc_data['tbx5_scores']
}
else:
print(f" Warning: Missing embedding files for {chrom}")
loaded_chrom_data[chrom] = None
missing_files += 1
continue
# Skip if chromosome data not available
if loaded_chrom_data[chrom] is None:
missing_samples += 1
continue
chrom_data = loaded_chrom_data[chrom]
forward_starts = chrom_data['forward_starts']
forward_embeddings = chrom_data['forward_embeddings']
rc_embeddings = chrom_data['rc_embeddings']
# Find matching sample in embeddings (use chromosome and start only)
mask = (forward_starts == start)
if np.any(mask):
# If multiple matches, take the first one
emb_idx = np.where(mask)[0][0]
# Get embeddings
forward_emb = forward_embeddings[emb_idx]
rc_emb = rc_embeddings[emb_idx]
# Combine embeddings
combined_emb = np.concatenate([forward_emb, rc_emb])
all_embeddings.append(combined_emb)
all_labels.append(label)
all_starts.append(start)
all_ends.append(end)
all_tbx5_scores.append(tbx5_score)
all_chromosomes.append(chrom)
found_samples += 1
else:
missing_samples += 1
# Skip missing samples instead of adding zeros
continue
print(f" Summary: {found_samples}/{total_samples} samples loaded")
print(f" Missing files: {missing_files} samples")
print(f" Missing embeddings: {missing_samples} samples")
return (
np.array(all_embeddings),
np.array(all_labels),
np.array(all_starts),
np.array(all_ends),
np.array(all_tbx5_scores),
all_chromosomes
)
# Load data for each split
print("Loading train data...")
X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train = load_embeddings_for_split(
train_df, embeddings_dir, rc_embeddings_dir
)
print("Loading validation data...")
X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val = load_embeddings_for_split(
val_df, embeddings_dir, rc_embeddings_dir
)
print("Loading test data...")
X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test = load_embeddings_for_split(
test_df, embeddings_dir, rc_embeddings_dir
)
print(f"\nLoaded data:")
print(f"Train: {len(X_train)} samples")
print(f"Val: {len(X_val)} samples")
print(f"Test: {len(X_test)} samples")
print(f"Embedding dimension: {X_train.shape[1]}")
print(f"Train positive samples: {np.sum(y_train)}")
print(f"Val positive samples: {np.sum(y_val)}")
print(f"Test positive samples: {np.sum(y_test)}")
# Check if we have enough data
if len(X_train) == 0:
raise ValueError("No training data loaded! Check embedding files and CSV data.")
if len(X_val) == 0:
raise ValueError("No validation data loaded! Check embedding files and CSV data.")
if len(X_test) == 0:
raise ValueError("No test data loaded! Check embedding files and CSV data.")
print(f"\nData quality check:")
print(f"Train positive ratio: {np.mean(y_train):.3f}")
print(f"Val positive ratio: {np.mean(y_val):.3f}")
print(f"Test positive ratio: {np.mean(y_test):.3f}")
metadata = {
"total_samples": len(X_train) + len(X_val) + len(X_test),
"embedding_dim": X_train.shape[1],
"train_samples": len(X_train),
"val_samples": len(X_val),
"test_samples": len(X_test),
"train_positive": int(np.sum(y_train)),
"val_positive": int(np.sum(y_val)),
"test_positive": int(np.sum(y_test)),
"sequence_type": "forward_and_reverse_complement"
}
return (
X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train,
X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val,
X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test,
metadata
)
def prepare_data_with_scaling(X_train, X_val, X_test, y_train, y_val, y_test):
"""
Scale the features for train/val/test splits.
"""
print("Scaling features...")
# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)
return X_train_scaled, X_val_scaled, X_test_scaled, scaler
def train_model(
model,
train_loader,
val_loader,
test_loader,
device,
output_dir,
num_epochs=500,
learning_rate=1e-4,
patience=100,
lr_patience=20,
min_lr=1e-6,
gradient_clip=1.0,
save_every=5,
):
"""
Train the model with specified optimization settings.
"""
print(f"Training model with learning rate {learning_rate}")
print(f"Early stopping patience: {patience}")
print(f"Learning rate reduction patience: {lr_patience}")
# Loss and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=lr_patience, min_lr=min_lr
)
# Training history
train_losses = []
val_losses = []
val_aucs = []
test_results_by_epoch = {} # Store test results for each saved epoch
best_val_auc = 0.0
best_epoch = 0
epochs_without_improvement = 0
print(f"Starting training for {num_epochs} epochs...")
for epoch in range(num_epochs):
# Training phase
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for batch_embeddings, batch_labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
batch_embeddings = batch_embeddings.to(device)
batch_labels = batch_labels.to(device).float()
optimizer.zero_grad()
outputs = model(batch_embeddings).squeeze()
loss = criterion(outputs, batch_labels)
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
optimizer.step()
train_loss += loss.item()
predicted = (outputs > 0.5).float()
train_correct += (predicted == batch_labels).sum().item()
train_total += batch_labels.size(0)
train_loss /= len(train_loader)
train_acc = train_correct / train_total
# Validation phase
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
val_predictions = []
val_labels = []
with torch.no_grad():
for batch_embeddings, batch_labels in val_loader:
batch_embeddings = batch_embeddings.to(device)
batch_labels = batch_labels.to(device).float()
outputs = model(batch_embeddings).squeeze()
loss = criterion(outputs, batch_labels)
val_loss += loss.item()
predicted = (outputs > 0.5).float()
val_correct += (predicted == batch_labels).sum().item()
val_total += batch_labels.size(0)
val_predictions.extend(outputs.cpu().numpy())
val_labels.extend(batch_labels.cpu().numpy())
val_loss /= len(val_loader)
val_acc = val_correct / val_total
val_auc = roc_auc_score(val_labels, val_predictions)
# Update learning rate
scheduler.step(val_loss)
# Store history
train_losses.append(train_loss)
val_losses.append(val_loss)
val_aucs.append(val_auc)
# Check for improvement
if val_auc > best_val_auc:
best_val_auc = val_auc
best_epoch = epoch
epochs_without_improvement = 0
# Save best model
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'val_auc': val_auc,
'val_loss': val_loss,
'input_dim': model.fc1.in_features,
}, os.path.join(output_dir, 'best_model.pth'))
print(f"New best model saved! Val AUC: {val_auc:.4f}")
else:
epochs_without_improvement += 1
# Save model and evaluate every N epochs
if (epoch + 1) % save_every == 0 or epoch == 0:
# Save model state
epoch_model_path = os.path.join(output_dir, f"model_epoch_{epoch+1}.pth")
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch + 1,
'val_auc': val_auc,
'val_loss': val_loss,
'input_dim': model.fc1.in_features,
}, epoch_model_path)
# Evaluate on test set
test_results = evaluate_model_simple(model, test_loader, device)
test_results_by_epoch[epoch + 1] = test_results
print(f"Epoch {epoch+1:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}, "
f"Test AUC: {test_results['auc']:.4f}")
# Print progress for other epochs
elif (epoch + 1) % 10 == 0:
current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch+1:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}, "
f"LR: {current_lr:.2e}")
# Early stopping
if epochs_without_improvement >= patience:
print(f"Early stopping at epoch {epoch+1} (no improvement for {patience} epochs)")
break
print(f"Training completed! Best validation AUC: {best_val_auc:.4f} at epoch {best_epoch+1}")
# Load best model for testing
checkpoint = torch.load(os.path.join(output_dir, 'best_model.pth'), map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
# Test evaluation
model.eval()
test_predictions = []
test_labels = []
test_loss = 0.0
test_correct = 0
test_total = 0
with torch.no_grad():
for batch_embeddings, batch_labels in test_loader:
batch_embeddings = batch_embeddings.to(device)
batch_labels = batch_labels.to(device).float()
outputs = model(batch_embeddings).squeeze()
loss = criterion(outputs, batch_labels)
test_loss += loss.item()
predicted = (outputs > 0.5).float()
test_correct += (predicted == batch_labels).sum().item()
test_total += batch_labels.size(0)
test_predictions.extend(outputs.cpu().numpy())
test_labels.extend(batch_labels.cpu().numpy())
test_loss /= len(test_loader)
test_acc = test_correct / test_total
test_auc = roc_auc_score(test_labels, test_predictions)
# Calculate additional metrics
precision, recall, f1, _ = precision_recall_fscore_support(test_labels, [1 if p > 0.5 else 0 for p in test_predictions], average='binary')
cm = confusion_matrix(test_labels, [1 if p > 0.5 else 0 for p in test_predictions])
# Save results
results = {
'test_auc': float(test_auc),
'test_accuracy': float(test_acc),
'test_loss': float(test_loss),
'test_precision': float(precision),
'test_recall': float(recall),
'test_f1': float(f1),
'confusion_matrix': cm.tolist(),
'best_val_auc': float(best_val_auc),
'best_epoch': int(best_epoch + 1),
'total_epochs': int(epoch + 1),
'sequence_type': 'forward_and_reverse_complement',
'predictions': [float(p) for p in test_predictions],
'labels': [float(l) for l in test_labels]
}
with open(os.path.join(output_dir, 'test_results.json'), 'w') as f:
json.dump(results, f, indent=2)
# Save training history
history = {
'train_losses': train_losses,
'val_losses': val_losses,
'val_aucs': val_aucs,
'best_epoch': best_epoch + 1,
'best_val_auc': best_val_auc
}
with open(os.path.join(output_dir, 'training_history.json'), 'w') as f:
json.dump(history, f, indent=2)
# Plot training history
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.subplot(1, 3, 2)
plt.plot(val_aucs, label='Val AUC', color='green')
plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.title('Validation AUC')
plt.legend()
plt.grid(True, alpha=0.3)
plt.subplot(1, 3, 3)
plt.plot(range(len(train_losses)), train_losses, label='Train Loss')
plt.plot(range(len(val_losses)), val_losses, label='Val Loss')
plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Comparison')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'training_history.png'), dpi=300, bbox_inches='tight')
plt.close()
print(f"\n=== Test Results ===")
print(f"Test AUC: {test_auc:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test Precision: {precision:.4f}")
print(f"Test Recall: {recall:.4f}")
print(f"Test F1: {f1:.4f}")
print(f"Confusion Matrix:\n{cm}")
return results, test_results_by_epoch
def evaluate_model_simple(model, test_loader, device):
"""Simple evaluation that returns just basic metrics."""
model.eval()
test_preds = []
test_labels = []
with torch.no_grad():
for batch_X, batch_y in test_loader:
batch_X = batch_X.to(device)
outputs = model(batch_X).squeeze()
test_preds.extend(outputs.cpu().numpy())
test_labels.extend(batch_y.numpy())
test_preds = np.array(test_preds)
test_labels = np.array(test_labels)
# Calculate basic metrics
test_auc = roc_auc_score(test_labels, test_preds)
test_preds_binary = (test_preds > 0.5).astype(int)
test_acc = accuracy_score(test_labels, test_preds_binary)
precision, recall, f1, _ = precision_recall_fscore_support(
test_labels, test_preds_binary, average="binary"
)
return {
"auc": test_auc,
"accuracy": test_acc,
"precision": precision,
"recall": recall,
"f1": f1,
}
def save_epoch_analysis(test_results_by_epoch, output_dir):
"""Save analysis of results across epochs."""
epochs = sorted(test_results_by_epoch.keys())
# Create summary DataFrame
summary_data = []
for epoch in epochs:
results = test_results_by_epoch[epoch]
summary_data.append(
{
"epoch": epoch,
"test_auc": results["auc"],
"test_accuracy": results["accuracy"],
"test_precision": results["precision"],
"test_recall": results["recall"],
"test_f1": results["f1"],
}
)
df = pd.DataFrame(summary_data)
# Save to CSV
csv_path = os.path.join(output_dir, "epoch_analysis.csv")
df.to_csv(csv_path, index=False)
# Save to JSON
json_path = os.path.join(output_dir, "epoch_analysis.json")
with open(json_path, "w") as f:
json.dump(test_results_by_epoch, f, indent=2)
# Print summary
print("\n" + "=" * 50)
print("EPOCH-WISE TEST PERFORMANCE ANALYSIS")
print("=" * 50)
best_auc_epoch = df.loc[df["test_auc"].idxmax()]
best_f1_epoch = df.loc[df["test_f1"].idxmax()]
print(
f"Best Test AUC: {best_auc_epoch['test_auc']:.4f} at Epoch {best_auc_epoch['epoch']}"
)
print(
f"Best Test F1: {best_f1_epoch['test_f1']:.4f} at Epoch {best_f1_epoch['epoch']}"
)
print()
print("Epoch-wise Performance:")
print(df.to_string(index=False, float_format="%.4f"))
# Check for overfitting
if len(epochs) >= 2:
auc_trend = df["test_auc"].iloc[-1] - df["test_auc"].iloc[0]
if auc_trend < -0.01: # Significant decrease
print(
f"\n⚠️ OVERFITTING DETECTED: Test AUC decreased by {abs(auc_trend):.4f} from epoch {epochs[0]} to {epochs[-1]}"
)
elif auc_trend > 0.01:
print(
f"\n✅ GOOD TRAINING: Test AUC improved by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}"
)
else:
print(
f"\n📊 STABLE TRAINING: Test AUC changed by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}"
)
return df
def plot_training_history(train_losses, val_losses, val_aucs, output_dir):
"""Plot training history."""
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Loss plot
axes[0].plot(train_losses, label="Train Loss")
axes[0].plot(val_losses, label="Val Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training and Validation Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# AUC plot
axes[1].plot(val_aucs, label="Val AUC", color="green")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("AUC")
axes[1].set_title("Validation AUC")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "training_history.png"), dpi=100)
plt.close()
def plot_confusion_matrix(cm, output_dir):
"""Plot confusion matrix."""
plt.figure(figsize=(6, 5))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=["Non-binding", "TBX5-binding"],
yticklabels=["Non-binding", "TBX5-binding"],
)
plt.title("Confusion Matrix")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=100)
plt.close()
def main():
parser = argparse.ArgumentParser(description="Train TBX5 classifier with forward and reverse complement embeddings")
parser.add_argument(
"--embeddings-dir",
type=str,
default="tbx5_embeddings",
help="Directory containing forward embeddings (default: tbx5_embeddings)",
)
parser.add_argument(
"--rc-embeddings-dir",
type=str,
default="tbx5_embeddings_reverse_complement",
help="Directory containing reverse complement embeddings (default: tbx5_embeddings_reverse_complement)",
)
parser.add_argument(
"--output-dir",
type=str,
default="result_with_rc",
help="Output directory for results (default: result_with_rc)",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size for training (default: 32)",
)
parser.add_argument(
"--num-epochs",
type=int,
default=500,
help="Number of training epochs (default: 500)",
)
parser.add_argument(
"--learning-rate",
type=float,
default=1e-4,
help="Learning rate (default: 1e-4)",
)
parser.add_argument(
"--patience",
type=int,
default=100,
help="Early stopping patience (default: 100)",
)
parser.add_argument(
"--dropout-rate",
type=float,
default=0.5,
help="Dropout rate (default: 0.5)",
)
parser.add_argument(
"--processed-data-dir",
type=str,
default="processed_data_new",
help="Directory containing train/val/test CSV files (default: processed_data_new)",
)
args = parser.parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load embeddings using CSV splits
print("Loading combined embeddings using CSV splits...")
(X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train,
X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val,
X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test,
metadata) = load_tbx5_embeddings_with_rc_from_csv(
args.embeddings_dir, args.rc_embeddings_dir, args.processed_data_dir
)
# Save metadata
with open(os.path.join(args.output_dir, 'metadata.json'), 'w') as f:
json.dump(metadata, f, indent=2)
# Scale features
X_train_scaled, X_val_scaled, X_test_scaled, scaler = prepare_data_with_scaling(
X_train, X_val, X_test, y_train, y_val, y_test
)
# Save scaler
with open(os.path.join(args.output_dir, 'scaler.pkl'), 'wb') as f:
pickle.dump(scaler, f)
# Create data loaders
train_dataset = TensorDataset(torch.FloatTensor(X_train_scaled), torch.LongTensor(y_train))
val_dataset = TensorDataset(torch.FloatTensor(X_val_scaled), torch.LongTensor(y_val))
test_dataset = TensorDataset(torch.FloatTensor(X_test_scaled), torch.LongTensor(y_test))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
# Initialize model
input_dim = X_train_scaled.shape[1]
print(f"Input dimension: {input_dim}")
model = TBX5ClassifierWithRC(input_dim=input_dim, dropout_rate=args.dropout_rate).to(device)
# Print model architecture
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Train model
results, test_results_by_epoch = train_model(
model, train_loader, val_loader, test_loader, device, args.output_dir,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
patience=args.patience,
)
# Save epoch analysis
save_epoch_analysis(test_results_by_epoch, args.output_dir)
# Plot results
plot_training_history(results.get('train_losses', []), results.get('val_losses', []), results.get('val_aucs', []), args.output_dir)
plot_confusion_matrix(np.array(results['confusion_matrix']), args.output_dir)
print(f"\nTraining completed! Results saved to {args.output_dir}")
print(f"Best test AUC: {results['test_auc']:.4f}")
if __name__ == "__main__":
main()
# Check for overfitting
if len(epochs) >= 2:
auc_trend = df["test_auc"].iloc[-1] - df["test_auc"].iloc[0]
if auc_trend < -0.01: # Significant decrease
print(
f"\n⚠️ OVERFITTING DETECTED: Test AUC decreased by {abs(auc_trend):.4f} from epoch {epochs[0]} to {epochs[-1]}"
)
elif auc_trend > 0.01:
print(
f"\n✅ GOOD TRAINING: Test AUC improved by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}"
)
else:
print(
f"\n📊 STABLE TRAINING: Test AUC changed by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}"
)
return df
def plot_training_history(train_losses, val_losses, val_aucs, output_dir):
"""Plot training history."""
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Loss plot
axes[0].plot(train_losses, label="Train Loss")
axes[0].plot(val_losses, label="Val Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training and Validation Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# AUC plot
axes[1].plot(val_aucs, label="Val AUC", color="green")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("AUC")
axes[1].set_title("Validation AUC")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "training_history.png"), dpi=100)
plt.close()
def plot_confusion_matrix(cm, output_dir):
"""Plot confusion matrix."""
plt.figure(figsize=(6, 5))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=["Non-binding", "TBX5-binding"],
yticklabels=["Non-binding", "TBX5-binding"],
)
plt.title("Confusion Matrix")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=100)
plt.close()
def main():
parser = argparse.ArgumentParser(description="Train TBX5 classifier with forward and reverse complement embeddings")
parser.add_argument(
"--embeddings-dir",
type=str,
default="tbx5_embeddings",
help="Directory containing forward embeddings (default: tbx5_embeddings)",
)
parser.add_argument(
"--rc-embeddings-dir",
type=str,
default="tbx5_embeddings_reverse_complement",
help="Directory containing reverse complement embeddings (default: tbx5_embeddings_reverse_complement)",
)
parser.add_argument(
"--output-dir",
type=str,
default="result_with_rc",
help="Output directory for results (default: result_with_rc)",
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size for training (default: 32)",
)
parser.add_argument(
"--num-epochs",
type=int,
default=500,
help="Number of training epochs (default: 500)",
)
parser.add_argument(
"--learning-rate",
type=float,
default=1e-4,
help="Learning rate (default: 1e-4)",
)
parser.add_argument(
"--patience",
type=int,
default=100,
help="Early stopping patience (default: 100)",
)
parser.add_argument(
"--dropout-rate",
type=float,
default=0.5,
help="Dropout rate (default: 0.5)",
)
parser.add_argument(
"--processed-data-dir",
type=str,
default="processed_data_new",
help="Directory containing train/val/test CSV files (default: processed_data_new)",
)
args = parser.parse_args()
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load embeddings using CSV splits
print("Loading combined embeddings using CSV splits...")
(X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train,
X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val,
X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test,
metadata) = load_tbx5_embeddings_with_rc_from_csv(
args.embeddings_dir, args.rc_embeddings_dir, args.processed_data_dir
)
# Save metadata
with open(os.path.join(args.output_dir, 'metadata.json'), 'w') as f:
json.dump(metadata, f, indent=2)
# Scale features
X_train_scaled, X_val_scaled, X_test_scaled, scaler = prepare_data_with_scaling(
X_train, X_val, X_test, y_train, y_val, y_test
)
# Save scaler
with open(os.path.join(args.output_dir, 'scaler.pkl'), 'wb') as f:
pickle.dump(scaler, f)
# Create data loaders
train_dataset = TensorDataset(torch.FloatTensor(X_train_scaled), torch.LongTensor(y_train))
val_dataset = TensorDataset(torch.FloatTensor(X_val_scaled), torch.LongTensor(y_val))
test_dataset = TensorDataset(torch.FloatTensor(X_test_scaled), torch.LongTensor(y_test))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
# Initialize model
input_dim = X_train_scaled.shape[1]
print(f"Input dimension: {input_dim}")
model = TBX5ClassifierWithRC(input_dim=input_dim, dropout_rate=args.dropout_rate).to(device)
# Print model architecture
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Train model
results, test_results_by_epoch = train_model(
model, train_loader, val_loader, test_loader, device, args.output_dir,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
patience=args.patience,
)
# Save epoch analysis
save_epoch_analysis(test_results_by_epoch, args.output_dir)
# Plot results
plot_training_history(results.get('train_losses', []), results.get('val_losses', []), results.get('val_aucs', []), args.output_dir)
plot_confusion_matrix(np.array(results['confusion_matrix']), args.output_dir)
print(f"\nTraining completed! Results saved to {args.output_dir}")
print(f"Best test AUC: {results['test_auc']:.4f}")
if __name__ == "__main__":
main()