Commit
·
82509b6
1
Parent(s):
cf7560f
Updated experiment scripts
Browse files- src/{run_experiments_aminoacid_counts.py → run_experiments_aminoacidcnt.py} +0 -0
- src/{run_experiments_cells_onehot.py → run_experiments_cellsonehot.py} +0 -0
- src/run_experiments_cellsonehot_aminoacidcnt.py +168 -0
- src/{run_experiments.py → run_experiments_pytorch.py} +3 -3
- src/run_experiments_xgboost.py +1 -1
src/{run_experiments_aminoacid_counts.py → run_experiments_aminoacidcnt.py}
RENAMED
|
File without changes
|
src/{run_experiments_cells_onehot.py → run_experiments_cellsonehot.py}
RENAMED
|
File without changes
|
src/run_experiments_cellsonehot_aminoacidcnt.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
import warnings
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Literal
|
| 7 |
+
|
| 8 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 9 |
+
|
| 10 |
+
import protac_degradation_predictor as pdp
|
| 11 |
+
|
| 12 |
+
import pytorch_lightning as pl
|
| 13 |
+
from rdkit import Chem
|
| 14 |
+
from rdkit.Chem import AllChem
|
| 15 |
+
from rdkit import DataStructs
|
| 16 |
+
from jsonargparse import CLI
|
| 17 |
+
import pandas as pd
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import numpy as np
|
| 20 |
+
from sklearn.preprocessing import OrdinalEncoder
|
| 21 |
+
from sklearn.model_selection import (
|
| 22 |
+
StratifiedKFold,
|
| 23 |
+
StratifiedGroupKFold,
|
| 24 |
+
)
|
| 25 |
+
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
|
| 26 |
+
from sklearn.feature_extraction.text import CountVectorizer
|
| 27 |
+
|
| 28 |
+
# Ignore UserWarning from Matplotlib
|
| 29 |
+
warnings.filterwarnings("ignore", ".*FixedLocator*")
|
| 30 |
+
# Ignore UserWarning from PyTorch Lightning
|
| 31 |
+
warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
root = logging.getLogger()
|
| 35 |
+
root.setLevel(logging.DEBUG)
|
| 36 |
+
|
| 37 |
+
handler = logging.StreamHandler(sys.stdout)
|
| 38 |
+
handler.setLevel(logging.DEBUG)
|
| 39 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 40 |
+
handler.setFormatter(formatter)
|
| 41 |
+
root.addHandler(handler)
|
| 42 |
+
|
| 43 |
+
def main(
|
| 44 |
+
active_col: str = 'Active (Dmax 0.6, pDC50 6.0)',
|
| 45 |
+
n_trials: int = 100,
|
| 46 |
+
fast_dev_run: bool = False,
|
| 47 |
+
test_split: float = 0.1,
|
| 48 |
+
cv_n_splits: int = 5,
|
| 49 |
+
max_epochs: int = 100,
|
| 50 |
+
force_study: bool = False,
|
| 51 |
+
experiments: str | Literal['all', 'standard', 'e3_ligase', 'similarity', 'target'] = 'all',
|
| 52 |
+
):
|
| 53 |
+
""" Run experiments with the cells one-hot encoding model.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
active_col (str): Name of the column containing the active values.
|
| 57 |
+
n_trials (int): Number of hyperparameter optimization trials.
|
| 58 |
+
fast_dev_run (bool): Whether to run a fast development run.
|
| 59 |
+
test_split (float): Percentage of data to use for testing.
|
| 60 |
+
cv_n_splits (int): Number of cross-validation splits.
|
| 61 |
+
max_epochs (int): Maximum number of epochs to train the model.
|
| 62 |
+
force_study (bool): Whether to force the creation of a new study.
|
| 63 |
+
experiments (str): Type of experiments to run. Options are 'all', 'standard', 'e3_ligase', 'similarity', 'target'.
|
| 64 |
+
"""
|
| 65 |
+
pl.seed_everything(42)
|
| 66 |
+
|
| 67 |
+
# Make directory ../reports if it does not exist
|
| 68 |
+
if not os.path.exists('../reports'):
|
| 69 |
+
os.makedirs('../reports')
|
| 70 |
+
|
| 71 |
+
# Load embedding dictionaries
|
| 72 |
+
protein2embedding = pdp.load_protein2embedding('../data/uniprot2embedding.h5')
|
| 73 |
+
cell2embedding = pdp.load_cell2embedding('../data/cell2embedding.pkl')
|
| 74 |
+
|
| 75 |
+
# Get one-hot encoded embeddings for cell lines
|
| 76 |
+
onehotenc = OneHotEncoder(sparse_output=False)
|
| 77 |
+
cell_embeddings = onehotenc.fit_transform(
|
| 78 |
+
np.array(list(cell2embedding.keys())).reshape(-1, 1)
|
| 79 |
+
)
|
| 80 |
+
cell2embedding = {k: v for k, v in zip(cell2embedding.keys(), cell_embeddings)}
|
| 81 |
+
|
| 82 |
+
# Create a new protein2embedding dictionary with amino acid sequence
|
| 83 |
+
protac_df = pdp.load_curated_dataset()
|
| 84 |
+
# Create the dictionary mapping 'Uniprot' to 'POI Sequence'
|
| 85 |
+
protein2embedding = protac_df.set_index('Uniprot')['POI Sequence'].to_dict()
|
| 86 |
+
# Create the dictionary mapping 'E3 Ligase Uniprot' to 'E3 Ligase Sequence'
|
| 87 |
+
e32seq = protac_df.set_index('E3 Ligase Uniprot')['E3 Ligase Sequence'].to_dict()
|
| 88 |
+
# Merge the two dictionaries into a new protein2embedding dictionary
|
| 89 |
+
protein2embedding.update(e32seq)
|
| 90 |
+
|
| 91 |
+
# Get count vectorized embeddings for proteins
|
| 92 |
+
# NOTE: Check that the protein2embedding is a dictionary of strings
|
| 93 |
+
if not all(isinstance(k, str) for k in protein2embedding.keys()):
|
| 94 |
+
raise ValueError("All keys in `protein2embedding` must be strings.")
|
| 95 |
+
countvec = CountVectorizer(ngram_range=(1, 1), analyzer='char')
|
| 96 |
+
protein_embeddings = countvec.fit_transform(
|
| 97 |
+
list(protein2embedding.keys())
|
| 98 |
+
).toarray()
|
| 99 |
+
protein2embedding = {k: v for k, v in zip(protein2embedding.keys(), protein_embeddings)}
|
| 100 |
+
|
| 101 |
+
studies_dir = '../data/studies'
|
| 102 |
+
train_val_perc = f'{int((1 - test_split) * 100)}'
|
| 103 |
+
test_perc = f'{int(test_split * 100)}'
|
| 104 |
+
active_name = active_col.replace(' ', '_').replace('(', '').replace(')', '').replace(',', '')
|
| 105 |
+
|
| 106 |
+
if experiments == 'all':
|
| 107 |
+
experiments = ['standard', 'similarity', 'target']
|
| 108 |
+
else:
|
| 109 |
+
experiments = [experiments]
|
| 110 |
+
|
| 111 |
+
# Cross-Validation Training
|
| 112 |
+
reports = defaultdict(list)
|
| 113 |
+
for split_type in experiments:
|
| 114 |
+
|
| 115 |
+
train_val_filename = f'{split_type}_train_val_{train_val_perc}split_{active_name}.csv'
|
| 116 |
+
test_filename = f'{split_type}_test_{test_perc}split_{active_name}.csv'
|
| 117 |
+
|
| 118 |
+
train_val_df = pd.read_csv(os.path.join(studies_dir, train_val_filename))
|
| 119 |
+
test_df = pd.read_csv(os.path.join(studies_dir, test_filename))
|
| 120 |
+
|
| 121 |
+
# Get SMILES and precompute fingerprints dictionary
|
| 122 |
+
unique_smiles = pd.concat([train_val_df, test_df])['Smiles'].unique().tolist()
|
| 123 |
+
smiles2fp = {s: np.array(pdp.get_fingerprint(s)) for s in unique_smiles}
|
| 124 |
+
|
| 125 |
+
# Get the CV object
|
| 126 |
+
if split_type == 'standard':
|
| 127 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
| 128 |
+
group = None
|
| 129 |
+
elif split_type == 'e3_ligase':
|
| 130 |
+
kf = StratifiedKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
| 131 |
+
group = train_val_df['E3 Group'].to_numpy()
|
| 132 |
+
elif split_type == 'similarity':
|
| 133 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
| 134 |
+
group = train_val_df['Tanimoto Group'].to_numpy()
|
| 135 |
+
elif split_type == 'target':
|
| 136 |
+
kf = StratifiedGroupKFold(n_splits=cv_n_splits, shuffle=True, random_state=42)
|
| 137 |
+
group = train_val_df['Uniprot Group'].to_numpy()
|
| 138 |
+
|
| 139 |
+
# Start the experiment
|
| 140 |
+
experiment_name = f'{split_type}_{active_name}_test_split_{test_split}'
|
| 141 |
+
optuna_reports = pdp.hyperparameter_tuning_and_training(
|
| 142 |
+
protein2embedding=protein2embedding,
|
| 143 |
+
cell2embedding=cell2embedding,
|
| 144 |
+
smiles2fp=smiles2fp,
|
| 145 |
+
train_val_df=train_val_df,
|
| 146 |
+
test_df=test_df,
|
| 147 |
+
kf=kf,
|
| 148 |
+
groups=group,
|
| 149 |
+
split_type=split_type,
|
| 150 |
+
n_models_for_test=3,
|
| 151 |
+
fast_dev_run=fast_dev_run,
|
| 152 |
+
n_trials=n_trials,
|
| 153 |
+
max_epochs=max_epochs,
|
| 154 |
+
logger_save_dir='../logs',
|
| 155 |
+
logger_name=f'cellsonehot_aminoacidcnt_{experiment_name}',
|
| 156 |
+
active_label=active_col,
|
| 157 |
+
study_filename=f'../reports/study_cellsonehot_aminoacidcnt_{experiment_name}.pkl',
|
| 158 |
+
force_study=force_study,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Save the reports to file
|
| 162 |
+
for report_name, report in optuna_reports.items():
|
| 163 |
+
report.to_csv(f'../reports/cellsonehot_aminoacidcnt_{report_name}_{experiment_name}.csv', index=False)
|
| 164 |
+
reports[report_name].append(report.copy())
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == '__main__':
|
| 168 |
+
cli = CLI(main)
|
src/{run_experiments.py → run_experiments_pytorch.py}
RENAMED
|
@@ -346,15 +346,15 @@ def main(
|
|
| 346 |
n_trials=n_trials,
|
| 347 |
max_epochs=max_epochs,
|
| 348 |
logger_save_dir='../logs',
|
| 349 |
-
logger_name=f'{experiment_name}',
|
| 350 |
active_label=active_col,
|
| 351 |
-
study_filename=f'../reports/
|
| 352 |
force_study=force_study,
|
| 353 |
)
|
| 354 |
|
| 355 |
# Save the reports to file
|
| 356 |
for report_name, report in optuna_reports.items():
|
| 357 |
-
report.to_csv(f'../reports/{report_name}_{experiment_name}.csv', index=False)
|
| 358 |
reports[report_name].append(report.copy())
|
| 359 |
|
| 360 |
|
|
|
|
| 346 |
n_trials=n_trials,
|
| 347 |
max_epochs=max_epochs,
|
| 348 |
logger_save_dir='../logs',
|
| 349 |
+
logger_name=f'pytorch_{experiment_name}',
|
| 350 |
active_label=active_col,
|
| 351 |
+
study_filename=f'../reports/study_pytorch_{experiment_name}.pkl',
|
| 352 |
force_study=force_study,
|
| 353 |
)
|
| 354 |
|
| 355 |
# Save the reports to file
|
| 356 |
for report_name, report in optuna_reports.items():
|
| 357 |
+
report.to_csv(f'../reports/pytorch_{report_name}_{experiment_name}.csv', index=False)
|
| 358 |
reports[report_name].append(report.copy())
|
| 359 |
|
| 360 |
|
src/run_experiments_xgboost.py
CHANGED
|
@@ -324,7 +324,7 @@ def main(
|
|
| 324 |
group = train_val_df['Uniprot Group'].to_numpy()
|
| 325 |
|
| 326 |
# Start the experiment
|
| 327 |
-
experiment_name = f'{active_name}_test_split_{test_split}
|
| 328 |
optuna_reports = pdp.xgboost_hyperparameter_tuning_and_training(
|
| 329 |
protein2embedding=protein2embedding,
|
| 330 |
cell2embedding=cell2embedding,
|
|
|
|
| 324 |
group = train_val_df['Uniprot Group'].to_numpy()
|
| 325 |
|
| 326 |
# Start the experiment
|
| 327 |
+
experiment_name = f'{split_type}_{active_name}_test_split_{test_split}'
|
| 328 |
optuna_reports = pdp.xgboost_hyperparameter_tuning_and_training(
|
| 329 |
protein2embedding=protein2embedding,
|
| 330 |
cell2embedding=cell2embedding,
|