File size: 8,234 Bytes
813c6b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve, f1_score
import optuna
from optuna.trial import TrialState
import xgboost as xgb
import os
from datasets import load_from_disk
from lightning.pytorch import seed_everything
from rdkit import Chem, rdBase, DataStructs
from typing import List
from rdkit.Chem import AllChem
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"
def save_and_plot_binary_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, threshold, output_path):
"""
Saves the true and predicted values for training and validation sets, and generates binary classification plots.
Parameters:
y_true_train (array): True labels for the training set.
y_pred_train (array): Predicted probabilities for the training set.
y_true_val (array): True labels for the validation set.
y_pred_val (array): Predicted probabilities for the validation set.
threshold (float): Classification threshold for predictions.
output_path (str): Directory to save the CSV files and plots.
"""
os.makedirs(output_path, exist_ok=True)
# Convert probabilities to binary predictions
y_pred_train_binary = (y_pred_train >= threshold).astype(int)
y_pred_val_binary = (y_pred_val >= threshold).astype(int)
# Save training predictions
train_df = pd.DataFrame({
'True Label': y_true_train,
'Predicted Probability': y_pred_train,
'Predicted Label': y_pred_train_binary
})
train_df.to_csv(os.path.join(output_path, 'train_predictions_binary.csv'), index=False)
# Save validation predictions
val_df = pd.DataFrame({
'True Label': y_true_val,
'Predicted Probability': y_pred_val,
'Predicted Label': y_pred_val_binary
})
val_df.to_csv(os.path.join(output_path, 'val_predictions_binary.csv'), index=False)
# Plot training predictions
plot_binary_correlation(
y_true_train,
y_pred_train,
threshold,
title="Training Set Binary Classification Plot",
output_file=os.path.join(output_path, 'train_classification_plot.png')
)
# Plot validation predictions
plot_binary_correlation(
y_true_val,
y_pred_val,
threshold,
title="Validation Set Binary Classification Plot",
output_file=os.path.join(output_path, 'val_classification_plot.png')
)
def plot_binary_correlation(y_true, y_pred, threshold, title, output_file):
"""
Generates a scatter plot for binary classification and saves it to a file.
Parameters:
y_true (array): True labels.
y_pred (array): Predicted probabilities.
threshold (float): Classification threshold for predictions.
title (str): Title of the plot.
output_file (str): Path to save the plot.
"""
# Scatter plot
plt.figure(figsize=(10, 8))
plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')
# Add threshold line
plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}')
# Add annotations
plt.title(title)
plt.xlabel("True Labels")
plt.ylabel("Predicted Probability")
plt.legend()
# Save and show the plot
plt.tight_layout()
plt.savefig(output_file)
plt.show()
seed_everything(42)
dataset = load_from_disk(f'{base_path}/data/solubility')
sequences = np.stack(dataset['sequence']) # Ensure sequences are SMILES strings
labels = np.stack(dataset['labels'])
embeddings = np.stack(dataset['embedding'])
# Initialize best F1 score and model path
best_f1 = -np.inf
best_model_path = f"{base_path}/src/solubility"
# Trial callback
def trial_info_callback(study, trial):
if study.best_trial == trial:
print(f"Trial {trial.number}:")
print(f" Weighted F1 Score: {trial.value}")
def objective(trial):
# Define hyperparameters
params = {
'objective': 'binary:logistic',
'lambda': trial.suggest_float('lambda', 1e-8, 50.0, log=True),
'alpha': trial.suggest_float('alpha', 1e-8, 50.0, log=True),
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.3, 1.0),
'subsample': trial.suggest_float('subsample', 0.5, 1.0),
'learning_rate': trial.suggest_float('learning_rate', 0.001, 0.3),
'max_depth': trial.suggest_int('max_depth', 2, 15),
'min_child_weight': trial.suggest_int('min_child_weight', 1, 500),
'gamma': trial.suggest_float('gamma', 0, 10.0),
'tree_method': 'hist',
'device': 'cuda:6',
}
# Suggest number of boosting rounds
num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
threshold = 0.5 # Initial classification threshold
# Split the data
train_idx, val_idx = train_test_split(
np.arange(len(sequences)), test_size=0.2, stratify=labels, random_state=42
)
train_subset = dataset.select(train_idx).with_format("torch")
val_subset = dataset.select(val_idx).with_format("torch")
# Extract embeddings and labels for train/validation
train_embeddings = np.array(train_subset['embedding'])
valid_embeddings = np.array(val_subset['embedding'])
train_labels = np.array(train_subset['labels'])
valid_labels = np.array(val_subset['labels'])
# Prepare training and validation sets
dtrain = xgb.DMatrix(train_embeddings, label=train_labels)
dvalid = xgb.DMatrix(valid_embeddings, label=valid_labels)
# Train the model
model = xgb.train(
params=params,
dtrain=dtrain,
num_boost_round=num_boost_round,
evals=[(dvalid, "validation")],
early_stopping_rounds=50,
verbose_eval=False,
)
# Predict probabilities
preds_train = model.predict(dtrain)
preds_val = model.predict(dvalid)
# Calculate metrics
f1_val = f1_score(valid_labels, (preds_val >= threshold).astype(int), average="weighted")
auc_val = roc_auc_score(valid_labels, preds_val)
print(f"Trial {trial.number}: AUC: {auc_val:.3f}, F1 Score: {f1_val:.3f}")
# Save the model if it has the best F1 score
current_best = trial.study.user_attrs.get("best_f1", -np.inf)
if f1_val > current_best:
trial.study.set_user_attr("best_f1", f1_val)
trial.study.set_user_attr("best_auc", auc_val)
trial.study.set_user_attr("best_trial", trial.number)
os.makedirs(best_model_path, exist_ok=True)
# Save the model
model.save_model(os.path.join(best_model_path, "best_model_f1.json"))
print(f"✓ NEW BEST! Trial {trial.number}: F1={f1_val:.4f}, AUC={auc_val:.4f} - Model saved!")
# Save and plot binary predictions
save_and_plot_binary_predictions(
train_labels, preds_train, valid_labels, preds_val, threshold, best_model_path
)
return f1_val
if __name__ == "__main__":
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=200)
# Prepare summary text
summary = []
summary.append("\n" + "="*60)
summary.append("OPTIMIZATION COMPLETE")
summary.append("="*60)
summary.append(f"Number of finished trials: {len(study.trials)}")
summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}")
summary.append(f"Best F1 Score: {study.user_attrs.get('best_f1', None):.4f}")
summary.append(f"Best AUC Score: {study.user_attrs.get('best_auc', None):.4f}")
summary.append(f"Optuna Best Trial Value: {study.best_trial.value:.4f}")
summary.append(f"\nBest hyperparameters:")
for key, value in study.best_trial.params.items():
summary.append(f" {key}: {value}")
summary.append("="*60)
# Print to console
for line in summary:
print(line)
# Save to file
metrics_file = os.path.join(best_model_path, "optimization_metrics.txt")
with open(metrics_file, 'w') as f:
f.write('\n'.join(summary))
print(f"\n✓ Metrics saved to: {metrics_file}") |