Spaces:
Sleeping
Sleeping
| # from itertools import product | |
| from numbers import Number | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Sequence, Union, Literal | |
| # import numpy as np | |
| import pandas as pd | |
| from lightning import LightningDataModule | |
| from sklearn.base import TransformerMixin | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| from deepscreen.data.utils.dataset import SingleEntitySingleTargetDataset, BaseEntityDataset | |
| from deepscreen.data.utils.label import label_transform | |
| from deepscreen.data.utils.collator import collate_fn | |
| from deepscreen.data.utils.sampler import SafeBatchSampler | |
| class EntityDataModule(LightningDataModule): | |
| """ | |
| DTI DataModule | |
| A DataModule implements 5 key methods: | |
| def prepare_data(self): | |
| # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
| # download data, pre-process, split, save to disk, etc. | |
| def setup(self, stage): | |
| # things to do on every process in DDP | |
| # load data, set variables, etc. | |
| def train_dataloader(self): | |
| # return train dataloader | |
| def val_dataloader(self): | |
| # return validation dataloader | |
| def test_dataloader(self): | |
| # return test dataloader | |
| def teardown(self): | |
| # called on every process in DDP | |
| # clean up after fit or test | |
| This allows you to share a full dataset without explaining how to download, | |
| split, transform and process the data. | |
| Read the docs: | |
| https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html | |
| """ | |
| def __init__( | |
| self, | |
| dataset: type[BaseEntityDataset], | |
| task: Literal['regression', 'binary', 'multiclass'], | |
| n_classes: Optional[int], | |
| train: bool, | |
| batch_size: int, | |
| num_workers: int = 0, | |
| thresholds: Optional[Union[Number, Sequence[Number]]] = None, | |
| pin_memory: bool = False, | |
| data_dir: str = "data/", | |
| data_file: Optional[str] = None, | |
| train_val_test_split: Optional[Sequence[Number], Sequence[str]] = None, | |
| split: Optional[callable] = random_split, | |
| ): | |
| super().__init__() | |
| data_path = Path(data_dir) / data_file | |
| # this line allows to access init params with 'self.hparams' attribute | |
| # also ensures init params will be stored in ckpt | |
| self.save_hyperparameters(logger=False) | |
| # data processing | |
| self.split = split | |
| if train: | |
| if all([data_file, split]): | |
| if all(isinstance(split, Number) for split in train_val_test_split): | |
| pass | |
| else: | |
| raise ValueError('`train_val_test_split` must be a sequence of 3 numbers ' | |
| '(float for percentages and int for sample numbers) if ' | |
| '`data_file` and `split` have been specified.') | |
| elif all(isinstance(split, str) for split in train_val_test_split) and not any([data_file, split]): | |
| self.train_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[0])) | |
| self.val_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[1])) | |
| self.test_data = dataset(dataset_path=str(Path(data_dir) / train_val_test_split[2])) | |
| else: | |
| raise ValueError('For training (train=True), you must specify either ' | |
| '`dataset_name` and `split` with `train_val_test_split` of 3 numbers or ' | |
| 'solely `train_val_test_split` of 3 data file names.') | |
| else: | |
| if data_file and not any([split, train_val_test_split]): | |
| self.test_data = self.predict_data = dataset(dataset_path=str(Path(data_dir) / data_file)) | |
| else: | |
| raise ValueError("For testing/predicting (train=False), you must specify only `data_file` without " | |
| "`train_val_test_split` or `split`") | |
| def prepare_data(self): | |
| """ | |
| Download data if needed. | |
| Do not use it to assign state (e.g., self.x = x). | |
| """ | |
| def setup(self, stage: Optional[str] = None, encoding: str = None): | |
| """ | |
| Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
| This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be | |
| careful not to execute data splitting twice. | |
| """ | |
| # load and split datasets only if not loaded in initialization | |
| if not any([self.data_train, self.data_val, self.data_test, self.data_predict]): | |
| dataset = SingleEntitySingleTargetDataset( | |
| task=self.hparams.task, | |
| n_classes=self.hparams.n_classes, | |
| dataset_path=Path(self.hparams.data_dir) / self.hparams.dataset_name, | |
| transformer=self.hparams.transformer, | |
| featurizer=self.hparams.featurizer, | |
| thresholds=self.hparams.thresholds, | |
| ) | |
| if self.hparams.train: | |
| self.data_train, self.data_val, self.data_test = self.split( | |
| dataset=dataset, | |
| lengths=self.hparams.train_val_test_split | |
| ) | |
| else: | |
| self.data_test = self.data_predict = dataset | |
| def train_dataloader(self): | |
| return DataLoader( | |
| dataset=self.data_train, | |
| batch_sampler=SafeBatchSampler( | |
| data_source=self.data_train, | |
| batch_size=self.hparams.batch_size, | |
| shuffle=True), | |
| # batch_size=self.hparams.batch_size, | |
| # shuffle=True, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| collate_fn=collate_fn, | |
| persistent_workers=True if self.hparams.num_workers > 0 else False | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| dataset=self.data_val, | |
| batch_sampler=SafeBatchSampler( | |
| data_source=self.data_val, | |
| batch_size=self.hparams.batch_size, | |
| shuffle=False), | |
| # batch_size=self.hparams.batch_size, | |
| # shuffle=False, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| collate_fn=collate_fn, | |
| persistent_workers=True if self.hparams.num_workers > 0 else False | |
| ) | |
| def test_dataloader(self): | |
| return DataLoader( | |
| dataset=self.data_test, | |
| batch_sampler=SafeBatchSampler( | |
| data_source=self.data_test, | |
| batch_size=self.hparams.batch_size, | |
| shuffle=False), | |
| # batch_size=self.hparams.batch_size, | |
| # shuffle=False, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| collate_fn=collate_fn, | |
| persistent_workers=True if self.hparams.num_workers > 0 else False | |
| ) | |
| def predict_dataloader(self): | |
| return DataLoader( | |
| dataset=self.data_predict, | |
| batch_sampler=SafeBatchSampler( | |
| data_source=self.data_predict, | |
| batch_size=self.hparams.batch_size, | |
| shuffle=False), | |
| # batch_size=self.hparams.batch_size, | |
| # shuffle=False, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| collate_fn=collate_fn, | |
| persistent_workers=True if self.hparams.num_workers > 0 else False | |
| ) | |
| def teardown(self, stage: Optional[str] = None): | |
| """Clean up after fit or test.""" | |
| pass | |
| def state_dict(self): | |
| """Extra things to save to checkpoint.""" | |
| return {} | |
| def load_state_dict(self, state_dict: Dict[str, Any]): | |
| """Things to do when loading checkpoint.""" | |
| pass | |