ribesstefano commited on
Commit
4e1d3f6
·
1 Parent(s): 48aea13

Added ablation study with randomly sampled vectors + Started working on LightningDataModule wrapper

Browse files
Files changed (35) hide show
  1. plots/Active_Dmax_0.6_pDC50_6.0_metrics.pdf +0 -0
  2. plots/Active_Dmax_0.6_pDC50_6.0_metrics_majority_vote.pdf +0 -0
  3. plots/ablation_study_random.pdf +0 -0
  4. plots/ablation_study_tanimoto.pdf +0 -0
  5. plots/ablation_study_uniprot.pdf +0 -0
  6. plots/old_Active_Dmax_0.6_pDC50_6.0_metrics.pdf +0 -0
  7. plots/training_metrics_random_best_model_n0.pdf +0 -0
  8. plots/training_metrics_random_best_model_n1.pdf +0 -0
  9. plots/training_metrics_random_best_model_n2.pdf +0 -0
  10. plots/training_metrics_random_cv_model_fold0.pdf +0 -0
  11. plots/training_metrics_random_cv_model_fold1.pdf +0 -0
  12. plots/training_metrics_random_cv_model_fold2.pdf +0 -0
  13. plots/training_metrics_random_cv_model_fold3.pdf +0 -0
  14. plots/training_metrics_random_cv_model_fold4.pdf +0 -0
  15. plots/training_metrics_tanimoto_best_model_n0.pdf +0 -0
  16. plots/training_metrics_tanimoto_best_model_n1.pdf +0 -0
  17. plots/training_metrics_tanimoto_best_model_n2.pdf +0 -0
  18. plots/training_metrics_tanimoto_cv_model_fold0.pdf +0 -0
  19. plots/training_metrics_tanimoto_cv_model_fold1.pdf +0 -0
  20. plots/training_metrics_tanimoto_cv_model_fold2.pdf +0 -0
  21. plots/training_metrics_tanimoto_cv_model_fold3.pdf +0 -0
  22. plots/training_metrics_tanimoto_cv_model_fold4.pdf +0 -0
  23. plots/training_metrics_uniprot_best_model_n0.pdf +0 -0
  24. plots/training_metrics_uniprot_best_model_n1.pdf +0 -0
  25. plots/training_metrics_uniprot_best_model_n2.pdf +0 -0
  26. plots/training_metrics_uniprot_cv_model_fold0.pdf +0 -0
  27. plots/training_metrics_uniprot_cv_model_fold1.pdf +0 -0
  28. plots/training_metrics_uniprot_cv_model_fold2.pdf +0 -0
  29. plots/training_metrics_uniprot_cv_model_fold3.pdf +0 -0
  30. plots/training_metrics_uniprot_cv_model_fold4.pdf +0 -0
  31. protac_degradation_predictor/optuna_utils.py +96 -52
  32. protac_degradation_predictor/protac_dataset.py +383 -14
  33. protac_degradation_predictor/pytorch_models.py +48 -31
  34. src/plot_experiment_results.py +49 -31
  35. src/run_experiments.py +1 -1
plots/Active_Dmax_0.6_pDC50_6.0_metrics.pdf ADDED
Binary file (16.8 kB). View file
 
plots/Active_Dmax_0.6_pDC50_6.0_metrics_majority_vote.pdf ADDED
Binary file (16.7 kB). View file
 
plots/ablation_study_random.pdf ADDED
Binary file (15.8 kB). View file
 
plots/ablation_study_tanimoto.pdf ADDED
Binary file (15.3 kB). View file
 
plots/ablation_study_uniprot.pdf ADDED
Binary file (15.6 kB). View file
 
plots/old_Active_Dmax_0.6_pDC50_6.0_metrics.pdf ADDED
Binary file (16.8 kB). View file
 
plots/training_metrics_random_best_model_n0.pdf ADDED
Binary file (16.9 kB). View file
 
plots/training_metrics_random_best_model_n1.pdf ADDED
Binary file (16.7 kB). View file
 
plots/training_metrics_random_best_model_n2.pdf ADDED
Binary file (16.9 kB). View file
 
plots/training_metrics_random_cv_model_fold0.pdf ADDED
Binary file (17.4 kB). View file
 
plots/training_metrics_random_cv_model_fold1.pdf ADDED
Binary file (17.3 kB). View file
 
plots/training_metrics_random_cv_model_fold2.pdf ADDED
Binary file (17.1 kB). View file
 
plots/training_metrics_random_cv_model_fold3.pdf ADDED
Binary file (17.3 kB). View file
 
plots/training_metrics_random_cv_model_fold4.pdf ADDED
Binary file (17.6 kB). View file
 
plots/training_metrics_tanimoto_best_model_n0.pdf ADDED
Binary file (16.9 kB). View file
 
plots/training_metrics_tanimoto_best_model_n1.pdf ADDED
Binary file (16.9 kB). View file
 
plots/training_metrics_tanimoto_best_model_n2.pdf ADDED
Binary file (16.6 kB). View file
 
plots/training_metrics_tanimoto_cv_model_fold0.pdf ADDED
Binary file (17.1 kB). View file
 
plots/training_metrics_tanimoto_cv_model_fold1.pdf ADDED
Binary file (16.8 kB). View file
 
plots/training_metrics_tanimoto_cv_model_fold2.pdf ADDED
Binary file (17.3 kB). View file
 
plots/training_metrics_tanimoto_cv_model_fold3.pdf ADDED
Binary file (17 kB). View file
 
plots/training_metrics_tanimoto_cv_model_fold4.pdf ADDED
Binary file (17.1 kB). View file
 
plots/training_metrics_uniprot_best_model_n0.pdf ADDED
Binary file (16.3 kB). View file
 
plots/training_metrics_uniprot_best_model_n1.pdf ADDED
Binary file (15.9 kB). View file
 
plots/training_metrics_uniprot_best_model_n2.pdf ADDED
Binary file (16.1 kB). View file
 
plots/training_metrics_uniprot_cv_model_fold0.pdf ADDED
Binary file (17.2 kB). View file
 
plots/training_metrics_uniprot_cv_model_fold1.pdf ADDED
Binary file (17.7 kB). View file
 
plots/training_metrics_uniprot_cv_model_fold2.pdf ADDED
Binary file (17 kB). View file
 
plots/training_metrics_uniprot_cv_model_fold3.pdf ADDED
Binary file (16.5 kB). View file
 
plots/training_metrics_uniprot_cv_model_fold4.pdf ADDED
Binary file (17.1 kB). View file
 
protac_degradation_predictor/optuna_utils.py CHANGED
@@ -2,7 +2,13 @@ import os
2
  from typing import Literal, List, Tuple, Optional, Dict
3
  import logging
4
 
5
- from .pytorch_models import train_model, PROTAC_Model
 
 
 
 
 
 
6
  from .sklearn_models import (
7
  train_sklearn_model,
8
  suggest_random_forest,
@@ -83,6 +89,26 @@ def get_dataframe_stats(
83
  return stats
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def pytorch_model_objective(
87
  trial: optuna.Trial,
88
  protein2embedding: Dict,
@@ -198,18 +224,7 @@ def pytorch_model_objective(
198
 
199
  # Get the majority vote for the test predictions
200
  if test_df is not None and not fast_dev_run:
201
- # Get the majority vote for the test predictions
202
- test_preds = torch.stack(test_preds)
203
- test_preds, _ = torch.mode(test_preds, dim=0)
204
- y = torch.tensor(test_df[active_label].tolist())
205
- # Measure the test accuracy and ROC AUC
206
- majority_vote_metrics = {
207
- 'test_acc': Accuracy(task='binary')(test_preds, y).item(),
208
- 'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
209
- 'test_precision': Precision(task='binary')(test_preds, y).item(),
210
- 'test_recall': Recall(task='binary')(test_preds, y).item(),
211
- 'test_f1': F1Score(task='binary')(test_preds, y).item(),
212
- }
213
  majority_vote_metrics.update(get_dataframe_stats(train_df, val_df, test_df, active_label))
214
  trial.set_user_attr('majority_vote_metrics', majority_vote_metrics)
215
  logging.info(f'Majority vote metrics: {majority_vote_metrics}')
@@ -278,6 +293,7 @@ def hyperparameter_tuning_and_training(
278
  study = joblib.load(study_filename)
279
  study_loaded = True
280
  logging.info(f'Loaded study from {study_filename}')
 
281
 
282
  if not study_loaded or force_study:
283
  study.optimize(
@@ -333,12 +349,13 @@ def hyperparameter_tuning_and_training(
333
  )
334
 
335
  # Retrain N models with the best hyperparameters (measure model uncertainty)
 
336
  test_report = []
337
  test_preds = []
338
  dfs_stats = get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label)
339
  for i in range(n_models_for_test):
340
  pl.seed_everything(42 + i + 1)
341
- _, trainer, metrics, test_pred = train_model(
342
  protein2embedding=protein2embedding,
343
  cell2embedding=cell2embedding,
344
  smiles2fp=smiles2fp,
@@ -366,22 +383,12 @@ def hyperparameter_tuning_and_training(
366
 
367
  test_report.append(metrics.copy())
368
  test_preds.append(test_pred)
 
369
  test_report = pd.DataFrame(test_report)
370
 
371
  # Get the majority vote for the test predictions
372
  if not fast_dev_run:
373
- test_preds = torch.stack(test_preds)
374
- test_preds, _ = torch.mode(test_preds, dim=0)
375
- y = torch.tensor(test_df[active_label].tolist())
376
- # Measure the test accuracy and ROC AUC
377
- majority_vote_metrics = {
378
- 'cv_models': False,
379
- 'test_acc': Accuracy(task='binary')(test_preds, y).item(),
380
- 'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
381
- 'test_precision': Precision(task='binary')(test_preds, y).item(),
382
- 'test_recall': Recall(task='binary')(test_preds, y).item(),
383
- 'test_f1': F1Score(task='binary')(test_preds, y).item(),
384
- }
385
  majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label))
386
  majority_vote_metrics_cv = study.best_trial.user_attrs['majority_vote_metrics']
387
  majority_vote_metrics_cv['cv_models'] = True
@@ -408,34 +415,71 @@ def hyperparameter_tuning_and_training(
408
  logging.info('-' * 100)
409
  logging.info(f'Ablation study with disabled embeddings: {disabled_embeddings}')
410
  logging.info('-' * 100)
411
- _, _, metrics = train_model(
412
- protein2embedding=protein2embedding,
413
- cell2embedding=cell2embedding,
414
- smiles2fp=smiles2fp,
415
- train_df=train_val_df,
416
- val_df=test_df,
417
- fast_dev_run=fast_dev_run,
418
- active_label=active_label,
419
- max_epochs=max_epochs,
420
- use_logger=False,
421
- logger_save_dir=logger_save_dir,
422
- logger_name=f'{logger_name}_disabled-{"-".join(disabled_embeddings)}',
423
- disabled_embeddings=disabled_embeddings,
424
- batch_size=128,
425
- apply_scaling=True,
426
- **study.best_params,
427
- )
428
- # Rename the keys in the metrics dictionary
429
- metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
430
- metrics['disabled_embeddings'] = 'disabled ' + ' '.join(disabled_embeddings)
431
- metrics['model_type'] = 'Pytorch'
432
- metrics.update(dfs_stats)
 
 
 
 
 
433
 
434
- # Add the training metrics
435
- train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
436
- metrics.update(train_metrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
 
438
- ablation_report.append(metrics.copy())
439
  ablation_report = pd.DataFrame(ablation_report)
440
 
441
  # Add a column with the split_type to all reports
 
2
  from typing import Literal, List, Tuple, Optional, Dict
3
  import logging
4
 
5
+ from .pytorch_models import (
6
+ train_model,
7
+ PROTAC_Model,
8
+ evaluate_model,
9
+ )
10
+ from .protac_dataset import get_datasets
11
+
12
  from .sklearn_models import (
13
  train_sklearn_model,
14
  suggest_random_forest,
 
89
  return stats
90
 
91
 
92
+ def get_majority_vote_metrics(
93
+ test_preds: List,
94
+ test_df: pd.DataFrame,
95
+ active_label: str = 'Active',
96
+ ) -> Dict:
97
+ """ Get the majority vote metrics. """
98
+ test_preds = torch.stack(test_preds)
99
+ test_preds, _ = torch.mode(test_preds, dim=0)
100
+ y = torch.tensor(test_df[active_label].tolist())
101
+ # Measure the test accuracy and ROC AUC
102
+ majority_vote_metrics = {
103
+ 'test_acc': Accuracy(task='binary')(test_preds, y).item(),
104
+ 'test_roc_auc': AUROC(task='binary')(test_preds, y).item(),
105
+ 'test_precision': Precision(task='binary')(test_preds, y).item(),
106
+ 'test_recall': Recall(task='binary')(test_preds, y).item(),
107
+ 'test_f1': F1Score(task='binary')(test_preds, y).item(),
108
+ }
109
+ return majority_vote_metrics
110
+
111
+
112
  def pytorch_model_objective(
113
  trial: optuna.Trial,
114
  protein2embedding: Dict,
 
224
 
225
  # Get the majority vote for the test predictions
226
  if test_df is not None and not fast_dev_run:
227
+ majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
 
 
 
 
 
 
 
 
 
 
 
228
  majority_vote_metrics.update(get_dataframe_stats(train_df, val_df, test_df, active_label))
229
  trial.set_user_attr('majority_vote_metrics', majority_vote_metrics)
230
  logging.info(f'Majority vote metrics: {majority_vote_metrics}')
 
293
  study = joblib.load(study_filename)
294
  study_loaded = True
295
  logging.info(f'Loaded study from {study_filename}')
296
+ logging.info(f'Study best params: {study.best_params}')
297
 
298
  if not study_loaded or force_study:
299
  study.optimize(
 
349
  )
350
 
351
  # Retrain N models with the best hyperparameters (measure model uncertainty)
352
+ best_models = []
353
  test_report = []
354
  test_preds = []
355
  dfs_stats = get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label)
356
  for i in range(n_models_for_test):
357
  pl.seed_everything(42 + i + 1)
358
+ model, trainer, metrics, test_pred = train_model(
359
  protein2embedding=protein2embedding,
360
  cell2embedding=cell2embedding,
361
  smiles2fp=smiles2fp,
 
383
 
384
  test_report.append(metrics.copy())
385
  test_preds.append(test_pred)
386
+ best_models.append({'model': model, 'trainer': trainer})
387
  test_report = pd.DataFrame(test_report)
388
 
389
  # Get the majority vote for the test predictions
390
  if not fast_dev_run:
391
+ majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
 
 
 
 
 
 
 
 
 
 
 
392
  majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label))
393
  majority_vote_metrics_cv = study.best_trial.user_attrs['majority_vote_metrics']
394
  majority_vote_metrics_cv['cv_models'] = True
 
415
  logging.info('-' * 100)
416
  logging.info(f'Ablation study with disabled embeddings: {disabled_embeddings}')
417
  logging.info('-' * 100)
418
+ disabled_embeddings_str = 'disabled ' + ' '.join(disabled_embeddings)
419
+ test_preds = []
420
+ for i, model_trainer in enumerate(best_models):
421
+ logging.info(f'Evaluating model n.{i} on {disabled_embeddings_str}.')
422
+ model = model_trainer['model']
423
+ trainer = model_trainer['trainer']
424
+ _, test_ds, _ = get_datasets(
425
+ protein2embedding=protein2embedding,
426
+ cell2embedding=cell2embedding,
427
+ smiles2fp=smiles2fp,
428
+ train_df=train_val_df,
429
+ val_df=test_df,
430
+ disabled_embeddings=disabled_embeddings,
431
+ active_label=active_label,
432
+ scaler=model.scalers,
433
+ use_single_scaler=model.join_embeddings == 'beginning',
434
+ )
435
+ ret = evaluate_model(model, trainer, test_ds, batch_size=128)
436
+ # NOTE: We are passing the test set as the validation set argument
437
+ # Rename the keys in the metrics dictionary
438
+ test_preds.append(ret['val_pred'])
439
+ ret['val_metrics'] = {k.replace('val_', 'test_'): v for k, v in ret['val_metrics'].items()}
440
+ ret['val_metrics'].update(dfs_stats)
441
+ ret['val_metrics']['majority_vote'] = False
442
+ ret['val_metrics']['model_type'] = 'Pytorch'
443
+ ret['val_metrics']['disabled_embeddings'] = disabled_embeddings_str
444
+ ablation_report.append(ret['val_metrics'].copy())
445
 
446
+ # Get the majority vote for the test predictions
447
+ if not fast_dev_run:
448
+ majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
449
+ majority_vote_metrics.update(dfs_stats)
450
+ majority_vote_metrics['majority_vote'] = True
451
+ majority_vote_metrics['model_type'] = 'Pytorch'
452
+ majority_vote_metrics['disabled_embeddings'] = disabled_embeddings_str
453
+ ablation_report.append(majority_vote_metrics.copy())
454
+
455
+ # _, _, metrics = train_model(
456
+ # protein2embedding=protein2embedding,
457
+ # cell2embedding=cell2embedding,
458
+ # smiles2fp=smiles2fp,
459
+ # train_df=train_val_df,
460
+ # val_df=test_df,
461
+ # fast_dev_run=fast_dev_run,
462
+ # active_label=active_label,
463
+ # max_epochs=max_epochs,
464
+ # use_logger=False,
465
+ # logger_save_dir=logger_save_dir,
466
+ # logger_name=f'{logger_name}_disabled-{"-".join(disabled_embeddings)}',
467
+ # disabled_embeddings=disabled_embeddings,
468
+ # batch_size=128,
469
+ # apply_scaling=True,
470
+ # **study.best_params,
471
+ # )
472
+ # # Rename the keys in the metrics dictionary
473
+ # metrics = {k.replace('val_', 'test_'): v for k, v in metrics.items()}
474
+ # metrics['disabled_embeddings'] = disabled_embeddings_str
475
+ # metrics['model_type'] = 'Pytorch'
476
+ # metrics.update(dfs_stats)
477
+
478
+ # # Add the training metrics
479
+ # train_metrics = {m: v.item() for m, v in trainer.callback_metrics.items() if 'train' in m}
480
+ # metrics.update(train_metrics)
481
+ # ablation_report.append(metrics.copy())
482
 
 
483
  ablation_report = pd.DataFrame(ablation_report)
484
 
485
  # Add a column with the split_type to all reports
protac_degradation_predictor/protac_dataset.py CHANGED
@@ -1,13 +1,26 @@
1
  from typing import Literal, List, Tuple, Optional, Dict
 
2
 
3
- from torch.utils.data import Dataset
4
- import numpy as np
 
 
 
 
 
 
5
  from imblearn.over_sampling import SMOTE, ADASYN
 
 
6
  import pandas as pd
7
- from sklearn.preprocessing import StandardScaler
 
 
 
8
 
9
 
10
  class PROTAC_Dataset(Dataset):
 
11
  def __init__(
12
  self,
13
  protac_df: pd.DataFrame,
@@ -17,6 +30,9 @@ class PROTAC_Dataset(Dataset):
17
  use_smote: bool = False,
18
  oversampler: Optional[SMOTE | ADASYN] = None,
19
  active_label: str = 'Active',
 
 
 
20
  ):
21
  """ Initialize the PROTAC dataset
22
 
@@ -28,13 +44,17 @@ class PROTAC_Dataset(Dataset):
28
  use_smote (bool): Whether to use SMOTE for oversampling
29
  use_ored_activity (bool): Whether to use the 'Active - OR' column
30
  """
31
- # Filter out examples with NaN in active_col column
32
- self.data = protac_df # [~protac_df[active_col].isna()]
33
  self.protein2embedding = protein2embedding
34
  self.cell2embedding = cell2embedding
35
  self.smiles2fp = smiles2fp
36
  self.active_label = active_label
37
- self.use_single_scaler = None
 
 
 
 
38
 
39
  self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
40
  self.protein_emb_dim = protein2embedding[list(
@@ -115,15 +135,15 @@ class PROTAC_Dataset(Dataset):
115
  """
116
  if use_single_scaler:
117
  self.use_single_scaler = True
118
- scaler = StandardScaler(**scaler_kwargs)
119
  embeddings = np.hstack([
120
  np.array(self.data['Smiles'].tolist()),
121
  np.array(self.data['Uniprot'].tolist()),
122
  np.array(self.data['E3 Ligase Uniprot'].tolist()),
123
  np.array(self.data['Cell Line Identifier'].tolist()),
124
  ])
125
- scaler.fit(embeddings)
126
- return scaler
127
  else:
128
  self.use_single_scaler = False
129
  scalers = {}
@@ -137,6 +157,7 @@ class PROTAC_Dataset(Dataset):
137
  scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy()))
138
  scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy()))
139
 
 
140
  return scalers
141
 
142
  def apply_scaling(self, scalers: dict, use_single_scaler: bool = False):
@@ -190,11 +211,359 @@ class PROTAC_Dataset(Dataset):
190
  return len(self.data)
191
 
192
  def __getitem__(self, idx):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  elem = {
194
- 'smiles_emb': self.data['Smiles'].iloc[idx],
195
- 'poi_emb': self.data['Uniprot'].iloc[idx],
196
- 'e3_emb': self.data['E3 Ligase Uniprot'].iloc[idx],
197
- 'cell_emb': self.data['Cell Line Identifier'].iloc[idx],
198
  'active': self.data[self.active_label].iloc[idx],
199
  }
200
- return elem
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Literal, List, Tuple, Optional, Dict
2
+ from collections import defaultdict
3
 
4
+ from .data_utils import (
5
+ get_fingerprint,
6
+ is_active,
7
+ load_cell2embedding,
8
+ load_protein2embedding,
9
+ )
10
+
11
+ from torch.utils.data import Dataset, DataLoader
12
  from imblearn.over_sampling import SMOTE, ADASYN
13
+ from sklearn.preprocessing import StandardScaler, OrdinalEncoder
14
+ import numpy as np
15
  import pandas as pd
16
+ import pytorch_lightning as pl
17
+ from rdkit import Chem
18
+ from rdkit.Chem import AllChem
19
+ from rdkit import DataStructs
20
 
21
 
22
  class PROTAC_Dataset(Dataset):
23
+
24
  def __init__(
25
  self,
26
  protac_df: pd.DataFrame,
 
30
  use_smote: bool = False,
31
  oversampler: Optional[SMOTE | ADASYN] = None,
32
  active_label: str = 'Active',
33
+ disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
34
+ scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
35
+ use_single_scaler: Optional[bool] = None,
36
  ):
37
  """ Initialize the PROTAC dataset
38
 
 
44
  use_smote (bool): Whether to use SMOTE for oversampling
45
  use_ored_activity (bool): Whether to use the 'Active - OR' column
46
  """
47
+ # Filter out examples with NaN in active_label column
48
+ self.data = protac_df # [~protac_df[active_label].isna()]
49
  self.protein2embedding = protein2embedding
50
  self.cell2embedding = cell2embedding
51
  self.smiles2fp = smiles2fp
52
  self.active_label = active_label
53
+ self.disabled_embeddings = disabled_embeddings
54
+
55
+ # Scaling parameters
56
+ self.scaler = scaler
57
+ self.use_single_scaler = use_single_scaler
58
 
59
  self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
60
  self.protein_emb_dim = protein2embedding[list(
 
135
  """
136
  if use_single_scaler:
137
  self.use_single_scaler = True
138
+ self.scaler = StandardScaler(**scaler_kwargs)
139
  embeddings = np.hstack([
140
  np.array(self.data['Smiles'].tolist()),
141
  np.array(self.data['Uniprot'].tolist()),
142
  np.array(self.data['E3 Ligase Uniprot'].tolist()),
143
  np.array(self.data['Cell Line Identifier'].tolist()),
144
  ])
145
+ self.scaler.fit(embeddings)
146
+ return self.scaler
147
  else:
148
  self.use_single_scaler = False
149
  scalers = {}
 
157
  scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy()))
158
  scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy()))
159
 
160
+ self.scaler = scalers
161
  return scalers
162
 
163
  def apply_scaling(self, scalers: dict, use_single_scaler: bool = False):
 
211
  return len(self.data)
212
 
213
  def __getitem__(self, idx):
214
+ if 'smiles' in self.disabled_embeddings:
215
+ # Uniformly sample a binary vector for the fingerprint
216
+ smiles_emb = np.random.randint(0, 2, size=self.smiles_emb_dim).astype(np.float32)
217
+ if not self.use_single_scaler and self.scaler is not None:
218
+ smiles_emb = smiles_emb[np.newaxis, :]
219
+ smiles_emb = self.scaler['Smiles'].transform(smiles_emb).flatten()
220
+ else:
221
+ smiles_emb = self.data['Smiles'].iloc[idx]
222
+
223
+ if 'poi' in self.disabled_embeddings:
224
+ # Uniformly sample a vector for the protein
225
+ poi_emb = np.random.rand(self.protein_emb_dim).astype(np.float32)
226
+ if not self.use_single_scaler and self.scaler is not None:
227
+ poi_emb = poi_emb[np.newaxis, :]
228
+ poi_emb = self.scaler['Uniprot'].transform(poi_emb).flatten()
229
+ else:
230
+ poi_emb = self.data['Uniprot'].iloc[idx]
231
+
232
+ if 'e3' in self.disabled_embeddings:
233
+ # Uniformly sample a vector for the E3 ligase
234
+ e3_emb = np.random.rand(self.protein_emb_dim).astype(np.float32)
235
+ if not self.use_single_scaler and self.scaler is not None:
236
+ # Add extra dimension for compatibility with the scaler
237
+ e3_emb = e3_emb[np.newaxis, :]
238
+ e3_emb = self.scaler['E3 Ligase Uniprot'].transform(e3_emb)
239
+ e3_emb = e3_emb.flatten()
240
+ else:
241
+ e3_emb = self.data['E3 Ligase Uniprot'].iloc[idx]
242
+
243
+ if 'cell' in self.disabled_embeddings:
244
+ # Uniformly sample a vector for the cell line
245
+ cell_emb = np.random.rand(self.cell_emb_dim).astype(np.float32)
246
+ if not self.use_single_scaler and self.scaler is not None:
247
+ cell_emb = cell_emb[np.newaxis, :]
248
+ cell_emb = self.scaler['Cell Line Identifier'].transform(cell_emb).flatten()
249
+ else:
250
+ cell_emb = self.data['Cell Line Identifier'].iloc[idx]
251
+
252
  elem = {
253
+ 'smiles_emb': smiles_emb,
254
+ 'poi_emb': poi_emb,
255
+ 'e3_emb': e3_emb,
256
+ 'cell_emb': cell_emb,
257
  'active': self.data[self.active_label].iloc[idx],
258
  }
259
+ return elem
260
+
261
+
262
+ def get_datasets(
263
+ train_df: pd.DataFrame,
264
+ val_df: pd.DataFrame,
265
+ test_df: Optional[pd.DataFrame] = None,
266
+ protein2embedding: Dict = None,
267
+ cell2embedding: Dict = None,
268
+ smiles2fp: Dict = None,
269
+ use_smote: bool = True,
270
+ smote_k_neighbors: int = 5,
271
+ active_label: str = 'Active',
272
+ disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
273
+ scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
274
+ use_single_scaler: Optional[bool] = None,
275
+ ) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
276
+ """ Get the datasets for training the PROTAC model. """
277
+ oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
278
+ train_ds = PROTAC_Dataset(
279
+ train_df,
280
+ protein2embedding,
281
+ cell2embedding,
282
+ smiles2fp,
283
+ use_smote=use_smote,
284
+ oversampler=oversampler if use_smote else None,
285
+ active_label=active_label,
286
+ disabled_embeddings=disabled_embeddings,
287
+ scaler=scaler,
288
+ use_single_scaler=use_single_scaler,
289
+ )
290
+ val_ds = PROTAC_Dataset(
291
+ val_df,
292
+ protein2embedding,
293
+ cell2embedding,
294
+ smiles2fp,
295
+ active_label=active_label,
296
+ disabled_embeddings=disabled_embeddings,
297
+ scaler=train_ds.scaler if train_ds.scaler is not None else scaler,
298
+ use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
299
+ )
300
+ if test_df is not None:
301
+ test_ds = PROTAC_Dataset(
302
+ test_df,
303
+ protein2embedding,
304
+ cell2embedding,
305
+ smiles2fp,
306
+ active_label=active_label,
307
+ disabled_embeddings=disabled_embeddings,
308
+ scaler=train_ds.scaler if train_ds.scaler is not None else scaler,
309
+ use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
310
+ )
311
+ else:
312
+ test_ds = None
313
+ return train_ds, val_ds, test_ds
314
+
315
+
316
+ class PROTAC_DataModule(pl.LightningDataModule):
317
+ """ PyTorch Lightning DataModule for the PROTAC dataset.
318
+
319
+ TODO: Work in progress. It would be nice to wrap all information into a
320
+ single class, but it is not clear how to do it yet due to cross-validation
321
+ and the need to split the data into training, validation, and test sets
322
+ accordingly.
323
+
324
+ Args:
325
+ protac_csv_filepath (str): The path to the PROTAC CSV file.
326
+ protein2embedding_filepath (str): The path to the protein to embedding dictionary.
327
+ cell2embedding_filepath (str): The path to the cell line to embedding dictionary.
328
+ pDC50_threshold (float): The threshold for the pDC50 value to consider a PROTAC active.
329
+ Dmax_threshold (float): The threshold for the Dmax value to consider a PROTAC active.
330
+ use_smote (bool): Whether to use SMOTE for oversampling.
331
+ smote_k_neighbors (int): The number of neighbors to use for SMOTE.
332
+ active_label (str): The column containing the active/inactive information.
333
+ disabled_embeddings (list): The list of embeddings to disable.
334
+ scaler (StandardScaler | dict): The scaler to use for the embeddings.
335
+ use_single_scaler (bool): Whether to use a single scaler for all features.
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ protac_csv_filepath: str,
341
+ protein2embedding_filepath: str,
342
+ cell2embedding_filepath: str,
343
+ pDC50_threshold: float = 6.0,
344
+ Dmax_threshold: float = 0.6,
345
+ use_smote: bool = True,
346
+ smote_k_neighbors: int = 5,
347
+ active_label: str = 'Active',
348
+ disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
349
+ scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
350
+ use_single_scaler: Optional[bool] = None,
351
+ ):
352
+ super(PROTAC_DataModule, self).__init__()
353
+
354
+ # Load the PROTAC dataset
355
+ self.protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
356
+ # Map E3 Ligase Iap to IAP
357
+ self.protac_df['E3 Ligase'] = self.protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
358
+ self.protac_df[active_label] = self.protac_df.apply(
359
+ lambda x: is_active(
360
+ x['DC50 (nM)'],
361
+ x['Dmax (%)'],
362
+ pDC50_threshold=pDC50_threshold,
363
+ Dmax_threshold=Dmax_threshold,
364
+ ),
365
+ axis=1,
366
+ )
367
+ self.smiles2fp, self.protac_df = self.get_smiles2fp_and_avg_tanimoto(self.protac_df)
368
+ self.active_df = self.protac_df[self.protac_df[active_label].notna()].copy()
369
+
370
+ # Load embedding dictionaries
371
+ self.protein2embedding = load_protein2embedding(protein2embedding_filepath)
372
+ self.cell2embedding = load_cell2embedding(cell2embedding_filepath)
373
+
374
+
375
+ def setup(self, stage: str):
376
+ self.train_ds, self.val_ds, self.test_ds = get_datasets(
377
+ self.train_df,
378
+ self.val_df,
379
+ self.test_df,
380
+ self.protein2embedding,
381
+ self.cell2embedding,
382
+ self.smiles2fp,
383
+ use_smote=self.use_smote,
384
+ smote_k_neighbors=self.smote_k_neighbors,
385
+ active_label=self.active_label,
386
+ disabled_embeddings=self.disabled_embeddings,
387
+ scaler=self.scaler,
388
+ use_single_scaler=self.use_single_scaler,
389
+ )
390
+
391
+ def train_dataloader(self):
392
+ return DataLoader(self.train_ds, batch_size=32, shuffle=True)
393
+
394
+ def val_dataloader(self):
395
+ return DataLoader(self.val_ds, batch_size=32)
396
+
397
+ def test_dataloader(self):
398
+ return DataLoader(self.test_ds, batch_size=32)
399
+
400
+ @staticmethod
401
+ def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:
402
+ """ Get the indices of the test set using a random split.
403
+
404
+ Args:
405
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
406
+ test_split (float): The percentage of the active PROTACs to use as the test set.
407
+
408
+ Returns:
409
+ pd.Index: The indices of the test set.
410
+ """
411
+ return active_df.sample(frac=test_split, random_state=42).index
412
+
413
+ @staticmethod
414
+ def get_e3_ligase_split_indices(active_df: pd.DataFrame) -> pd.Index:
415
+ """ Get the indices of the test set using the E3 ligase split.
416
+
417
+ Args:
418
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
419
+
420
+ Returns:
421
+ pd.Index: The indices of the test set.
422
+ """
423
+ encoder = OrdinalEncoder()
424
+ active_df['E3 Group'] = encoder.fit_transform(active_df[['E3 Ligase']]).astype(int)
425
+ test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
426
+ return test_df.index
427
+
428
+ @staticmethod
429
+ def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple:
430
+ """ Get the SMILES to fingerprint dictionary and the average Tanimoto similarity.
431
+
432
+ Args:
433
+ protac_df (pd.DataFrame): The DataFrame containing the PROTACs.
434
+
435
+ Returns:
436
+ tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
437
+ """
438
+ unique_smiles = protac_df['Smiles'].unique().tolist()
439
+
440
+ smiles2fp = {}
441
+ for smiles in unique_smiles:
442
+ smiles2fp[smiles] = get_fingerprint(smiles)
443
+
444
+ tanimoto_matrix = defaultdict(list)
445
+ fps = list(smiles2fp.values())
446
+
447
+ # Compute all-against-all Tanimoto similarity using BulkTanimotoSimilarity
448
+ for i, (smiles1, fp1) in enumerate(zip(unique_smiles, fps)):
449
+ similarities = DataStructs.BulkTanimotoSimilarity(fp1, fps[i:]) # Only compute for i to end, avoiding duplicates
450
+ for j, similarity in enumerate(similarities):
451
+ distance = 1 - similarity
452
+ tanimoto_matrix[smiles1].append(distance) # Store as distance
453
+ if i != i + j:
454
+ tanimoto_matrix[unique_smiles[i + j]].append(distance) # Symmetric filling
455
+
456
+ # Calculate average Tanimoto distance for each unique SMILES
457
+ avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
458
+ protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
459
+
460
+ smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
461
+
462
+ return smiles2fp, protac_df
463
+
464
+ @staticmethod
465
+ def get_tanimoto_split_indices(
466
+ active_df: pd.DataFrame,
467
+ active_label: str,
468
+ test_split: float,
469
+ n_bins_tanimoto: int = 200,
470
+ ) -> pd.Index:
471
+ """ Get the indices of the test set using the Tanimoto-based split.
472
+
473
+ Args:
474
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
475
+ n_bins_tanimoto (int): The number of bins to use for the Tanimoto similarity.
476
+
477
+ Returns:
478
+ pd.Index: The indices of the test set.
479
+ """
480
+ tanimoto_groups = pd.cut(active_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
481
+ encoder = OrdinalEncoder()
482
+ active_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
483
+ # Sort the groups so that samples with the highest tanimoto similarity,
484
+ # i.e., the "less similar" ones, are placed in the test set first
485
+ tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
486
+
487
+ test_df = []
488
+ # For each group, get the number of active and inactive entries. Then, add those
489
+ # entries to the test_df if: 1) the test_df lenght + the group entries is less
490
+ # 20% of the active_df lenght, and 2) the percentage of True and False entries
491
+ # in the active_label in test_df is roughly 50%.
492
+ for group in tanimoto_groups:
493
+ group_df = active_df[active_df['Tanimoto Group'] == group]
494
+ if test_df == []:
495
+ test_df.append(group_df)
496
+ continue
497
+
498
+ num_entries = len(group_df)
499
+ num_active_group = group_df[active_label].sum()
500
+ num_inactive_group = num_entries - num_active_group
501
+
502
+ tmp_test_df = pd.concat(test_df)
503
+ num_entries_test = len(tmp_test_df)
504
+ num_active_test = tmp_test_df[active_label].sum()
505
+ num_inactive_test = num_entries_test - num_active_test
506
+
507
+ # Check if the group entries can be added to the test_df
508
+ if num_entries_test + num_entries < test_split * len(active_df):
509
+ # Add anything at the beggining
510
+ if num_entries_test + num_entries < test_split / 2 * len(active_df):
511
+ test_df.append(group_df)
512
+ continue
513
+ # Be more selective and make sure that the percentage of active and
514
+ # inactive is balanced
515
+ if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
516
+ if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
517
+ test_df.append(group_df)
518
+ test_df = pd.concat(test_df)
519
+ return test_df.index
520
+
521
+ @staticmethod
522
+ def get_target_split_indices(active_df: pd.DataFrame, active_label: str, test_split: float) -> pd.Index:
523
+ """ Get the indices of the test set using the target-based split.
524
+
525
+ Args:
526
+ active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
527
+ active_label (str): The column containing the active/inactive information.
528
+ test_split (float): The percentage of the active PROTACs to use as the test set.
529
+
530
+ Returns:
531
+ pd.Index: The indices of the test set.
532
+ """
533
+ encoder = OrdinalEncoder()
534
+ active_df['Uniprot Group'] = encoder.fit_transform(active_df[['Uniprot']]).astype(int)
535
+
536
+ test_df = []
537
+ # For each group, get the number of active and inactive entries. Then, add those
538
+ # entries to the test_df if: 1) the test_df lenght + the group entries is less
539
+ # 20% of the active_df lenght, and 2) the percentage of True and False entries
540
+ # in the active_label in test_df is roughly 50%.
541
+ # Start the loop from the groups containing the smallest number of entries.
542
+ for group in reversed(active_df['Uniprot'].value_counts().index):
543
+ group_df = active_df[active_df['Uniprot'] == group]
544
+ if test_df == []:
545
+ test_df.append(group_df)
546
+ continue
547
+
548
+ num_entries = len(group_df)
549
+ num_active_group = group_df[active_label].sum()
550
+ num_inactive_group = num_entries - num_active_group
551
+
552
+ tmp_test_df = pd.concat(test_df)
553
+ num_entries_test = len(tmp_test_df)
554
+ num_active_test = tmp_test_df[active_label].sum()
555
+ num_inactive_test = num_entries_test - num_active_test
556
+
557
+ # Check if the group entries can be added to the test_df
558
+ if num_entries_test + num_entries < test_split * len(active_df):
559
+ # Add anything at the beggining
560
+ if num_entries_test + num_entries < test_split / 2 * len(active_df):
561
+ test_df.append(group_df)
562
+ continue
563
+ # Be more selective and make sure that the percentage of active and
564
+ # inactive is balanced
565
+ if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
566
+ if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
567
+ test_df.append(group_df)
568
+ test_df = pd.concat(test_df)
569
+ return test_df.index
protac_degradation_predictor/pytorch_models.py CHANGED
@@ -3,7 +3,7 @@ import pickle
3
  import logging
4
  from typing import Literal, List, Tuple, Optional, Dict
5
 
6
- from .protac_dataset import PROTAC_Dataset
7
  from .config import config
8
 
9
  import pandas as pd
@@ -38,7 +38,7 @@ class PROTAC_Predictor(nn.Module):
38
  dropout: float = 0.2,
39
  join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
40
  use_batch_norm: bool = False,
41
- disabled_embeddings: list = [],
42
  ):
43
  """ Initialize the PROTAC model.
44
 
@@ -69,17 +69,17 @@ class PROTAC_Predictor(nn.Module):
69
  # and can be summed on a "similar scale".
70
  if self.join_embeddings != 'beginning':
71
  if 'poi' not in self.disabled_embeddings:
72
- self.poi_emb = nn.Sequential(
73
  nn.Linear(poi_emb_dim, hidden_dim),
74
  nn.Softmax(dim=1),
75
  )
76
  if 'e3' not in self.disabled_embeddings:
77
- self.e3_emb = nn.Sequential(
78
  nn.Linear(e3_emb_dim, hidden_dim),
79
  nn.Softmax(dim=1),
80
  )
81
  if 'cell' not in self.disabled_embeddings:
82
- self.cell_emb = nn.Sequential(
83
  nn.Linear(cell_emb_dim, hidden_dim),
84
  nn.Softmax(dim=1),
85
  )
@@ -95,12 +95,12 @@ class PROTAC_Predictor(nn.Module):
95
  joint_dim += poi_emb_dim if 'poi' not in self.disabled_embeddings else 0
96
  joint_dim += e3_emb_dim if 'e3' not in self.disabled_embeddings else 0
97
  joint_dim += cell_emb_dim if 'cell' not in self.disabled_embeddings else 0
 
98
  elif self.join_embeddings == 'concat':
99
  joint_dim = hidden_dim * (4 - len(self.disabled_embeddings))
100
  elif self.join_embeddings == 'sum':
101
  joint_dim = hidden_dim
102
 
103
- self.fc0 = nn.Linear(joint_dim, joint_dim)
104
  self.fc1 = nn.Linear(joint_dim, hidden_dim)
105
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
106
  self.fc3 = nn.Linear(hidden_dim, 1)
@@ -125,11 +125,11 @@ class PROTAC_Predictor(nn.Module):
125
  x = self.dropout(F.relu(self.fc0(x)))
126
  else:
127
  if 'poi' not in self.disabled_embeddings:
128
- embeddings.append(self.poi_emb(poi_emb))
129
  if 'e3' not in self.disabled_embeddings:
130
- embeddings.append(self.e3_emb(e3_emb))
131
  if 'cell' not in self.disabled_embeddings:
132
- embeddings.append(self.cell_emb(cell_emb))
133
  if 'smiles' not in self.disabled_embeddings:
134
  embeddings.append(self.smiles_emb(smiles_emb))
135
  if self.join_embeddings == 'concat':
@@ -163,7 +163,7 @@ class PROTAC_Model(pl.LightningModule):
163
  train_dataset: PROTAC_Dataset = None,
164
  val_dataset: PROTAC_Dataset = None,
165
  test_dataset: PROTAC_Dataset = None,
166
- disabled_embeddings: list = [],
167
  apply_scaling: bool = True,
168
  ):
169
  """ Initialize the PROTAC Pytorch Lightning model.
@@ -217,7 +217,7 @@ class PROTAC_Model(pl.LightningModule):
217
  dropout=dropout,
218
  join_embeddings=join_embeddings,
219
  use_batch_norm=use_batch_norm,
220
- disabled_embeddings=disabled_embeddings,
221
  )
222
 
223
  stages = ['train_metrics', 'val_metrics', 'test_metrics']
@@ -429,7 +429,7 @@ def train_model(
429
  logger_name: str = 'protac',
430
  enable_checkpointing: bool = False,
431
  checkpoint_model_name: str = 'protac',
432
- disabled_embeddings: List[str] = [],
433
  return_predictions: bool = False,
434
  ) -> tuple:
435
  """ Train a PROTAC model using the given datasets and hyperparameters.
@@ -453,31 +453,18 @@ def train_model(
453
  Returns:
454
  tuple: The trained model, the trainer, and the metrics.
455
  """
456
- oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
457
- train_ds = PROTAC_Dataset(
458
  train_df,
459
- protein2embedding,
460
- cell2embedding,
461
- smiles2fp,
462
- use_smote=use_smote,
463
- oversampler=oversampler if use_smote else None,
464
- active_label=active_label,
465
- )
466
- val_ds = PROTAC_Dataset(
467
  val_df,
 
468
  protein2embedding,
469
  cell2embedding,
470
  smiles2fp,
 
 
471
  active_label=active_label,
 
472
  )
473
- if test_df is not None:
474
- test_ds = PROTAC_Dataset(
475
- test_df,
476
- protein2embedding,
477
- cell2embedding,
478
- smiles2fp,
479
- active_label=active_label,
480
- )
481
  loggers = [
482
  pl.loggers.TensorBoardLogger(
483
  save_dir=logger_save_dir,
@@ -505,7 +492,7 @@ def train_model(
505
  ),
506
  pl.callbacks.EarlyStopping(
507
  monitor='val_loss',
508
- patience=10, # Original: 5
509
  mode='min',
510
  verbose=False,
511
  ),
@@ -586,6 +573,36 @@ def train_model(
586
  return model, trainer, metrics
587
 
588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  def load_model(
590
  ckpt_path: str,
591
  ) -> PROTAC_Model:
 
3
  import logging
4
  from typing import Literal, List, Tuple, Optional, Dict
5
 
6
+ from .protac_dataset import PROTAC_Dataset, get_datasets
7
  from .config import config
8
 
9
  import pandas as pd
 
38
  dropout: float = 0.2,
39
  join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
40
  use_batch_norm: bool = False,
41
+ disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
42
  ):
43
  """ Initialize the PROTAC model.
44
 
 
69
  # and can be summed on a "similar scale".
70
  if self.join_embeddings != 'beginning':
71
  if 'poi' not in self.disabled_embeddings:
72
+ self.poi_fc = nn.Sequential(
73
  nn.Linear(poi_emb_dim, hidden_dim),
74
  nn.Softmax(dim=1),
75
  )
76
  if 'e3' not in self.disabled_embeddings:
77
+ self.e3_fc = nn.Sequential(
78
  nn.Linear(e3_emb_dim, hidden_dim),
79
  nn.Softmax(dim=1),
80
  )
81
  if 'cell' not in self.disabled_embeddings:
82
+ self.cell_fc = nn.Sequential(
83
  nn.Linear(cell_emb_dim, hidden_dim),
84
  nn.Softmax(dim=1),
85
  )
 
95
  joint_dim += poi_emb_dim if 'poi' not in self.disabled_embeddings else 0
96
  joint_dim += e3_emb_dim if 'e3' not in self.disabled_embeddings else 0
97
  joint_dim += cell_emb_dim if 'cell' not in self.disabled_embeddings else 0
98
+ self.fc0 = nn.Linear(joint_dim, joint_dim)
99
  elif self.join_embeddings == 'concat':
100
  joint_dim = hidden_dim * (4 - len(self.disabled_embeddings))
101
  elif self.join_embeddings == 'sum':
102
  joint_dim = hidden_dim
103
 
 
104
  self.fc1 = nn.Linear(joint_dim, hidden_dim)
105
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
106
  self.fc3 = nn.Linear(hidden_dim, 1)
 
125
  x = self.dropout(F.relu(self.fc0(x)))
126
  else:
127
  if 'poi' not in self.disabled_embeddings:
128
+ embeddings.append(self.poi_fc(poi_emb))
129
  if 'e3' not in self.disabled_embeddings:
130
+ embeddings.append(self.e3_fc(e3_emb))
131
  if 'cell' not in self.disabled_embeddings:
132
+ embeddings.append(self.cell_fc(cell_emb))
133
  if 'smiles' not in self.disabled_embeddings:
134
  embeddings.append(self.smiles_emb(smiles_emb))
135
  if self.join_embeddings == 'concat':
 
163
  train_dataset: PROTAC_Dataset = None,
164
  val_dataset: PROTAC_Dataset = None,
165
  test_dataset: PROTAC_Dataset = None,
166
+ disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
167
  apply_scaling: bool = True,
168
  ):
169
  """ Initialize the PROTAC Pytorch Lightning model.
 
217
  dropout=dropout,
218
  join_embeddings=join_embeddings,
219
  use_batch_norm=use_batch_norm,
220
+ disabled_embeddings=[], # NOTE: This is handled in the PROTAC_Dataset classes
221
  )
222
 
223
  stages = ['train_metrics', 'val_metrics', 'test_metrics']
 
429
  logger_name: str = 'protac',
430
  enable_checkpointing: bool = False,
431
  checkpoint_model_name: str = 'protac',
432
+ disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
433
  return_predictions: bool = False,
434
  ) -> tuple:
435
  """ Train a PROTAC model using the given datasets and hyperparameters.
 
453
  Returns:
454
  tuple: The trained model, the trainer, and the metrics.
455
  """
456
+ train_ds, val_ds, test_ds = get_datasets(
 
457
  train_df,
 
 
 
 
 
 
 
 
458
  val_df,
459
+ test_df,
460
  protein2embedding,
461
  cell2embedding,
462
  smiles2fp,
463
+ use_smote=use_smote,
464
+ smote_k_neighbors=smote_k_neighbors,
465
  active_label=active_label,
466
+ disabled_embeddings=disabled_embeddings,
467
  )
 
 
 
 
 
 
 
 
468
  loggers = [
469
  pl.loggers.TensorBoardLogger(
470
  save_dir=logger_save_dir,
 
492
  ),
493
  pl.callbacks.EarlyStopping(
494
  monitor='val_loss',
495
+ patience=5, # Original: 5
496
  mode='min',
497
  verbose=False,
498
  ),
 
573
  return model, trainer, metrics
574
 
575
 
576
+ def evaluate_model(
577
+ model: PROTAC_Model,
578
+ trainer: pl.Trainer,
579
+ val_ds: PROTAC_Dataset,
580
+ test_ds: Optional[PROTAC_Dataset] = None,
581
+ batch_size: int = 128,
582
+ ) -> tuple:
583
+ """ Evaluate a PROTAC model using the given datasets. """
584
+ ret = {}
585
+
586
+ val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
587
+ val_metrics = trainer.validate(model, val_dl, verbose=False)[0]
588
+ val_metrics = {m: v for m, v in val_metrics.items() if 'val' in m}
589
+ # Get predictions on validation set
590
+ val_pred = torch.cat(trainer.predict(model, val_dl)).squeeze()
591
+ ret['val_metrics'] = val_metrics
592
+ ret['val_pred'] = val_pred
593
+
594
+ if test_ds is not None:
595
+ test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
596
+ test_metrics = trainer.test(model, test_dl, verbose=False)[0]
597
+ test_metrics = {m: v for m, v in test_metrics.items() if 'test' in m}
598
+ # Get predictions on test set
599
+ test_pred = torch.cat(trainer.predict(model, test_dl)).squeeze()
600
+ ret['test_metrics'] = test_metrics
601
+ ret['test_pred'] = test_pred
602
+
603
+ return ret
604
+
605
+
606
  def load_model(
607
  ckpt_path: str,
608
  ) -> PROTAC_Model:
src/plot_experiment_results.py CHANGED
@@ -12,7 +12,7 @@ import numpy as np
12
  palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
13
 
14
 
15
- def plot_training_curves(df, split_type, stage='test'):
16
  Stage = 'Test' if stage == 'test' else 'Validation'
17
 
18
  # Clean the data
@@ -22,20 +22,28 @@ def plot_training_curves(df, split_type, stage='test'):
22
  df = df.apply(pd.to_numeric, errors='coerce')
23
 
24
  # Group by 'epoch' and aggregate by mean
25
- epoch_data = df.groupby('epoch').mean()
 
 
 
26
 
27
  fig, ax = plt.subplots(3, 1, figsize=(10, 15))
28
 
29
  # Plot training loss
30
- ax[0].plot(epoch_data.index, epoch_data['train_loss_epoch'], label='Training Loss')
31
- ax[0].plot(epoch_data.index, epoch_data[f'{stage}_loss'], label=f'{Stage} Loss', linestyle='--')
 
 
 
32
  ax[0].set_ylabel('Loss')
33
  ax[0].legend(loc='lower right')
34
  ax[0].grid(axis='both', alpha=0.5)
35
 
36
  # Plot training accuracy
37
- ax[1].plot(epoch_data.index, epoch_data['train_acc_epoch'], label='Training Accuracy')
38
- ax[1].plot(epoch_data.index, epoch_data[f'{stage}_acc'], label=f'{Stage} Accuracy', linestyle='--')
 
 
39
  ax[1].set_ylabel('Accuracy')
40
  ax[1].legend(loc='lower right')
41
  ax[1].grid(axis='both', alpha=0.5)
@@ -45,8 +53,10 @@ def plot_training_curves(df, split_type, stage='test'):
45
  ax[1].yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
46
 
47
  # Plot training ROC-AUC
48
- ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
49
- ax[2].plot(epoch_data.index, epoch_data[f'{stage}_roc_auc'], label=f'{Stage} ROC-AUC', linestyle='--')
 
 
50
  ax[2].set_ylabel('ROC-AUC')
51
  ax[2].legend(loc='lower right')
52
  ax[2].grid(axis='both', alpha=0.5)
@@ -167,6 +177,7 @@ def plot_ablation_study(report):
167
  'disabled poi',
168
  'disabled e3',
169
  'disabled cell',
 
170
  'disabled poi e3 smiles',
171
  'disabled poi e3 cell',
172
  ]
@@ -226,6 +237,7 @@ def plot_ablation_study(report):
226
  'disabled e3': 'Disabled E3 information',
227
  'disabled poi': 'Disabled target information',
228
  'disabled cell': 'Disabled cell information',
 
229
  'disabled poi e3 smiles': 'Disabled compound, E3, and target info\n(only cell information left)',
230
  'disabled poi e3 cell': 'Disabled cell, E3, and target info\n(only compound information left)',
231
  })
@@ -323,6 +335,7 @@ def main():
323
 
324
 
325
  for split_type in ['random', 'tanimoto', 'uniprot']:
 
326
  for i in range(n_models_for_test):
327
  logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
328
  metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
@@ -330,36 +343,41 @@ def main():
330
  # Rename 'val_' columns to 'test_' columns
331
  metrics = metrics.rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
332
  plot_training_curves(metrics, f'{split_type}_best_model_n{i}')
 
 
333
 
 
334
  for i in range(cv_n_folds):
335
  # logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
336
  logs_dir = f'logs_{report_base_name}_{split_type}_{split_type}_cv_model_fold{i}'
337
  metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
338
  metrics['fold'] = i
339
  plot_training_curves(metrics, f'{split_type}_cv_model_fold{i}', stage='val')
340
-
341
- plot_performance_metrics(
342
- reports['cv_train'],
343
- reports['test'],
344
- title=f'{active_name}_metrics',
345
- )
346
-
347
- plot_performance_metrics(
348
- reports['cv_train'],
349
- reports['majority_vote'][reports['majority_vote']['cv_models'] == False],
350
- title=f'{active_name}_metrics_majority_vote',
351
- )
352
-
353
- plot_majority_voting_performance(reports['majority_vote'])
354
-
355
- reports['test']['disabled_embeddings'] = pd.NA
356
- plot_ablation_study(pd.concat([
357
- reports['ablation'],
358
- reports['test'],
359
- ]))
360
-
361
- # Plot hyperparameter optimization results to markdown
362
- print(reports['hparam'][['split_type', 'hidden_dim', 'learning_rate', 'dropout', 'use_smote', 'smote_k_neighbors']].to_markdown(index=False))
 
 
363
 
364
 
365
  if __name__ == '__main__':
 
12
  palette = ['#83B8FE', '#FFA54C', '#94ED67', '#FF7FFF']
13
 
14
 
15
+ def plot_training_curves(df, split_type, stage='test', multimodels=False, groupby='model_id'):
16
  Stage = 'Test' if stage == 'test' else 'Validation'
17
 
18
  # Clean the data
 
22
  df = df.apply(pd.to_numeric, errors='coerce')
23
 
24
  # Group by 'epoch' and aggregate by mean
25
+ if multimodels:
26
+ epoch_data = df.groupby([groupby, 'epoch']).mean().reset_index()
27
+ else:
28
+ epoch_data = df.groupby('epoch').mean().reset_index()
29
 
30
  fig, ax = plt.subplots(3, 1, figsize=(10, 15))
31
 
32
  # Plot training loss
33
+ # ax[0].plot(epoch_data.index, epoch_data['train_loss_epoch'], label='Training Loss')
34
+ # ax[0].plot(epoch_data.index, epoch_data[f'{stage}_loss'], label=f'{Stage} Loss', linestyle='--')
35
+ sns.lineplot(data=epoch_data, x='epoch', y='train_loss_epoch', ax=ax[0], label='Training Loss')
36
+ sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_loss', ax=ax[0], label=f'{Stage} Loss', linestyle='--')
37
+
38
  ax[0].set_ylabel('Loss')
39
  ax[0].legend(loc='lower right')
40
  ax[0].grid(axis='both', alpha=0.5)
41
 
42
  # Plot training accuracy
43
+ # ax[1].plot(epoch_data.index, epoch_data['train_acc_epoch'], label='Training Accuracy')
44
+ # ax[1].plot(epoch_data.index, epoch_data[f'{stage}_acc'], label=f'{Stage} Accuracy', linestyle='--')
45
+ sns.lineplot(data=epoch_data, x='epoch', y='train_acc_epoch', ax=ax[1], label='Training Accuracy')
46
+ sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_acc', ax=ax[1], label=f'{Stage} Accuracy', linestyle='--')
47
  ax[1].set_ylabel('Accuracy')
48
  ax[1].legend(loc='lower right')
49
  ax[1].grid(axis='both', alpha=0.5)
 
53
  ax[1].yaxis.set_major_formatter(plt.matplotlib.ticker.PercentFormatter(1, decimals=0))
54
 
55
  # Plot training ROC-AUC
56
+ # ax[2].plot(epoch_data.index, epoch_data['train_roc_auc_epoch'], label='Training ROC-AUC')
57
+ # ax[2].plot(epoch_data.index, epoch_data[f'{stage}_roc_auc'], label=f'{Stage} ROC-AUC', linestyle='--')
58
+ sns.lineplot(data=epoch_data, x='epoch', y='train_roc_auc_epoch', ax=ax[2], label='Training ROC-AUC')
59
+ sns.lineplot(data=epoch_data, x='epoch', y=f'{stage}_roc_auc', ax=ax[2], label=f'{Stage} ROC-AUC', linestyle='--')
60
  ax[2].set_ylabel('ROC-AUC')
61
  ax[2].legend(loc='lower right')
62
  ax[2].grid(axis='both', alpha=0.5)
 
177
  'disabled poi',
178
  'disabled e3',
179
  'disabled cell',
180
+ 'disabled poi e3',
181
  'disabled poi e3 smiles',
182
  'disabled poi e3 cell',
183
  ]
 
237
  'disabled e3': 'Disabled E3 information',
238
  'disabled poi': 'Disabled target information',
239
  'disabled cell': 'Disabled cell information',
240
+ 'disabled poi e3': 'Disabled E3 and target info',
241
  'disabled poi e3 smiles': 'Disabled compound, E3, and target info\n(only cell information left)',
242
  'disabled poi e3 cell': 'Disabled cell, E3, and target info\n(only compound information left)',
243
  })
 
335
 
336
 
337
  for split_type in ['random', 'tanimoto', 'uniprot']:
338
+ split_metrics = []
339
  for i in range(n_models_for_test):
340
  logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
341
  metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
 
343
  # Rename 'val_' columns to 'test_' columns
344
  metrics = metrics.rename(columns={'val_loss': 'test_loss', 'val_acc': 'test_acc', 'val_roc_auc': 'test_roc_auc'})
345
  plot_training_curves(metrics, f'{split_type}_best_model_n{i}')
346
+ split_metrics.append(metrics)
347
+ plot_training_curves(pd.concat(split_metrics), f'{split_type}_best_model', multimodels=True)
348
 
349
+ split_metrics_cv = []
350
  for i in range(cv_n_folds):
351
  # logs_dir = f'logs_{report_base_name}_{split_type}_best_model_n{i}'
352
  logs_dir = f'logs_{report_base_name}_{split_type}_{split_type}_cv_model_fold{i}'
353
  metrics = pd.read_csv(f'logs/{logs_dir}/{logs_dir}/metrics.csv')
354
  metrics['fold'] = i
355
  plot_training_curves(metrics, f'{split_type}_cv_model_fold{i}', stage='val')
356
+ split_metrics_cv.append(metrics)
357
+ plot_training_curves(pd.concat(split_metrics_cv), f'{split_type}_cv_model', stage='val', multimodels=True, groupby='fold')
358
+
359
+ # plot_performance_metrics(
360
+ # reports['cv_train'],
361
+ # reports['test'],
362
+ # title=f'{active_name}_metrics',
363
+ # )
364
+
365
+ # plot_performance_metrics(
366
+ # reports['cv_train'],
367
+ # reports['majority_vote'][reports['majority_vote']['cv_models'] == False],
368
+ # title=f'{active_name}_metrics_majority_vote',
369
+ # )
370
+
371
+ # plot_majority_voting_performance(reports['majority_vote'])
372
+
373
+ # reports['test']['disabled_embeddings'] = pd.NA
374
+ # plot_ablation_study(pd.concat([
375
+ # reports['ablation'],
376
+ # reports['test'],
377
+ # ]))
378
+
379
+ # # Plot hyperparameter optimization results to markdown
380
+ # print(reports['hparam'][['split_type', 'hidden_dim', 'learning_rate', 'dropout', 'use_smote', 'smote_k_neighbors']].to_markdown(index=False))
381
 
382
 
383
  if __name__ == '__main__':
src/run_experiments.py CHANGED
@@ -305,7 +305,7 @@ def main(
305
 
306
  # Start the experiment
307
  experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
308
- optuna_reports = pdp.hyperparameter_tuning_and_training(
309
  protein2embedding=protein2embedding,
310
  cell2embedding=cell2embedding,
311
  smiles2fp=smiles2fp,
 
305
 
306
  # Start the experiment
307
  experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
308
+ optuna_reports = pdp.hyperparameter_tuning_and_training(
309
  protein2embedding=protein2embedding,
310
  cell2embedding=cell2embedding,
311
  smiles2fp=smiles2fp,