Commit ·
4e1d3f6
1
Parent(s): 48aea13
Added ablation study with randomly sampled vectors + Started working on LightningDataModule wrapper
Browse files- plots/Active_Dmax_0.6_pDC50_6.0_metrics.pdf +0 -0
- plots/Active_Dmax_0.6_pDC50_6.0_metrics_majority_vote.pdf +0 -0
- plots/ablation_study_random.pdf +0 -0
- plots/ablation_study_tanimoto.pdf +0 -0
- plots/ablation_study_uniprot.pdf +0 -0
- plots/old_Active_Dmax_0.6_pDC50_6.0_metrics.pdf +0 -0
- plots/training_metrics_random_best_model_n0.pdf +0 -0
- plots/training_metrics_random_best_model_n1.pdf +0 -0
- plots/training_metrics_random_best_model_n2.pdf +0 -0
- plots/training_metrics_random_cv_model_fold0.pdf +0 -0
- plots/training_metrics_random_cv_model_fold1.pdf +0 -0
- plots/training_metrics_random_cv_model_fold2.pdf +0 -0
- plots/training_metrics_random_cv_model_fold3.pdf +0 -0
- plots/training_metrics_random_cv_model_fold4.pdf +0 -0
- plots/training_metrics_tanimoto_best_model_n0.pdf +0 -0
- plots/training_metrics_tanimoto_best_model_n1.pdf +0 -0
- plots/training_metrics_tanimoto_best_model_n2.pdf +0 -0
- plots/training_metrics_tanimoto_cv_model_fold0.pdf +0 -0
- plots/training_metrics_tanimoto_cv_model_fold1.pdf +0 -0
- plots/training_metrics_tanimoto_cv_model_fold2.pdf +0 -0
- plots/training_metrics_tanimoto_cv_model_fold3.pdf +0 -0
- plots/training_metrics_tanimoto_cv_model_fold4.pdf +0 -0
- plots/training_metrics_uniprot_best_model_n0.pdf +0 -0
- plots/training_metrics_uniprot_best_model_n1.pdf +0 -0
- plots/training_metrics_uniprot_best_model_n2.pdf +0 -0
- plots/training_metrics_uniprot_cv_model_fold0.pdf +0 -0
- plots/training_metrics_uniprot_cv_model_fold1.pdf +0 -0
- plots/training_metrics_uniprot_cv_model_fold2.pdf +0 -0
- plots/training_metrics_uniprot_cv_model_fold3.pdf +0 -0
- plots/training_metrics_uniprot_cv_model_fold4.pdf +0 -0
- protac_degradation_predictor/optuna_utils.py +96 -52
- protac_degradation_predictor/protac_dataset.py +383 -14
- protac_degradation_predictor/pytorch_models.py +48 -31
- src/plot_experiment_results.py +49 -31
- src/run_experiments.py +1 -1
plots/Active_Dmax_0.6_pDC50_6.0_metrics.pdf
ADDED
|
Binary file (16.8 kB). View file
|
|
|
plots/Active_Dmax_0.6_pDC50_6.0_metrics_majority_vote.pdf
ADDED
|
Binary file (16.7 kB). View file
|
|
|
plots/ablation_study_random.pdf
ADDED
|
Binary file (15.8 kB). View file
|
|
|
plots/ablation_study_tanimoto.pdf
ADDED
|
Binary file (15.3 kB). View file
|
|
|
plots/ablation_study_uniprot.pdf
ADDED
|
Binary file (15.6 kB). View file
|
|
|
plots/old_Active_Dmax_0.6_pDC50_6.0_metrics.pdf
ADDED
|
Binary file (16.8 kB). View file
|
|
|
plots/training_metrics_random_best_model_n0.pdf
ADDED
|
Binary file (16.9 kB). View file
|
|
|
plots/training_metrics_random_best_model_n1.pdf
ADDED
|
Binary file (16.7 kB). View file
|
|
|
plots/training_metrics_random_best_model_n2.pdf
ADDED
|
Binary file (16.9 kB). View file
|
|
|
plots/training_metrics_random_cv_model_fold0.pdf
ADDED
|
Binary file (17.4 kB). View file
|
|
|
plots/training_metrics_random_cv_model_fold1.pdf
ADDED
|
Binary file (17.3 kB). View file
|
|
|
plots/training_metrics_random_cv_model_fold2.pdf
ADDED
|
Binary file (17.1 kB). View file
|
|
|
plots/training_metrics_random_cv_model_fold3.pdf
ADDED
|
Binary file (17.3 kB). View file
|
|
|
plots/training_metrics_random_cv_model_fold4.pdf
ADDED
|
Binary file (17.6 kB). View file
|
|
|
plots/training_metrics_tanimoto_best_model_n0.pdf
ADDED
|
Binary file (16.9 kB). View file
|
|
|
plots/training_metrics_tanimoto_best_model_n1.pdf
ADDED
|
Binary file (16.9 kB). View file
|
|
|
plots/training_metrics_tanimoto_best_model_n2.pdf
ADDED
|
Binary file (16.6 kB). View file
|
|
|
plots/training_metrics_tanimoto_cv_model_fold0.pdf
ADDED
|
Binary file (17.1 kB). View file
|
|
|
plots/training_metrics_tanimoto_cv_model_fold1.pdf
ADDED
|
Binary file (16.8 kB). View file
|
|
|
plots/training_metrics_tanimoto_cv_model_fold2.pdf
ADDED
|
Binary file (17.3 kB). View file
|
|
|
plots/training_metrics_tanimoto_cv_model_fold3.pdf
ADDED
|
Binary file (17 kB). View file
|
|
|
plots/training_metrics_tanimoto_cv_model_fold4.pdf
ADDED
|
Binary file (17.1 kB). View file
|
|
|
plots/training_metrics_uniprot_best_model_n0.pdf
ADDED
|
Binary file (16.3 kB). View file
|
|
|
plots/training_metrics_uniprot_best_model_n1.pdf
ADDED
|
Binary file (15.9 kB). View file
|
|
|
plots/training_metrics_uniprot_best_model_n2.pdf
ADDED
|
Binary file (16.1 kB). View file
|
|
|
plots/training_metrics_uniprot_cv_model_fold0.pdf
ADDED
|
Binary file (17.2 kB). View file
|
|
|
plots/training_metrics_uniprot_cv_model_fold1.pdf
ADDED
|
Binary file (17.7 kB). View file
|
|
|
plots/training_metrics_uniprot_cv_model_fold2.pdf
ADDED
|
Binary file (17 kB). View file
|
|
|
plots/training_metrics_uniprot_cv_model_fold3.pdf
ADDED
|
Binary file (16.5 kB). View file
|
|
|
plots/training_metrics_uniprot_cv_model_fold4.pdf
ADDED
|
Binary file (17.1 kB). View file
|
|
|
protac_degradation_predictor/optuna_utils.py
CHANGED
|
@@ -2,7 +2,13 @@ import os
|
|
| 2 |
from typing import Literal, List, Tuple, Optional, Dict
|
| 3 |
import logging
|
| 4 |
|
| 5 |
-
from .pytorch_models import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from .sklearn_models import (
|
| 7 |
train_sklearn_model,
|
| 8 |
suggest_random_forest,
|
|
@@ -83,6 +89,26 @@ def get_dataframe_stats(
|
|
| 83 |
return stats
|
| 84 |
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
def pytorch_model_objective(
|
| 87 |
trial: optuna.Trial,
|
| 88 |
protein2embedding: Dict,
|
|
@@ -198,18 +224,7 @@ def pytorch_model_objective(
|
|
| 198 |
|
| 199 |
# Get the majority vote for the test predictions
|
| 200 |
if test_df is not None and not fast_dev_run:
|
| 201 |
-
|
| 202 |
-
test_preds = torch.stack(test_preds)
|
| 203 |
-
test_preds, _ = torch.mode(test_preds, dim=0)
|
| 204 |
-
y = torch.tensor(test_df[active_label].tolist())
|
| 205 |
-
# Measure the test accuracy and ROC AUC
|
| 206 |
-
majority_vote_metrics = {
|
| 207 |
-
'test_acc': Accuracy(task='binary')(test_preds, y).item(),
|
| 208 |
-
'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
|
| 209 |
-
'test_precision': Precision(task='binary')(test_preds, y).item(),
|
| 210 |
-
'test_recall': Recall(task='binary')(test_preds, y).item(),
|
| 211 |
-
'test_f1': F1Score(task='binary')(test_preds, y).item(),
|
| 212 |
-
}
|
| 213 |
majority_vote_metrics.update(get_dataframe_stats(train_df, val_df, test_df, active_label))
|
| 214 |
trial.set_user_attr('majority_vote_metrics', majority_vote_metrics)
|
| 215 |
logging.info(f'Majority vote metrics: {majority_vote_metrics}')
|
|
@@ -278,6 +293,7 @@ def hyperparameter_tuning_and_training(
|
|
| 278 |
study = joblib.load(study_filename)
|
| 279 |
study_loaded = True
|
| 280 |
logging.info(f'Loaded study from {study_filename}')
|
|
|
|
| 281 |
|
| 282 |
if not study_loaded or force_study:
|
| 283 |
study.optimize(
|
|
@@ -333,12 +349,13 @@ def hyperparameter_tuning_and_training(
|
|
| 333 |
)
|
| 334 |
|
| 335 |
# Retrain N models with the best hyperparameters (measure model uncertainty)
|
|
|
|
| 336 |
test_report = []
|
| 337 |
test_preds = []
|
| 338 |
dfs_stats = get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label)
|
| 339 |
for i in range(n_models_for_test):
|
| 340 |
pl.seed_everything(42 + i + 1)
|
| 341 |
-
|
| 342 |
protein2embedding=protein2embedding,
|
| 343 |
cell2embedding=cell2embedding,
|
| 344 |
smiles2fp=smiles2fp,
|
|
@@ -366,22 +383,12 @@ def hyperparameter_tuning_and_training(
|
|
| 366 |
|
| 367 |
test_report.append(metrics.copy())
|
| 368 |
test_preds.append(test_pred)
|
|
|
|
| 369 |
test_report = pd.DataFrame(test_report)
|
| 370 |
|
| 371 |
# Get the majority vote for the test predictions
|
| 372 |
if not fast_dev_run:
|
| 373 |
-
|
| 374 |
-
test_preds, _ = torch.mode(test_preds, dim=0)
|
| 375 |
-
y = torch.tensor(test_df[active_label].tolist())
|
| 376 |
-
# Measure the test accuracy and ROC AUC
|
| 377 |
-
majority_vote_metrics = {
|
| 378 |
-
'cv_models': False,
|
| 379 |
-
'test_acc': Accuracy(task='binary')(test_preds, y).item(),
|
| 380 |
-
'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
|
| 381 |
-
'test_precision': Precision(task='binary')(test_preds, y).item(),
|
| 382 |
-
'test_recall': Recall(task='binary')(test_preds, y).item(),
|
| 383 |
-
'test_f1': F1Score(task='binary')(test_preds, y).item(),
|
| 384 |
-
}
|
| 385 |
majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label))
|
| 386 |
majority_vote_metrics_cv = study.best_trial.user_attrs['majority_vote_metrics']
|
| 387 |
majority_vote_metrics_cv['cv_models'] = True
|
|
@@ -408,34 +415,71 @@ def hyperparameter_tuning_and_training(
|
|
| 408 |
logging.info('-' * 100)
|
| 409 |
logging.info(f'Ablation study with disabled embeddings: {disabled_embeddings}')
|
| 410 |
logging.info('-' * 100)
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
|
| 434 |
-
#
|
| 435 |
-
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
-
ablation_report.append(metrics.copy())
|
| 439 |
ablation_report = pd.DataFrame(ablation_report)
|
| 440 |
|
| 441 |
# Add a column with the split_type to all reports
|
|
|
|
| 2 |
from typing import Literal, List, Tuple, Optional, Dict
|
| 3 |
import logging
|
| 4 |
|
| 5 |
+
from .pytorch_models import (
|
| 6 |
+
train_model,
|
| 7 |
+
PROTAC_Model,
|
| 8 |
+
evaluate_model,
|
| 9 |
+
)
|
| 10 |
+
from .protac_dataset import get_datasets
|
| 11 |
+
|
| 12 |
from .sklearn_models import (
|
| 13 |
train_sklearn_model,
|
| 14 |
suggest_random_forest,
|
|
|
|
| 89 |
return stats
|
| 90 |
|
| 91 |
|
| 92 |
+
def get_majority_vote_metrics(
|
| 93 |
+
test_preds: List,
|
| 94 |
+
test_df: pd.DataFrame,
|
| 95 |
+
active_label: str = 'Active',
|
| 96 |
+
) -> Dict:
|
| 97 |
+
""" Get the majority vote metrics. """
|
| 98 |
+
test_preds = torch.stack(test_preds)
|
| 99 |
+
test_preds, _ = torch.mode(test_preds, dim=0)
|
| 100 |
+
y = torch.tensor(test_df[active_label].tolist())
|
| 101 |
+
# Measure the test accuracy and ROC AUC
|
| 102 |
+
majority_vote_metrics = {
|
| 103 |
+
'test_acc': Accuracy(task='binary')(test_preds, y).item(),
|
| 104 |
+
'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
|
| 105 |
+
'test_precision': Precision(task='binary')(test_preds, y).item(),
|
| 106 |
+
'test_recall': Recall(task='binary')(test_preds, y).item(),
|
| 107 |
+
'test_f1': F1Score(task='binary')(test_preds, y).item(),
|
| 108 |
+
}
|
| 109 |
+
return majority_vote_metrics
|
| 110 |
+
|
| 111 |
+
|
| 112 |
def pytorch_model_objective(
|
| 113 |
trial: optuna.Trial,
|
| 114 |
protein2embedding: Dict,
|
|
|
|
| 224 |
|
| 225 |
# Get the majority vote for the test predictions
|
| 226 |
if test_df is not None and not fast_dev_run:
|
| 227 |
+
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
majority_vote_metrics.update(get_dataframe_stats(train_df, val_df, test_df, active_label))
|
| 229 |
trial.set_user_attr('majority_vote_metrics', majority_vote_metrics)
|
| 230 |
logging.info(f'Majority vote metrics: {majority_vote_metrics}')
|
|
|
|
| 293 |
study = joblib.load(study_filename)
|
| 294 |
study_loaded = True
|
| 295 |
logging.info(f'Loaded study from {study_filename}')
|
| 296 |
+
logging.info(f'Study best params: {study.best_params}')
|
| 297 |
|
| 298 |
if not study_loaded or force_study:
|
| 299 |
study.optimize(
|
|
|
|
| 349 |
)
|
| 350 |
|
| 351 |
# Retrain N models with the best hyperparameters (measure model uncertainty)
|
| 352 |
+
best_models = []
|
| 353 |
test_report = []
|
| 354 |
test_preds = []
|
| 355 |
dfs_stats = get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label)
|
| 356 |
for i in range(n_models_for_test):
|
| 357 |
pl.seed_everything(42 + i + 1)
|
| 358 |
+
model, trainer, metrics, test_pred = train_model(
|
| 359 |
protein2embedding=protein2embedding,
|
| 360 |
cell2embedding=cell2embedding,
|
| 361 |
smiles2fp=smiles2fp,
|
|
|
|
| 383 |
|
| 384 |
test_report.append(metrics.copy())
|
| 385 |
test_preds.append(test_pred)
|
| 386 |
+
best_models.append({'model': model, 'trainer': trainer})
|
| 387 |
test_report = pd.DataFrame(test_report)
|
| 388 |
|
| 389 |
# Get the majority vote for the test predictions
|
| 390 |
if not fast_dev_run:
|
| 391 |
+
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label))
|
| 393 |
majority_vote_metrics_cv = study.best_trial.user_attrs['majority_vote_metrics']
|
| 394 |
majority_vote_metrics_cv['cv_models'] = True
|
|
|
|
| 415 |
logging.info('-' * 100)
|
| 416 |
logging.info(f'Ablation study with disabled embeddings: {disabled_embeddings}')
|
| 417 |
logging.info('-' * 100)
|
| 418 |
+
disabled_embeddings_str = 'disabled ' + ' '.join(disabled_embeddings)
|
| 419 |
+
test_preds = []
|
| 420 |
+
for i, model_trainer in enumerate(best_models):
|
| 421 |
+
logging.info(f'Evaluating model n.{i} on {disabled_embeddings_str}.')
|
| 422 |
+
model = model_trainer['model']
|
| 423 |
+
trainer = model_trainer['trainer']
|
| 424 |
+
_, test_ds, _ = get_datasets(
|
| 425 |
+
protein2embedding=protein2embedding,
|
| 426 |
+
cell2embedding=cell2embedding,
|
| 427 |
+
smiles2fp=smiles2fp,
|
| 428 |
+
train_df=train_val_df,
|
| 429 |
+
val_df=test_df,
|
| 430 |
+
disabled_embeddings=disabled_embeddings,
|
| 431 |
+
active_label=active_label,
|
| 432 |
+
scaler=model.scalers,
|
| 433 |
+
use_single_scaler=model.join_embeddings == 'beginning',
|
| 434 |
+
)
|
| 435 |
+
ret = evaluate_model(model, trainer, test_ds, batch_size=128)
|
| 436 |
+
# NOTE: We are passing the test set as the validation set argument
|
| 437 |
+
# Rename the keys in the metrics dictionary
|
| 438 |
+
test_preds.append(ret['val_pred'])
|
| 439 |
+
ret['val_metrics'] = {k.replace('val_', 'test_'): v for k, v in ret['val_metrics'].items()}
|
| 440 |
+
ret['val_metrics'].update(dfs_stats)
|
| 441 |
+
ret['val_metrics']['majority_vote'] = False
|
| 442 |
+
ret['val_metrics']['model_type'] = 'Pytorch'
|
| 443 |
+
ret['val_metrics']['disabled_embeddings'] = disabled_embeddings_str
|
| 444 |
+
ablation_report.append(ret['val_metrics'].copy())
|
| 445 |
|
| 446 |
+
# Get the majority vote for the test predictions
|
| 447 |
+
if not fast_dev_run:
|
| 448 |
+
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
|
| 449 |
+
majority_vote_metrics.update(dfs_stats)
|
| 450 |
+
majority_vote_metrics['majority_vote'] = True
|
| 451 |
+
majority_vote_metrics['model_type'] = 'Pytorch'
|
| 452 |
+
majority_vote_metrics['disabled_embeddings'] = disabled_embeddings_str
|
| 453 |
+
ablation_report.append(majority_vote_metrics.copy())
|
| 454 |
+
|
| 455 |
+
# _, _, metrics = train_model(
|
| 456 |
+
# protein2embedding=protein2embedding,
|
| 457 |
+
# cell2embedding=cell2embedding,
|
| 458 |
+
# smiles2fp=smiles2fp,
|
| 459 |
+
# train_df=train_val_df,
|
| 460 |
+
# val_df=test_df,
|
| 461 |
+
# fast_dev_run=fast_dev_run,
|
| 462 |
+
# active_label=active_label,
|
| 463 |
+
# max_epochs=max_epochs,
|
| 464 |
+
# use_logger=False,
|
| 465 |
+
# logger_save_dir=logger_save_dir,
|
| 466 |
+
# logger_name=f'{logger_name}_disabled-{"-".join(disabled_embeddings)}',
|
| 467 |
+
# disabled_embeddings=disabled_embeddings,
|
| 468 |
+
# batch_size=128,
|
| 469 |
+
# apply_scaling=True,
|
| 470 |
+
# **study.best_params,
|
| 471 |
+
# )
|
| 472 |
+
# # Rename the keys in the metrics dictionary
|
| 473 |
+
# metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
|
| 474 |
+
# metrics['disabled_embeddings'] = disabled_embeddings_str
|
| 475 |
+
# metrics['model_type'] = 'Pytorch'
|
| 476 |
+
# metrics.update(dfs_stats)
|
| 477 |
+
|
| 478 |
+
# # Add the training metrics
|
| 479 |
+
# train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
|
| 480 |
+
# metrics.update(train_metrics)
|
| 481 |
+
# ablation_report.append(metrics.copy())
|
| 482 |
|
|
|
|
| 483 |
ablation_report = pd.DataFrame(ablation_report)
|
| 484 |
|
| 485 |
# Add a column with the split_type to all reports
|
protac_degradation_predictor/protac_dataset.py
CHANGED
|
@@ -1,13 +1,26 @@
|
|
| 1 |
from typing import Literal, List, Tuple, Optional, Dict
|
|
|
|
| 2 |
|
| 3 |
-
from
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from imblearn.over_sampling import SMOTE, ADASYN
|
|
|
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class PROTAC_Dataset(Dataset):
|
|
|
|
| 11 |
def __init__(
|
| 12 |
self,
|
| 13 |
protac_df: pd.DataFrame,
|
|
@@ -17,6 +30,9 @@ class PROTAC_Dataset(Dataset):
|
|
| 17 |
use_smote: bool = False,
|
| 18 |
oversampler: Optional[SMOTE | ADASYN] = None,
|
| 19 |
active_label: str = 'Active',
|
|
|
|
|
|
|
|
|
|
| 20 |
):
|
| 21 |
""" Initialize the PROTAC dataset
|
| 22 |
|
|
@@ -28,13 +44,17 @@ class PROTAC_Dataset(Dataset):
|
|
| 28 |
use_smote (bool): Whether to use SMOTE for oversampling
|
| 29 |
use_ored_activity (bool): Whether to use the 'Active - OR' column
|
| 30 |
"""
|
| 31 |
-
# Filter out examples with NaN in
|
| 32 |
-
self.data = protac_df # [~protac_df[
|
| 33 |
self.protein2embedding = protein2embedding
|
| 34 |
self.cell2embedding = cell2embedding
|
| 35 |
self.smiles2fp = smiles2fp
|
| 36 |
self.active_label = active_label
|
| 37 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
|
| 40 |
self.protein_emb_dim = protein2embedding[list(
|
|
@@ -115,15 +135,15 @@ class PROTAC_Dataset(Dataset):
|
|
| 115 |
"""
|
| 116 |
if use_single_scaler:
|
| 117 |
self.use_single_scaler = True
|
| 118 |
-
scaler = StandardScaler(**scaler_kwargs)
|
| 119 |
embeddings = np.hstack([
|
| 120 |
np.array(self.data['Smiles'].tolist()),
|
| 121 |
np.array(self.data['Uniprot'].tolist()),
|
| 122 |
np.array(self.data['E3 Ligase Uniprot'].tolist()),
|
| 123 |
np.array(self.data['Cell Line Identifier'].tolist()),
|
| 124 |
])
|
| 125 |
-
scaler.fit(embeddings)
|
| 126 |
-
return scaler
|
| 127 |
else:
|
| 128 |
self.use_single_scaler = False
|
| 129 |
scalers = {}
|
|
@@ -137,6 +157,7 @@ class PROTAC_Dataset(Dataset):
|
|
| 137 |
scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy()))
|
| 138 |
scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy()))
|
| 139 |
|
|
|
|
| 140 |
return scalers
|
| 141 |
|
| 142 |
def apply_scaling(self, scalers: dict, use_single_scaler: bool = False):
|
|
@@ -190,11 +211,359 @@ class PROTAC_Dataset(Dataset):
|
|
| 190 |
return len(self.data)
|
| 191 |
|
| 192 |
def __getitem__(self, idx):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
elem = {
|
| 194 |
-
'smiles_emb':
|
| 195 |
-
'poi_emb':
|
| 196 |
-
'e3_emb':
|
| 197 |
-
'cell_emb':
|
| 198 |
'active': self.data[self.active_label].iloc[idx],
|
| 199 |
}
|
| 200 |
-
return elem
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Literal, List, Tuple, Optional, Dict
|
| 2 |
+
from collections import defaultdict
|
| 3 |
|
| 4 |
+
from .data_utils import (
|
| 5 |
+
get_fingerprint,
|
| 6 |
+
is_active,
|
| 7 |
+
load_cell2embedding,
|
| 8 |
+
load_protein2embedding,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
from torch.utils.data import Dataset, DataLoader
|
| 12 |
from imblearn.over_sampling import SMOTE, ADASYN
|
| 13 |
+
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
|
| 14 |
+
import numpy as np
|
| 15 |
import pandas as pd
|
| 16 |
+
import pytorch_lightning as pl
|
| 17 |
+
from rdkit import Chem
|
| 18 |
+
from rdkit.Chem import AllChem
|
| 19 |
+
from rdkit import DataStructs
|
| 20 |
|
| 21 |
|
| 22 |
class PROTAC_Dataset(Dataset):
|
| 23 |
+
|
| 24 |
def __init__(
|
| 25 |
self,
|
| 26 |
protac_df: pd.DataFrame,
|
|
|
|
| 30 |
use_smote: bool = False,
|
| 31 |
oversampler: Optional[SMOTE | ADASYN] = None,
|
| 32 |
active_label: str = 'Active',
|
| 33 |
+
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
| 34 |
+
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
| 35 |
+
use_single_scaler: Optional[bool] = None,
|
| 36 |
):
|
| 37 |
""" Initialize the PROTAC dataset
|
| 38 |
|
|
|
|
| 44 |
use_smote (bool): Whether to use SMOTE for oversampling
|
| 45 |
use_ored_activity (bool): Whether to use the 'Active - OR' column
|
| 46 |
"""
|
| 47 |
+
# Filter out examples with NaN in active_label column
|
| 48 |
+
self.data = protac_df # [~protac_df[active_label].isna()]
|
| 49 |
self.protein2embedding = protein2embedding
|
| 50 |
self.cell2embedding = cell2embedding
|
| 51 |
self.smiles2fp = smiles2fp
|
| 52 |
self.active_label = active_label
|
| 53 |
+
self.disabled_embeddings = disabled_embeddings
|
| 54 |
+
|
| 55 |
+
# Scaling parameters
|
| 56 |
+
self.scaler = scaler
|
| 57 |
+
self.use_single_scaler = use_single_scaler
|
| 58 |
|
| 59 |
self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
|
| 60 |
self.protein_emb_dim = protein2embedding[list(
|
|
|
|
| 135 |
"""
|
| 136 |
if use_single_scaler:
|
| 137 |
self.use_single_scaler = True
|
| 138 |
+
self.scaler = StandardScaler(**scaler_kwargs)
|
| 139 |
embeddings = np.hstack([
|
| 140 |
np.array(self.data['Smiles'].tolist()),
|
| 141 |
np.array(self.data['Uniprot'].tolist()),
|
| 142 |
np.array(self.data['E3 Ligase Uniprot'].tolist()),
|
| 143 |
np.array(self.data['Cell Line Identifier'].tolist()),
|
| 144 |
])
|
| 145 |
+
self.scaler.fit(embeddings)
|
| 146 |
+
return self.scaler
|
| 147 |
else:
|
| 148 |
self.use_single_scaler = False
|
| 149 |
scalers = {}
|
|
|
|
| 157 |
scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy()))
|
| 158 |
scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy()))
|
| 159 |
|
| 160 |
+
self.scaler = scalers
|
| 161 |
return scalers
|
| 162 |
|
| 163 |
def apply_scaling(self, scalers: dict, use_single_scaler: bool = False):
|
|
|
|
| 211 |
return len(self.data)
|
| 212 |
|
| 213 |
def __getitem__(self, idx):
|
| 214 |
+
if 'smiles' in self.disabled_embeddings:
|
| 215 |
+
# Uniformly sample a binary vector for the fingerprint
|
| 216 |
+
smiles_emb = np.random.randint(0, 2, size=self.smiles_emb_dim).astype(np.float32)
|
| 217 |
+
if not self.use_single_scaler and self.scaler is not None:
|
| 218 |
+
smiles_emb = smiles_emb[np.newaxis, :]
|
| 219 |
+
smiles_emb = self.scaler['Smiles'].transform(smiles_emb).flatten()
|
| 220 |
+
else:
|
| 221 |
+
smiles_emb = self.data['Smiles'].iloc[idx]
|
| 222 |
+
|
| 223 |
+
if 'poi' in self.disabled_embeddings:
|
| 224 |
+
# Uniformly sample a vector for the protein
|
| 225 |
+
poi_emb = np.random.rand(self.protein_emb_dim).astype(np.float32)
|
| 226 |
+
if not self.use_single_scaler and self.scaler is not None:
|
| 227 |
+
poi_emb = poi_emb[np.newaxis, :]
|
| 228 |
+
poi_emb = self.scaler['Uniprot'].transform(poi_emb).flatten()
|
| 229 |
+
else:
|
| 230 |
+
poi_emb = self.data['Uniprot'].iloc[idx]
|
| 231 |
+
|
| 232 |
+
if 'e3' in self.disabled_embeddings:
|
| 233 |
+
# Uniformly sample a vector for the E3 ligase
|
| 234 |
+
e3_emb = np.random.rand(self.protein_emb_dim).astype(np.float32)
|
| 235 |
+
if not self.use_single_scaler and self.scaler is not None:
|
| 236 |
+
# Add extra dimension for compatibility with the scaler
|
| 237 |
+
e3_emb = e3_emb[np.newaxis, :]
|
| 238 |
+
e3_emb = self.scaler['E3 Ligase Uniprot'].transform(e3_emb)
|
| 239 |
+
e3_emb = e3_emb.flatten()
|
| 240 |
+
else:
|
| 241 |
+
e3_emb = self.data['E3 Ligase Uniprot'].iloc[idx]
|
| 242 |
+
|
| 243 |
+
if 'cell' in self.disabled_embeddings:
|
| 244 |
+
# Uniformly sample a vector for the cell line
|
| 245 |
+
cell_emb = np.random.rand(self.cell_emb_dim).astype(np.float32)
|
| 246 |
+
if not self.use_single_scaler and self.scaler is not None:
|
| 247 |
+
cell_emb = cell_emb[np.newaxis, :]
|
| 248 |
+
cell_emb = self.scaler['Cell Line Identifier'].transform(cell_emb).flatten()
|
| 249 |
+
else:
|
| 250 |
+
cell_emb = self.data['Cell Line Identifier'].iloc[idx]
|
| 251 |
+
|
| 252 |
elem = {
|
| 253 |
+
'smiles_emb': smiles_emb,
|
| 254 |
+
'poi_emb': poi_emb,
|
| 255 |
+
'e3_emb': e3_emb,
|
| 256 |
+
'cell_emb': cell_emb,
|
| 257 |
'active': self.data[self.active_label].iloc[idx],
|
| 258 |
}
|
| 259 |
+
return elem
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def get_datasets(
|
| 263 |
+
train_df: pd.DataFrame,
|
| 264 |
+
val_df: pd.DataFrame,
|
| 265 |
+
test_df: Optional[pd.DataFrame] = None,
|
| 266 |
+
protein2embedding: Dict = None,
|
| 267 |
+
cell2embedding: Dict = None,
|
| 268 |
+
smiles2fp: Dict = None,
|
| 269 |
+
use_smote: bool = True,
|
| 270 |
+
smote_k_neighbors: int = 5,
|
| 271 |
+
active_label: str = 'Active',
|
| 272 |
+
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
| 273 |
+
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
| 274 |
+
use_single_scaler: Optional[bool] = None,
|
| 275 |
+
) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
|
| 276 |
+
""" Get the datasets for training the PROTAC model. """
|
| 277 |
+
oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
|
| 278 |
+
train_ds = PROTAC_Dataset(
|
| 279 |
+
train_df,
|
| 280 |
+
protein2embedding,
|
| 281 |
+
cell2embedding,
|
| 282 |
+
smiles2fp,
|
| 283 |
+
use_smote=use_smote,
|
| 284 |
+
oversampler=oversampler if use_smote else None,
|
| 285 |
+
active_label=active_label,
|
| 286 |
+
disabled_embeddings=disabled_embeddings,
|
| 287 |
+
scaler=scaler,
|
| 288 |
+
use_single_scaler=use_single_scaler,
|
| 289 |
+
)
|
| 290 |
+
val_ds = PROTAC_Dataset(
|
| 291 |
+
val_df,
|
| 292 |
+
protein2embedding,
|
| 293 |
+
cell2embedding,
|
| 294 |
+
smiles2fp,
|
| 295 |
+
active_label=active_label,
|
| 296 |
+
disabled_embeddings=disabled_embeddings,
|
| 297 |
+
scaler=train_ds.scaler if train_ds.scaler is not None else scaler,
|
| 298 |
+
use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
|
| 299 |
+
)
|
| 300 |
+
if test_df is not None:
|
| 301 |
+
test_ds = PROTAC_Dataset(
|
| 302 |
+
test_df,
|
| 303 |
+
protein2embedding,
|
| 304 |
+
cell2embedding,
|
| 305 |
+
smiles2fp,
|
| 306 |
+
active_label=active_label,
|
| 307 |
+
disabled_embeddings=disabled_embeddings,
|
| 308 |
+
scaler=train_ds.scaler if train_ds.scaler is not None else scaler,
|
| 309 |
+
use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
|
| 310 |
+
)
|
| 311 |
+
else:
|
| 312 |
+
test_ds = None
|
| 313 |
+
return train_ds, val_ds, test_ds
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class PROTAC_DataModule(pl.LightningDataModule):
|
| 317 |
+
""" PyTorch Lightning DataModule for the PROTAC dataset.
|
| 318 |
+
|
| 319 |
+
TODO: Work in progress. It would be nice to wrap all information into a
|
| 320 |
+
single class, but it is not clear how to do it yet due to cross-validation
|
| 321 |
+
and the need to split the data into training, validation, and test sets
|
| 322 |
+
accordingly.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
protac_csv_filepath (str): The path to the PROTAC CSV file.
|
| 326 |
+
protein2embedding_filepath (str): The path to the protein to embedding dictionary.
|
| 327 |
+
cell2embedding_filepath (str): The path to the cell line to embedding dictionary.
|
| 328 |
+
pDC50_threshold (float): The threshold for the pDC50 value to consider a PROTAC active.
|
| 329 |
+
Dmax_threshold (float): The threshold for the Dmax value to consider a PROTAC active.
|
| 330 |
+
use_smote (bool): Whether to use SMOTE for oversampling.
|
| 331 |
+
smote_k_neighbors (int): The number of neighbors to use for SMOTE.
|
| 332 |
+
active_label (str): The column containing the active/inactive information.
|
| 333 |
+
disabled_embeddings (list): The list of embeddings to disable.
|
| 334 |
+
scaler (StandardScaler | dict): The scaler to use for the embeddings.
|
| 335 |
+
use_single_scaler (bool): Whether to use a single scaler for all features.
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
protac_csv_filepath: str,
|
| 341 |
+
protein2embedding_filepath: str,
|
| 342 |
+
cell2embedding_filepath: str,
|
| 343 |
+
pDC50_threshold: float = 6.0,
|
| 344 |
+
Dmax_threshold: float = 0.6,
|
| 345 |
+
use_smote: bool = True,
|
| 346 |
+
smote_k_neighbors: int = 5,
|
| 347 |
+
active_label: str = 'Active',
|
| 348 |
+
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
| 349 |
+
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
| 350 |
+
use_single_scaler: Optional[bool] = None,
|
| 351 |
+
):
|
| 352 |
+
super(PROTAC_DataModule, self).__init__()
|
| 353 |
+
|
| 354 |
+
# Load the PROTAC dataset
|
| 355 |
+
self.protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
|
| 356 |
+
# Map E3 Ligase Iap to IAP
|
| 357 |
+
self.protac_df['E3 Ligase'] = self.protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
|
| 358 |
+
self.protac_df[active_label] = self.protac_df.apply(
|
| 359 |
+
lambda x: is_active(
|
| 360 |
+
x['DC50 (nM)'],
|
| 361 |
+
x['Dmax (%)'],
|
| 362 |
+
pDC50_threshold=pDC50_threshold,
|
| 363 |
+
Dmax_threshold=Dmax_threshold,
|
| 364 |
+
),
|
| 365 |
+
axis=1,
|
| 366 |
+
)
|
| 367 |
+
self.smiles2fp, self.protac_df = self.get_smiles2fp_and_avg_tanimoto(self.protac_df)
|
| 368 |
+
self.active_df = self.protac_df[self.protac_df[active_label].notna()].copy()
|
| 369 |
+
|
| 370 |
+
# Load embedding dictionaries
|
| 371 |
+
self.protein2embedding = load_protein2embedding(protein2embedding_filepath)
|
| 372 |
+
self.cell2embedding = load_cell2embedding(cell2embedding_filepath)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def setup(self, stage: str):
|
| 376 |
+
self.train_ds, self.val_ds, self.test_ds = get_datasets(
|
| 377 |
+
self.train_df,
|
| 378 |
+
self.val_df,
|
| 379 |
+
self.test_df,
|
| 380 |
+
self.protein2embedding,
|
| 381 |
+
self.cell2embedding,
|
| 382 |
+
self.smiles2fp,
|
| 383 |
+
use_smote=self.use_smote,
|
| 384 |
+
smote_k_neighbors=self.smote_k_neighbors,
|
| 385 |
+
active_label=self.active_label,
|
| 386 |
+
disabled_embeddings=self.disabled_embeddings,
|
| 387 |
+
scaler=self.scaler,
|
| 388 |
+
use_single_scaler=self.use_single_scaler,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
def train_dataloader(self):
|
| 392 |
+
return DataLoader(self.train_ds, batch_size=32, shuffle=True)
|
| 393 |
+
|
| 394 |
+
def val_dataloader(self):
|
| 395 |
+
return DataLoader(self.val_ds, batch_size=32)
|
| 396 |
+
|
| 397 |
+
def test_dataloader(self):
|
| 398 |
+
return DataLoader(self.test_ds, batch_size=32)
|
| 399 |
+
|
| 400 |
+
@staticmethod
|
| 401 |
+
def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:
|
| 402 |
+
""" Get the indices of the test set using a random split.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
| 406 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
pd.Index: The indices of the test set.
|
| 410 |
+
"""
|
| 411 |
+
return active_df.sample(frac=test_split, random_state=42).index
|
| 412 |
+
|
| 413 |
+
@staticmethod
|
| 414 |
+
def get_e3_ligase_split_indices(active_df: pd.DataFrame) -> pd.Index:
|
| 415 |
+
""" Get the indices of the test set using the E3 ligase split.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
pd.Index: The indices of the test set.
|
| 422 |
+
"""
|
| 423 |
+
encoder = OrdinalEncoder()
|
| 424 |
+
active_df['E3 Group'] = encoder.fit_transform(active_df[['E3 Ligase']]).astype(int)
|
| 425 |
+
test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
|
| 426 |
+
return test_df.index
|
| 427 |
+
|
| 428 |
+
@staticmethod
|
| 429 |
+
def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple:
|
| 430 |
+
""" Get the SMILES to fingerprint dictionary and the average Tanimoto similarity.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
protac_df (pd.DataFrame): The DataFrame containing the PROTACs.
|
| 434 |
+
|
| 435 |
+
Returns:
|
| 436 |
+
tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
|
| 437 |
+
"""
|
| 438 |
+
unique_smiles = protac_df['Smiles'].unique().tolist()
|
| 439 |
+
|
| 440 |
+
smiles2fp = {}
|
| 441 |
+
for smiles in unique_smiles:
|
| 442 |
+
smiles2fp[smiles] = get_fingerprint(smiles)
|
| 443 |
+
|
| 444 |
+
tanimoto_matrix = defaultdict(list)
|
| 445 |
+
fps = list(smiles2fp.values())
|
| 446 |
+
|
| 447 |
+
# Compute all-against-all Tanimoto similarity using BulkTanimotoSimilarity
|
| 448 |
+
for i, (smiles1, fp1) in enumerate(zip(unique_smiles, fps)):
|
| 449 |
+
similarities = DataStructs.BulkTanimotoSimilarity(fp1, fps[i:]) # Only compute for i to end, avoiding duplicates
|
| 450 |
+
for j, similarity in enumerate(similarities):
|
| 451 |
+
distance = 1 - similarity
|
| 452 |
+
tanimoto_matrix[smiles1].append(distance) # Store as distance
|
| 453 |
+
if i != i + j:
|
| 454 |
+
tanimoto_matrix[unique_smiles[i + j]].append(distance) # Symmetric filling
|
| 455 |
+
|
| 456 |
+
# Calculate average Tanimoto distance for each unique SMILES
|
| 457 |
+
avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
|
| 458 |
+
protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
|
| 459 |
+
|
| 460 |
+
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
|
| 461 |
+
|
| 462 |
+
return smiles2fp, protac_df
|
| 463 |
+
|
| 464 |
+
@staticmethod
|
| 465 |
+
def get_tanimoto_split_indices(
|
| 466 |
+
active_df: pd.DataFrame,
|
| 467 |
+
active_label: str,
|
| 468 |
+
test_split: float,
|
| 469 |
+
n_bins_tanimoto: int = 200,
|
| 470 |
+
) -> pd.Index:
|
| 471 |
+
""" Get the indices of the test set using the Tanimoto-based split.
|
| 472 |
+
|
| 473 |
+
Args:
|
| 474 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
| 475 |
+
n_bins_tanimoto (int): The number of bins to use for the Tanimoto similarity.
|
| 476 |
+
|
| 477 |
+
Returns:
|
| 478 |
+
pd.Index: The indices of the test set.
|
| 479 |
+
"""
|
| 480 |
+
tanimoto_groups = pd.cut(active_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
|
| 481 |
+
encoder = OrdinalEncoder()
|
| 482 |
+
active_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
|
| 483 |
+
# Sort the groups so that samples with the highest tanimoto similarity,
|
| 484 |
+
# i.e., the "less similar" ones, are placed in the test set first
|
| 485 |
+
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
|
| 486 |
+
|
| 487 |
+
test_df = []
|
| 488 |
+
# For each group, get the number of active and inactive entries. Then, add those
|
| 489 |
+
# entries to the test_df if: 1) the test_df lenght + the group entries is less
|
| 490 |
+
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
| 491 |
+
# in the active_label in test_df is roughly 50%.
|
| 492 |
+
for group in tanimoto_groups:
|
| 493 |
+
group_df = active_df[active_df['Tanimoto Group'] == group]
|
| 494 |
+
if test_df == []:
|
| 495 |
+
test_df.append(group_df)
|
| 496 |
+
continue
|
| 497 |
+
|
| 498 |
+
num_entries = len(group_df)
|
| 499 |
+
num_active_group = group_df[active_label].sum()
|
| 500 |
+
num_inactive_group = num_entries - num_active_group
|
| 501 |
+
|
| 502 |
+
tmp_test_df = pd.concat(test_df)
|
| 503 |
+
num_entries_test = len(tmp_test_df)
|
| 504 |
+
num_active_test = tmp_test_df[active_label].sum()
|
| 505 |
+
num_inactive_test = num_entries_test - num_active_test
|
| 506 |
+
|
| 507 |
+
# Check if the group entries can be added to the test_df
|
| 508 |
+
if num_entries_test + num_entries < test_split * len(active_df):
|
| 509 |
+
# Add anything at the beggining
|
| 510 |
+
if num_entries_test + num_entries < test_split / 2 * len(active_df):
|
| 511 |
+
test_df.append(group_df)
|
| 512 |
+
continue
|
| 513 |
+
# Be more selective and make sure that the percentage of active and
|
| 514 |
+
# inactive is balanced
|
| 515 |
+
if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
|
| 516 |
+
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
|
| 517 |
+
test_df.append(group_df)
|
| 518 |
+
test_df = pd.concat(test_df)
|
| 519 |
+
return test_df.index
|
| 520 |
+
|
| 521 |
+
@staticmethod
|
| 522 |
+
def get_target_split_indices(active_df: pd.DataFrame, active_label: str, test_split: float) -> pd.Index:
|
| 523 |
+
""" Get the indices of the test set using the target-based split.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
|
| 527 |
+
active_label (str): The column containing the active/inactive information.
|
| 528 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
| 529 |
+
|
| 530 |
+
Returns:
|
| 531 |
+
pd.Index: The indices of the test set.
|
| 532 |
+
"""
|
| 533 |
+
encoder = OrdinalEncoder()
|
| 534 |
+
active_df['Uniprot Group'] = encoder.fit_transform(active_df[['Uniprot']]).astype(int)
|
| 535 |
+
|
| 536 |
+
test_df = []
|
| 537 |
+
# For each group, get the number of active and inactive entries. Then, add those
|
| 538 |
+
# entries to the test_df if: 1) the test_df lenght + the group entries is less
|
| 539 |
+
# 20% of the active_df lenght, and 2) the percentage of True and False entries
|
| 540 |
+
# in the active_label in test_df is roughly 50%.
|
| 541 |
+
# Start the loop from the groups containing the smallest number of entries.
|
| 542 |
+
for group in reversed(active_df['Uniprot'].value_counts().index):
|
| 543 |
+
group_df = active_df[active_df['Uniprot'] == group]
|
| 544 |
+
if test_df == []:
|
| 545 |
+
test_df.append(group_df)
|
| 546 |
+
continue
|
| 547 |
+
|
| 548 |
+
num_entries = len(group_df)
|
| 549 |
+
num_active_group = group_df[active_label].sum()
|
| 550 |
+
num_inactive_group = num_entries - num_active_group
|
| 551 |
+
|
| 552 |
+
tmp_test_df = pd.concat(test_df)
|
| 553 |
+
num_entries_test = len(tmp_test_df)
|
| 554 |
+
num_active_test = tmp_test_df[active_label].sum()
|
| 555 |
+
num_inactive_test = num_entries_test - num_active_test
|
| 556 |
+
|
| 557 |
+
# Check if the group entries can be added to the test_df
|
| 558 |
+
if num_entries_test + num_entries < test_split * len(active_df):
|
| 559 |
+
# Add anything at the beggining
|
| 560 |
+
if num_entries_test + num_entries < test_split / 2 * len(active_df):
|
| 561 |
+
test_df.append(group_df)
|
| 562 |
+
continue
|
| 563 |
+
# Be more selective and make sure that the percentage of active and
|
| 564 |
+
# inactive is balanced
|
| 565 |
+
if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
|
| 566 |
+
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
|
| 567 |
+
test_df.append(group_df)
|
| 568 |
+
test_df = pd.concat(test_df)
|
| 569 |
+
return test_df.index
|
protac_degradation_predictor/pytorch_models.py
CHANGED
|
@@ -3,7 +3,7 @@ import pickle
|
|
| 3 |
import logging
|
| 4 |
from typing import Literal, List, Tuple, Optional, Dict
|
| 5 |
|
| 6 |
-
from .protac_dataset import PROTAC_Dataset
|
| 7 |
from .config import config
|
| 8 |
|
| 9 |
import pandas as pd
|
|
@@ -38,7 +38,7 @@ class PROTAC_Predictor(nn.Module):
|
|
| 38 |
dropout: float = 0.2,
|
| 39 |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
| 40 |
use_batch_norm: bool = False,
|
| 41 |
-
disabled_embeddings:
|
| 42 |
):
|
| 43 |
""" Initialize the PROTAC model.
|
| 44 |
|
|
@@ -69,17 +69,17 @@ class PROTAC_Predictor(nn.Module):
|
|
| 69 |
# and can be summed on a "similar scale".
|
| 70 |
if self.join_embeddings != 'beginning':
|
| 71 |
if 'poi' not in self.disabled_embeddings:
|
| 72 |
-
self.
|
| 73 |
nn.Linear(poi_emb_dim, hidden_dim),
|
| 74 |
nn.Softmax(dim=1),
|
| 75 |
)
|
| 76 |
if 'e3' not in self.disabled_embeddings:
|
| 77 |
-
self.
|
| 78 |
nn.Linear(e3_emb_dim, hidden_dim),
|
| 79 |
nn.Softmax(dim=1),
|
| 80 |
)
|
| 81 |
if 'cell' not in self.disabled_embeddings:
|
| 82 |
-
self.
|
| 83 |
nn.Linear(cell_emb_dim, hidden_dim),
|
| 84 |
nn.Softmax(dim=1),
|
| 85 |
)
|
|
@@ -95,12 +95,12 @@ class PROTAC_Predictor(nn.Module):
|
|
| 95 |
joint_dim += poi_emb_dim if 'poi' not in self.disabled_embeddings else 0
|
| 96 |
joint_dim += e3_emb_dim if 'e3' not in self.disabled_embeddings else 0
|
| 97 |
joint_dim += cell_emb_dim if 'cell' not in self.disabled_embeddings else 0
|
|
|
|
| 98 |
elif self.join_embeddings == 'concat':
|
| 99 |
joint_dim = hidden_dim * (4 - len(self.disabled_embeddings))
|
| 100 |
elif self.join_embeddings == 'sum':
|
| 101 |
joint_dim = hidden_dim
|
| 102 |
|
| 103 |
-
self.fc0 = nn.Linear(joint_dim, joint_dim)
|
| 104 |
self.fc1 = nn.Linear(joint_dim, hidden_dim)
|
| 105 |
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 106 |
self.fc3 = nn.Linear(hidden_dim, 1)
|
|
@@ -125,11 +125,11 @@ class PROTAC_Predictor(nn.Module):
|
|
| 125 |
x = self.dropout(F.relu(self.fc0(x)))
|
| 126 |
else:
|
| 127 |
if 'poi' not in self.disabled_embeddings:
|
| 128 |
-
embeddings.append(self.
|
| 129 |
if 'e3' not in self.disabled_embeddings:
|
| 130 |
-
embeddings.append(self.
|
| 131 |
if 'cell' not in self.disabled_embeddings:
|
| 132 |
-
embeddings.append(self.
|
| 133 |
if 'smiles' not in self.disabled_embeddings:
|
| 134 |
embeddings.append(self.smiles_emb(smiles_emb))
|
| 135 |
if self.join_embeddings == 'concat':
|
|
@@ -163,7 +163,7 @@ class PROTAC_Model(pl.LightningModule):
|
|
| 163 |
train_dataset: PROTAC_Dataset = None,
|
| 164 |
val_dataset: PROTAC_Dataset = None,
|
| 165 |
test_dataset: PROTAC_Dataset = None,
|
| 166 |
-
disabled_embeddings:
|
| 167 |
apply_scaling: bool = True,
|
| 168 |
):
|
| 169 |
""" Initialize the PROTAC Pytorch Lightning model.
|
|
@@ -217,7 +217,7 @@ class PROTAC_Model(pl.LightningModule):
|
|
| 217 |
dropout=dropout,
|
| 218 |
join_embeddings=join_embeddings,
|
| 219 |
use_batch_norm=use_batch_norm,
|
| 220 |
-
disabled_embeddings=
|
| 221 |
)
|
| 222 |
|
| 223 |
stages = ['train_metrics', 'val_metrics', 'test_metrics']
|
|
@@ -429,7 +429,7 @@ def train_model(
|
|
| 429 |
logger_name: str = 'protac',
|
| 430 |
enable_checkpointing: bool = False,
|
| 431 |
checkpoint_model_name: str = 'protac',
|
| 432 |
-
disabled_embeddings: List[
|
| 433 |
return_predictions: bool = False,
|
| 434 |
) -> tuple:
|
| 435 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
|
@@ -453,31 +453,18 @@ def train_model(
|
|
| 453 |
Returns:
|
| 454 |
tuple: The trained model, the trainer, and the metrics.
|
| 455 |
"""
|
| 456 |
-
|
| 457 |
-
train_ds = PROTAC_Dataset(
|
| 458 |
train_df,
|
| 459 |
-
protein2embedding,
|
| 460 |
-
cell2embedding,
|
| 461 |
-
smiles2fp,
|
| 462 |
-
use_smote=use_smote,
|
| 463 |
-
oversampler=oversampler if use_smote else None,
|
| 464 |
-
active_label=active_label,
|
| 465 |
-
)
|
| 466 |
-
val_ds = PROTAC_Dataset(
|
| 467 |
val_df,
|
|
|
|
| 468 |
protein2embedding,
|
| 469 |
cell2embedding,
|
| 470 |
smiles2fp,
|
|
|
|
|
|
|
| 471 |
active_label=active_label,
|
|
|
|
| 472 |
)
|
| 473 |
-
if test_df is not None:
|
| 474 |
-
test_ds = PROTAC_Dataset(
|
| 475 |
-
test_df,
|
| 476 |
-
protein2embedding,
|
| 477 |
-
cell2embedding,
|
| 478 |
-
smiles2fp,
|
| 479 |
-
active_label=active_label,
|
| 480 |
-
)
|
| 481 |
loggers = [
|
| 482 |
pl.loggers.TensorBoardLogger(
|
| 483 |
save_dir=logger_save_dir,
|
|
@@ -505,7 +492,7 @@ def train_model(
|
|
| 505 |
),
|
| 506 |
pl.callbacks.EarlyStopping(
|
| 507 |
monitor='val_loss',
|
| 508 |
-
patience=
|
| 509 |
mode='min',
|
| 510 |
verbose=False,
|
| 511 |
),
|
|
@@ -586,6 +573,36 @@ def train_model(
|
|
| 586 |
return model, trainer, metrics
|
| 587 |
|
| 588 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
def load_model(
|
| 590 |
ckpt_path: str,
|
| 591 |
) -> PROTAC_Model:
|
|
|
|
| 3 |
import logging
|
| 4 |
from typing import Literal, List, Tuple, Optional, Dict
|
| 5 |
|
| 6 |
+
from .protac_dataset import PROTAC_Dataset, get_datasets
|
| 7 |
from .config import config
|
| 8 |
|
| 9 |
import pandas as pd
|
|
|
|
| 38 |
dropout: float = 0.2,
|
| 39 |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
| 40 |
use_batch_norm: bool = False,
|
| 41 |
+
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
| 42 |
):
|
| 43 |
""" Initialize the PROTAC model.
|
| 44 |
|
|
|
|
| 69 |
# and can be summed on a "similar scale".
|
| 70 |
if self.join_embeddings != 'beginning':
|
| 71 |
if 'poi' not in self.disabled_embeddings:
|
| 72 |
+
self.poi_fc = nn.Sequential(
|
| 73 |
nn.Linear(poi_emb_dim, hidden_dim),
|
| 74 |
nn.Softmax(dim=1),
|
| 75 |
)
|
| 76 |
if 'e3' not in self.disabled_embeddings:
|
| 77 |
+
self.e3_fc = nn.Sequential(
|
| 78 |
nn.Linear(e3_emb_dim, hidden_dim),
|
| 79 |
nn.Softmax(dim=1),
|
| 80 |
)
|
| 81 |
if 'cell' not in self.disabled_embeddings:
|
| 82 |
+
self.cell_fc = nn.Sequential(
|
| 83 |
nn.Linear(cell_emb_dim, hidden_dim),
|
| 84 |
nn.Softmax(dim=1),
|
| 85 |
)
|
|
|
|
| 95 |
joint_dim += poi_emb_dim if 'poi' not in self.disabled_embeddings else 0
|
| 96 |
joint_dim += e3_emb_dim if 'e3' not in self.disabled_embeddings else 0
|
| 97 |
joint_dim += cell_emb_dim if 'cell' not in self.disabled_embeddings else 0
|
| 98 |
+
self.fc0 = nn.Linear(joint_dim, joint_dim)
|
| 99 |
elif self.join_embeddings == 'concat':
|
| 100 |
joint_dim = hidden_dim * (4 - len(self.disabled_embeddings))
|
| 101 |
elif self.join_embeddings == 'sum':
|
| 102 |
joint_dim = hidden_dim
|
| 103 |
|
|
|
|
| 104 |
self.fc1 = nn.Linear(joint_dim, hidden_dim)
|
| 105 |
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
| 106 |
self.fc3 = nn.Linear(hidden_dim, 1)
|
|
|
|
| 125 |
x = self.dropout(F.relu(self.fc0(x)))
|
| 126 |
else:
|
| 127 |
if 'poi' not in self.disabled_embeddings:
|
| 128 |
+
embeddings.append(self.poi_fc(poi_emb))
|
| 129 |
if 'e3' not in self.disabled_embeddings:
|
| 130 |
+
embeddings.append(self.e3_fc(e3_emb))
|
| 131 |
if 'cell' not in self.disabled_embeddings:
|
| 132 |
+
embeddings.append(self.cell_fc(cell_emb))
|
| 133 |
if 'smiles' not in self.disabled_embeddings:
|
| 134 |
embeddings.append(self.smiles_emb(smiles_emb))
|
| 135 |
if self.join_embeddings == 'concat':
|
|
|
|
| 163 |
train_dataset: PROTAC_Dataset = None,
|
| 164 |
val_dataset: PROTAC_Dataset = None,
|
| 165 |
test_dataset: PROTAC_Dataset = None,
|
| 166 |
+
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
| 167 |
apply_scaling: bool = True,
|
| 168 |
):
|
| 169 |
""" Initialize the PROTAC Pytorch Lightning model.
|
|
|
|
| 217 |
dropout=dropout,
|
| 218 |
join_embeddings=join_embeddings,
|
| 219 |
use_batch_norm=use_batch_norm,
|
| 220 |
+
disabled_embeddings=[], # NOTE: This is handled in the PROTAC_Dataset classes
|
| 221 |
)
|
| 222 |
|
| 223 |
stages = ['train_metrics', 'val_metrics', 'test_metrics']
|
|
|
|
| 429 |
logger_name: str = 'protac',
|
| 430 |
enable_checkpointing: bool = False,
|
| 431 |
checkpoint_model_name: str = 'protac',
|
| 432 |
+
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
| 433 |
return_predictions: bool = False,
|
| 434 |
) -> tuple:
|
| 435 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
|
|
|
| 453 |
Returns:
|
| 454 |
tuple: The trained model, the trainer, and the metrics.
|
| 455 |
"""
|
| 456 |
+
train_ds, val_ds, test_ds = get_datasets(
|
|
|
|
| 457 |
train_df,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
val_df,
|
| 459 |
+
test_df,
|
| 460 |
protein2embedding,
|
| 461 |
cell2embedding,
|
| 462 |
smiles2fp,
|
| 463 |
+
use_smote=use_smote,
|
| 464 |
+
smote_k_neighbors=smote_k_neighbors,
|
| 465 |
active_label=active_label,
|
| 466 |
+
disabled_embeddings=disabled_embeddings,
|
| 467 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
loggers = [
|
| 469 |
pl.loggers.TensorBoardLogger(
|
| 470 |
save_dir=logger_save_dir,
|
|
|
|
| 492 |
),
|
| 493 |
pl.callbacks.EarlyStopping(
|
| 494 |
monitor='val_loss',
|
| 495 |
+
patience=5, # Original: 5
|
| 496 |
mode='min',
|
| 497 |
verbose=False,
|
| 498 |
),
|
|
|
|
| 573 |
return model, trainer, metrics
|
| 574 |
|
| 575 |
|
| 576 |
+
def evaluate_model(
|
| 577 |
+
model: PROTAC_Model,
|
| 578 |
+
trainer: pl.Trainer,
|
| 579 |
+
val_ds: PROTAC_Dataset,
|
| 580 |
+
test_ds: Optional[PROTAC_Dataset] = None,
|
| 581 |
+
batch_size: int = 128,
|
| 582 |
+
) -> tuple:
|
| 583 |
+
""" Evaluate a PROTAC model using the given datasets. """
|
| 584 |
+
ret = {}
|
| 585 |
+
|
| 586 |
+
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
|
| 587 |
+
val_metrics = trainer.validate(model, val_dl, verbose=False)[0]
|
| 588 |
+
val_metrics = {m: v for m, v in val_metrics.items() if 'val' in m}
|
| 589 |
+
# Get predictions on validation set
|
| 590 |
+
val_pred = torch.cat(trainer.predict(model, val_dl)).squeeze()
|
| 591 |
+
ret['val_metrics'] = val_metrics
|
| 592 |
+
ret['val_pred'] = val_pred
|
| 593 |
+
|
| 594 |
+
if test_ds is not None:
|
| 595 |
+
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
|
| 596 |
+
test_metrics = trainer.test(model, test_dl, verbose=False)[0]
|
| 597 |
+
test_metrics = {m: v for m, v in test_metrics.items() if 'test' in m}
|
| 598 |
+
# Get predictions on test set
|
| 599 |
+
test_pred = torch.cat(trainer.predict(model, test_dl)).squeeze()
|
| 600 |
+
ret['test_metrics'] = test_metrics
|
| 601 |
+
ret['test_pred'] = test_pred
|
| 602 |
+
|
| 603 |
+
return ret
|
| 604 |
+
|
| 605 |
+
|
| 606 |
def load_model(
|
| 607 |
ckpt_path: str,
|
| 608 |
) -> PROTAC_Model:
|
src/plot_experiment_results.py
CHANGED
|
@@ -12,7 +12,7 @@ import numpy as np
|
|
| 12 |
palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
|
| 13 |
|
| 14 |
|
| 15 |
-
def plot_training_curves(df, split_type, stage='test'):
|
| 16 |
Stage = 'Test' if stage == 'test' else 'Validation'
|
| 17 |
|
| 18 |
# Clean the data
|
|
@@ -22,20 +22,28 @@ def plot_training_curves(df, split_type, stage='test'):
|
|
| 22 |
df = df.apply(pd.to_numeric, errors='coerce')
|
| 23 |
|
| 24 |
# Group by 'epoch' and aggregate by mean
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
fig, ax = plt.subplots(3, 1, figsize=(10, 15))
|
| 28 |
|
| 29 |
# Plot training loss
|
| 30 |
-
ax[0].plot(epoch_data.index, epoch_data['train_loss_epoch'], label='Training Loss')
|
| 31 |
-
ax[0].plot(epoch_data.index, epoch_data[f'{stage}_loss'], label=f'{Stage} Loss', linestyle='--')
|
|
|
|
|
|
|
|
|
|
| 32 |
ax[0].set_ylabel('Loss')
|
| 33 |
ax[0].legend(loc='lower right')
|
| 34 |
ax[0].grid(axis='both', alpha=0.5)
|
| 35 |
|
| 36 |
# Plot training accuracy
|
| 37 |
-
ax[1].plot(epoch_data.index, epoch_data['train_acc_epoch'], label='Training Accuracy')
|
| 38 |
-
ax[1].plot(epoch_data.index, epoch_data[f'{stage}_acc'], label=f'{Stage} Accuracy', linestyle='--')
|
|
|
|
|
|
|
| 39 |
ax[1].set_ylabel('Accuracy')
|
| 40 |
ax[1].legend(loc='lower right')
|
| 41 |
ax[1].grid(axis='both', alpha=0.5)
|
|
@@ -45,8 +53,10 @@ def plot_training_curves(df, split_type, stage='test'):
|
|
| 45 |
ax[1].yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
|
| 46 |
|
| 47 |
# Plot training ROC-AUC
|
| 48 |
-
ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
|
| 49 |
-
ax[2].plot(epoch_data.index, epoch_data[f'{stage}_roc_auc'], label=f'{Stage} ROC-AUC', linestyle='--')
|
|
|
|
|
|
|
| 50 |
ax[2].set_ylabel('ROC-AUC')
|
| 51 |
ax[2].legend(loc='lower right')
|
| 52 |
ax[2].grid(axis='both', alpha=0.5)
|
|
@@ -167,6 +177,7 @@ def plot_ablation_study(report):
|
|
| 167 |
'disabled poi',
|
| 168 |
'disabled e3',
|
| 169 |
'disabled cell',
|
|
|
|
| 170 |
'disabled poi e3 smiles',
|
| 171 |
'disabled poi e3 cell',
|
| 172 |
]
|
|
@@ -226,6 +237,7 @@ def plot_ablation_study(report):
|
|
| 226 |
'disabled e3': 'Disabled E3 information',
|
| 227 |
'disabled poi': 'Disabled target information',
|
| 228 |
'disabled cell': 'Disabled cell information',
|
|
|
|
| 229 |
'disabled poi e3 smiles': 'Disabled compound, E3, and target info\n(only cell information left)',
|
| 230 |
'disabled poi e3 cell': 'Disabled cell, E3, and target info\n(only compound information left)',
|
| 231 |
})
|
|
@@ -323,6 +335,7 @@ def main():
|
|
| 323 |
|
| 324 |
|
| 325 |
for split_type in ['random', 'tanimoto', 'uniprot']:
|
|
|
|
| 326 |
for i in range(n_models_for_test):
|
| 327 |
logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
|
| 328 |
metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
|
|
@@ -330,36 +343,41 @@ def main():
|
|
| 330 |
# Rename 'val_' columns to 'test_' columns
|
| 331 |
metrics = metrics.rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
|
| 332 |
plot_training_curves(metrics, f'{split_type}_best_model_n{i}')
|
|
|
|
|
|
|
| 333 |
|
|
|
|
| 334 |
for i in range(cv_n_folds):
|
| 335 |
# logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
|
| 336 |
logs_dir = f'logs_{report_base_name}_{split_type}_{split_type}_cv_model_fold{i}'
|
| 337 |
metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
|
| 338 |
metrics['fold'] = i
|
| 339 |
plot_training_curves(metrics, f'{split_type}_cv_model_fold{i}', stage='val')
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
reports['
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
]
|
| 360 |
-
|
| 361 |
-
#
|
| 362 |
-
|
|
|
|
|
|
|
| 363 |
|
| 364 |
|
| 365 |
if __name__ == '__main__':
|
|
|
|
| 12 |
palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
|
| 13 |
|
| 14 |
|
| 15 |
+
def plot_training_curves(df, split_type, stage='test', multimodels=False, groupby='model_id'):
|
| 16 |
Stage = 'Test' if stage == 'test' else 'Validation'
|
| 17 |
|
| 18 |
# Clean the data
|
|
|
|
| 22 |
df = df.apply(pd.to_numeric, errors='coerce')
|
| 23 |
|
| 24 |
# Group by 'epoch' and aggregate by mean
|
| 25 |
+
if multimodels:
|
| 26 |
+
epoch_data = df.groupby([groupby, 'epoch']).mean().reset_index()
|
| 27 |
+
else:
|
| 28 |
+
epoch_data = df.groupby('epoch').mean().reset_index()
|
| 29 |
|
| 30 |
fig, ax = plt.subplots(3, 1, figsize=(10, 15))
|
| 31 |
|
| 32 |
# Plot training loss
|
| 33 |
+
# ax[0].plot(epoch_data.index, epoch_data['train_loss_epoch'], label='Training Loss')
|
| 34 |
+
# ax[0].plot(epoch_data.index, epoch_data[f'{stage}_loss'], label=f'{Stage} Loss', linestyle='--')
|
| 35 |
+
sns.lineplot(data=epoch_data, x='epoch', y='train_loss_epoch', ax=ax[0], label='Training Loss')
|
| 36 |
+
sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_loss', ax=ax[0], label=f'{Stage} Loss', linestyle='--')
|
| 37 |
+
|
| 38 |
ax[0].set_ylabel('Loss')
|
| 39 |
ax[0].legend(loc='lower right')
|
| 40 |
ax[0].grid(axis='both', alpha=0.5)
|
| 41 |
|
| 42 |
# Plot training accuracy
|
| 43 |
+
# ax[1].plot(epoch_data.index, epoch_data['train_acc_epoch'], label='Training Accuracy')
|
| 44 |
+
# ax[1].plot(epoch_data.index, epoch_data[f'{stage}_acc'], label=f'{Stage} Accuracy', linestyle='--')
|
| 45 |
+
sns.lineplot(data=epoch_data, x='epoch', y='train_acc_epoch', ax=ax[1], label='Training Accuracy')
|
| 46 |
+
sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_acc', ax=ax[1], label=f'{Stage} Accuracy', linestyle='--')
|
| 47 |
ax[1].set_ylabel('Accuracy')
|
| 48 |
ax[1].legend(loc='lower right')
|
| 49 |
ax[1].grid(axis='both', alpha=0.5)
|
|
|
|
| 53 |
ax[1].yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
|
| 54 |
|
| 55 |
# Plot training ROC-AUC
|
| 56 |
+
# ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
|
| 57 |
+
# ax[2].plot(epoch_data.index, epoch_data[f'{stage}_roc_auc'], label=f'{Stage} ROC-AUC', linestyle='--')
|
| 58 |
+
sns.lineplot(data=epoch_data, x='epoch', y='train_roc_auc_epoch', ax=ax[2], label='Training ROC-AUC')
|
| 59 |
+
sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_roc_auc', ax=ax[2], label=f'{Stage} ROC-AUC', linestyle='--')
|
| 60 |
ax[2].set_ylabel('ROC-AUC')
|
| 61 |
ax[2].legend(loc='lower right')
|
| 62 |
ax[2].grid(axis='both', alpha=0.5)
|
|
|
|
| 177 |
'disabled poi',
|
| 178 |
'disabled e3',
|
| 179 |
'disabled cell',
|
| 180 |
+
'disabled poi e3',
|
| 181 |
'disabled poi e3 smiles',
|
| 182 |
'disabled poi e3 cell',
|
| 183 |
]
|
|
|
|
| 237 |
'disabled e3': 'Disabled E3 information',
|
| 238 |
'disabled poi': 'Disabled target information',
|
| 239 |
'disabled cell': 'Disabled cell information',
|
| 240 |
+
'disabled poi e3': 'Disabled E3 and target info',
|
| 241 |
'disabled poi e3 smiles': 'Disabled compound, E3, and target info\n(only cell information left)',
|
| 242 |
'disabled poi e3 cell': 'Disabled cell, E3, and target info\n(only compound information left)',
|
| 243 |
})
|
|
|
|
| 335 |
|
| 336 |
|
| 337 |
for split_type in ['random', 'tanimoto', 'uniprot']:
|
| 338 |
+
split_metrics = []
|
| 339 |
for i in range(n_models_for_test):
|
| 340 |
logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
|
| 341 |
metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
|
|
|
|
| 343 |
# Rename 'val_' columns to 'test_' columns
|
| 344 |
metrics = metrics.rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
|
| 345 |
plot_training_curves(metrics, f'{split_type}_best_model_n{i}')
|
| 346 |
+
split_metrics.append(metrics)
|
| 347 |
+
plot_training_curves(pd.concat(split_metrics), f'{split_type}_best_model', multimodels=True)
|
| 348 |
|
| 349 |
+
split_metrics_cv = []
|
| 350 |
for i in range(cv_n_folds):
|
| 351 |
# logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
|
| 352 |
logs_dir = f'logs_{report_base_name}_{split_type}_{split_type}_cv_model_fold{i}'
|
| 353 |
metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
|
| 354 |
metrics['fold'] = i
|
| 355 |
plot_training_curves(metrics, f'{split_type}_cv_model_fold{i}', stage='val')
|
| 356 |
+
split_metrics_cv.append(metrics)
|
| 357 |
+
plot_training_curves(pd.concat(split_metrics_cv), f'{split_type}_cv_model', stage='val', multimodels=True, groupby='fold')
|
| 358 |
+
|
| 359 |
+
# plot_performance_metrics(
|
| 360 |
+
# reports['cv_train'],
|
| 361 |
+
# reports['test'],
|
| 362 |
+
# title=f'{active_name}_metrics',
|
| 363 |
+
# )
|
| 364 |
+
|
| 365 |
+
# plot_performance_metrics(
|
| 366 |
+
# reports['cv_train'],
|
| 367 |
+
# reports['majority_vote'][reports['majority_vote']['cv_models'] == False],
|
| 368 |
+
# title=f'{active_name}_metrics_majority_vote',
|
| 369 |
+
# )
|
| 370 |
+
|
| 371 |
+
# plot_majority_voting_performance(reports['majority_vote'])
|
| 372 |
+
|
| 373 |
+
# reports['test']['disabled_embeddings'] = pd.NA
|
| 374 |
+
# plot_ablation_study(pd.concat([
|
| 375 |
+
# reports['ablation'],
|
| 376 |
+
# reports['test'],
|
| 377 |
+
# ]))
|
| 378 |
+
|
| 379 |
+
# # Plot hyperparameter optimization results to markdown
|
| 380 |
+
# print(reports['hparam'][['split_type', 'hidden_dim', 'learning_rate', 'dropout', 'use_smote', 'smote_k_neighbors']].to_markdown(index=False))
|
| 381 |
|
| 382 |
|
| 383 |
if __name__ == '__main__':
|
src/run_experiments.py
CHANGED
|
@@ -305,7 +305,7 @@ def main(
|
|
| 305 |
|
| 306 |
# Start the experiment
|
| 307 |
experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
|
| 308 |
-
optuna_reports = pdp.hyperparameter_tuning_and_training(
|
| 309 |
protein2embedding=protein2embedding,
|
| 310 |
cell2embedding=cell2embedding,
|
| 311 |
smiles2fp=smiles2fp,
|
|
|
|
| 305 |
|
| 306 |
# Start the experiment
|
| 307 |
experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
|
| 308 |
+
optuna_reports = pdp.hyperparameter_tuning_and_training(
|
| 309 |
protein2embedding=protein2embedding,
|
| 310 |
cell2embedding=cell2embedding,
|
| 311 |
smiles2fp=smiles2fp,
|