| import pytorch_lightning as pl |
| from torch.utils.data import DataLoader |
|
|
|
|
| class BASEDataModule(pl.LightningDataModule): |
| def __init__(self, collate_fn): |
| super().__init__() |
|
|
| self.dataloader_options = {"collate_fn": collate_fn} |
| self.persistent_workers = True |
| self.is_mm = False |
|
|
| self._train_dataset = None |
| self._val_dataset = None |
| self._test_dataset = None |
|
|
| def get_sample_set(self, overrides={}): |
| sample_params = self.hparams.copy() |
| sample_params.update(overrides) |
| return self.DatasetEval(**sample_params) |
|
|
| @property |
| def train_dataset(self): |
| if self._train_dataset is None: |
| self._train_dataset = self.Dataset(split=self.cfg.TRAIN.SPLIT, |
| **self.hparams) |
| return self._train_dataset |
|
|
| @property |
| def val_dataset(self): |
| if self._val_dataset is None: |
| params = self.hparams.copy() |
| params['code_path'] = None |
| params['split'] = self.cfg.EVAL.SPLIT |
| self._val_dataset = self.DatasetEval(**params) |
| return self._val_dataset |
|
|
| @property |
| def test_dataset(self): |
| if self._test_dataset is None: |
| |
| |
| params = self.hparams.copy() |
| params['code_path'] = None |
| params['split'] = self.cfg.TEST.SPLIT |
| self._test_dataset = self.DatasetEval( **params) |
| return self._test_dataset |
|
|
| def setup(self, stage=None): |
| |
| if stage in (None, "fit"): |
| _ = self.train_dataset |
| _ = self.val_dataset |
| if stage in (None, "test"): |
| _ = self.test_dataset |
|
|
| def train_dataloader(self): |
| dataloader_options = self.dataloader_options.copy() |
| dataloader_options["batch_size"] = self.cfg.TRAIN.BATCH_SIZE |
| dataloader_options["num_workers"] = self.cfg.TRAIN.NUM_WORKERS |
| return DataLoader( |
| self.train_dataset, |
| shuffle=False, |
| persistent_workers=True, |
| **dataloader_options, |
| ) |
|
|
| def predict_dataloader(self): |
| dataloader_options = self.dataloader_options.copy() |
| dataloader_options[ |
| "batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE |
| dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS |
| dataloader_options["shuffle"] = False |
| return DataLoader( |
| self.test_dataset, |
| persistent_workers=True, |
| **dataloader_options, |
| ) |
|
|
| def val_dataloader(self): |
| |
| dataloader_options = self.dataloader_options.copy() |
| dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE |
| dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS |
| dataloader_options["shuffle"] = False |
| return DataLoader( |
| self.val_dataset, |
| persistent_workers=True, |
| **dataloader_options, |
| ) |
|
|
| def test_dataloader(self): |
| |
| dataloader_options = self.dataloader_options.copy() |
| dataloader_options[ |
| "batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE |
| dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS |
| dataloader_options["shuffle"] = False |
| return DataLoader( |
| self.test_dataset, |
| persistent_workers=True, |
| **dataloader_options, |
| ) |
|
|