import inspect import importlib import pytorch_lightning as pl from torch.utils.data import DataLoader class DInterface_base(pl.LightningDataModule): def __init__(self, **kwargs): super().__init__() self.save_hyperparameters() self.batch_size = self.hparams.batch_size print("batch_size", self.batch_size) self.load_data_module() def setup(self, stage=None): # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: self.trainset = self.instancialize(split = 'train') self.valset = self.instancialize(split='valid') # Assign test dataset for use in dataloader(s) if stage == 'test' or stage is None: self.testset = self.instancialize(split='test') def train_dataloader(self): return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, prefetch_factor=3) def val_dataloader(self): return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) def test_dataloader(self): return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) def load_data_module(self): name = self.dataset # Change the `snake_case.py` file name to `CamelCase` class name. # Please always name your model file name as `snake_case.py` and # class name corresponding `CamelCase`. camel_name = ''.join([i.capitalize() for i in name.split('_')]) try: self.data_module = getattr(importlib.import_module( '.'+name, package=__package__), camel_name) except: raise ValueError( f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}') def instancialize(self, **other_args): """ Instancialize a model using the corresponding parameters from self.hparams dictionary. You can also input any args to overwrite the corresponding value in self.kwargs. """ if other_args['split'] == 'train': self.data_module = getattr(importlib.import_module( '.AF2DB_dataset', package='data'), 'Af2dbDataset') else: self.data_module = getattr(importlib.import_module( '.CASP15_dataset', package='data'), 'CASP15Dataset') class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:] inkeys = self.kwargs.keys() args1 = {} for arg in class_args: if arg in inkeys: args1[arg] = self.kwargs[arg] args1.update(other_args) return self.data_module(**args1)