File size: 2,768 Bytes
42f26af
 
2c0063e
42f26af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from torch.utils.data.dataloader import DataLoader
from massspecgym.data.data_module import MassSpecDataModule
from flare.data.datasets import ContrastiveDataset
from functools import partial
from massspecgym.models.base import Stage

class TestDataModule(MassSpecDataModule):
    def __init__(
            self,
            collate_fn,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.collate_fn = collate_fn

    def prepare_data(self):
        pass
    
    def setup(self, stage=None):
        if stage == "test":
            self.test_dataset = self.dataset
        else:
            raise Exception("Data module supports test set only")

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=self.persistent_workers,
            drop_last=False,
            collate_fn=self.collate_fn,
        )

    def train_dataloader(self):
        return None
    
    def val_dataset(self):
        return None

class ContrastiveDataModule(MassSpecDataModule):
    def __init__(
            self,
            collate_fn,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.collate_fn = collate_fn
        self.regularization_flag = False
             
    def train_dataloader(self):
        self.train_contrastive_dataset = ContrastiveDataset(self.train_dataset)

        return DataLoader(self.train_contrastive_dataset,
                          batch_size=self.batch_size,
                          shuffle=True,
                          num_workers=self.num_workers,
                          persistent_workers=self.persistent_workers,
                          drop_last=False,
                          collate_fn=partial(self.collate_fn, stage=Stage.TRAIN),
                          )

    def val_dataloader(self):
        self.val_contrastive_dataset = ContrastiveDataset(self.val_dataset)

        return DataLoader(self.val_contrastive_dataset,
                          batch_size=self.batch_size,
                          shuffle=False,
                          num_workers=self.num_workers,
                          persistent_workers=self.persistent_workers,
                          drop_last=False,
                          collate_fn=partial(self.collate_fn, stage=Stage.VAL))

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=self.persistent_workers,
            drop_last=False,
            collate_fn=self.dataset.collate_fn,
        )