|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
from sklearn.model_selection import train_test_split |
|
|
from sklearn.metrics import precision_recall_curve, f1_score |
|
|
import optuna |
|
|
from optuna.trial import TrialState |
|
|
import xgboost as xgb |
|
|
import os |
|
|
from datasets import load_from_disk |
|
|
from lightning.pytorch import seed_everything |
|
|
from rdkit import Chem, rdBase, DataStructs |
|
|
from typing import List |
|
|
from rdkit.Chem import AllChem |
|
|
import matplotlib.pyplot as plt |
|
|
from sklearn.metrics import accuracy_score, roc_auc_score |
|
|
|
|
|
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse" |
|
|
|
|
|
def save_and_plot_binary_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, threshold, output_path): |
|
|
""" |
|
|
Saves the true and predicted values for training and validation sets, and generates binary classification plots. |
|
|
|
|
|
Parameters: |
|
|
y_true_train (array): True labels for the training set. |
|
|
y_pred_train (array): Predicted probabilities for the training set. |
|
|
y_true_val (array): True labels for the validation set. |
|
|
y_pred_val (array): Predicted probabilities for the validation set. |
|
|
threshold (float): Classification threshold for predictions. |
|
|
output_path (str): Directory to save the CSV files and plots. |
|
|
""" |
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
|
|
|
|
|
|
y_pred_train_binary = (y_pred_train >= threshold).astype(int) |
|
|
y_pred_val_binary = (y_pred_val >= threshold).astype(int) |
|
|
|
|
|
|
|
|
train_df = pd.DataFrame({ |
|
|
'True Label': y_true_train, |
|
|
'Predicted Probability': y_pred_train, |
|
|
'Predicted Label': y_pred_train_binary |
|
|
}) |
|
|
train_df.to_csv(os.path.join(output_path, 'train_predictions_binary.csv'), index=False) |
|
|
|
|
|
|
|
|
val_df = pd.DataFrame({ |
|
|
'True Label': y_true_val, |
|
|
'Predicted Probability': y_pred_val, |
|
|
'Predicted Label': y_pred_val_binary |
|
|
}) |
|
|
val_df.to_csv(os.path.join(output_path, 'val_predictions_binary.csv'), index=False) |
|
|
|
|
|
|
|
|
plot_binary_correlation( |
|
|
y_true_train, |
|
|
y_pred_train, |
|
|
threshold, |
|
|
title="Training Set Binary Classification Plot", |
|
|
output_file=os.path.join(output_path, 'train_classification_plot.png') |
|
|
) |
|
|
|
|
|
|
|
|
plot_binary_correlation( |
|
|
y_true_val, |
|
|
y_pred_val, |
|
|
threshold, |
|
|
title="Validation Set Binary Classification Plot", |
|
|
output_file=os.path.join(output_path, 'val_classification_plot.png') |
|
|
) |
|
|
|
|
|
def plot_binary_correlation(y_true, y_pred, threshold, title, output_file): |
|
|
""" |
|
|
Generates a scatter plot for binary classification and saves it to a file. |
|
|
|
|
|
Parameters: |
|
|
y_true (array): True labels. |
|
|
y_pred (array): Predicted probabilities. |
|
|
threshold (float): Classification threshold for predictions. |
|
|
title (str): Title of the plot. |
|
|
output_file (str): Path to save the plot. |
|
|
""" |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
|
plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF') |
|
|
|
|
|
|
|
|
plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}') |
|
|
|
|
|
|
|
|
plt.title(title) |
|
|
plt.xlabel("True Labels") |
|
|
plt.ylabel("Predicted Probability") |
|
|
plt.legend() |
|
|
|
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_file) |
|
|
plt.show() |
|
|
|
|
|
seed_everything(42) |
|
|
|
|
|
dataset = load_from_disk(f'{base_path}/data/solubility') |
|
|
|
|
|
sequences = np.stack(dataset['sequence']) |
|
|
labels = np.stack(dataset['labels']) |
|
|
embeddings = np.stack(dataset['embedding']) |
|
|
|
|
|
|
|
|
best_f1 = -np.inf |
|
|
best_model_path = f"{base_path}/src/solubility" |
|
|
|
|
|
|
|
|
def trial_info_callback(study, trial): |
|
|
if study.best_trial == trial: |
|
|
print(f"Trial {trial.number}:") |
|
|
print(f" Weighted F1 Score: {trial.value}") |
|
|
|
|
|
|
|
|
|
|
|
def objective(trial): |
|
|
|
|
|
params = { |
|
|
'objective': 'binary:logistic', |
|
|
'lambda': trial.suggest_float('lambda', 1e-8, 50.0, log=True), |
|
|
'alpha': trial.suggest_float('alpha', 1e-8, 50.0, log=True), |
|
|
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.3, 1.0), |
|
|
'subsample': trial.suggest_float('subsample', 0.5, 1.0), |
|
|
'learning_rate': trial.suggest_float('learning_rate', 0.001, 0.3), |
|
|
'max_depth': trial.suggest_int('max_depth', 2, 15), |
|
|
'min_child_weight': trial.suggest_int('min_child_weight', 1, 500), |
|
|
'gamma': trial.suggest_float('gamma', 0, 10.0), |
|
|
'tree_method': 'hist', |
|
|
'device': 'cuda:6', |
|
|
} |
|
|
|
|
|
|
|
|
num_boost_round = trial.suggest_int('num_boost_round', 10, 1000) |
|
|
threshold = 0.5 |
|
|
|
|
|
|
|
|
train_idx, val_idx = train_test_split( |
|
|
np.arange(len(sequences)), test_size=0.2, stratify=labels, random_state=42 |
|
|
) |
|
|
train_subset = dataset.select(train_idx).with_format("torch") |
|
|
val_subset = dataset.select(val_idx).with_format("torch") |
|
|
|
|
|
|
|
|
train_embeddings = np.array(train_subset['embedding']) |
|
|
valid_embeddings = np.array(val_subset['embedding']) |
|
|
train_labels = np.array(train_subset['labels']) |
|
|
valid_labels = np.array(val_subset['labels']) |
|
|
|
|
|
|
|
|
dtrain = xgb.DMatrix(train_embeddings, label=train_labels) |
|
|
dvalid = xgb.DMatrix(valid_embeddings, label=valid_labels) |
|
|
|
|
|
|
|
|
model = xgb.train( |
|
|
params=params, |
|
|
dtrain=dtrain, |
|
|
num_boost_round=num_boost_round, |
|
|
evals=[(dvalid, "validation")], |
|
|
early_stopping_rounds=50, |
|
|
verbose_eval=False, |
|
|
) |
|
|
|
|
|
|
|
|
preds_train = model.predict(dtrain) |
|
|
preds_val = model.predict(dvalid) |
|
|
|
|
|
|
|
|
f1_val = f1_score(valid_labels, (preds_val >= threshold).astype(int), average="weighted") |
|
|
auc_val = roc_auc_score(valid_labels, preds_val) |
|
|
print(f"Trial {trial.number}: AUC: {auc_val:.3f}, F1 Score: {f1_val:.3f}") |
|
|
|
|
|
|
|
|
current_best = trial.study.user_attrs.get("best_f1", -np.inf) |
|
|
if f1_val > current_best: |
|
|
trial.study.set_user_attr("best_f1", f1_val) |
|
|
trial.study.set_user_attr("best_auc", auc_val) |
|
|
trial.study.set_user_attr("best_trial", trial.number) |
|
|
os.makedirs(best_model_path, exist_ok=True) |
|
|
|
|
|
|
|
|
model.save_model(os.path.join(best_model_path, "best_model_f1.json")) |
|
|
print(f"✓ NEW BEST! Trial {trial.number}: F1={f1_val:.4f}, AUC={auc_val:.4f} - Model saved!") |
|
|
|
|
|
|
|
|
save_and_plot_binary_predictions( |
|
|
train_labels, preds_train, valid_labels, preds_val, threshold, best_model_path |
|
|
) |
|
|
|
|
|
return f1_val |
|
|
|
|
|
if __name__ == "__main__": |
|
|
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) |
|
|
study.optimize(objective, n_trials=200) |
|
|
|
|
|
|
|
|
summary = [] |
|
|
summary.append("\n" + "="*60) |
|
|
summary.append("OPTIMIZATION COMPLETE") |
|
|
summary.append("="*60) |
|
|
summary.append(f"Number of finished trials: {len(study.trials)}") |
|
|
summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}") |
|
|
summary.append(f"Best F1 Score: {study.user_attrs.get('best_f1', None):.4f}") |
|
|
summary.append(f"Best AUC Score: {study.user_attrs.get('best_auc', None):.4f}") |
|
|
summary.append(f"Optuna Best Trial Value: {study.best_trial.value:.4f}") |
|
|
summary.append(f"\nBest hyperparameters:") |
|
|
for key, value in study.best_trial.params.items(): |
|
|
summary.append(f" {key}: {value}") |
|
|
summary.append("="*60) |
|
|
|
|
|
|
|
|
for line in summary: |
|
|
print(line) |
|
|
|
|
|
|
|
|
metrics_file = os.path.join(best_model_path, "optimization_metrics.txt") |
|
|
with open(metrics_file, 'w') as f: |
|
|
f.write('\n'.join(summary)) |
|
|
print(f"\n✓ Metrics saved to: {metrics_file}") |