import pandas as pd import numpy as np import torch from sklearn.model_selection import train_test_split from sklearn.metrics import precision_recall_curve, f1_score import optuna from optuna.trial import TrialState import xgboost as xgb import os from datasets import load_from_disk from lightning.pytorch import seed_everything from rdkit import Chem, rdBase, DataStructs from typing import List from rdkit.Chem import AllChem import matplotlib.pyplot as plt from sklearn.metrics import accuracy_score, roc_auc_score base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse" def save_and_plot_binary_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, threshold, output_path): """ Saves the true and predicted values for training and validation sets, and generates binary classification plots. Parameters: y_true_train (array): True labels for the training set. y_pred_train (array): Predicted probabilities for the training set. y_true_val (array): True labels for the validation set. y_pred_val (array): Predicted probabilities for the validation set. threshold (float): Classification threshold for predictions. output_path (str): Directory to save the CSV files and plots. """ os.makedirs(output_path, exist_ok=True) # Convert probabilities to binary predictions y_pred_train_binary = (y_pred_train >= threshold).astype(int) y_pred_val_binary = (y_pred_val >= threshold).astype(int) # Save training predictions train_df = pd.DataFrame({ 'True Label': y_true_train, 'Predicted Probability': y_pred_train, 'Predicted Label': y_pred_train_binary }) train_df.to_csv(os.path.join(output_path, 'train_predictions_binary.csv'), index=False) # Save validation predictions val_df = pd.DataFrame({ 'True Label': y_true_val, 'Predicted Probability': y_pred_val, 'Predicted Label': y_pred_val_binary }) val_df.to_csv(os.path.join(output_path, 'val_predictions_binary.csv'), index=False) # Plot training predictions plot_binary_correlation( y_true_train, y_pred_train, threshold, title="Training Set Binary Classification Plot", output_file=os.path.join(output_path, 'train_classification_plot.png') ) # Plot validation predictions plot_binary_correlation( y_true_val, y_pred_val, threshold, title="Validation Set Binary Classification Plot", output_file=os.path.join(output_path, 'val_classification_plot.png') ) def plot_binary_correlation(y_true, y_pred, threshold, title, output_file): """ Generates a scatter plot for binary classification and saves it to a file. Parameters: y_true (array): True labels. y_pred (array): Predicted probabilities. threshold (float): Classification threshold for predictions. title (str): Title of the plot. output_file (str): Path to save the plot. """ # Scatter plot plt.figure(figsize=(10, 8)) plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF') # Add threshold line plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}') # Add annotations plt.title(title) plt.xlabel("True Labels") plt.ylabel("Predicted Probability") plt.legend() # Save and show the plot plt.tight_layout() plt.savefig(output_file) plt.show() seed_everything(42) dataset = load_from_disk(f'{base_path}/data/solubility') sequences = np.stack(dataset['sequence']) # Ensure sequences are SMILES strings labels = np.stack(dataset['labels']) embeddings = np.stack(dataset['embedding']) # Initialize best F1 score and model path best_f1 = -np.inf best_model_path = f"{base_path}/src/solubility" # Trial callback def trial_info_callback(study, trial): if study.best_trial == trial: print(f"Trial {trial.number}:") print(f" Weighted F1 Score: {trial.value}") def objective(trial): # Define hyperparameters params = { 'objective': 'binary:logistic', 'lambda': trial.suggest_float('lambda', 1e-8, 50.0, log=True), 'alpha': trial.suggest_float('alpha', 1e-8, 50.0, log=True), 'colsample_bytree': trial.suggest_float('colsample_bytree', 0.3, 1.0), 'subsample': trial.suggest_float('subsample', 0.5, 1.0), 'learning_rate': trial.suggest_float('learning_rate', 0.001, 0.3), 'max_depth': trial.suggest_int('max_depth', 2, 15), 'min_child_weight': trial.suggest_int('min_child_weight', 1, 500), 'gamma': trial.suggest_float('gamma', 0, 10.0), 'tree_method': 'hist', 'device': 'cuda:6', } # Suggest number of boosting rounds num_boost_round = trial.suggest_int('num_boost_round', 10, 1000) threshold = 0.5 # Initial classification threshold # Split the data train_idx, val_idx = train_test_split( np.arange(len(sequences)), test_size=0.2, stratify=labels, random_state=42 ) train_subset = dataset.select(train_idx).with_format("torch") val_subset = dataset.select(val_idx).with_format("torch") # Extract embeddings and labels for train/validation train_embeddings = np.array(train_subset['embedding']) valid_embeddings = np.array(val_subset['embedding']) train_labels = np.array(train_subset['labels']) valid_labels = np.array(val_subset['labels']) # Prepare training and validation sets dtrain = xgb.DMatrix(train_embeddings, label=train_labels) dvalid = xgb.DMatrix(valid_embeddings, label=valid_labels) # Train the model model = xgb.train( params=params, dtrain=dtrain, num_boost_round=num_boost_round, evals=[(dvalid, "validation")], early_stopping_rounds=50, verbose_eval=False, ) # Predict probabilities preds_train = model.predict(dtrain) preds_val = model.predict(dvalid) # Calculate metrics f1_val = f1_score(valid_labels, (preds_val >= threshold).astype(int), average="weighted") auc_val = roc_auc_score(valid_labels, preds_val) print(f"Trial {trial.number}: AUC: {auc_val:.3f}, F1 Score: {f1_val:.3f}") # Save the model if it has the best F1 score current_best = trial.study.user_attrs.get("best_f1", -np.inf) if f1_val > current_best: trial.study.set_user_attr("best_f1", f1_val) trial.study.set_user_attr("best_auc", auc_val) trial.study.set_user_attr("best_trial", trial.number) os.makedirs(best_model_path, exist_ok=True) # Save the model model.save_model(os.path.join(best_model_path, "best_model_f1.json")) print(f"āœ“ NEW BEST! Trial {trial.number}: F1={f1_val:.4f}, AUC={auc_val:.4f} - Model saved!") # Save and plot binary predictions save_and_plot_binary_predictions( train_labels, preds_train, valid_labels, preds_val, threshold, best_model_path ) return f1_val if __name__ == "__main__": study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) study.optimize(objective, n_trials=200) # Prepare summary text summary = [] summary.append("\n" + "="*60) summary.append("OPTIMIZATION COMPLETE") summary.append("="*60) summary.append(f"Number of finished trials: {len(study.trials)}") summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}") summary.append(f"Best F1 Score: {study.user_attrs.get('best_f1', None):.4f}") summary.append(f"Best AUC Score: {study.user_attrs.get('best_auc', None):.4f}") summary.append(f"Optuna Best Trial Value: {study.best_trial.value:.4f}") summary.append(f"\nBest hyperparameters:") for key, value in study.best_trial.params.items(): summary.append(f" {key}: {value}") summary.append("="*60) # Print to console for line in summary: print(line) # Save to file metrics_file = os.path.join(best_model_path, "optimization_metrics.txt") with open(metrics_file, 'w') as f: f.write('\n'.join(summary)) print(f"\nāœ“ Metrics saved to: {metrics_file}")