|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import optuna |
|
|
from optuna.trial import TrialState |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import AllChem |
|
|
from sklearn.metrics import mean_squared_error |
|
|
from sklearn.model_selection import train_test_split |
|
|
import xgboost as xgb |
|
|
import os |
|
|
from datasets import load_from_disk |
|
|
from scipy.stats import spearmanr |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
base_path = "/scratch/pranamlab/sophtang/home/scoring/PeptiVerse" |
|
|
|
|
|
def save_and_plot_predictions(y_true_train, y_pred_train, y_true_val, y_pred_val, output_path): |
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
|
|
|
|
|
|
train_df = pd.DataFrame({'True Permeability': y_true_train, 'Predicted Permeability': y_pred_train}) |
|
|
train_df.to_csv(os.path.join(output_path, 'train_predictions.csv'), index=False) |
|
|
|
|
|
|
|
|
val_df = pd.DataFrame({'True Permeability': y_true_val, 'Predicted Permeability': y_pred_val}) |
|
|
val_df.to_csv(os.path.join(output_path, 'val_predictions.csv'), index=False) |
|
|
|
|
|
|
|
|
plot_correlation( |
|
|
y_true_train, |
|
|
y_pred_train, |
|
|
title="Training Set Correlation Plot", |
|
|
output_file=os.path.join(output_path, 'train_correlation.png'), |
|
|
) |
|
|
|
|
|
|
|
|
plot_correlation( |
|
|
y_true_val, |
|
|
y_pred_val, |
|
|
title="Validation Set Correlation Plot", |
|
|
output_file=os.path.join(output_path, 'val_correlation.png'), |
|
|
) |
|
|
|
|
|
def plot_correlation(y_true, y_pred, title, output_file): |
|
|
spearman_corr, _ = spearmanr(y_true, y_pred) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
|
plt.scatter(y_true, y_pred, alpha=0.5, label='Data points', color='#BC80FF') |
|
|
plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], color='teal', linestyle='--', label='Ideal fit') |
|
|
|
|
|
|
|
|
plt.title(f"{title}\nSpearman Correlation: {spearman_corr:.3f}") |
|
|
plt.xlabel("True Permeability (logP)") |
|
|
plt.ylabel("Predicted Affinity (logP)") |
|
|
plt.legend() |
|
|
|
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_file) |
|
|
plt.show() |
|
|
|
|
|
|
|
|
dataset = load_from_disk(f'{base_path}/data/permeability') |
|
|
|
|
|
|
|
|
sequences = np.stack(dataset['sequence']) |
|
|
labels = np.stack(dataset['labels']) |
|
|
embeddings = np.stack(dataset['embedding']) |
|
|
|
|
|
|
|
|
def compute_morgan_fingerprints(smiles_list, radius=2, n_bits=2048): |
|
|
fps = [] |
|
|
for smiles in smiles_list: |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is not None: |
|
|
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits) |
|
|
fps.append(np.array(fp)) |
|
|
else: |
|
|
|
|
|
fps.append(np.zeros(n_bits)) |
|
|
print(f"Invalid SMILES: {smiles}") |
|
|
return np.array(fps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_features = embeddings |
|
|
|
|
|
|
|
|
best_model_path = f"{base_path}/src/permeability" |
|
|
os.makedirs(best_model_path, exist_ok=True) |
|
|
|
|
|
def trial_info_callback(study, trial): |
|
|
if study.best_trial == trial: |
|
|
print(f"Trial {trial.number}:") |
|
|
print(f" MSE: {trial.value}") |
|
|
|
|
|
def objective(trial): |
|
|
|
|
|
params = { |
|
|
'objective': 'reg:squarederror', |
|
|
'lambda': trial.suggest_float('lambda', 0.1, 10.0, log=True), |
|
|
'alpha': trial.suggest_float('alpha', 0.1, 10.0, log=True), |
|
|
'gamma': trial.suggest_float('gamma', 0, 5), |
|
|
'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0), |
|
|
'subsample': trial.suggest_float('subsample', 0.6, 0.9), |
|
|
'learning_rate': trial.suggest_float('learning_rate', 1e-5, 0.1), |
|
|
'max_depth': trial.suggest_int('max_depth', 2, 30), |
|
|
'min_child_weight': trial.suggest_int('min_child_weight', 1, 20), |
|
|
'tree_method': 'hist', |
|
|
'scale_pos_weight': trial.suggest_float('scale_pos_weight', 0.5, 10.0, log=True), |
|
|
'device': 'cuda:6', |
|
|
} |
|
|
num_boost_round = trial.suggest_int('num_boost_round', 10, 1000) |
|
|
|
|
|
|
|
|
X_train, X_val, y_train, y_val = train_test_split(input_features, labels, test_size=0.2, random_state=42) |
|
|
|
|
|
|
|
|
dtrain = xgb.DMatrix(X_train, label=y_train) |
|
|
dvalid = xgb.DMatrix(X_val, label=y_val) |
|
|
|
|
|
|
|
|
model = xgb.train( |
|
|
params=params, |
|
|
dtrain=dtrain, |
|
|
num_boost_round=num_boost_round, |
|
|
evals=[(dvalid, "validation")], |
|
|
early_stopping_rounds=50, |
|
|
verbose_eval=False, |
|
|
) |
|
|
|
|
|
|
|
|
preds_train = model.predict(dtrain) |
|
|
preds_val = model.predict(dvalid) |
|
|
|
|
|
mse = mean_squared_error(y_val, preds_val) |
|
|
|
|
|
|
|
|
spearman_train, _ = spearmanr(y_train, preds_train) |
|
|
spearman_val, _ = spearmanr(y_val, preds_val) |
|
|
print(f"Train Spearman: {spearman_train:.4f}, Val Spearman: {spearman_val:.4f}") |
|
|
|
|
|
|
|
|
if trial.study.user_attrs.get("best_mse", np.inf) > mse: |
|
|
trial.study.set_user_attr("best_mse", mse) |
|
|
trial.study.set_user_attr("best_spearman_train", spearman_train) |
|
|
trial.study.set_user_attr("best_spearman_val", spearman_val) |
|
|
trial.study.set_user_attr("best_trial", trial.number) |
|
|
model.save_model(os.path.join(best_model_path, "best_model.json")) |
|
|
save_and_plot_predictions(y_train, preds_train, y_val, preds_val, best_model_path) |
|
|
print(f"✓ NEW BEST! Trial {trial.number}: MSE={mse:.4f}, Train Spearman={spearman_train:.4f}, Val Spearman={spearman_val:.4f}") |
|
|
|
|
|
return mse |
|
|
|
|
|
if __name__ == "__main__": |
|
|
study = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner()) |
|
|
study.optimize(objective, n_trials=200, callbacks=[trial_info_callback]) |
|
|
|
|
|
|
|
|
summary = [] |
|
|
summary.append("\n" + "="*60) |
|
|
summary.append("OPTIMIZATION COMPLETE") |
|
|
summary.append("="*60) |
|
|
summary.append(f"Number of finished trials: {len(study.trials)}") |
|
|
summary.append(f"\nBest Trial: #{study.user_attrs.get('best_trial', 'N/A')}") |
|
|
summary.append(f"Best MSE: {study.best_trial.value:.4f}") |
|
|
summary.append(f"Best Training Spearman Correlation: {study.user_attrs.get('best_spearman_train', None):.4f}") |
|
|
summary.append(f"Best Validation Spearman Correlation: {study.user_attrs.get('best_spearman_val', None):.4f}") |
|
|
summary.append(f"\nBest hyperparameters:") |
|
|
for key, value in study.best_trial.params.items(): |
|
|
summary.append(f" {key}: {value}") |
|
|
summary.append("="*60) |
|
|
|
|
|
|
|
|
for line in summary: |
|
|
print(line) |
|
|
|
|
|
|
|
|
metrics_file = os.path.join(best_model_path, "optimization_metrics.txt") |
|
|
with open(metrics_file, 'w') as f: |
|
|
f.write('\n'.join(summary)) |
|
|
print(f"\n✓ Metrics saved to: {metrics_file}") |