Spaces:
Sleeping
Sleeping
| import re | |
| from functools import partial | |
| from numbers import Number | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Sequence, Union, Literal | |
| from lightning import LightningDataModule | |
| import pandas as pd | |
| from pandarallel import pandarallel | |
| from rdkit import Chem | |
| #import swifter | |
| from sklearn.preprocessing import LabelEncoder | |
| from torch.utils.data import Dataset, DataLoader | |
| from deepscreen.data.utils import label_transform, collate_fn, SafeBatchSampler | |
| from deepscreen.utils import get_logger | |
| log = get_logger(__name__) | |
| pandarallel.initialize(progress_bar=True) | |
| SMILES_PAT = r"[^A-Za-z0-9=#:+\-\[\]<>()/\\@%,.*]" | |
| FASTA_PAT = r"[^A-Z*\-]" | |
| def validate_seq_str(seq, regex): | |
| if seq: | |
| err_charset = set(re.findall(regex, seq)) | |
| if not err_charset: | |
| return None | |
| else: | |
| return ', '.join(err_charset) | |
| else: | |
| return 'Empty string' | |
| # TODO: save a list of corrupted records | |
| def rdkit_canonicalize(smiles): | |
| try: | |
| mol = Chem.MolFromSmiles(smiles) | |
| smiles = Chem.MolToSmiles(mol) | |
| except Exception as e: | |
| log.warning(f'Failed to canonicalize SMILES using RDKIT due to {str(e)}. Returning original SMILES: {smiles}') | |
| return smiles | |
| class DTIDataset(Dataset): | |
| def __init__( | |
| self, | |
| task: Literal['regression', 'binary', 'multiclass'], | |
| num_classes: Optional[int], | |
| data_path: str | Path, | |
| drug_featurizer: callable, | |
| protein_featurizer: callable, | |
| thresholds: Optional[Union[Number, Sequence[Number]]] = None, | |
| discard_intermediate: Optional[bool] = False, | |
| query: Optional[str] = 'X2' | |
| ): | |
| df = pd.read_csv( | |
| data_path, | |
| engine='python', | |
| header=0, | |
| usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'], | |
| dtype={ | |
| 'X1': 'str', | |
| 'ID1': 'str', | |
| 'X2': 'str', | |
| 'ID2': 'str', | |
| 'Y': 'float32', | |
| 'U': 'str', | |
| }, | |
| ) | |
| # Read the whole data table | |
| # if 'ID1' in df: | |
| # self.x1_to_id1 = dict(zip(df['X1'], df['ID1'])) | |
| # if 'ID2' in df: | |
| # self.x2_to_id2 = dict(zip(df['X2'], df['ID2'])) | |
| # self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2'])))) | |
| # self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2'])))) | |
| # # train and eval mode data processing (fully labelled) | |
| # if 'Y' in df.columns and df['Y'].notnull().all(): | |
| log.info(f"Processing data file: {data_path}") | |
| # Forward-fill all non-label columns | |
| df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0) | |
| # Fill NAs in string cols with an empty string to prevent wrong type inference by pytorch collator | |
| for col in df.columns: | |
| if df[col].dtype == 'object': | |
| df[col] = df[col].fillna('') | |
| # TODO potentially allow running through the whole data validation process | |
| # error = False | |
| if 'Y' in df: | |
| log.info(f"Validating labels (`Y`)...") | |
| # TODO: check sklearn.utils.multiclass.check_classification_targets | |
| match task: | |
| case 'regression': | |
| assert all(df['Y'].parallel_apply(lambda x: isinstance(x, Number))), \ | |
| f"""`Y` must be numeric for `regression` task, | |
| but it has {set(df['Y'].parallel_apply(type))}.""" | |
| case 'binary': | |
| if all(df['Y'].isin([0, 1])): | |
| assert not thresholds, \ | |
| f"""`Y` is already 0 or 1 for `binary` (classification) `task`, | |
| but still got `thresholds` ({thresholds}). | |
| Double check your choices of `task` and `thresholds`, and records in the `Y` column.""" | |
| else: | |
| assert thresholds, \ | |
| f"""`Y` must be 0 or 1 for `binary` (classification) `task`, | |
| but it has {pd.unique(df['Y'])}. | |
| You may set `thresholds` to discretize continuous labels.""" # TODO print err idx instead | |
| case 'multiclass': | |
| assert num_classes >= 3, f'`num_classes` for `task=multiclass` must be at least 3.' | |
| if all(df['Y'].parallel_apply(lambda x: x.is_integer() and x >= 0)): | |
| assert not thresholds, \ | |
| f"""`Y` is already non-negative integers for | |
| `multiclass` (classification) `task`, but still got `thresholds` ({thresholds}). | |
| Double check your choice of `task`, `thresholds` and records in the `Y` column.""" | |
| else: | |
| assert thresholds, \ | |
| f"""`Y` must be non-negative integers for | |
| `multiclass` (classification) 'task',but it has {pd.unique(df['Y'])}. | |
| You must set `thresholds` to discretize continuous labels.""" # TODO print err idx instead | |
| if 'U' in df.columns: | |
| units = df['U'] | |
| else: | |
| units = None | |
| log.warning("Units ('U') not in the data table. " | |
| "Assuming all labels to be discrete or in p-scale (-log10[M]).") | |
| # Transform labels | |
| df['Y'] = label_transform(labels=df['Y'], units=units, thresholds=thresholds, | |
| discard_intermediate=discard_intermediate) | |
| # Filter out rows with a NaN in Y (missing values) | |
| df.dropna(subset=['Y'], inplace=True) | |
| match task: | |
| case 'regression': | |
| df['Y'] = df['Y'].astype('float32') | |
| assert all(df['Y'].parallel_apply(lambda x: isinstance(x, Number))), \ | |
| f"""`Y` must be numeric for `regression` task, | |
| but after transformation it still has {set(df['Y'].parallel_apply(type))}. | |
| Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
| # TODO print err idx instead | |
| case 'binary': | |
| df['Y'] = df['Y'].astype('int') | |
| assert all(df['Y'].isin([0, 1])), \ | |
| f"""`Y` must be 0 or 1 for `task=binary`, " | |
| but after transformation it still has {pd.unique(df['Y'])}. | |
| Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
| # TODO print err idx instead | |
| case 'multiclass': | |
| df['Y'] = df['Y'].astype('int') | |
| assert all(df['Y'].parallel_apply(lambda x: x.is_integer() and x >= 0)), \ | |
| f"""Y must be non-negative integers for `task=multiclass` | |
| but after transformation it still has {pd.unique(df['Y'])}. | |
| Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
| # TODO print err idx instead | |
| target_n_unique = df['Y'].nunique() | |
| assert target_n_unique == num_classes, \ | |
| f"""You have set `num_classes` for `task=multiclass` to {num_classes}, | |
| but after transformation Y still has {target_n_unique} unique labels. | |
| Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
| log.info("Validating SMILES (`X1`)...") | |
| df['X1_ERR'] = df['X1'].parallel_apply(validate_seq_str, regex=SMILES_PAT) | |
| if not df['X1_ERR'].isna().all(): | |
| raise Exception(f"Encountered invalid SMILES:\n{df[~df['X1_ERR'].isna()][['X1', 'X1_ERR']]}") | |
| df['X1^'] = df['X1'].parallel_apply(rdkit_canonicalize) | |
| log.info("Validating FASTA (`X2`)...") | |
| df['X2'] = df['X2'].str.upper() | |
| df['X2_ERR'] = df['X2'].parallel_apply(validate_seq_str, regex=FASTA_PAT) | |
| if not df['X2_ERR'].isna().all(): | |
| raise Exception(f"Encountered invalid FASTA:\n{df[~df['X2_ERR'].isna()][['X2', 'X2_ERR']]}") | |
| # FASTA/SMILES indices as query for retrieval metrics like enrichment factor and hit rate | |
| if query: | |
| df['ID^'] = LabelEncoder().fit_transform(df[query]) | |
| self.df = df | |
| self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x) | |
| self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x) | |
| def __len__(self): | |
| return len(self.df.index) | |
| def __getitem__(self, i): | |
| sample = self.df.loc[i] | |
| sample_dict = { | |
| 'N': i, | |
| 'X1': sample['X1'], | |
| 'X1^': self.drug_featurizer(sample['X1^']), | |
| # 'ID1': sample.get('ID1'), | |
| 'X2': sample['X2'], | |
| 'X2^': self.protein_featurizer(sample['X2']), | |
| # 'ID2': sample.get('ID2'), | |
| # 'Y': sample.get('Y'), | |
| # 'ID^': sample.get('ID^'), | |
| } | |
| optional_keys = ['ID1', 'ID2', 'ID^', 'Y'] | |
| sample_dict.update({key: sample[key] for key in optional_keys if sample.get(key) is not None}) | |
| return sample_dict | |
| class DTIDataModule(LightningDataModule): | |
| """ | |
| DTI DataModule | |
| A DataModule implements 5 key methods: | |
| def prepare_data(self): | |
| # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
| # download data, pre-process, split, save to disk, etc. | |
| def setup(self, stage): | |
| # things to do on every process in DDP | |
| # load data, set variables, etc. | |
| def train_dataloader(self): | |
| # return train dataloader | |
| def val_dataloader(self): | |
| # return validation dataloader | |
| def test_dataloader(self): | |
| # return test dataloader | |
| def teardown(self): | |
| # called on every process in DDP | |
| # clean up after fit or test | |
| This allows you to share a full dataset without explaining how to download, | |
| split, transform and process the data. | |
| Read the docs: | |
| https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html | |
| """ | |
| def __init__( | |
| self, | |
| task: Literal['regression', 'binary', 'multiclass'], | |
| num_classes: Optional[int], | |
| batch_size: int, | |
| # train: bool, | |
| drug_featurizer: callable, | |
| protein_featurizer: callable, | |
| collator: callable = collate_fn, | |
| data_dir: str = "data/", | |
| data_file: Optional[str] = None, | |
| train_val_test_split: Optional[Union[Sequence[Number | str]]] = None, | |
| split: Optional[callable] = None, | |
| thresholds: Optional[Union[Number, Sequence[Number]]] = None, | |
| discard_intermediate: Optional[bool] = False, | |
| query: Optional[str] = 'X2', | |
| num_workers: int = 0, | |
| pin_memory: bool = False, | |
| ): | |
| super().__init__() | |
| self.train_data: Optional[Dataset] = None | |
| self.val_data: Optional[Dataset] = None | |
| self.test_data: Optional[Dataset] = None | |
| self.predict_data: Optional[Dataset] = None | |
| self.split = split | |
| self.collator = collator | |
| self.dataset = partial( | |
| DTIDataset, | |
| task=task, | |
| num_classes=num_classes, | |
| drug_featurizer=drug_featurizer, | |
| protein_featurizer=protein_featurizer, | |
| thresholds=thresholds, | |
| discard_intermediate=discard_intermediate, | |
| query=query | |
| ) | |
| # this line allows to access init params with 'self.hparams' ensures init params will be stored in ckpt | |
| self.save_hyperparameters(logger=False) # ignore=['split'] | |
| def prepare_data(self): | |
| """ | |
| Download data if needed. | |
| Do not use it to assign state (e.g., self.x = x). | |
| """ | |
| def setup(self, stage: Optional[str] = None, encoding: str = None): | |
| """ | |
| Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
| This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be | |
| careful not to execute data splitting twice. | |
| """ | |
| # load and split datasets only if not loaded in initialization | |
| if not any([self.train_data, self.test_data, self.val_data, self.predict_data]): | |
| if self.hparams.train_val_test_split: | |
| if len(self.hparams.train_val_test_split) != 3: | |
| raise ValueError('Length of `train_val_test_split` must be 3. ' | |
| 'Set the second element to None for training without validation. ' | |
| 'Set the third element to None for training without testing.') | |
| self.train_data = self.hparams.train_val_test_split[0] | |
| self.val_data = self.hparams.train_val_test_split[1] | |
| self.test_data = self.hparams.train_val_test_split[2] | |
| if all([self.hparams.data_file, self.split]): | |
| if all(isinstance(split, Number) or split is None | |
| for split in self.hparams.train_val_test_split): | |
| split_data = self.split( | |
| dataset=self.dataset(data_path=Path(self.hparams.data_dir, self.hparams.data_file)), | |
| lengths=[split for split in self.hparams.train_val_test_split if split is not None] | |
| ) | |
| for dataset in ['train_data', 'val_data', 'test_data']: | |
| if getattr(self, dataset) is not None: | |
| setattr(self, dataset, split_data.pop(0)) | |
| else: | |
| raise ValueError('`train_val_test_split` must be a sequence numbers or None' | |
| '(float for percentages and int for sample numbers) ' | |
| 'if both `data_file` and `split` have been specified.') | |
| elif (all(isinstance(split, str) or split is None | |
| for split in self.hparams.train_val_test_split) | |
| and not any([self.hparams.data_file, self.split])): | |
| for dataset in ['train_data', 'val_data', 'test_data']: | |
| if getattr(self, dataset) is not None: | |
| data_path = Path(getattr(self, dataset)) | |
| if not data_path.is_absolute(): | |
| data_path = Path(self.hparams.data_dir, data_path) | |
| setattr(self, dataset, self.dataset(data_path=data_path)) | |
| else: | |
| raise ValueError('For training, you must specify either all of `data_file`, `split`, ' | |
| 'and `train_val_test_split` as a sequence of numbers or ' | |
| 'solely `train_val_test_split` as a sequence of data file paths.') | |
| elif self.hparams.data_file and not any([self.split, self.hparams.train_val_test_split]): | |
| data_path = Path(self.hparams.data_file) | |
| if not data_path.is_absolute(): | |
| data_path = Path(self.hparams.data_dir, data_path) | |
| self.test_data = self.predict_data = self.dataset(data_path=data_path) | |
| else: | |
| raise ValueError("For training, you must specify `train_val_test_split`. " | |
| "For testing/predicting, you must specify only `data_file` without " | |
| "`train_val_test_split` or `split`.") | |
| def train_dataloader(self): | |
| return DataLoader( | |
| dataset=self.train_data, | |
| batch_sampler=SafeBatchSampler( | |
| data_source=self.train_data, | |
| batch_size=self.hparams.batch_size, | |
| # Dropping the last batch prevents problems caused by variable batch sizes in training, e.g., | |
| # batch_size=1 in BatchNorm, and shuffling ensures the model be trained on all samples over epochs. | |
| drop_last=True, | |
| shuffle=True, | |
| ), | |
| # batch_size=self.hparams.batch_size, | |
| # shuffle=True, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| collate_fn=self.collator, | |
| persistent_workers=True if self.hparams.num_workers > 0 else False | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| dataset=self.val_data, | |
| batch_sampler=SafeBatchSampler( | |
| data_source=self.val_data, | |
| batch_size=self.hparams.batch_size, | |
| drop_last=False, | |
| shuffle=False | |
| ), | |
| # batch_size=self.hparams.batch_size, | |
| # shuffle=False, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| collate_fn=self.collator, | |
| persistent_workers=True if self.hparams.num_workers > 0 else False | |
| ) | |
| def test_dataloader(self): | |
| return DataLoader( | |
| dataset=self.test_data, | |
| batch_sampler=SafeBatchSampler( | |
| data_source=self.test_data, | |
| batch_size=self.hparams.batch_size, | |
| drop_last=False, | |
| shuffle=False | |
| ), | |
| # batch_size=self.hparams.batch_size, | |
| # shuffle=False, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| collate_fn=self.collator, | |
| persistent_workers=True if self.hparams.num_workers > 0 else False | |
| ) | |
| def predict_dataloader(self): | |
| return DataLoader( | |
| dataset=self.predict_data, | |
| batch_sampler=SafeBatchSampler( | |
| data_source=self.predict_data, | |
| batch_size=self.hparams.batch_size, | |
| drop_last=False, | |
| shuffle=False | |
| ), | |
| # batch_size=self.hparams.batch_size, | |
| # shuffle=False, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| collate_fn=self.collator, | |
| persistent_workers=True if self.hparams.num_workers > 0 else False | |
| ) | |
| def teardown(self, stage: Optional[str] = None): | |
| """Clean up after fit or test.""" | |
| pass | |
| def state_dict(self): | |
| """Extra things to save to checkpoint.""" | |
| return {} | |
| def load_state_dict(self, state_dict: Dict[str, Any]): | |
| """Things to do when loading checkpoint.""" | |
| pass | |