import torch import pandas as pd import lightning.pytorch as pl from omegaconf import OmegaConf from datasets import load_from_disk from torch.utils.data import DataLoader from torch.nn.utils.rnn import pad_sequence from functools import partial from src.utils.model_utils import _print config = OmegaConf.load('/scratch/pranamlab/sgoel/MadSBM/configs/wt_pep.yaml') # class DNADataset(Dataset): # def __init__(self, config, data_path): # self.config = config # self.data = pd.read_csv(data_path) # self.custom_tokenizer = CustomDNATokenizer(config.model.dna_model_path) # def __len__(self): # return len(self.data) # def __getitem__(self, idx): # sequence = self.data.iloc[idx]["Sequence"] # seq = sequence.upper() # tokenized = self.custom_tokenizer(seq, max_length=self.config.data.max_seq_len) # return { # "input_ids": tokenized["input_ids"].squeeze(0), # "attention_mask": tokenized["attention_mask"].squeeze(0) # } def collate_fn(batch, pad_id=None): input_ids = torch.tensor(batch[0]['input_ids']) attention_mask = torch.tensor(batch[0]['attention_mask']) return { 'input_ids': input_ids, 'attention_mask': attention_mask } class PeptideDataModule(pl.LightningDataModule): def __init__(self, config, train_dataset, val_dataset, test_dataset, tokenizer, collate_fn=collate_fn): super().__init__() self.train_dataset = train_dataset self.val_dataset = val_dataset self.test_dataset = test_dataset self.tokenizer = tokenizer self.collate_fn = collate_fn self.batch_size = config.data.batch_size assert self.batch_size == 1, f'Batch size = {self.batch_size}. Needs to be 1 for dynamic batching' def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=partial(self.collate_fn), num_workers=8, shuffle=False, pin_memory=True) def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=partial(self.collate_fn), num_workers=8, shuffle=False, pin_memory=True) def test_dataloader(self): return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=partial(self.collate_fn), num_workers=8, shuffle=False, pin_memory=True) def get_datasets(config): """Helper method to grab datasets to quickly init data module in main.py""" train_dataset = load_from_disk(config.data.train) test_dataset = load_from_disk(config.data.test) val_dataset = load_from_disk(config.data.val) return { "train": train_dataset, "val": val_dataset, "test": test_dataset }