Commit ·
8aec0bb
1
Parent(s): 1ee75b1
Started working on cell line one-hot encoding experiments
Browse files
protac_degradation_predictor/protac_dataset.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
from typing import Literal, List, Tuple, Optional, Dict
|
| 2 |
from collections import defaultdict
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from .data_utils import (
|
| 5 |
get_fingerprint,
|
|
@@ -24,15 +26,16 @@ class PROTAC_Dataset(Dataset):
|
|
| 24 |
def __init__(
|
| 25 |
self,
|
| 26 |
protac_df: pd.DataFrame,
|
| 27 |
-
protein2embedding: Dict,
|
| 28 |
-
cell2embedding: Dict,
|
| 29 |
-
smiles2fp: Dict,
|
| 30 |
use_smote: bool = False,
|
| 31 |
oversampler: Optional[SMOTE | ADASYN] = None,
|
| 32 |
active_label: str = 'Active',
|
| 33 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
| 34 |
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
| 35 |
use_single_scaler: Optional[bool] = None,
|
|
|
|
| 36 |
):
|
| 37 |
""" Initialize the PROTAC dataset
|
| 38 |
|
|
@@ -47,6 +50,7 @@ class PROTAC_Dataset(Dataset):
|
|
| 47 |
disabled_embeddings (list): The list of embeddings to disable, i.e., return a zero vector
|
| 48 |
scaler (StandardScaler | dict): The scaler to use for the embeddings
|
| 49 |
use_single_scaler (bool): Whether to use a single scaler for all features
|
|
|
|
| 50 |
"""
|
| 51 |
# Filter out examples with NaN in active_label column
|
| 52 |
self.data = protac_df # [~protac_df[active_label].isna()]
|
|
@@ -84,6 +88,22 @@ class PROTAC_Dataset(Dataset):
|
|
| 84 |
self.oversampler = oversampler
|
| 85 |
if self.use_smote:
|
| 86 |
self.apply_smote()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
def apply_smote(self):
|
| 89 |
# Prepare the dataset for SMOTE
|
|
@@ -269,6 +289,17 @@ class PROTAC_Dataset(Dataset):
|
|
| 269 |
else:
|
| 270 |
cell_emb = self.data['Cell Line Identifier'].iloc[idx]
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
elem = {
|
| 273 |
'smiles_emb': smiles_emb,
|
| 274 |
'poi_emb': poi_emb,
|
|
@@ -293,6 +324,7 @@ def get_datasets(
|
|
| 293 |
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
| 294 |
use_single_scaler: Optional[bool] = None,
|
| 295 |
apply_scaling: bool = False,
|
|
|
|
| 296 |
) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
|
| 297 |
""" Get the datasets for training the PROTAC model.
|
| 298 |
|
|
@@ -323,6 +355,7 @@ def get_datasets(
|
|
| 323 |
disabled_embeddings=disabled_embeddings,
|
| 324 |
scaler=scaler,
|
| 325 |
use_single_scaler=use_single_scaler,
|
|
|
|
| 326 |
)
|
| 327 |
val_ds = PROTAC_Dataset(
|
| 328 |
val_df,
|
|
|
|
| 1 |
from typing import Literal, List, Tuple, Optional, Dict
|
| 2 |
from collections import defaultdict
|
| 3 |
+
import random
|
| 4 |
+
import logging
|
| 5 |
|
| 6 |
from .data_utils import (
|
| 7 |
get_fingerprint,
|
|
|
|
| 26 |
def __init__(
|
| 27 |
self,
|
| 28 |
protac_df: pd.DataFrame,
|
| 29 |
+
protein2embedding: Dict[str, np.ndarray],
|
| 30 |
+
cell2embedding: Dict[str, np.ndarray],
|
| 31 |
+
smiles2fp: Dict[str, np.ndarray],
|
| 32 |
use_smote: bool = False,
|
| 33 |
oversampler: Optional[SMOTE | ADASYN] = None,
|
| 34 |
active_label: str = 'Active',
|
| 35 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
| 36 |
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
| 37 |
use_single_scaler: Optional[bool] = None,
|
| 38 |
+
shuffle_embedding_prob: float = 0.0,
|
| 39 |
):
|
| 40 |
""" Initialize the PROTAC dataset
|
| 41 |
|
|
|
|
| 50 |
disabled_embeddings (list): The list of embeddings to disable, i.e., return a zero vector
|
| 51 |
scaler (StandardScaler | dict): The scaler to use for the embeddings
|
| 52 |
use_single_scaler (bool): Whether to use a single scaler for all features
|
| 53 |
+
shuffle_embedding_prob (float): The probability of shuffling the embeddings. Used for testing whether embeddings act as "barcodes". Defaults to 0.0, i.e., no shuffling.
|
| 54 |
"""
|
| 55 |
# Filter out examples with NaN in active_label column
|
| 56 |
self.data = protac_df # [~protac_df[active_label].isna()]
|
|
|
|
| 88 |
self.oversampler = oversampler
|
| 89 |
if self.use_smote:
|
| 90 |
self.apply_smote()
|
| 91 |
+
|
| 92 |
+
if shuffle_embedding_prob > 0.0:
|
| 93 |
+
self.shuffle_embedding_prob = shuffle_embedding_prob
|
| 94 |
+
# Set random seed
|
| 95 |
+
random.seed(42)
|
| 96 |
+
if self.protein_emb_dim != self.cell_emb_dim:
|
| 97 |
+
logging.warning('Protein and cell embeddings have different dimensions. Shuffling will be on POI and E3 embeddings only.')
|
| 98 |
+
|
| 99 |
+
def get_smiles_emb_dim(self):
|
| 100 |
+
return self.smiles_emb_dim
|
| 101 |
+
|
| 102 |
+
def get_protein_emb_dim(self):
|
| 103 |
+
return self.protein_emb_dim
|
| 104 |
+
|
| 105 |
+
def get_cell_emb_dim(self):
|
| 106 |
+
return self.cell_emb_dim
|
| 107 |
|
| 108 |
def apply_smote(self):
|
| 109 |
# Prepare the dataset for SMOTE
|
|
|
|
| 289 |
else:
|
| 290 |
cell_emb = self.data['Cell Line Identifier'].iloc[idx]
|
| 291 |
|
| 292 |
+
# Shuffle the embeddings if the probability is met
|
| 293 |
+
if random.random() < self.shuffle_embedding_prob:
|
| 294 |
+
if self.protein_emb_dim == self.cell_emb_dim:
|
| 295 |
+
# Randomly shuffle the embeddings for POI, cell, and E3
|
| 296 |
+
embeddings = np.vstack([poi_emb, e3_emb, cell_emb])
|
| 297 |
+
np.random.shuffle(embeddings)
|
| 298 |
+
poi_emb, e3_emb, cell_emb = embeddings
|
| 299 |
+
else:
|
| 300 |
+
# Swap POI and E3 embeddings only, because of different dimensions
|
| 301 |
+
poi_emb, e3_emb = e3_emb, poi_emb
|
| 302 |
+
|
| 303 |
elem = {
|
| 304 |
'smiles_emb': smiles_emb,
|
| 305 |
'poi_emb': poi_emb,
|
|
|
|
| 324 |
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
|
| 325 |
use_single_scaler: Optional[bool] = None,
|
| 326 |
apply_scaling: bool = False,
|
| 327 |
+
shuffle_embedding_prob: float = 0.0,
|
| 328 |
) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
|
| 329 |
""" Get the datasets for training the PROTAC model.
|
| 330 |
|
|
|
|
| 355 |
disabled_embeddings=disabled_embeddings,
|
| 356 |
scaler=scaler,
|
| 357 |
use_single_scaler=use_single_scaler,
|
| 358 |
+
shuffle_embedding_prob=shuffle_embedding_prob,
|
| 359 |
)
|
| 360 |
val_ds = PROTAC_Dataset(
|
| 361 |
val_df,
|
protac_degradation_predictor/pytorch_models.py
CHANGED
|
@@ -23,7 +23,8 @@ from torchmetrics import (
|
|
| 23 |
MetricCollection,
|
| 24 |
)
|
| 25 |
from imblearn.over_sampling import SMOTE
|
| 26 |
-
from sklearn.preprocessing import StandardScaler
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class PROTAC_Predictor(nn.Module):
|
|
@@ -402,9 +403,9 @@ class PROTAC_Model(pl.LightningModule):
|
|
| 402 |
|
| 403 |
# TODO: Use some sort of **kwargs to pass all the parameters to the model...
|
| 404 |
def train_model(
|
| 405 |
-
protein2embedding: Dict,
|
| 406 |
-
cell2embedding: Dict,
|
| 407 |
-
smiles2fp: Dict,
|
| 408 |
train_df: pd.DataFrame,
|
| 409 |
val_df: pd.DataFrame,
|
| 410 |
test_df: Optional[pd.DataFrame] = None,
|
|
@@ -414,10 +415,6 @@ def train_model(
|
|
| 414 |
dropout: float = 0.2,
|
| 415 |
max_epochs: int = 50,
|
| 416 |
use_batch_norm: bool = False,
|
| 417 |
-
smiles_emb_dim: int = config.fingerprint_size,
|
| 418 |
-
poi_emb_dim: int = config.protein_embedding_size,
|
| 419 |
-
e3_emb_dim: int = config.protein_embedding_size,
|
| 420 |
-
cell_emb_dim: int = config.cell_embedding_size,
|
| 421 |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
| 422 |
smote_k_neighbors:int = 5,
|
| 423 |
use_smote: bool = True,
|
|
@@ -431,29 +428,61 @@ def train_model(
|
|
| 431 |
checkpoint_model_name: str = 'protac',
|
| 432 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
| 433 |
return_predictions: bool = False,
|
|
|
|
|
|
|
|
|
|
| 434 |
) -> tuple:
|
| 435 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
| 436 |
|
| 437 |
Args:
|
| 438 |
-
protein2embedding (dict):
|
| 439 |
-
cell2embedding (dict):
|
| 440 |
-
smiles2fp (dict):
|
| 441 |
-
train_df (pd.DataFrame): The training
|
| 442 |
-
val_df (pd.DataFrame): The validation
|
| 443 |
-
test_df (pd.DataFrame): The test
|
| 444 |
-
hidden_dim (int): The hidden dimension of the model
|
| 445 |
-
batch_size (int): The batch size
|
| 446 |
-
learning_rate (float): The learning rate
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
Returns:
|
| 455 |
tuple: The trained model, the trainer, and the metrics over the validation and test sets.
|
| 456 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
train_ds, val_ds, test_ds = get_datasets(
|
| 458 |
train_df,
|
| 459 |
val_df,
|
|
@@ -465,7 +494,14 @@ def train_model(
|
|
| 465 |
smote_k_neighbors=smote_k_neighbors,
|
| 466 |
active_label=active_label,
|
| 467 |
disabled_embeddings=disabled_embeddings,
|
|
|
|
| 468 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
loggers = [
|
| 470 |
pl.loggers.TensorBoardLogger(
|
| 471 |
save_dir=logger_save_dir,
|
|
|
|
| 23 |
MetricCollection,
|
| 24 |
)
|
| 25 |
from imblearn.over_sampling import SMOTE
|
| 26 |
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
| 27 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
| 28 |
|
| 29 |
|
| 30 |
class PROTAC_Predictor(nn.Module):
|
|
|
|
| 403 |
|
| 404 |
# TODO: Use some sort of **kwargs to pass all the parameters to the model...
|
| 405 |
def train_model(
|
| 406 |
+
protein2embedding: Dict[str, np.ndarray],
|
| 407 |
+
cell2embedding: Dict[str, np.ndarray],
|
| 408 |
+
smiles2fp: Dict[str, np.ndarray],
|
| 409 |
train_df: pd.DataFrame,
|
| 410 |
val_df: pd.DataFrame,
|
| 411 |
test_df: Optional[pd.DataFrame] = None,
|
|
|
|
| 415 |
dropout: float = 0.2,
|
| 416 |
max_epochs: int = 50,
|
| 417 |
use_batch_norm: bool = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
join_embeddings: Literal['beginning', 'concat', 'sum'] = 'sum',
|
| 419 |
smote_k_neighbors:int = 5,
|
| 420 |
use_smote: bool = True,
|
|
|
|
| 428 |
checkpoint_model_name: str = 'protac',
|
| 429 |
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
|
| 430 |
return_predictions: bool = False,
|
| 431 |
+
shuffle_embedding_prob: float = 0.0,
|
| 432 |
+
use_cells_one_hot: bool = False,
|
| 433 |
+
use_amino_acid_count: bool = False,
|
| 434 |
) -> tuple:
|
| 435 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
| 436 |
|
| 437 |
Args:
|
| 438 |
+
protein2embedding (dict): A dictionary mapping protein identifiers to embeddings.
|
| 439 |
+
cell2embedding (dict): A dictionary mapping cell line identifiers to embeddings.
|
| 440 |
+
smiles2fp (dict): A dictionary mapping SMILES strings to fingerprints.
|
| 441 |
+
train_df (pd.DataFrame): The training dataframe.
|
| 442 |
+
val_df (pd.DataFrame): The validation dataframe.
|
| 443 |
+
test_df (Optional[pd.DataFrame]): The test dataframe.
|
| 444 |
+
hidden_dim (int): The hidden dimension of the model
|
| 445 |
+
batch_size (int): The batch size
|
| 446 |
+
learning_rate (float): The learning rate
|
| 447 |
+
dropout (float): The dropout rate
|
| 448 |
+
max_epochs (int): The maximum number of epochs
|
| 449 |
+
use_batch_norm (bool): Whether to use batch normalization
|
| 450 |
+
join_embeddings (Literal['beginning', 'concat', 'sum']): How to join the embeddings
|
| 451 |
+
smote_k_neighbors (int): The number of neighbors to use in SMOTE
|
| 452 |
+
use_smote (bool): Whether to use SMOTE
|
| 453 |
+
apply_scaling (bool): Whether to apply scaling to the embeddings
|
| 454 |
+
active_label (str): The name of the active label. Default: 'Active'
|
| 455 |
+
fast_dev_run (bool): Whether to run a fast development run (see PyTorch Lightning documentation)
|
| 456 |
+
use_logger (bool): Whether to use a logger
|
| 457 |
+
logger_save_dir (str): The directory to save the logs
|
| 458 |
+
logger_name (str): The name of the logger
|
| 459 |
+
enable_checkpointing (bool): Whether to enable checkpointing
|
| 460 |
+
checkpoint_model_name (str): The name of the model for checkpointing
|
| 461 |
+
disabled_embeddings (list): List of disabled embeddings. Can be 'poi', 'e3', 'cell', 'smiles'
|
| 462 |
+
return_predictions (bool): Whether to return predictions on the validation and test sets
|
| 463 |
|
| 464 |
Returns:
|
| 465 |
tuple: The trained model, the trainer, and the metrics over the validation and test sets.
|
| 466 |
"""
|
| 467 |
+
if use_cells_one_hot:
|
| 468 |
+
# Get one-hot encoded embeddings for cell lines
|
| 469 |
+
onehotenc = OneHotEncoder(sparse_output=False)
|
| 470 |
+
cell_embeddings = onehotenc.fit_transform(
|
| 471 |
+
np.array(list(cell2embedding.keys()))
|
| 472 |
+
)
|
| 473 |
+
cell2embedding = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)}
|
| 474 |
+
|
| 475 |
+
if use_amino_acid_count:
|
| 476 |
+
# Get count vectorized embeddings for proteins
|
| 477 |
+
# NOTE: Check that the protein2embedding is a dictionary of strings
|
| 478 |
+
if not all(isinstance(k, str) for k in protein2embedding.keys()):
|
| 479 |
+
raise ValueError("All keys in `protein2embedding` must be strings.")
|
| 480 |
+
countvec = CountVectorizer(ngram_range=(1,1), analyzer='char')
|
| 481 |
+
protein_embeddings = countvec.fit_transform(
|
| 482 |
+
list(protein2embedding.keys())
|
| 483 |
+
)
|
| 484 |
+
protein2embedding = {k: v for k, v in zip(protein2embedding.keys(), protein_embeddings)}
|
| 485 |
+
|
| 486 |
train_ds, val_ds, test_ds = get_datasets(
|
| 487 |
train_df,
|
| 488 |
val_df,
|
|
|
|
| 494 |
smote_k_neighbors=smote_k_neighbors,
|
| 495 |
active_label=active_label,
|
| 496 |
disabled_embeddings=disabled_embeddings,
|
| 497 |
+
shuffle_embedding_prob=shuffle_embedding_prob,
|
| 498 |
)
|
| 499 |
+
# NOTE: The embeddings dimensions should already match in all sets
|
| 500 |
+
smiles_emb_dim = train_ds.get_smiles_emb_dim()
|
| 501 |
+
poi_emb_dim = train_ds.get_protein_emb_dim()
|
| 502 |
+
e3_emb_dim = train_ds.get_protein_emb_dim()
|
| 503 |
+
cell_emb_dim = train_ds.get_cell_emb_dim()
|
| 504 |
+
|
| 505 |
loggers = [
|
| 506 |
pl.loggers.TensorBoardLogger(
|
| 507 |
save_dir=logger_save_dir,
|
src/run_experiments.py
CHANGED
|
@@ -238,10 +238,15 @@ def main(
|
|
| 238 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
| 239 |
|
| 240 |
Args:
|
| 241 |
-
|
| 242 |
-
n_trials (int): The number of hyperparameter
|
| 243 |
-
n_splits (int): The number of cross-validation splits.
|
| 244 |
fast_dev_run (bool): Whether to run a fast development run.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
"""
|
| 246 |
pl.seed_everything(42)
|
| 247 |
|
|
|
|
| 238 |
""" Train a PROTAC model using the given datasets and hyperparameters.
|
| 239 |
|
| 240 |
Args:
|
| 241 |
+
active_col (str): The column containing the active/inactive information. Must be in the format 'Active (Dmax N, pDC50 M)'.
|
| 242 |
+
n_trials (int): The number of hyperparameter tuning trials to run.
|
|
|
|
| 243 |
fast_dev_run (bool): Whether to run a fast development run.
|
| 244 |
+
test_split (float): The percentage of the active PROTACs to use as the test set.
|
| 245 |
+
cv_n_splits (int): The number of cross-validation splits to use.
|
| 246 |
+
max_epochs (int): The maximum number of epochs to train the model.
|
| 247 |
+
run_sklearn (bool): Whether to run sklearn models.
|
| 248 |
+
force_study (bool): Whether to force the creation of a new Optuna study.
|
| 249 |
+
experiments (str): The type of experiments to run.
|
| 250 |
"""
|
| 251 |
pl.seed_everything(42)
|
| 252 |
|
src/run_experiments_cells_onehot.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
import warnings
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Literal
|
| 7 |
+
|
| 8 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 9 |
+
|
| 10 |
+
import protac_degradation_predictor as pdp
|
| 11 |
+
|
| 12 |
+
import pytorch_lightning as pl
|
| 13 |
+
from rdkit import Chem
|
| 14 |
+
from rdkit.Chem import AllChem
|
| 15 |
+
from rdkit import DataStructs
|
| 16 |
+
from jsonargparse import CLI
|
| 17 |
+
import pandas as pd
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import numpy as np
|
| 20 |
+
from sklearn.preprocessing import OrdinalEncoder
|
| 21 |
+
from sklearn.model_selection import (
|
| 22 |
+
StratifiedKFold,
|
| 23 |
+
StratifiedGroupKFold,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# Ignore UserWarning from Matplotlib
|
| 27 |
+
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
| 28 |
+
# Ignore UserWarning from PyTorch Lightning
|
| 29 |
+
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
root = logging.getLogger()
|
| 33 |
+
root.setLevel(logging.DEBUG)
|
| 34 |
+
|
| 35 |
+
handler = logging.StreamHandler(sys.stdout)
|
| 36 |
+
handler.setLevel(logging.DEBUG)
|
| 37 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 38 |
+
handler.setFormatter(formatter)
|
| 39 |
+
root.addHandler(handler)
|
| 40 |
+
|
| 41 |
+
def main(
|
| 42 |
+
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
| 43 |
+
n_trials: int = 100,
|
| 44 |
+
fast_dev_run: bool = False,
|
| 45 |
+
test_split: float = 0.1,
|
| 46 |
+
cv_n_splits: int = 5,
|
| 47 |
+
max_epochs: int = 100,
|
| 48 |
+
force_study: bool = False,
|
| 49 |
+
experiments: str | Literal['all', 'standard', 'e3_ligase', 'similarity', 'target'] = 'all',
|
| 50 |
+
):
|
| 51 |
+
""" Run experiments with the cells one-hot encoding model.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
active_col (str): Name of the column containing the active values.
|
| 55 |
+
n_trials (int): Number of hyperparameter optimization trials.
|
| 56 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
| 57 |
+
test_split (float): Percentage of data to use for testing.
|
| 58 |
+
cv_n_splits (int): Number of cross-validation splits.
|
| 59 |
+
max_epochs (int): Maximum number of epochs to train the model.
|
| 60 |
+
force_study (bool): Whether to force the creation of a new study.
|
| 61 |
+
experiments (str): Type of experiments to run. Options are 'all', 'standard', 'e3_ligase', 'similarity', 'target'.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
# Make directory ../reports if it does not exist
|
| 65 |
+
if not os.path.exists('../reports'):
|
| 66 |
+
os.makedirs('../reports')
|
| 67 |
+
|
| 68 |
+
# Load embedding dictionaries
|
| 69 |
+
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
|
| 70 |
+
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
|
| 71 |
+
|
| 72 |
+
studies_dir = '../data/studies'
|
| 73 |
+
train_val_perc = f'{int((1 - test_split) * 100)}'
|
| 74 |
+
test_perc = f'{int(test_split * 100)}'
|
| 75 |
+
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
| 76 |
+
|
| 77 |
+
if experiments == 'all':
|
| 78 |
+
experiments = ['standard', 'similarity', 'target']
|
| 79 |
+
|
| 80 |
+
# Cross-Validation Training
|
| 81 |
+
reports = defaultdict(list)
|
| 82 |
+
for split_type in experiments:
|
| 83 |
+
|
| 84 |
+
train_val_filename = f'{split_type}_train_val_{train_val_perc}split_{active_name}.csv'
|
| 85 |
+
test_filename = f'{split_type}_test_{test_perc}split_{active_name}.csv'
|
| 86 |
+
|
| 87 |
+
train_val_df = pd.read_csv(os.path.join(studies_dir, train_val_filename))
|
| 88 |
+
test_df = pd.read_csv(os.path.join(studies_dir, test_filename))
|
| 89 |
+
|
| 90 |
+
# Get SMILES and precompute fingerprints dictionary
|
| 91 |
+
unique_smiles = pd.concat([train_val_df, test_df])['Smiles'].unique().tolist()
|
| 92 |
+
smiles2fp = {s: np.array(pdp.get_fingerprint(s)) for s in unique_smiles}
|
| 93 |
+
|
| 94 |
+
# Get the CV object
|
| 95 |
+
if split_type == 'standard':
|
| 96 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
| 97 |
+
group = None
|
| 98 |
+
elif split_type == 'e3_ligase':
|
| 99 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
| 100 |
+
group = train_val_df['E3 Group'].to_numpy()
|
| 101 |
+
elif split_type == 'similarity':
|
| 102 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
| 103 |
+
group = train_val_df['Tanimoto Group'].to_numpy()
|
| 104 |
+
elif split_type == 'target':
|
| 105 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
| 106 |
+
group = train_val_df['Uniprot Group'].to_numpy()
|
| 107 |
+
|
| 108 |
+
# Start the experiment
|
| 109 |
+
experiment_name = f'{active_name}_test_split_{test_split}_{split_type}'
|
| 110 |
+
optuna_reports = pdp.hyperparameter_tuning_and_training(
|
| 111 |
+
protein2embedding=protein2embedding,
|
| 112 |
+
cell2embedding=cell2embedding,
|
| 113 |
+
smiles2fp=smiles2fp,
|
| 114 |
+
train_val_df=train_val_df,
|
| 115 |
+
test_df=test_df,
|
| 116 |
+
kf=kf,
|
| 117 |
+
groups=group,
|
| 118 |
+
split_type=split_type,
|
| 119 |
+
n_models_for_test=3,
|
| 120 |
+
fast_dev_run=fast_dev_run,
|
| 121 |
+
n_trials=n_trials,
|
| 122 |
+
max_epochs=max_epochs,
|
| 123 |
+
logger_save_dir='../logs',
|
| 124 |
+
logger_name=f'logs_{experiment_name}',
|
| 125 |
+
active_label=active_col,
|
| 126 |
+
study_filename=f'../reports/study_cellsonehot_{experiment_name}.pkl',
|
| 127 |
+
force_study=force_study,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Save the reports to file
|
| 131 |
+
for report_name, report in optuna_reports.items():
|
| 132 |
+
report.to_csv(f'../reports/cellsonehot_{report_name}_{experiment_name}.csv', index=False)
|
| 133 |
+
reports[report_name].append(report.copy())
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == '__main__':
|
| 137 |
+
cli = CLI(main)
|
src/{run_xgboost_experiments.py → run_experiments_xgboost.py}
RENAMED
|
File without changes
|