File size: 3,074 Bytes
9627ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import inspect
import importlib
import pytorch_lightning as pl
from torch.utils.data import DataLoader


class DInterface_base(pl.LightningDataModule):
    def __init__(self, num_workers=8,
                 dataset='',
                 **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.num_workers = num_workers
        self.dataset = dataset
        self.kwargs = kwargs
        self.batch_size = kwargs.get('batch_size', 4)
        self.task_name = kwargs.get("task_name")
        self.finetune_type = kwargs.get("finetune_type")
        print("batch_size", self.batch_size)
        print("task_name", self.task_name)

    # def setup(self, stage=None):
    #     # Assign train/val datasets for use in dataloaders
    #     if stage == 'fit' or stage is None:
    #         self.trainset = self.instancialize(split = 'train')
    #         self.valset = self.instancialize(split='valid')

    #     # Assign test dataset for use in dataloader(s)
    #     if stage == 'test' or stage is None:
    #         self.testset = self.instancialize(split='test')

    def train_dataloader(self):
        return DataLoader(self.trainset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, prefetch_factor=3)

    def val_dataloader(self):
        return DataLoader(self.valset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.testset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)

    def load_data_module(self):
        name = self.dataset
        # Change the `snake_case.py` file name to `CamelCase` class name.
        # Please always name your model file name as `snake_case.py` and
        # class name corresponding `CamelCase`.
        camel_name = ''.join([i.capitalize() for i in name.split('_')])
        try:
            self.data_module = getattr(importlib.import_module(
                '.'+name, package=__package__), camel_name)
        except:
            raise ValueError(
                f'Invalid Dataset File Name or Invalid Class Name data.{name}.{camel_name}')

    def instancialize(self, **other_args):
        """ Instancialize a model using the corresponding parameters
            from self.hparams dictionary. You can also input any args
            to overwrite the corresponding value in self.kwargs.
        """
        if other_args['split'] == 'train':
            self.data_module = getattr(importlib.import_module(
                    '.AF2DB_dataset', package='data'), 'Af2dbDataset')
        else:
            self.data_module = getattr(importlib.import_module(
                    '.CASP15_dataset', package='data'), 'CASP15Dataset')
        
        class_args =  list(inspect.signature(self.data_module.__init__).parameters)[1:]
        inkeys = self.kwargs.keys()
        args1 = {}
        for arg in class_args:
            if arg in inkeys:
                args1[arg] = self.kwargs[arg]
        args1.update(other_args)
        return self.data_module(**args1)