Spaces:
Sleeping
Sleeping
| from functools import partial | |
| from numbers import Number | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Sequence, Union, Literal | |
| from lightning import LightningDataModule | |
| import pandas as pd | |
| from sklearn.preprocessing import LabelEncoder | |
| from torch.utils.data import Dataset, DataLoader | |
| from deepscreen.data.utils import label_transform, collate_fn, SafeBatchSampler | |
| from deepscreen.utils import get_logger | |
| log = get_logger(__name__) | |
| # TODO: save a list of corrupted records | |
| class DTIDataset(Dataset): | |
| def __init__( | |
| self, | |
| task: Literal['regression', 'binary', 'multiclass'], | |
| n_class: Optional[int], | |
| data_path: str | Path, | |
| drug_featurizer: callable, | |
| protein_featurizer: callable, | |
| thresholds: Optional[Union[Number, Sequence[Number]]] = None, | |
| discard_intermediate: Optional[bool] = False, | |
| ): | |
| df = pd.read_csv( | |
| data_path, | |
| engine='python', | |
| header=0, | |
| usecols=lambda x: x in ['X1', 'ID1', 'X2', 'ID2', 'Y', 'U'], | |
| dtype={ | |
| 'X1': 'str', | |
| 'ID1': 'str', | |
| 'X2': 'str', | |
| 'ID2': 'str', | |
| 'Y': 'float32', | |
| 'U': 'str', | |
| }, | |
| ) | |
| # Read the whole data table | |
| # if 'ID1' in df: | |
| # self.x1_to_id1 = dict(zip(df['X1'], df['ID1'])) | |
| # if 'ID2' in df: | |
| # self.x2_to_id2 = dict(zip(df['X2'], df['ID2'])) | |
| # self.id2_to_indexes = dict(zip(df['ID2'], range(len(df['ID2'])))) | |
| # self.x2_to_indexes = dict(zip(df['X2'], range(len(df['X2'])))) | |
| # # train and eval mode data processing (fully labelled) | |
| # if 'Y' in df.columns and df['Y'].notnull().all(): | |
| log.info(f"Processing data file: {data_path}") | |
| # Forward-fill all non-label columns | |
| df.loc[:, df.columns != 'Y'] = df.loc[:, df.columns != 'Y'].ffill(axis=0) | |
| if 'Y' in df: | |
| log.info(f"Performing pre-transformation target validation.") | |
| # TODO: check sklearn.utils.multiclass.check_classification_targets | |
| match task: | |
| case 'regression': | |
| assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \ | |
| f"""`Y` must be numeric for `regression` task, | |
| but it has {set(df['Y'].apply(type))}.""" | |
| case 'binary': | |
| if all(df['Y'].isin([0, 1])): | |
| assert not thresholds, \ | |
| f"""`Y` is already 0 or 1 for `binary` (classification) `task`, | |
| but still got `thresholds` {thresholds}. | |
| Double check your choices of `task` and `thresholds` and records in the `Y` column.""" | |
| else: | |
| assert thresholds, \ | |
| f"""`Y` must be 0 or 1 for `binary` (classification) `task`, | |
| but it has {pd.unique(df['Y'])}. | |
| You must set `thresholds` to discretize continuous labels.""" | |
| case 'multiclass': | |
| assert n_class >= 3, f'`n_class` for `multiclass` (classification) `task` must be at least 3.' | |
| if all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)): | |
| assert not thresholds, \ | |
| f"""`Y` is already non-negative integers for | |
| `multiclass` (classification) `task`, but still got `thresholds` {thresholds}. | |
| Double check your choice of `task`, `thresholds` and records in the `Y` column.""" | |
| else: | |
| assert thresholds, \ | |
| f"""`Y` must be non-negative integers for | |
| `multiclass` (classification) 'task',but it has {pd.unique(df['Y'])}. | |
| You must set `thresholds` to discretize continuous labels.""" | |
| if 'U' in df.columns: | |
| units = df['U'] | |
| else: | |
| units = None | |
| log.warning("Units ('U') not in the data table. " | |
| "Assuming all labels to be discrete or in p-scale (-log10[M]).") | |
| # Transform labels | |
| df['Y'] = label_transform(labels=df['Y'], units=units, thresholds=thresholds, | |
| discard_intermediate=discard_intermediate) | |
| # Filter out rows with a NaN in Y (missing values) | |
| df.dropna(subset=['Y'], inplace=True) | |
| log.info(f"Performing post-transformation target validation.") | |
| match task: | |
| case 'regression': | |
| df['Y'] = df['Y'].astype('float32') | |
| assert all(df['Y'].apply(lambda x: isinstance(x, Number))), \ | |
| f"""`Y` must be numeric for `regression` task, | |
| but after transformation it still has {set(df['Y'].apply(type))}. | |
| Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
| case 'binary': | |
| df['Y'] = df['Y'].astype('int') | |
| assert all(df['Y'].isin([0, 1])), \ | |
| f"""`Y` must be 0 or 1 for `binary` (classification) `task`, " | |
| but after transformation it still has {pd.unique(df['Y'])}. | |
| Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
| case 'multiclass': | |
| df['Y'] = df['Y'].astype('int') | |
| assert all(df['Y'].apply(lambda x: x.is_integer() and x >= 0)), \ | |
| f"""Y must be non-negative integers for task `multiclass` (classification) | |
| but after transformation it still has {pd.unique(df['Y'])}. | |
| Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
| target_n_unique = df['Y'].nunique() | |
| assert target_n_unique == n_class, \ | |
| f"""You have set `n_class` for `multiclass` (classification) `task` to {n_class}, | |
| but after transformation Y still has {target_n_unique} unique labels. | |
| Double check your choices of `task` and `thresholds` and records in the `Y` and `U` columns.""" | |
| # Indexed protein/FASTA for retrieval metrics | |
| df['IDX'] = LabelEncoder().fit_transform(df['X2']) | |
| self.df = df | |
| self.drug_featurizer = drug_featurizer if drug_featurizer is not None else (lambda x: x) | |
| self.protein_featurizer = protein_featurizer if protein_featurizer is not None else (lambda x: x) | |
| def __len__(self): | |
| return len(self.df.index) | |
| def __getitem__(self, i): | |
| sample = self.df.loc[i] | |
| return { | |
| 'N': i, | |
| 'X1': self.drug_featurizer(sample['X1']), | |
| 'ID1': sample.get('ID1', sample['X1']), | |
| 'X2': self.protein_featurizer(sample['X2']), | |
| 'ID2': sample.get('ID2', sample['X2']), | |
| 'Y': sample.get('Y'), | |
| 'IDX': sample['IDX'], | |
| } | |
| class DTIDataModule(LightningDataModule): | |
| """ | |
| DTI DataModule | |
| A DataModule implements 5 key methods: | |
| def prepare_data(self): | |
| # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
| # download data, pre-process, split, save to disk, etc. | |
| def setup(self, stage): | |
| # things to do on every process in DDP | |
| # load data, set variables, etc. | |
| def train_dataloader(self): | |
| # return train dataloader | |
| def val_dataloader(self): | |
| # return validation dataloader | |
| def test_dataloader(self): | |
| # return test dataloader | |
| def teardown(self): | |
| # called on every process in DDP | |
| # clean up after fit or test | |
| This allows you to share a full dataset without explaining how to download, | |
| split, transform and process the data. | |
| Read the docs: | |
| https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html | |
| """ | |
| def __init__( | |
| self, | |
| task: Literal['regression', 'binary', 'multiclass'], | |
| n_class: Optional[int], | |
| batch_size: int, | |
| # train: bool, | |
| drug_featurizer: callable, | |
| protein_featurizer: callable, | |
| collator: callable = collate_fn, | |
| data_dir: str = "data/", | |
| data_file: Optional[str] = None, | |
| train_val_test_split: Optional[Union[Sequence[Number | str]]] = None, | |
| split: Optional[callable] = None, | |
| thresholds: Optional[Union[Number, Sequence[Number]]] = None, | |
| discard_intermediate: Optional[bool] = False, | |
| num_workers: int = 0, | |
| pin_memory: bool = False, | |
| ): | |
| super().__init__() | |
| self.train_data: Optional[Dataset] = None | |
| self.val_data: Optional[Dataset] = None | |
| self.test_data: Optional[Dataset] = None | |
| self.predict_data: Optional[Dataset] = None | |
| self.split = split | |
| self.collator = collator | |
| self.dataset = partial( | |
| DTIDataset, | |
| task=task, | |
| n_class=n_class, | |
| drug_featurizer=drug_featurizer, | |
| protein_featurizer=protein_featurizer, | |
| thresholds=thresholds, | |
| discard_intermediate=discard_intermediate | |
| ) | |
| if train_val_test_split: | |
| # TODO test behavior for trainer.test and predict when this is passed | |
| if len(train_val_test_split) not in [2, 3]: | |
| raise ValueError('Length of `train_val_test_split` must be 2 (for training without testing) or 3.') | |
| if all([data_file, split]): | |
| if all(isinstance(split, Number) for split in train_val_test_split): | |
| pass | |
| else: | |
| raise ValueError('`train_val_test_split` must be a sequence numbers ' | |
| '(float for percentages and int for sample numbers) ' | |
| 'if both `data_file` and `split` have been specified.') | |
| elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]): | |
| split_paths = [] | |
| for split in train_val_test_split: | |
| split = Path(split) | |
| if not split.is_absolute(): | |
| split = Path(data_dir, split) | |
| split_paths.append(split) | |
| self.train_data = self.dataset(data_path=split_paths[0]) | |
| self.val_data = self.dataset(data_path=split_paths[1]) | |
| if len(train_val_test_split) == 3: | |
| self.test_data = self.dataset(data_path=split_paths[2]) | |
| else: | |
| raise ValueError('For training, you must specify either `data_file`, `split`, ' | |
| 'and `train_val_test_split` as a sequence of numbers or ' | |
| 'solely `train_val_test_split` as a sequence of data file paths.') | |
| elif data_file and not any([split, train_val_test_split]): | |
| data_file = Path(data_file) | |
| if not data_file.is_absolute(): | |
| data_file = Path(data_dir, data_file) | |
| self.test_data = self.predict_data = self.dataset(data_path=data_file) | |
| else: | |
| raise ValueError("For training, you must specify `train_val_test_split`. " | |
| "For testing/predicting, you must specify only `data_file` without " | |
| "`train_val_test_split` or `split`.") | |
| # this line allows to access init params with 'self.hparams' attribute | |
| # also ensures init params will be stored in ckpt | |
| self.save_hyperparameters(logger=False) # ignore=['split'] | |
| def prepare_data(self): | |
| """ | |
| Download data if needed. | |
| Do not use it to assign state (e.g., self.x = x). | |
| """ | |
| def setup(self, stage: Optional[str] = None, encoding: str = None): | |
| """ | |
| Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
| This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be | |
| careful not to execute data splitting twice. | |
| """ | |
| # TODO test SafeBatchSampler (which skips samples with any None without introducing variable batch size) | |
| # load and split datasets only if not loaded in initialization | |
| if not any([self.train_data, self.test_data, self.val_data, self.predict_data]): | |
| self.train_data, self.val_data, self.test_data = self.split( | |
| dataset=self.dataset(data_path=Path(self.hparams.data_dir, self.hparams.data_file)), | |
| lengths=self.hparams.train_val_test_split | |
| ) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| dataset=self.train_data, | |
| batch_sampler=SafeBatchSampler( | |
| data_source=self.train_data, | |
| batch_size=self.hparams.batch_size, | |
| # Dropping the last batch prevents problems caused by variable batch sizes in training, e.g., | |
| # batch_size=1 in BatchNorm, and shuffling ensures the model be trained on all samples over epochs. | |
| drop_last=True, | |
| shuffle=True, | |
| ), | |
| # batch_size=self.hparams.batch_size, | |
| # shuffle=True, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| collate_fn=self.collator, | |
| persistent_workers=True if self.hparams.num_workers > 0 else False | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| dataset=self.val_data, | |
| batch_sampler=SafeBatchSampler( | |
| data_source=self.val_data, | |
| batch_size=self.hparams.batch_size, | |
| drop_last=False, | |
| shuffle=False | |
| ), | |
| # batch_size=self.hparams.batch_size, | |
| # shuffle=False, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| collate_fn=self.collator, | |
| persistent_workers=True if self.hparams.num_workers > 0 else False | |
| ) | |
| def test_dataloader(self): | |
| return DataLoader( | |
| dataset=self.test_data, | |
| batch_sampler=SafeBatchSampler( | |
| data_source=self.test_data, | |
| batch_size=self.hparams.batch_size, | |
| drop_last=False, | |
| shuffle=False | |
| ), | |
| # batch_size=self.hparams.batch_size, | |
| # shuffle=False, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| collate_fn=self.collator, | |
| persistent_workers=True if self.hparams.num_workers > 0 else False | |
| ) | |
| def predict_dataloader(self): | |
| return DataLoader( | |
| dataset=self.predict_data, | |
| batch_sampler=SafeBatchSampler( | |
| data_source=self.predict_data, | |
| batch_size=self.hparams.batch_size, | |
| drop_last=False, | |
| shuffle=False | |
| ), | |
| # batch_size=self.hparams.batch_size, | |
| # shuffle=False, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| collate_fn=self.collator, | |
| persistent_workers=True if self.hparams.num_workers > 0 else False | |
| ) | |
| def teardown(self, stage: Optional[str] = None): | |
| """Clean up after fit or test.""" | |
| pass | |
| def state_dict(self): | |
| """Extra things to save to checkpoint.""" | |
| return {} | |
| def load_state_dict(self, state_dict: Dict[str, Any]): | |
| """Things to do when loading checkpoint.""" | |
| pass | |