Honzus24's picture
initial commit
7968cb0
import inspect
import importlib
import pytorch_lightning as pl
from torch.utils.data import DataLoader
class DInterface_base(pl.LightningDataModule):
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters()
self.batch_size = self.hparams.batch_size
print("batch_size", self.batch_size)
self.load_data_module()
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
self.trainset = self.instancialize(split = 'train')
self.valset = self.instancialize(split='valid')
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.testset = self.instancialize(split='test')
def train_dataloader(self):
return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, prefetch_factor=3)
def val_dataloader(self):
return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
def test_dataloader(self):
return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
def load_data_module(self):
name = self.dataset
# Change the `snake_case.py` file name to `CamelCase` class name.
# Please always name your model file name as `snake_case.py` and
# class name corresponding `CamelCase`.
camel_name = ''.join([i.capitalize() for i in name.split('_')])
try:
self.data_module = getattr(importlib.import_module(
'.'+name, package=__package__), camel_name)
except:
raise ValueError(
f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}')
def instancialize(self, **other_args):
""" Instancialize a model using the corresponding parameters
from self.hparams dictionary. You can also input any args
to overwrite the corresponding value in self.kwargs.
"""
if other_args['split'] == 'train':
self.data_module = getattr(importlib.import_module(
'.AF2DB_dataset', package='data'), 'Af2dbDataset')
else:
self.data_module = getattr(importlib.import_module(
'.CASP15_dataset', package='data'), 'CASP15Dataset')
class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:]
inkeys = self.kwargs.keys()
args1 = {}
for arg in class_args:
if arg in inkeys:
args1[arg] = self.kwargs[arg]
args1.update(other_args)
return self.data_module(**args1)