MemDLM / src /guidance /dataloader.py
Shrey Goel
adding code
d04a061
raw
history blame
3.68 kB
import torch
import pandas as pd
import lightning.pytorch as pl
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
class MembraneDataset(Dataset):
def __init__(self, config, data_path):
self.config = config
self.data = pd.read_csv(data_path)
self.tokenizer = AutoTokenizer.from_pretrained(self.config.lm.pretrained_esm)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sequence = self.data.iloc[idx]["Sequence"]
tokens = self.tokenizer(
sequence.upper(),
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=self.config.data.max_seq_len,
)
labels = self.get_labels(sequence)
return {
"input_ids": tokens['input_ids'],
"attention_mask": tokens['attention_mask'],
"labels": labels
}
def get_labels(self, sequence):
max_len = self.config.data.max_seq_len
# Create per-residue labels
labels = torch.tensor([1 if residue.islower() else 0 for residue in sequence], dtype=torch.float)
if len(labels) < max_len: # Padding if sequence shorter than tokenizer truncation length
padded_labels = torch.cat(
[labels, torch.full(size=(max_len - len(labels),), fill_value=self.config.model.label_pad_value)]
)
else: # Truncation otherwise
padded_labels = labels[:max_len]
return padded_labels
def collate_fn(batch):
input_ids = torch.stack([item['input_ids'].squeeze(0) for item in batch])
masks = torch.stack([item['attention_mask'].squeeze(0) for item in batch])
labels = torch.stack([item['labels'] for item in batch])
return {
'input_ids': input_ids,
'attention_mask': masks,
'labels': labels
}
class MembraneDataModule(pl.LightningDataModule):
def __init__(self, config, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn):
super().__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.test_dataset = test_dataset
self.collate_fn = collate_fn
self.batch_size = config.data.batch_size
def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn,
num_workers=8,
pin_memory=True)
def val_dataloader(self):
return DataLoader(self.val_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn,
num_workers=8,
pin_memory=True)
def test_dataloader(self):
return DataLoader(self.test_dataset,
batch_size=self.batch_size,
collate_fn=self.collate_fn,
num_workers=8,
pin_memory=True)
def get_datasets(config):
"""Helper method to grab datasets to quickly init data module in main.py"""
esm_model = AutoModel.from_pretrained(config.lm.pretrained_esm)
tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_esm)
train_dataset = MembraneDataset(config, config.data.train)
val_dataset = MembraneDataset(config, config.data.val)
test_dataset = MembraneDataset(config, config.data.test)
return {
"train": train_dataset,
"val": val_dataset,
"test": test_dataset
}