| import logging |
| from src.datamodules.base import BaseDataModule |
| from src.datasets import GRIDNET |
|
|
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| class GRIDNETDataModule(BaseDataModule): |
| """LightningDataModule for GRIDNET dataset. |
| |
| A DataModule implements 5 key methods: |
| |
| def prepare_data(self): |
| # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) |
| # download data, pre-process, split, save to disk, etc... |
| def setup(self, stage): |
| # things to do on every process in DDP |
| # load data, set variables, etc... |
| def train_dataloader(self): |
| # return train dataloader |
| def val_dataloader(self): |
| # return validation dataloader |
| def test_dataloader(self): |
| # return test dataloader |
| def teardown(self): |
| # called on every process in DDP |
| # clean up after fit or test |
| |
| This allows you to share a full dataset without explaining how to download, |
| split, transform and process the data. |
| |
| Read the docs: |
| https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html |
| """ |
| _DATASET_CLASS = GRIDNET |
|
|
|
|
|
|
| if __name__ == "__main__": |
| import hydra |
| import omegaconf |
| import pyrootutils |
|
|
| root = str(pyrootutils.setup_root(__file__, pythonpath=True)) |
| cfg = omegaconf.OmegaConf.load(root + "/configs/datamodule/semantic/gridnet.yaml") |
| cfg.data_dir = root + "/data" |
| _ = hydra.utils.instantiate(cfg) |
|
|