| | import pandas as pd |
| | import numpy as np |
| | import torch |
| | from sklearn.model_selection import train_test_split |
| | from sklearn.metrics import precision_recall_curve, f1_score |
| | import optuna |
| | from optuna.trial import TrialState |
| | import xgboost as xgb |
| | import os |
| | from datasets import load_from_disk |
| | from lightning.pytorch import seed_everything |
| | from rdkit import Chem, rdBase, DataStructs |
| | from typing import List |
| | from rdkit.Chem import AllChem |
| | import matplotlib.pyplot as plt |
| | from sklearn.metrics import accuracy_score, roc_auc_score |
| | import seaborn as sns |
| |
|
| | def save_and_plot_binary_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, threshold, output_path): |
| | """ |
| | Saves the true and predicted values for training and validation sets, and generates binary classification plots. |
| | |
| | Parameters: |
| | y_true_train (array): True labels for the training set. |
| | y_pred_train (array): Predicted probabilities for the training set. |
| | y_true_val (array): True labels for the validation set. |
| | y_pred_val (array): Predicted probabilities for the validation set. |
| | threshold (float): Classification threshold for predictions. |
| | output_path (str): Directory to save the CSV files and plots. |
| | """ |
| | os.makedirs(output_path, exist_ok=True) |
| |
|
| | |
| | y_pred_train_binary = (y_pred_train >= threshold).astype(int) |
| | y_pred_val_binary = (y_pred_val >= threshold).astype(int) |
| |
|
| | |
| | train_df = pd.DataFrame({ |
| | 'True Label': y_true_train, |
| | 'Predicted Probability': y_pred_train, |
| | 'Predicted Label': y_pred_train_binary |
| | }) |
| | train_df.to_csv(os.path.join(output_path, 'train_predictions_binary.csv'), index=False) |
| |
|
| | |
| | val_df = pd.DataFrame({ |
| | 'True Label': y_true_val, |
| | 'Predicted Probability': y_pred_val, |
| | 'Predicted Label': y_pred_val_binary |
| | }) |
| | val_df.to_csv(os.path.join(output_path, 'val_predictions_binary.csv'), index=False) |
| |
|
| | |
| | plot_boxplot_with_threshold( |
| | y_true_train, |
| | y_pred_train, |
| | threshold, |
| | title="Training Set Binary Classification Plot", |
| | output_file=os.path.join(output_path, 'train_classification_plot.png') |
| | ) |
| |
|
| | |
| | plot_boxplot_with_threshold( |
| | y_true_val, |
| | y_pred_val, |
| | threshold, |
| | title="Validation Set Binary Classification Plot", |
| | output_file=os.path.join(output_path, 'val_classification_plot.png') |
| | ) |
| |
|
| | def plot_binary_correlation(y_true, y_pred, threshold, title, output_file): |
| | |
| | plt.figure(figsize=(10, 8)) |
| | plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF') |
| |
|
| | |
| | plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}') |
| |
|
| | |
| | plt.title(title) |
| | plt.xlabel("True Labels") |
| | plt.ylabel("Predicted Probability") |
| | plt.legend() |
| |
|
| | |
| | plt.tight_layout() |
| | plt.savefig(output_file) |
| | plt.show() |
| |
|
| | def plot_boxplot_with_threshold(y_true, y_pred, threshold, title, output_file): |
| | """ |
| | Generates a boxplot for binary classification and includes a threshold line. |
| | |
| | Parameters: |
| | y_true (array): True labels. |
| | y_pred (array): Predicted probabilities. |
| | threshold (float): Classification threshold for predictions. |
| | title (str): Title of the plot. |
| | output_file (str): Path to save the plot. |
| | """ |
| | plt.figure(figsize=(10, 8)) |
| |
|
| | |
| | df = pd.DataFrame({'True Label': y_true, 'Predicted Probability': y_pred}) |
| |
|
| | |
| | sns.boxplot(x='True Label', y='Predicted Probability', data=df) |
| |
|
| | |
| | plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}') |
| | plt.text( |
| | x=0.5, y=threshold + 0.05, s=f"Threshold = {threshold}", color="red", fontsize=10 |
| | ) |
| |
|
| | |
| | plt.title(title) |
| | plt.xlabel("True Label") |
| | plt.ylabel("Predicted Probability") |
| | plt.legend() |
| |
|
| | |
| | plt.tight_layout() |
| | plt.savefig(output_file) |
| | plt.show() |
| |
|
| | def plot_boxplot(y_true, y_pred, title, output_file): |
| | plt.figure(figsize=(10, 8)) |
| |
|
| | |
| | df = pd.DataFrame({'True Label': y_true, 'Predicted Probability': y_pred}) |
| | sns.boxplot(x='True Label', y='Predicted Probability', data=df) |
| |
|
| | |
| | plt.title(title) |
| | plt.xlabel("True Label") |
| | plt.ylabel("Predicted Probability") |
| |
|
| | |
| | plt.tight_layout() |
| | plt.savefig(output_file) |
| | plt.show() |
| | |
| | def plot_binary_correlation_with_density(y_true, y_pred, threshold, title, output_file): |
| | """ |
| | Generates a scatter plot with a density plot for binary classification and saves it to a file. |
| | """ |
| | plt.figure(figsize=(10, 8)) |
| |
|
| | |
| | plt.scatter(range(len(y_true)), y_pred, alpha=0.5, label='Predicted Probabilities', color='#BC80FF') |
| |
|
| | |
| | sns.kdeplot(y_pred, color='green', fill=True, alpha=0.3, label='Probability Density') |
| |
|
| | |
| | plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}') |
| |
|
| | |
| | plt.title(title) |
| | plt.xlabel("Index") |
| | plt.ylabel("Predicted Probability") |
| | plt.legend() |
| |
|
| | |
| | plt.tight_layout() |
| | plt.savefig(output_file) |
| | plt.show() |
| |
|
| | seed_everything(42) |
| |
|
| | dataset = load_from_disk('/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/solubility/new_data') |
| |
|
| | sequences = np.stack(dataset['sequence']) |
| | labels = np.stack(dataset['labels']) |
| | embeddings = np.stack(dataset['embedding']) |
| |
|
| | |
| | best_f1 = -np.inf |
| | best_model_path = "/home/st512/peptune/scripts/peptide-mdlm-mcts/scoring/functions/solubility/new_train/" |
| |
|
| | |
| | def trial_info_callback(study, trial): |
| | if study.best_trial == trial: |
| | print(f"Trial {trial.number}:") |
| | print(f" Weighted F1 Score: {trial.value}") |
| | |
| |
|
| | def objective(trial): |
| | params = { |
| | 'objective': 'binary:logistic', |
| | 'lambda': trial.suggest_float('lambda', 1e-8, 10.0, log=True), |
| | 'alpha': trial.suggest_float('alpha', 1e-8, 10.0, log=True), |
| | 'colsample_bytree': trial.suggest_float('colsample_bytree', 0.1, 1.0), |
| | 'subsample': trial.suggest_float('subsample', 0.1, 1.0), |
| | 'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3), |
| | 'max_depth': trial.suggest_int('max_depth', 2, 30), |
| | 'min_child_weight': trial.suggest_int('min_child_weight', 1, 20), |
| | 'tree_method': 'hist', |
| | 'device': 'cuda:0', |
| | } |
| | num_boost_round = trial.suggest_int('num_boost_round', 10, 1000) |
| |
|
| | |
| | train_idx, val_idx = train_test_split( |
| | np.arange(len(sequences)), test_size=0.2, stratify=labels, random_state=42 |
| | ) |
| | train_subset = dataset.select(train_idx).with_format("torch") |
| | val_subset = dataset.select(val_idx).with_format("torch") |
| |
|
| | |
| | train_embeddings = train_subset['embedding'] |
| | valid_embeddings = val_subset['embedding'] |
| | train_labels = train_subset['labels'] |
| | valid_labels = val_subset['labels'] |
| |
|
| | |
| | dtrain = xgb.DMatrix(train_embeddings, label=train_labels) |
| | dvalid = xgb.DMatrix(valid_embeddings, label=valid_labels) |
| |
|
| | |
| | model = xgb.train( |
| | params=params, |
| | dtrain=dtrain, |
| | num_boost_round=num_boost_round, |
| | evals=[(dvalid, "validation")], |
| | early_stopping_rounds=50, |
| | verbose_eval=False, |
| | ) |
| |
|
| | |
| | preds_train = model.predict(dtrain) |
| | preds_val = model.predict(dvalid) |
| |
|
| | |
| | best_f1_val = -np.inf |
| | best_threshold = 0.5 |
| |
|
| | for threshold in np.arange(0.1, 1.0, 0.05): |
| | preds_val_binary = (preds_val >= threshold).astype(int) |
| | f1_temp = f1_score(valid_labels, preds_val_binary, average="weighted") |
| | if f1_temp > best_f1_val: |
| | best_f1_val = f1_temp |
| | best_threshold = threshold |
| |
|
| | print(f"Best F1 Score: {best_f1_val:.3f} at Threshold: {best_threshold:.3f}") |
| |
|
| | |
| | auc_val = roc_auc_score(valid_labels, preds_val) |
| | print(f"AUC: {auc_val:.3f}") |
| |
|
| | |
| | if trial.study.user_attrs.get("best_f1", -np.inf) < best_f1_val: |
| | trial.study.set_user_attr("best_f1", best_f1_val) |
| | trial.study.set_user_attr("best_threshold", best_threshold) |
| | os.makedirs(best_model_path, exist_ok=True) |
| |
|
| | model.save_model(os.path.join(best_model_path, "best_model.json")) |
| | print(f"Best model saved to {os.path.join(best_model_path, 'best_model.json')}") |
| |
|
| | |
| | save_and_plot_binary_predictions( |
| | train_labels, |
| | preds_train, |
| | valid_labels, |
| | preds_val, |
| | best_threshold, |
| | best_model_path |
| | ) |
| |
|
| | return best_f1_val |
| |
|
| | if __name__ == "__main__": |
| | study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner()) |
| | study.optimize(objective, n_trials=200) |
| | |
| | print("Study statistics: ") |
| | print(f" Number of finished trials: {len(study.trials)}") |
| | print(f" Best AUC: {study.user_attrs.get('best_auc', None)}") |
| | for key, value in study.best_trial.params.items(): |
| | print(f" {key}: {value}") |