File size: 3,167 Bytes
94c2704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import pandas as pd
import lightning.pytorch as pl

from omegaconf import OmegaConf
from datasets import load_from_disk
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from functools import partial
from src.utils.model_utils import _print

config = OmegaConf.load('/scratch/pranamlab/sgoel/MadSBM/configs/wt_pep.yaml')


# class DNADataset(Dataset):
#     def __init__(self, config, data_path):
#         self.config = config
#         self.data = pd.read_csv(data_path)
#         self.custom_tokenizer = CustomDNATokenizer(config.model.dna_model_path)

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         sequence = self.data.iloc[idx]["Sequence"]
#         seq = sequence.upper()

#         tokenized = self.custom_tokenizer(seq, max_length=self.config.data.max_seq_len)

#         return {
#             "input_ids": tokenized["input_ids"].squeeze(0),
#             "attention_mask": tokenized["attention_mask"].squeeze(0)
#         }



def collate_fn(batch, pad_id=None):
    input_ids = torch.tensor(batch[0]['input_ids'])
    attention_mask = torch.tensor(batch[0]['attention_mask'])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }


class PeptideDataModule(pl.LightningDataModule):
    def __init__(self, config, train_dataset, val_dataset, test_dataset, tokenizer, collate_fn=collate_fn):
        super().__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.tokenizer = tokenizer
        self.collate_fn = collate_fn
        self.batch_size = config.data.batch_size
        assert self.batch_size == 1, f'Batch size = {self.batch_size}. Needs to be 1 for dynamic batching'

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          collate_fn=partial(self.collate_fn),
                          num_workers=8,
                          shuffle=False,
                          pin_memory=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          collate_fn=partial(self.collate_fn),
                          num_workers=8,
                          shuffle=False,
                          pin_memory=True)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.batch_size,
                          collate_fn=partial(self.collate_fn),
                          num_workers=8,
                          shuffle=False,
                          pin_memory=True)
    

def get_datasets(config):
    """Helper method to grab datasets to quickly init data module in main.py"""
    train_dataset = load_from_disk(config.data.train)
    test_dataset = load_from_disk(config.data.test)
    val_dataset = load_from_disk(config.data.val)
    
    return  {
        "train": train_dataset,
        "val": val_dataset,
        "test": test_dataset
    }