Commit ·
de956c8
1
Parent(s): 6a5a99e
Added model tests + checkpointing of the scaler object
Browse files
protac_degradation_predictor/__init__.py
CHANGED
|
@@ -5,6 +5,8 @@ from .data_utils import (
|
|
| 5 |
is_active,
|
| 6 |
)
|
| 7 |
from .pytorch_models import (
|
|
|
|
|
|
|
| 8 |
train_model,
|
| 9 |
)
|
| 10 |
from .sklearn_models import (
|
|
|
|
| 5 |
is_active,
|
| 6 |
)
|
| 7 |
from .pytorch_models import (
|
| 8 |
+
PROTAC_Predictor,
|
| 9 |
+
PROTAC_Model,
|
| 10 |
train_model,
|
| 11 |
)
|
| 12 |
from .sklearn_models import (
|
protac_degradation_predictor/optuna_utils.py
CHANGED
|
@@ -73,7 +73,7 @@ def pytorch_model_objective(
|
|
| 73 |
dropout = trial.suggest_float('dropout', *dropout_options)
|
| 74 |
|
| 75 |
# Start the CV over the folds
|
| 76 |
-
X = train_val_df.drop(columns=active_label)
|
| 77 |
y = train_val_df[active_label].tolist()
|
| 78 |
report = []
|
| 79 |
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
|
|
@@ -108,11 +108,11 @@ def pytorch_model_objective(
|
|
| 108 |
# At each fold, train and evaluate the Pytorch model
|
| 109 |
# Train the model with the current set of hyperparameters
|
| 110 |
_, _, metrics = train_model(
|
| 111 |
-
protein2embedding,
|
| 112 |
-
cell2embedding,
|
| 113 |
-
smiles2fp,
|
| 114 |
-
train_df,
|
| 115 |
-
val_df,
|
| 116 |
hidden_dim=hidden_dim,
|
| 117 |
batch_size=batch_size,
|
| 118 |
join_embeddings=join_embeddings,
|
|
@@ -223,7 +223,7 @@ def hyperparameter_tuning_and_training(
|
|
| 223 |
test_report = []
|
| 224 |
# Retrain N models with the best hyperparameters (measure model uncertainty)
|
| 225 |
for i in range(n_models_for_test):
|
| 226 |
-
pl.seed_everything(42 + i)
|
| 227 |
_, _, metrics = train_model(
|
| 228 |
protein2embedding=protein2embedding,
|
| 229 |
cell2embedding=cell2embedding,
|
|
@@ -235,9 +235,9 @@ def hyperparameter_tuning_and_training(
|
|
| 235 |
active_label=active_label,
|
| 236 |
max_epochs=max_epochs,
|
| 237 |
disabled_embeddings=[],
|
| 238 |
-
logger_name=f'{logger_name}
|
| 239 |
enable_checkpointing=True,
|
| 240 |
-
checkpoint_model_name=f'
|
| 241 |
**study.best_params,
|
| 242 |
)
|
| 243 |
# Rename the keys in the metrics dictionary
|
|
@@ -245,6 +245,9 @@ def hyperparameter_tuning_and_training(
|
|
| 245 |
metrics = {k.replace('train_', 'train_val_'): v for k, v in metrics.items()}
|
| 246 |
metrics['model_type'] = 'Pytorch'
|
| 247 |
metrics['test_model_id'] = i
|
|
|
|
|
|
|
|
|
|
| 248 |
test_report.append(metrics.copy())
|
| 249 |
test_report = pd.DataFrame(test_report)
|
| 250 |
|
|
|
|
| 73 |
dropout = trial.suggest_float('dropout', *dropout_options)
|
| 74 |
|
| 75 |
# Start the CV over the folds
|
| 76 |
+
X = train_val_df.copy().drop(columns=active_label)
|
| 77 |
y = train_val_df[active_label].tolist()
|
| 78 |
report = []
|
| 79 |
for k, (train_index, val_index) in enumerate(kf.split(X, y, groups)):
|
|
|
|
| 108 |
# At each fold, train and evaluate the Pytorch model
|
| 109 |
# Train the model with the current set of hyperparameters
|
| 110 |
_, _, metrics = train_model(
|
| 111 |
+
protein2embedding=protein2embedding,
|
| 112 |
+
cell2embedding=cell2embedding,
|
| 113 |
+
smiles2fp=smiles2fp,
|
| 114 |
+
train_df=train_df,
|
| 115 |
+
val_df=val_df,
|
| 116 |
hidden_dim=hidden_dim,
|
| 117 |
batch_size=batch_size,
|
| 118 |
join_embeddings=join_embeddings,
|
|
|
|
| 223 |
test_report = []
|
| 224 |
# Retrain N models with the best hyperparameters (measure model uncertainty)
|
| 225 |
for i in range(n_models_for_test):
|
| 226 |
+
pl.seed_everything(42 + i + 1)
|
| 227 |
_, _, metrics = train_model(
|
| 228 |
protein2embedding=protein2embedding,
|
| 229 |
cell2embedding=cell2embedding,
|
|
|
|
| 235 |
active_label=active_label,
|
| 236 |
max_epochs=max_epochs,
|
| 237 |
disabled_embeddings=[],
|
| 238 |
+
logger_name=f'{logger_name}_best_model_n{i}',
|
| 239 |
enable_checkpointing=True,
|
| 240 |
+
checkpoint_model_name=f'best_model_n{i}_{split_type}',
|
| 241 |
**study.best_params,
|
| 242 |
)
|
| 243 |
# Rename the keys in the metrics dictionary
|
|
|
|
| 245 |
metrics = {k.replace('train_', 'train_val_'): v for k, v in metrics.items()}
|
| 246 |
metrics['model_type'] = 'Pytorch'
|
| 247 |
metrics['test_model_id'] = i
|
| 248 |
+
metrics['test_len'] = len(test_df)
|
| 249 |
+
metrics['test_active_perc'] = test_df[active_label].sum() / len(test_df)
|
| 250 |
+
metrics['test_inactive_perc'] = (len(test_df) - test_df[active_label].sum()) / len(test_df)
|
| 251 |
test_report.append(metrics.copy())
|
| 252 |
test_report = pd.DataFrame(test_report)
|
| 253 |
|
protac_degradation_predictor/pytorch_models.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
import warnings
|
|
|
|
|
|
|
| 2 |
from typing import Literal, List, Tuple, Optional, Dict
|
| 3 |
|
| 4 |
from .protac_dataset import PROTAC_Dataset
|
|
@@ -125,7 +127,6 @@ class PROTAC_Predictor(nn.Module):
|
|
| 125 |
return x
|
| 126 |
|
| 127 |
|
| 128 |
-
|
| 129 |
class PROTAC_Model(pl.LightningModule):
|
| 130 |
|
| 131 |
def __init__(
|
|
@@ -218,13 +219,26 @@ class PROTAC_Model(pl.LightningModule):
|
|
| 218 |
'''
|
| 219 |
|
| 220 |
# Apply scaling in datasets
|
| 221 |
-
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
self.scalers = self.train_dataset.fit_scaling(use_single_scaler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
self.train_dataset.apply_scaling(self.scalers, use_single_scaler)
|
|
|
|
| 225 |
self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
|
| 226 |
-
|
| 227 |
-
|
| 228 |
|
| 229 |
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
|
| 230 |
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
|
|
@@ -316,6 +330,23 @@ class PROTAC_Model(pl.LightningModule):
|
|
| 316 |
batch_size=self.batch_size,
|
| 317 |
shuffle=False,
|
| 318 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
|
| 321 |
def train_model(
|
|
@@ -421,7 +452,7 @@ def train_model(
|
|
| 421 |
monitor='val_acc',
|
| 422 |
mode='max',
|
| 423 |
verbose=False,
|
| 424 |
-
filename=checkpoint_model_name + '-{epoch}-{
|
| 425 |
))
|
| 426 |
# Define Trainer
|
| 427 |
trainer = pl.Trainer(
|
|
@@ -455,6 +486,9 @@ def train_model(
|
|
| 455 |
warnings.simplefilter("ignore")
|
| 456 |
trainer.fit(model)
|
| 457 |
metrics = trainer.validate(model, verbose=False)[0]
|
|
|
|
|
|
|
|
|
|
| 458 |
if test_df is not None:
|
| 459 |
test_metrics = trainer.test(model, verbose=False)[0]
|
| 460 |
metrics.update(test_metrics)
|
|
@@ -472,6 +506,15 @@ def load_model(
|
|
| 472 |
Returns:
|
| 473 |
PROTAC_Model: The loaded model.
|
| 474 |
"""
|
| 475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
model.eval()
|
| 477 |
return model
|
|
|
|
| 1 |
import warnings
|
| 2 |
+
import pickle
|
| 3 |
+
import logging
|
| 4 |
from typing import Literal, List, Tuple, Optional, Dict
|
| 5 |
|
| 6 |
from .protac_dataset import PROTAC_Dataset
|
|
|
|
| 127 |
return x
|
| 128 |
|
| 129 |
|
|
|
|
| 130 |
class PROTAC_Model(pl.LightningModule):
|
| 131 |
|
| 132 |
def __init__(
|
|
|
|
| 219 |
'''
|
| 220 |
|
| 221 |
# Apply scaling in datasets
|
| 222 |
+
self.scalers = None
|
| 223 |
+
if self.apply_scaling and self.train_dataset is not None:
|
| 224 |
+
self.initialize_scalers()
|
| 225 |
+
|
| 226 |
+
def initialize_scalers(self):
|
| 227 |
+
"""Initialize or reinitialize scalers based on dataset properties."""
|
| 228 |
+
if self.scalers is None:
|
| 229 |
+
use_single_scaler = self.join_embeddings == 'beginning'
|
| 230 |
self.scalers = self.train_dataset.fit_scaling(use_single_scaler)
|
| 231 |
+
self.apply_scalers()
|
| 232 |
+
|
| 233 |
+
def apply_scalers(self):
|
| 234 |
+
"""Apply scalers to all datasets."""
|
| 235 |
+
use_single_scaler = self.join_embeddings == 'beginning'
|
| 236 |
+
if self.train_dataset:
|
| 237 |
self.train_dataset.apply_scaling(self.scalers, use_single_scaler)
|
| 238 |
+
if self.val_dataset:
|
| 239 |
self.val_dataset.apply_scaling(self.scalers, use_single_scaler)
|
| 240 |
+
if self.test_dataset:
|
| 241 |
+
self.test_dataset.apply_scaling(self.scalers, use_single_scaler)
|
| 242 |
|
| 243 |
def forward(self, poi_emb, e3_emb, cell_emb, smiles_emb):
|
| 244 |
return self.model(poi_emb, e3_emb, cell_emb, smiles_emb)
|
|
|
|
| 330 |
batch_size=self.batch_size,
|
| 331 |
shuffle=False,
|
| 332 |
)
|
| 333 |
+
|
| 334 |
+
def on_save_checkpoint(self, checkpoint):
|
| 335 |
+
""" Serialize the scalers to the checkpoint. """
|
| 336 |
+
checkpoint['scalers'] = pickle.dumps(self.scalers)
|
| 337 |
+
|
| 338 |
+
def on_load_checkpoint(self, checkpoint):
|
| 339 |
+
"""Deserialize the scalers from the checkpoint."""
|
| 340 |
+
if 'scalers' in checkpoint:
|
| 341 |
+
self.scalers = pickle.loads(checkpoint['scalers'])
|
| 342 |
+
else:
|
| 343 |
+
self.scalers = None
|
| 344 |
+
if self.apply_scaling:
|
| 345 |
+
if self.scalers is not None:
|
| 346 |
+
# Re-apply scalers to ensure datasets are scaled
|
| 347 |
+
self.apply_scalers()
|
| 348 |
+
else:
|
| 349 |
+
logging.warning("Scalers not found in checkpoint. Consider re-fitting scalers if necessary.")
|
| 350 |
|
| 351 |
|
| 352 |
def train_model(
|
|
|
|
| 452 |
monitor='val_acc',
|
| 453 |
mode='max',
|
| 454 |
verbose=False,
|
| 455 |
+
filename=checkpoint_model_name + '-{epoch}-{val_acc:.2f}-{val_roc_auc:.3f}',
|
| 456 |
))
|
| 457 |
# Define Trainer
|
| 458 |
trainer = pl.Trainer(
|
|
|
|
| 486 |
warnings.simplefilter("ignore")
|
| 487 |
trainer.fit(model)
|
| 488 |
metrics = trainer.validate(model, verbose=False)[0]
|
| 489 |
+
|
| 490 |
+
# Add train metrics to metrics
|
| 491 |
+
|
| 492 |
if test_df is not None:
|
| 493 |
test_metrics = trainer.test(model, verbose=False)[0]
|
| 494 |
metrics.update(test_metrics)
|
|
|
|
| 506 |
Returns:
|
| 507 |
PROTAC_Model: The loaded model.
|
| 508 |
"""
|
| 509 |
+
# NOTE: The `map_locat` argument is automatically handled in newer versions
|
| 510 |
+
# of PyTorch Lightning, but we keep it here for compatibility with older ones.
|
| 511 |
+
model = PROTAC_Model.load_from_checkpoint(
|
| 512 |
+
ckpt_path,
|
| 513 |
+
map_location=torch.device('cpu') if not torch.cuda.is_available() else None,
|
| 514 |
+
)
|
| 515 |
+
# NOTE: The following is left as example for eventually re-applying scaling
|
| 516 |
+
# with other datasets...
|
| 517 |
+
# if model.apply_scaling:
|
| 518 |
+
# model.apply_scalers()
|
| 519 |
model.eval()
|
| 520 |
return model
|
src/run_experiments.py
CHANGED
|
@@ -207,10 +207,11 @@ def get_target_split_indices(active_df: pd.DataFrame, active_col: str, test_spli
|
|
| 207 |
|
| 208 |
def main(
|
| 209 |
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
| 210 |
-
n_trials: int =
|
| 211 |
fast_dev_run: bool = False,
|
| 212 |
-
test_split: float = 0.
|
| 213 |
cv_n_splits: int = 5,
|
|
|
|
| 214 |
run_sklearn: bool = False,
|
| 215 |
):
|
| 216 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
|
@@ -287,7 +288,7 @@ def main(
|
|
| 287 |
n_models_for_test=3,
|
| 288 |
fast_dev_run=fast_dev_run,
|
| 289 |
n_trials=n_trials,
|
| 290 |
-
max_epochs=
|
| 291 |
logger_name=f'logs_{experiment_name}',
|
| 292 |
active_label=active_col,
|
| 293 |
study_filename=f'../reports/study_{experiment_name}.pkl',
|
|
|
|
| 207 |
|
| 208 |
def main(
|
| 209 |
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
| 210 |
+
n_trials: int = 100,
|
| 211 |
fast_dev_run: bool = False,
|
| 212 |
+
test_split: float = 0.1,
|
| 213 |
cv_n_splits: int = 5,
|
| 214 |
+
max_epochs: int = 100,
|
| 215 |
run_sklearn: bool = False,
|
| 216 |
):
|
| 217 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
|
|
|
| 288 |
n_models_for_test=3,
|
| 289 |
fast_dev_run=fast_dev_run,
|
| 290 |
n_trials=n_trials,
|
| 291 |
+
max_epochs=max_epochs,
|
| 292 |
logger_name=f'logs_{experiment_name}',
|
| 293 |
active_label=active_col,
|
| 294 |
study_filename=f'../reports/study_{experiment_name}.pkl',
|
tests/test_pytorch_model.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 7 |
+
|
| 8 |
+
from protac_degradation_predictor import PROTAC_Model, PROTAC_Predictor
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_protac_model():
|
| 14 |
+
model = PROTAC_Model(hidden_dim=128)
|
| 15 |
+
assert model.hidden_dim == 128
|
| 16 |
+
assert model.smiles_emb_dim == 224
|
| 17 |
+
assert model.poi_emb_dim == 1024
|
| 18 |
+
assert model.e3_emb_dim == 1024
|
| 19 |
+
assert model.cell_emb_dim == 768
|
| 20 |
+
assert model.batch_size == 32
|
| 21 |
+
assert model.learning_rate == 0.001
|
| 22 |
+
assert model.dropout == 0.2
|
| 23 |
+
assert model.join_embeddings == 'concat'
|
| 24 |
+
assert model.train_dataset is None
|
| 25 |
+
assert model.val_dataset is None
|
| 26 |
+
assert model.test_dataset is None
|
| 27 |
+
assert model.disabled_embeddings == []
|
| 28 |
+
assert model.apply_scaling == False
|
| 29 |
+
|
| 30 |
+
def test_protac_predictor():
|
| 31 |
+
predictor = PROTAC_Predictor(hidden_dim=128)
|
| 32 |
+
assert predictor.hidden_dim == 128
|
| 33 |
+
assert predictor.smiles_emb_dim == 224
|
| 34 |
+
assert predictor.poi_emb_dim == 1024
|
| 35 |
+
assert predictor.e3_emb_dim == 1024
|
| 36 |
+
assert predictor.cell_emb_dim == 768
|
| 37 |
+
assert predictor.join_embeddings == 'concat'
|
| 38 |
+
assert predictor.disabled_embeddings == []
|
| 39 |
+
|
| 40 |
+
def test_load_model(caplog):
|
| 41 |
+
caplog.set_level(logging.WARNING)
|
| 42 |
+
|
| 43 |
+
model = PROTAC_Model.load_from_checkpoint(
|
| 44 |
+
'data/test_model.ckpt',
|
| 45 |
+
map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
|
| 46 |
+
)
|
| 47 |
+
# apply_scaling: true
|
| 48 |
+
# batch_size: 8
|
| 49 |
+
# cell_emb_dim: 768
|
| 50 |
+
# disabled_embeddings: []
|
| 51 |
+
# dropout: 0.1498104322091649
|
| 52 |
+
# e3_emb_dim: 1024
|
| 53 |
+
# hidden_dim: 768
|
| 54 |
+
# join_embeddings: concat
|
| 55 |
+
# learning_rate: 4.881387978425994e-05
|
| 56 |
+
# poi_emb_dim: 1024
|
| 57 |
+
# smiles_emb_dim: 224
|
| 58 |
+
assert model.hidden_dim == 768
|
| 59 |
+
assert model.smiles_emb_dim == 224
|
| 60 |
+
assert model.poi_emb_dim == 1024
|
| 61 |
+
assert model.e3_emb_dim == 1024
|
| 62 |
+
assert model.cell_emb_dim == 768
|
| 63 |
+
assert model.batch_size == 8
|
| 64 |
+
assert model.learning_rate == 4.881387978425994e-05
|
| 65 |
+
assert model.dropout == 0.1498104322091649
|
| 66 |
+
assert model.join_embeddings == 'concat'
|
| 67 |
+
assert model.disabled_embeddings == []
|
| 68 |
+
assert model.apply_scaling == True
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def test_checkpoint_file():
|
| 72 |
+
checkpoint = torch.load(
|
| 73 |
+
'data/test_model.ckpt',
|
| 74 |
+
map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
|
| 75 |
+
)
|
| 76 |
+
print(checkpoint.keys())
|
| 77 |
+
print(checkpoint["hyper_parameters"])
|
| 78 |
+
print([k for k, v in checkpoint["state_dict"].items()])
|
| 79 |
+
|
| 80 |
+
pytest.main()
|