| | import torch |
| | import pandas as pd |
| | import lightning.pytorch as pl |
| |
|
| | from omegaconf import OmegaConf |
| | from datasets import load_from_disk |
| | from torch.utils.data import DataLoader |
| | from torch.nn.utils.rnn import pad_sequence |
| | from functools import partial |
| | from src.utils.model_utils import _print |
| |
|
| | config = OmegaConf.load('/scratch/pranamlab/sgoel/MadSBM/configs/wt_pep.yaml') |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| | def collate_fn(batch, pad_id=None): |
| | input_ids = torch.tensor(batch[0]['input_ids']) |
| | attention_mask = torch.tensor(batch[0]['attention_mask']) |
| | return { |
| | 'input_ids': input_ids, |
| | 'attention_mask': attention_mask |
| | } |
| |
|
| |
|
| | class PeptideDataModule(pl.LightningDataModule): |
| | def __init__(self, config, train_dataset, val_dataset, test_dataset, tokenizer, collate_fn=collate_fn): |
| | super().__init__() |
| | self.train_dataset = train_dataset |
| | self.val_dataset = val_dataset |
| | self.test_dataset = test_dataset |
| | self.tokenizer = tokenizer |
| | self.collate_fn = collate_fn |
| | self.batch_size = config.data.batch_size |
| | assert self.batch_size == 1, f'Batch size = {self.batch_size}. Needs to be 1 for dynamic batching' |
| |
|
| | def train_dataloader(self): |
| | return DataLoader(self.train_dataset, |
| | batch_size=self.batch_size, |
| | collate_fn=partial(self.collate_fn), |
| | num_workers=8, |
| | shuffle=False, |
| | pin_memory=True) |
| | |
| | def val_dataloader(self): |
| | return DataLoader(self.val_dataset, |
| | batch_size=self.batch_size, |
| | collate_fn=partial(self.collate_fn), |
| | num_workers=8, |
| | shuffle=False, |
| | pin_memory=True) |
| | |
| | def test_dataloader(self): |
| | return DataLoader(self.test_dataset, |
| | batch_size=self.batch_size, |
| | collate_fn=partial(self.collate_fn), |
| | num_workers=8, |
| | shuffle=False, |
| | pin_memory=True) |
| | |
| |
|
| | def get_datasets(config): |
| | """Helper method to grab datasets to quickly init data module in main.py""" |
| | train_dataset = load_from_disk(config.data.train) |
| | test_dataset = load_from_disk(config.data.test) |
| | val_dataset = load_from_disk(config.data.val) |
| | |
| | return { |
| | "train": train_dataset, |
| | "val": val_dataset, |
| | "test": test_dataset |
| | } |
| |
|
| |
|