Spaces:
Running
on
Zero
Running
on
Zero
| import inspect | |
| import importlib | |
| import pytorch_lightning as pl | |
| from torch.utils.data import DataLoader | |
| class DInterface_base(pl.LightningDataModule): | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.batch_size = self.hparams.batch_size | |
| print("batch_size", self.batch_size) | |
| self.load_data_module() | |
| def setup(self, stage=None): | |
| # Assign train/val datasets for use in dataloaders | |
| if stage == 'fit' or stage is None: | |
| self.trainset = self.instancialize(split = 'train') | |
| self.valset = self.instancialize(split='valid') | |
| # Assign test dataset for use in dataloader(s) | |
| if stage == 'test' or stage is None: | |
| self.testset = self.instancialize(split='test') | |
| def train_dataloader(self): | |
| return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, prefetch_factor=3) | |
| def val_dataloader(self): | |
| return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) | |
| def test_dataloader(self): | |
| return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) | |
| def load_data_module(self): | |
| name = self.dataset | |
| # Change the `snake_case.py` file name to `CamelCase` class name. | |
| # Please always name your model file name as `snake_case.py` and | |
| # class name corresponding `CamelCase`. | |
| camel_name = ''.join([i.capitalize() for i in name.split('_')]) | |
| try: | |
| self.data_module = getattr(importlib.import_module( | |
| '.'+name, package=__package__), camel_name) | |
| except: | |
| raise ValueError( | |
| f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}') | |
| def instancialize(self, **other_args): | |
| """ Instancialize a model using the corresponding parameters | |
| from self.hparams dictionary. You can also input any args | |
| to overwrite the corresponding value in self.kwargs. | |
| """ | |
| if other_args['split'] == 'train': | |
| self.data_module = getattr(importlib.import_module( | |
| '.AF2DB_dataset', package='data'), 'Af2dbDataset') | |
| else: | |
| self.data_module = getattr(importlib.import_module( | |
| '.CASP15_dataset', package='data'), 'CASP15Dataset') | |
| class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:] | |
| inkeys = self.kwargs.keys() | |
| args1 = {} | |
| for arg in class_args: | |
| if arg in inkeys: | |
| args1[arg] = self.kwargs[arg] | |
| args1.update(other_args) | |
| return self.data_module(**args1) |