| import logging | |
| from src.datamodules.base import BaseDataModule | |
| from src.datasets import GRIDNET | |
| log = logging.getLogger(__name__) | |
| class GRIDNETDataModule(BaseDataModule): | |
| """LightningDataModule for GRIDNET dataset. | |
| A DataModule implements 5 key methods: | |
| def prepare_data(self): | |
| # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) | |
| # download data, pre-process, split, save to disk, etc... | |
| def setup(self, stage): | |
| # things to do on every process in DDP | |
| # load data, set variables, etc... | |
| def train_dataloader(self): | |
| # return train dataloader | |
| def val_dataloader(self): | |
| # return validation dataloader | |
| def test_dataloader(self): | |
| # return test dataloader | |
| def teardown(self): | |
| # called on every process in DDP | |
| # clean up after fit or test | |
| This allows you to share a full dataset without explaining how to download, | |
| split, transform and process the data. | |
| Read the docs: | |
| https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html | |
| """ | |
| _DATASET_CLASS = GRIDNET | |
| if __name__ == "__main__": | |
| import hydra | |
| import omegaconf | |
| import pyrootutils | |
| root = str(pyrootutils.setup_root(__file__, pythonpath=True)) | |
| cfg = omegaconf.OmegaConf.load(root + "/configs/datamodule/semantic/gridnet.yaml") | |
| cfg.data_dir = root + "/data" | |
| _ = hydra.utils.instantiate(cfg) | |