import logging from src.datamodules.base import BaseDataModule from src.datasets import GRIDNET log = logging.getLogger(__name__) class GRIDNETDataModule(BaseDataModule): """LightningDataModule for GRIDNET dataset. A DataModule implements 5 key methods: def prepare_data(self): # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP) # download data, pre-process, split, save to disk, etc... def setup(self, stage): # things to do on every process in DDP # load data, set variables, etc... def train_dataloader(self): # return train dataloader def val_dataloader(self): # return validation dataloader def test_dataloader(self): # return test dataloader def teardown(self): # called on every process in DDP # clean up after fit or test This allows you to share a full dataset without explaining how to download, split, transform and process the data. Read the docs: https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html """ _DATASET_CLASS = GRIDNET if __name__ == "__main__": import hydra import omegaconf import pyrootutils root = str(pyrootutils.setup_root(__file__, pythonpath=True)) cfg = omegaconf.OmegaConf.load(root + "/configs/datamodule/semantic/gridnet.yaml") cfg.data_dir = root + "/data" _ = hydra.utils.instantiate(cfg)