#!/usr/bin/env python3 """ Train TBX5 classifier using both forward and reverse complement embeddings. This script combines embeddings from both strands to improve classification accuracy. """ import os import sys import argparse import numpy as np import pandas as pd import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.metrics import ( roc_auc_score, accuracy_score, precision_recall_fscore_support, confusion_matrix, ) import json import pickle from tqdm import tqdm import matplotlib.pyplot as plt import seaborn as sns from datetime import datetime # Add the parent directory to the path to import from finetuning sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'finetuning')) class TBX5ClassifierWithRC(nn.Module): """ 3-layer feedforward neural network for TBX5 binding site classification using both forward and reverse complement embeddings. Architecture: - Input (8192 dimensions: 4096 forward + 4096 reverse complement) -> 2048 -> 512 -> 128 -> 1 (sigmoid) - ReLU activation, BatchNorm, Dropout(0.5) after each hidden layer """ def __init__(self, input_dim=8192, dropout_rate=0.5): super(TBX5ClassifierWithRC, self).__init__() self.fc1 = nn.Linear(input_dim, 2048) self.bn1 = nn.BatchNorm1d(2048) self.dropout1 = nn.Dropout(dropout_rate) self.fc2 = nn.Linear(2048, 512) self.bn2 = nn.BatchNorm1d(512) self.dropout2 = nn.Dropout(dropout_rate) self.fc3 = nn.Linear(512, 128) self.bn3 = nn.BatchNorm1d(128) self.dropout3 = nn.Dropout(dropout_rate) self.fc4 = nn.Linear(128, 1) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def forward(self, x): # Layer 1 x = self.fc1(x) x = self.relu(x) x = self.bn1(x) x = self.dropout1(x) # Layer 2 x = self.fc2(x) x = self.relu(x) x = self.bn2(x) x = self.dropout2(x) # Layer 3 x = self.fc3(x) x = self.relu(x) x = self.bn3(x) x = self.dropout3(x) # Output layer x = self.fc4(x) x = self.sigmoid(x) return x def load_tbx5_embeddings_with_rc_from_csv(embeddings_dir, rc_embeddings_dir, processed_data_dir): """ Load TBX5 embeddings using train/val/test splits from processed_data_new CSV files. Args: embeddings_dir: Directory containing forward embeddings rc_embeddings_dir: Directory containing reverse complement embeddings processed_data_dir: Directory containing train/val/test CSV files Returns: train/val/test data splits with combined embeddings """ print(f"Loading data using CSV splits from: {processed_data_dir}") print(f"Loading forward embeddings from: {embeddings_dir}") print(f"Loading reverse complement embeddings from: {rc_embeddings_dir}") # Load CSV files train_df = pd.read_csv(os.path.join(processed_data_dir, 'train_tbx5_data_new.csv')) val_df = pd.read_csv(os.path.join(processed_data_dir, 'val_tbx5_data_new.csv')) test_df = pd.read_csv(os.path.join(processed_data_dir, 'test_tbx5_data_new.csv')) print(f"Train samples: {len(train_df)}") print(f"Val samples: {len(val_df)}") print(f"Test samples: {len(test_df)}") def load_embeddings_for_split(df, embeddings_dir, rc_embeddings_dir): """Load embeddings for a specific split.""" all_embeddings = [] all_labels = [] all_starts = [] all_ends = [] all_tbx5_scores = [] all_chromosomes = [] total_samples = len(df) found_samples = 0 missing_files = 0 missing_samples = 0 # Keep track of loaded chromosome data to avoid reloading loaded_chrom_data = {} # Process samples in original order to maintain sequence for idx, row in df.iterrows(): chrom_num = row['chromosome'] chrom = f"chr{chrom_num}" start = row['start'] end = row['end'] label = row['label'] tbx5_score = row['tbx5_score'] # Load chromosome data if not already loaded if chrom not in loaded_chrom_data: forward_file = os.path.join(embeddings_dir, f"{chrom}_tbx5_embeddings_arrays.npz") rc_file = os.path.join(rc_embeddings_dir, f"{chrom}_tbx5_embeddings_rc_arrays.npz") if os.path.exists(forward_file) and os.path.exists(rc_file): print(f" Loading {chrom}...") forward_data = np.load(forward_file) rc_data = np.load(rc_file) loaded_chrom_data[chrom] = { 'forward_embeddings': forward_data['embeddings'], 'forward_starts': forward_data['starts'], 'forward_ends': forward_data['ends'], 'forward_tbx5_scores': forward_data['tbx5_scores'], 'rc_embeddings': rc_data['embeddings'], 'rc_starts': rc_data['starts'], 'rc_ends': rc_data['ends'], 'rc_tbx5_scores': rc_data['tbx5_scores'] } else: print(f" Warning: Missing embedding files for {chrom}") loaded_chrom_data[chrom] = None missing_files += 1 continue # Skip if chromosome data not available if loaded_chrom_data[chrom] is None: missing_samples += 1 continue chrom_data = loaded_chrom_data[chrom] forward_starts = chrom_data['forward_starts'] forward_embeddings = chrom_data['forward_embeddings'] rc_embeddings = chrom_data['rc_embeddings'] # Find matching sample in embeddings (use chromosome and start only) mask = (forward_starts == start) if np.any(mask): # If multiple matches, take the first one emb_idx = np.where(mask)[0][0] # Get embeddings forward_emb = forward_embeddings[emb_idx] rc_emb = rc_embeddings[emb_idx] # Combine embeddings combined_emb = np.concatenate([forward_emb, rc_emb]) all_embeddings.append(combined_emb) all_labels.append(label) all_starts.append(start) all_ends.append(end) all_tbx5_scores.append(tbx5_score) all_chromosomes.append(chrom) found_samples += 1 else: missing_samples += 1 # Skip missing samples instead of adding zeros continue print(f" Summary: {found_samples}/{total_samples} samples loaded") print(f" Missing files: {missing_files} samples") print(f" Missing embeddings: {missing_samples} samples") return ( np.array(all_embeddings), np.array(all_labels), np.array(all_starts), np.array(all_ends), np.array(all_tbx5_scores), all_chromosomes ) # Load data for each split print("Loading train data...") X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train = load_embeddings_for_split( train_df, embeddings_dir, rc_embeddings_dir ) print("Loading validation data...") X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val = load_embeddings_for_split( val_df, embeddings_dir, rc_embeddings_dir ) print("Loading test data...") X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test = load_embeddings_for_split( test_df, embeddings_dir, rc_embeddings_dir ) print(f"\nLoaded data:") print(f"Train: {len(X_train)} samples") print(f"Val: {len(X_val)} samples") print(f"Test: {len(X_test)} samples") print(f"Embedding dimension: {X_train.shape[1]}") print(f"Train positive samples: {np.sum(y_train)}") print(f"Val positive samples: {np.sum(y_val)}") print(f"Test positive samples: {np.sum(y_test)}") # Check if we have enough data if len(X_train) == 0: raise ValueError("No training data loaded! Check embedding files and CSV data.") if len(X_val) == 0: raise ValueError("No validation data loaded! Check embedding files and CSV data.") if len(X_test) == 0: raise ValueError("No test data loaded! Check embedding files and CSV data.") print(f"\nData quality check:") print(f"Train positive ratio: {np.mean(y_train):.3f}") print(f"Val positive ratio: {np.mean(y_val):.3f}") print(f"Test positive ratio: {np.mean(y_test):.3f}") metadata = { "total_samples": len(X_train) + len(X_val) + len(X_test), "embedding_dim": X_train.shape[1], "train_samples": len(X_train), "val_samples": len(X_val), "test_samples": len(X_test), "train_positive": int(np.sum(y_train)), "val_positive": int(np.sum(y_val)), "test_positive": int(np.sum(y_test)), "sequence_type": "forward_and_reverse_complement" } return ( X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train, X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val, X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test, metadata ) def prepare_data_with_scaling(X_train, X_val, X_test, y_train, y_val, y_test): """ Scale the features for train/val/test splits. """ print("Scaling features...") # Scale features scaler = StandardScaler() X_train_scaled = scaler.fit_transform(X_train) X_val_scaled = scaler.transform(X_val) X_test_scaled = scaler.transform(X_test) return X_train_scaled, X_val_scaled, X_test_scaled, scaler def train_model( model, train_loader, val_loader, test_loader, device, output_dir, num_epochs=500, learning_rate=1e-4, patience=100, lr_patience=20, min_lr=1e-6, gradient_clip=1.0, save_every=5, ): """ Train the model with specified optimization settings. """ print(f"Training model with learning rate {learning_rate}") print(f"Early stopping patience: {patience}") print(f"Learning rate reduction patience: {lr_patience}") # Loss and optimizer criterion = nn.BCELoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=lr_patience, min_lr=min_lr ) # Training history train_losses = [] val_losses = [] val_aucs = [] test_results_by_epoch = {} # Store test results for each saved epoch best_val_auc = 0.0 best_epoch = 0 epochs_without_improvement = 0 print(f"Starting training for {num_epochs} epochs...") for epoch in range(num_epochs): # Training phase model.train() train_loss = 0.0 train_correct = 0 train_total = 0 for batch_embeddings, batch_labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"): batch_embeddings = batch_embeddings.to(device) batch_labels = batch_labels.to(device).float() optimizer.zero_grad() outputs = model(batch_embeddings).squeeze() loss = criterion(outputs, batch_labels) loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) optimizer.step() train_loss += loss.item() predicted = (outputs > 0.5).float() train_correct += (predicted == batch_labels).sum().item() train_total += batch_labels.size(0) train_loss /= len(train_loader) train_acc = train_correct / train_total # Validation phase model.eval() val_loss = 0.0 val_correct = 0 val_total = 0 val_predictions = [] val_labels = [] with torch.no_grad(): for batch_embeddings, batch_labels in val_loader: batch_embeddings = batch_embeddings.to(device) batch_labels = batch_labels.to(device).float() outputs = model(batch_embeddings).squeeze() loss = criterion(outputs, batch_labels) val_loss += loss.item() predicted = (outputs > 0.5).float() val_correct += (predicted == batch_labels).sum().item() val_total += batch_labels.size(0) val_predictions.extend(outputs.cpu().numpy()) val_labels.extend(batch_labels.cpu().numpy()) val_loss /= len(val_loader) val_acc = val_correct / val_total val_auc = roc_auc_score(val_labels, val_predictions) # Update learning rate scheduler.step(val_loss) # Store history train_losses.append(train_loss) val_losses.append(val_loss) val_aucs.append(val_auc) # Check for improvement if val_auc > best_val_auc: best_val_auc = val_auc best_epoch = epoch epochs_without_improvement = 0 # Save best model torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch, 'val_auc': val_auc, 'val_loss': val_loss, 'input_dim': model.fc1.in_features, }, os.path.join(output_dir, 'best_model.pth')) print(f"New best model saved! Val AUC: {val_auc:.4f}") else: epochs_without_improvement += 1 # Save model and evaluate every N epochs if (epoch + 1) % save_every == 0 or epoch == 0: # Save model state epoch_model_path = os.path.join(output_dir, f"model_epoch_{epoch+1}.pth") torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch + 1, 'val_auc': val_auc, 'val_loss': val_loss, 'input_dim': model.fc1.in_features, }, epoch_model_path) # Evaluate on test set test_results = evaluate_model_simple(model, test_loader, device) test_results_by_epoch[epoch + 1] = test_results print(f"Epoch {epoch+1:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}, " f"Test AUC: {test_results['auc']:.4f}") # Print progress for other epochs elif (epoch + 1) % 10 == 0: current_lr = optimizer.param_groups[0]['lr'] print(f"Epoch {epoch+1:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}, " f"LR: {current_lr:.2e}") # Early stopping if epochs_without_improvement >= patience: print(f"Early stopping at epoch {epoch+1} (no improvement for {patience} epochs)") break print(f"Training completed! Best validation AUC: {best_val_auc:.4f} at epoch {best_epoch+1}") # Load best model for testing checkpoint = torch.load(os.path.join(output_dir, 'best_model.pth'), map_location=device, weights_only=False) model.load_state_dict(checkpoint['model_state_dict']) # Test evaluation model.eval() test_predictions = [] test_labels = [] test_loss = 0.0 test_correct = 0 test_total = 0 with torch.no_grad(): for batch_embeddings, batch_labels in test_loader: batch_embeddings = batch_embeddings.to(device) batch_labels = batch_labels.to(device).float() outputs = model(batch_embeddings).squeeze() loss = criterion(outputs, batch_labels) test_loss += loss.item() predicted = (outputs > 0.5).float() test_correct += (predicted == batch_labels).sum().item() test_total += batch_labels.size(0) test_predictions.extend(outputs.cpu().numpy()) test_labels.extend(batch_labels.cpu().numpy()) test_loss /= len(test_loader) test_acc = test_correct / test_total test_auc = roc_auc_score(test_labels, test_predictions) # Calculate additional metrics precision, recall, f1, _ = precision_recall_fscore_support(test_labels, [1 if p > 0.5 else 0 for p in test_predictions], average='binary') cm = confusion_matrix(test_labels, [1 if p > 0.5 else 0 for p in test_predictions]) # Save results results = { 'test_auc': float(test_auc), 'test_accuracy': float(test_acc), 'test_loss': float(test_loss), 'test_precision': float(precision), 'test_recall': float(recall), 'test_f1': float(f1), 'confusion_matrix': cm.tolist(), 'best_val_auc': float(best_val_auc), 'best_epoch': int(best_epoch + 1), 'total_epochs': int(epoch + 1), 'sequence_type': 'forward_and_reverse_complement', 'predictions': [float(p) for p in test_predictions], 'labels': [float(l) for l in test_labels] } with open(os.path.join(output_dir, 'test_results.json'), 'w') as f: json.dump(results, f, indent=2) # Save training history history = { 'train_losses': train_losses, 'val_losses': val_losses, 'val_aucs': val_aucs, 'best_epoch': best_epoch + 1, 'best_val_auc': best_val_auc } with open(os.path.join(output_dir, 'training_history.json'), 'w') as f: json.dump(history, f, indent=2) # Plot training history plt.figure(figsize=(15, 5)) plt.subplot(1, 3, 1) plt.plot(train_losses, label='Train Loss') plt.plot(val_losses, label='Val Loss') plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Training and Validation Loss') plt.legend() plt.grid(True, alpha=0.3) plt.subplot(1, 3, 2) plt.plot(val_aucs, label='Val AUC', color='green') plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})') plt.xlabel('Epoch') plt.ylabel('AUC') plt.title('Validation AUC') plt.legend() plt.grid(True, alpha=0.3) plt.subplot(1, 3, 3) plt.plot(range(len(train_losses)), train_losses, label='Train Loss') plt.plot(range(len(val_losses)), val_losses, label='Val Loss') plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Loss Comparison') plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(output_dir, 'training_history.png'), dpi=300, bbox_inches='tight') plt.close() print(f"\n=== Test Results ===") print(f"Test AUC: {test_auc:.4f}") print(f"Test Accuracy: {test_acc:.4f}") print(f"Test Precision: {precision:.4f}") print(f"Test Recall: {recall:.4f}") print(f"Test F1: {f1:.4f}") print(f"Confusion Matrix:\n{cm}") return results, test_results_by_epoch def evaluate_model_simple(model, test_loader, device): """Simple evaluation that returns just basic metrics.""" model.eval() test_preds = [] test_labels = [] with torch.no_grad(): for batch_X, batch_y in test_loader: batch_X = batch_X.to(device) outputs = model(batch_X).squeeze() test_preds.extend(outputs.cpu().numpy()) test_labels.extend(batch_y.numpy()) test_preds = np.array(test_preds) test_labels = np.array(test_labels) # Calculate basic metrics test_auc = roc_auc_score(test_labels, test_preds) test_preds_binary = (test_preds > 0.5).astype(int) test_acc = accuracy_score(test_labels, test_preds_binary) precision, recall, f1, _ = precision_recall_fscore_support( test_labels, test_preds_binary, average="binary" ) return { "auc": test_auc, "accuracy": test_acc, "precision": precision, "recall": recall, "f1": f1, } def save_epoch_analysis(test_results_by_epoch, output_dir): """Save analysis of results across epochs.""" epochs = sorted(test_results_by_epoch.keys()) # Create summary DataFrame summary_data = [] for epoch in epochs: results = test_results_by_epoch[epoch] summary_data.append( { "epoch": epoch, "test_auc": results["auc"], "test_accuracy": results["accuracy"], "test_precision": results["precision"], "test_recall": results["recall"], "test_f1": results["f1"], } ) df = pd.DataFrame(summary_data) # Save to CSV csv_path = os.path.join(output_dir, "epoch_analysis.csv") df.to_csv(csv_path, index=False) # Save to JSON json_path = os.path.join(output_dir, "epoch_analysis.json") with open(json_path, "w") as f: json.dump(test_results_by_epoch, f, indent=2) # Print summary print("\n" + "=" * 50) print("EPOCH-WISE TEST PERFORMANCE ANALYSIS") print("=" * 50) best_auc_epoch = df.loc[df["test_auc"].idxmax()] best_f1_epoch = df.loc[df["test_f1"].idxmax()] print( f"Best Test AUC: {best_auc_epoch['test_auc']:.4f} at Epoch {best_auc_epoch['epoch']}" ) print( f"Best Test F1: {best_f1_epoch['test_f1']:.4f} at Epoch {best_f1_epoch['epoch']}" ) print() print("Epoch-wise Performance:") print(df.to_string(index=False, float_format="%.4f")) # Check for overfitting if len(epochs) >= 2: auc_trend = df["test_auc"].iloc[-1] - df["test_auc"].iloc[0] if auc_trend < -0.01: # Significant decrease print( f"\n⚠️ OVERFITTING DETECTED: Test AUC decreased by {abs(auc_trend):.4f} from epoch {epochs[0]} to {epochs[-1]}" ) elif auc_trend > 0.01: print( f"\n✅ GOOD TRAINING: Test AUC improved by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}" ) else: print( f"\n📊 STABLE TRAINING: Test AUC changed by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}" ) return df def plot_training_history(train_losses, val_losses, val_aucs, output_dir): """Plot training history.""" fig, axes = plt.subplots(1, 2, figsize=(12, 4)) # Loss plot axes[0].plot(train_losses, label="Train Loss") axes[0].plot(val_losses, label="Val Loss") axes[0].set_xlabel("Epoch") axes[0].set_ylabel("Loss") axes[0].set_title("Training and Validation Loss") axes[0].legend() axes[0].grid(True, alpha=0.3) # AUC plot axes[1].plot(val_aucs, label="Val AUC", color="green") axes[1].set_xlabel("Epoch") axes[1].set_ylabel("AUC") axes[1].set_title("Validation AUC") axes[1].legend() axes[1].grid(True, alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(output_dir, "training_history.png"), dpi=100) plt.close() def plot_confusion_matrix(cm, output_dir): """Plot confusion matrix.""" plt.figure(figsize=(6, 5)) sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Non-binding", "TBX5-binding"], yticklabels=["Non-binding", "TBX5-binding"], ) plt.title("Confusion Matrix") plt.ylabel("True Label") plt.xlabel("Predicted Label") plt.tight_layout() plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=100) plt.close() def main(): parser = argparse.ArgumentParser(description="Train TBX5 classifier with forward and reverse complement embeddings") parser.add_argument( "--embeddings-dir", type=str, default="tbx5_embeddings", help="Directory containing forward embeddings (default: tbx5_embeddings)", ) parser.add_argument( "--rc-embeddings-dir", type=str, default="tbx5_embeddings_reverse_complement", help="Directory containing reverse complement embeddings (default: tbx5_embeddings_reverse_complement)", ) parser.add_argument( "--output-dir", type=str, default="result_with_rc", help="Output directory for results (default: result_with_rc)", ) parser.add_argument( "--batch-size", type=int, default=32, help="Batch size for training (default: 32)", ) parser.add_argument( "--num-epochs", type=int, default=500, help="Number of training epochs (default: 500)", ) parser.add_argument( "--learning-rate", type=float, default=1e-4, help="Learning rate (default: 1e-4)", ) parser.add_argument( "--patience", type=int, default=100, help="Early stopping patience (default: 100)", ) parser.add_argument( "--dropout-rate", type=float, default=0.5, help="Dropout rate (default: 0.5)", ) parser.add_argument( "--processed-data-dir", type=str, default="processed_data_new", help="Directory containing train/val/test CSV files (default: processed_data_new)", ) args = parser.parse_args() # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load embeddings using CSV splits print("Loading combined embeddings using CSV splits...") (X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train, X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val, X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test, metadata) = load_tbx5_embeddings_with_rc_from_csv( args.embeddings_dir, args.rc_embeddings_dir, args.processed_data_dir ) # Save metadata with open(os.path.join(args.output_dir, 'metadata.json'), 'w') as f: json.dump(metadata, f, indent=2) # Scale features X_train_scaled, X_val_scaled, X_test_scaled, scaler = prepare_data_with_scaling( X_train, X_val, X_test, y_train, y_val, y_test ) # Save scaler with open(os.path.join(args.output_dir, 'scaler.pkl'), 'wb') as f: pickle.dump(scaler, f) # Create data loaders train_dataset = TensorDataset(torch.FloatTensor(X_train_scaled), torch.LongTensor(y_train)) val_dataset = TensorDataset(torch.FloatTensor(X_val_scaled), torch.LongTensor(y_val)) test_dataset = TensorDataset(torch.FloatTensor(X_test_scaled), torch.LongTensor(y_test)) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) # Initialize model input_dim = X_train_scaled.shape[1] print(f"Input dimension: {input_dim}") model = TBX5ClassifierWithRC(input_dim=input_dim, dropout_rate=args.dropout_rate).to(device) # Print model architecture total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") # Train model results, test_results_by_epoch = train_model( model, train_loader, val_loader, test_loader, device, args.output_dir, num_epochs=args.num_epochs, learning_rate=args.learning_rate, patience=args.patience, ) # Save epoch analysis save_epoch_analysis(test_results_by_epoch, args.output_dir) # Plot results plot_training_history(results.get('train_losses', []), results.get('val_losses', []), results.get('val_aucs', []), args.output_dir) plot_confusion_matrix(np.array(results['confusion_matrix']), args.output_dir) print(f"\nTraining completed! Results saved to {args.output_dir}") print(f"Best test AUC: {results['test_auc']:.4f}") if __name__ == "__main__": main() # Check for overfitting if len(epochs) >= 2: auc_trend = df["test_auc"].iloc[-1] - df["test_auc"].iloc[0] if auc_trend < -0.01: # Significant decrease print( f"\n⚠️ OVERFITTING DETECTED: Test AUC decreased by {abs(auc_trend):.4f} from epoch {epochs[0]} to {epochs[-1]}" ) elif auc_trend > 0.01: print( f"\n✅ GOOD TRAINING: Test AUC improved by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}" ) else: print( f"\n📊 STABLE TRAINING: Test AUC changed by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}" ) return df def plot_training_history(train_losses, val_losses, val_aucs, output_dir): """Plot training history.""" fig, axes = plt.subplots(1, 2, figsize=(12, 4)) # Loss plot axes[0].plot(train_losses, label="Train Loss") axes[0].plot(val_losses, label="Val Loss") axes[0].set_xlabel("Epoch") axes[0].set_ylabel("Loss") axes[0].set_title("Training and Validation Loss") axes[0].legend() axes[0].grid(True, alpha=0.3) # AUC plot axes[1].plot(val_aucs, label="Val AUC", color="green") axes[1].set_xlabel("Epoch") axes[1].set_ylabel("AUC") axes[1].set_title("Validation AUC") axes[1].legend() axes[1].grid(True, alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(output_dir, "training_history.png"), dpi=100) plt.close() def plot_confusion_matrix(cm, output_dir): """Plot confusion matrix.""" plt.figure(figsize=(6, 5)) sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Non-binding", "TBX5-binding"], yticklabels=["Non-binding", "TBX5-binding"], ) plt.title("Confusion Matrix") plt.ylabel("True Label") plt.xlabel("Predicted Label") plt.tight_layout() plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=100) plt.close() def main(): parser = argparse.ArgumentParser(description="Train TBX5 classifier with forward and reverse complement embeddings") parser.add_argument( "--embeddings-dir", type=str, default="tbx5_embeddings", help="Directory containing forward embeddings (default: tbx5_embeddings)", ) parser.add_argument( "--rc-embeddings-dir", type=str, default="tbx5_embeddings_reverse_complement", help="Directory containing reverse complement embeddings (default: tbx5_embeddings_reverse_complement)", ) parser.add_argument( "--output-dir", type=str, default="result_with_rc", help="Output directory for results (default: result_with_rc)", ) parser.add_argument( "--batch-size", type=int, default=32, help="Batch size for training (default: 32)", ) parser.add_argument( "--num-epochs", type=int, default=500, help="Number of training epochs (default: 500)", ) parser.add_argument( "--learning-rate", type=float, default=1e-4, help="Learning rate (default: 1e-4)", ) parser.add_argument( "--patience", type=int, default=100, help="Early stopping patience (default: 100)", ) parser.add_argument( "--dropout-rate", type=float, default=0.5, help="Dropout rate (default: 0.5)", ) parser.add_argument( "--processed-data-dir", type=str, default="processed_data_new", help="Directory containing train/val/test CSV files (default: processed_data_new)", ) args = parser.parse_args() # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load embeddings using CSV splits print("Loading combined embeddings using CSV splits...") (X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train, X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val, X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test, metadata) = load_tbx5_embeddings_with_rc_from_csv( args.embeddings_dir, args.rc_embeddings_dir, args.processed_data_dir ) # Save metadata with open(os.path.join(args.output_dir, 'metadata.json'), 'w') as f: json.dump(metadata, f, indent=2) # Scale features X_train_scaled, X_val_scaled, X_test_scaled, scaler = prepare_data_with_scaling( X_train, X_val, X_test, y_train, y_val, y_test ) # Save scaler with open(os.path.join(args.output_dir, 'scaler.pkl'), 'wb') as f: pickle.dump(scaler, f) # Create data loaders train_dataset = TensorDataset(torch.FloatTensor(X_train_scaled), torch.LongTensor(y_train)) val_dataset = TensorDataset(torch.FloatTensor(X_val_scaled), torch.LongTensor(y_val)) test_dataset = TensorDataset(torch.FloatTensor(X_test_scaled), torch.LongTensor(y_test)) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) # Initialize model input_dim = X_train_scaled.shape[1] print(f"Input dimension: {input_dim}") model = TBX5ClassifierWithRC(input_dim=input_dim, dropout_rate=args.dropout_rate).to(device) # Print model architecture total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") # Train model results, test_results_by_epoch = train_model( model, train_loader, val_loader, test_loader, device, args.output_dir, num_epochs=args.num_epochs, learning_rate=args.learning_rate, patience=args.patience, ) # Save epoch analysis save_epoch_analysis(test_results_by_epoch, args.output_dir) # Plot results plot_training_history(results.get('train_losses', []), results.get('val_losses', []), results.get('val_aucs', []), args.output_dir) plot_confusion_matrix(np.array(results['confusion_matrix']), args.output_dir) print(f"\nTraining completed! Results saved to {args.output_dir}") print(f"Best test AUC: {results['test_auc']:.4f}") if __name__ == "__main__": main()