| | import inspect |
| | import importlib |
| | import pytorch_lightning as pl |
| | from torch.utils.data import DataLoader |
| |
|
| |
|
| | class DInterface_base(pl.LightningDataModule): |
| | def __init__(self, num_workers=8, |
| | dataset='', |
| | **kwargs): |
| | super().__init__() |
| | self.save_hyperparameters() |
| | self.num_workers = num_workers |
| | self.dataset = dataset |
| | self.kwargs = kwargs |
| | self.batch_size = kwargs.get('batch_size', 4) |
| | self.task_name = kwargs.get("task_name") |
| | self.finetune_type = kwargs.get("finetune_type") |
| | print("batch_size", self.batch_size) |
| | print("task_name", self.task_name) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | def train_dataloader(self): |
| | return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, prefetch_factor=3) |
| |
|
| | def val_dataloader(self): |
| | return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) |
| |
|
| | def test_dataloader(self): |
| | return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) |
| |
|
| | def load_data_module(self): |
| | name = self.dataset |
| | |
| | |
| | |
| | camel_name = ''.join([i.capitalize() for i in name.split('_')]) |
| | try: |
| | self.data_module = getattr(importlib.import_module( |
| | '.'+name, package=__package__), camel_name) |
| | except: |
| | raise ValueError( |
| | f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}') |
| |
|
| | def instancialize(self, **other_args): |
| | """ Instancialize a model using the corresponding parameters |
| | from self.hparams dictionary. You can also input any args |
| | to overwrite the corresponding value in self.kwargs. |
| | """ |
| | if other_args['split'] == 'train': |
| | self.data_module = getattr(importlib.import_module( |
| | '.AF2DB_dataset', package='data'), 'Af2dbDataset') |
| | else: |
| | self.data_module = getattr(importlib.import_module( |
| | '.CASP15_dataset', package='data'), 'CASP15Dataset') |
| | |
| | class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:] |
| | inkeys = self.kwargs.keys() |
| | args1 = {} |
| | for arg in class_args: |
| | if arg in inkeys: |
| | args1[arg] = self.kwargs[arg] |
| | args1.update(other_args) |
| | return self.data_module(**args1) |
| |
|