File size: 3,074 Bytes
9627ce0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | import inspect
import importlib
import pytorch_lightning as pl
from torch.utils.data import DataLoader
class DInterface_base(pl.LightningDataModule):
def __init__(self, num_workers=8,
dataset='',
**kwargs):
super().__init__()
self.save_hyperparameters()
self.num_workers = num_workers
self.dataset = dataset
self.kwargs = kwargs
self.batch_size = kwargs.get('batch_size', 4)
self.task_name = kwargs.get("task_name")
self.finetune_type = kwargs.get("finetune_type")
print("batch_size", self.batch_size)
print("task_name", self.task_name)
# def setup(self, stage=None):
# # Assign train/val datasets for use in dataloaders
# if stage == 'fit' or stage is None:
# self.trainset = self.instancialize(split = 'train')
# self.valset = self.instancialize(split='valid')
# # Assign test dataset for use in dataloader(s)
# if stage == 'test' or stage is None:
# self.testset = self.instancialize(split='test')
def train_dataloader(self):
return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, prefetch_factor=3)
def val_dataloader(self):
return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
def test_dataloader(self):
return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
def load_data_module(self):
name = self.dataset
# Change the `snake_case.py` file name to `CamelCase` class name.
# Please always name your model file name as `snake_case.py` and
# class name corresponding `CamelCase`.
camel_name = ''.join([i.capitalize() for i in name.split('_')])
try:
self.data_module = getattr(importlib.import_module(
'.'+name, package=__package__), camel_name)
except:
raise ValueError(
f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}')
def instancialize(self, **other_args):
""" Instancialize a model using the corresponding parameters
from self.hparams dictionary. You can also input any args
to overwrite the corresponding value in self.kwargs.
"""
if other_args['split'] == 'train':
self.data_module = getattr(importlib.import_module(
'.AF2DB_dataset', package='data'), 'Af2dbDataset')
else:
self.data_module = getattr(importlib.import_module(
'.CASP15_dataset', package='data'), 'CASP15Dataset')
class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:]
inkeys = self.kwargs.keys()
args1 = {}
for arg in class_args:
if arg in inkeys:
args1[arg] = self.kwargs[arg]
args1.update(other_args)
return self.data_module(**args1)
|