File size: 8,234 Bytes
813c6b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve, f1_score
import optuna
from optuna.trial import TrialState
import xgboost as xgb
import os
from datasets import load_from_disk
from lightning.pytorch import seed_everything
from rdkit import Chem, rdBase, DataStructs
from typing import List
from rdkit.Chem import AllChem
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score

base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse"

def save_and_plot_binary_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, threshold, output_path):
    """
    Saves the true and predicted values for training and validation sets, and generates binary classification plots.

    Parameters:
        y_true_train (array): True labels for the training set.
        y_pred_train (array): Predicted probabilities for the training set.
        y_true_val (array): True labels for the validation set.
        y_pred_val (array): Predicted probabilities for the validation set.
        threshold (float): Classification threshold for predictions.
        output_path (str): Directory to save the CSV files and plots.
    """
    os.makedirs(output_path, exist_ok=True)

    # Convert probabilities to binary predictions
    y_pred_train_binary = (y_pred_train >= threshold).astype(int)
    y_pred_val_binary = (y_pred_val >= threshold).astype(int)

    # Save training predictions
    train_df = pd.DataFrame({
        'True Label': y_true_train,
        'Predicted Probability': y_pred_train,
        'Predicted Label': y_pred_train_binary
    })
    train_df.to_csv(os.path.join(output_path, 'train_predictions_binary.csv'), index=False)

    # Save validation predictions
    val_df = pd.DataFrame({
        'True Label': y_true_val,
        'Predicted Probability': y_pred_val,
        'Predicted Label': y_pred_val_binary
    })
    val_df.to_csv(os.path.join(output_path, 'val_predictions_binary.csv'), index=False)

    # Plot training predictions
    plot_binary_correlation(
        y_true_train,
        y_pred_train,
        threshold,
        title="Training Set Binary Classification Plot",
        output_file=os.path.join(output_path, 'train_classification_plot.png')
    )

    # Plot validation predictions
    plot_binary_correlation(
        y_true_val,
        y_pred_val,
        threshold,
        title="Validation Set Binary Classification Plot",
        output_file=os.path.join(output_path, 'val_classification_plot.png')
    )

def plot_binary_correlation(y_true, y_pred, threshold, title, output_file):
    """
    Generates a scatter plot for binary classification and saves it to a file.

    Parameters:
        y_true (array): True labels.
        y_pred (array): Predicted probabilities.
        threshold (float): Classification threshold for predictions.
        title (str): Title of the plot.
        output_file (str): Path to save the plot.
    """
    # Scatter plot
    plt.figure(figsize=(10, 8))
    plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF')

    # Add threshold line
    plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold = {threshold}')

    # Add annotations
    plt.title(title)
    plt.xlabel("True Labels")
    plt.ylabel("Predicted Probability")
    plt.legend()

    # Save and show the plot
    plt.tight_layout()
    plt.savefig(output_file)
    plt.show()

seed_everything(42)

dataset = load_from_disk(f'{base_path}/data/solubility') 

sequences = np.stack(dataset['sequence'])  # Ensure sequences are SMILES strings
labels = np.stack(dataset['labels']) 
embeddings = np.stack(dataset['embedding'])

# Initialize best F1 score and model path
best_f1 = -np.inf
best_model_path = f"{base_path}/src/solubility"

# Trial callback
def trial_info_callback(study, trial):
    if study.best_trial == trial:
        print(f"Trial {trial.number}:")
        print(f"  Weighted F1 Score: {trial.value}")



def objective(trial):
    # Define hyperparameters
    params = {
        'objective': 'binary:logistic',
        'lambda': trial.suggest_float('lambda', 1e-8, 50.0, log=True),
        'alpha': trial.suggest_float('alpha', 1e-8, 50.0, log=True),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.3, 1.0),
        'subsample': trial.suggest_float('subsample', 0.5, 1.0),
        'learning_rate': trial.suggest_float('learning_rate', 0.001, 0.3),
        'max_depth': trial.suggest_int('max_depth', 2, 15),
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 500),
        'gamma': trial.suggest_float('gamma', 0, 10.0),
        'tree_method': 'hist',
        'device': 'cuda:6',
    }

    # Suggest number of boosting rounds
    num_boost_round = trial.suggest_int('num_boost_round', 10, 1000)
    threshold = 0.5  # Initial classification threshold

    # Split the data
    train_idx, val_idx = train_test_split(
        np.arange(len(sequences)), test_size=0.2, stratify=labels, random_state=42
    )
    train_subset = dataset.select(train_idx).with_format("torch")
    val_subset = dataset.select(val_idx).with_format("torch")

    # Extract embeddings and labels for train/validation
    train_embeddings = np.array(train_subset['embedding'])
    valid_embeddings = np.array(val_subset['embedding'])
    train_labels = np.array(train_subset['labels'])
    valid_labels = np.array(val_subset['labels'])

    # Prepare training and validation sets
    dtrain = xgb.DMatrix(train_embeddings, label=train_labels)
    dvalid = xgb.DMatrix(valid_embeddings, label=valid_labels)

    # Train the model
    model = xgb.train(
        params=params,
        dtrain=dtrain,
        num_boost_round=num_boost_round,
        evals=[(dvalid, "validation")],
        early_stopping_rounds=50,
        verbose_eval=False,
    )

    # Predict probabilities
    preds_train = model.predict(dtrain)
    preds_val = model.predict(dvalid)

    # Calculate metrics
    f1_val = f1_score(valid_labels, (preds_val >= threshold).astype(int), average="weighted")
    auc_val = roc_auc_score(valid_labels, preds_val)
    print(f"Trial {trial.number}: AUC: {auc_val:.3f}, F1 Score: {f1_val:.3f}")

    # Save the model if it has the best F1 score
    current_best = trial.study.user_attrs.get("best_f1", -np.inf)
    if f1_val > current_best:
        trial.study.set_user_attr("best_f1", f1_val)
        trial.study.set_user_attr("best_auc", auc_val)
        trial.study.set_user_attr("best_trial", trial.number)
        os.makedirs(best_model_path, exist_ok=True)

        # Save the model
        model.save_model(os.path.join(best_model_path, "best_model_f1.json"))
        print(f"✓ NEW BEST! Trial {trial.number}: F1={f1_val:.4f}, AUC={auc_val:.4f} - Model saved!")

        # Save and plot binary predictions
        save_and_plot_binary_predictions(
            train_labels, preds_train, valid_labels, preds_val, threshold, best_model_path
        )

    return f1_val

if __name__ == "__main__":
    study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
    study.optimize(objective, n_trials=200)
    
    # Prepare summary text
    summary = []
    summary.append("\n" + "="*60)
    summary.append("OPTIMIZATION COMPLETE")
    summary.append("="*60)
    summary.append(f"Number of finished trials: {len(study.trials)}")
    summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}")
    summary.append(f"Best F1 Score: {study.user_attrs.get('best_f1', None):.4f}")
    summary.append(f"Best AUC Score: {study.user_attrs.get('best_auc', None):.4f}")
    summary.append(f"Optuna Best Trial Value: {study.best_trial.value:.4f}")
    summary.append(f"\nBest hyperparameters:")
    for key, value in study.best_trial.params.items():
        summary.append(f"  {key}: {value}")
    summary.append("="*60)
    
    # Print to console
    for line in summary:
        print(line)
    
    # Save to file
    metrics_file = os.path.join(best_model_path, "optimization_metrics.txt")
    with open(metrics_file, 'w') as f:
        f.write('\n'.join(summary))
    print(f"\n✓ Metrics saved to: {metrics_file}")