MadSBM / src /madsbm /wt_peptide /dataloader.py
Shrey Goel
initial commit
94c2704
import torch
import pandas as pd
import lightning.pytorch as pl
from omegaconf import OmegaConf
from datasets import load_from_disk
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from functools import partial
from src.utils.model_utils import _print
config = OmegaConf.load('/scratch/pranamlab/sgoel/MadSBM/configs/wt_pep.yaml')
# class DNADataset(Dataset):
# def __init__(self, config, data_path):
# self.config = config
# self.data = pd.read_csv(data_path)
# self.custom_tokenizer = CustomDNATokenizer(config.model.dna_model_path)
# def __len__(self):
# return len(self.data)
# def __getitem__(self, idx):
# sequence = self.data.iloc[idx]["Sequence"]
# seq = sequence.upper()
# tokenized = self.custom_tokenizer(seq, max_length=self.config.data.max_seq_len)
# return {
# "input_ids": tokenized["input_ids"].squeeze(0),
# "attention_mask": tokenized["attention_mask"].squeeze(0)
# }
def collate_fn(batch, pad_id=None):
input_ids = torch.tensor(batch[0]['input_ids'])
attention_mask = torch.tensor(batch[0]['attention_mask'])
return {
'input_ids': input_ids,
'attention_mask': attention_mask
}
class PeptideDataModule(pl.LightningDataModule):
def __init__(self, config, train_dataset, val_dataset, test_dataset, tokenizer, collate_fn=collate_fn):
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.tokenizer = tokenizer
self.collate_fn = collate_fn
self.batch_size = config.data.batch_size
assert self.batch_size == 1, f'Batch size = {self.batch_size}. Needs to be 1 for dynamic batching'
def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
collate_fn=partial(self.collate_fn),
num_workers=8,
shuffle=False,
pin_memory=True)
def val_dataloader(self):
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
collate_fn=partial(self.collate_fn),
num_workers=8,
shuffle=False,
pin_memory=True)
def test_dataloader(self):
return DataLoader(self.test_dataset,
batch_size=self.batch_size,
collate_fn=partial(self.collate_fn),
num_workers=8,
shuffle=False,
pin_memory=True)
def get_datasets(config):
"""Helper method to grab datasets to quickly init data module in main.py"""
train_dataset = load_from_disk(config.data.train)
test_dataset = load_from_disk(config.data.test)
val_dataset = load_from_disk(config.data.val)
return {
"train": train_dataset,
"val": val_dataset,
"test": test_dataset
}