| import pandas as pd |
| import torch |
| import pickle |
| import os |
| from torch.utils.data import DataLoader, Dataset |
| from fuson_plm.utils.logging import log_update |
|
|
| |
| |
| def custom_collate_fn(batch): |
| """ |
| Custom collate function to handle batches with strings and tensors. |
| |
| Args: |
| batch (list): List of tuples returned by __getitem__. |
| |
| Returns: |
| tuple: (sequences, embeddings, labels) |
| - sequences: List of strings |
| - embeddings: Tensor of shape (batch_size, embedding_dim) |
| - labels: Tensor of shape (batch_size, sequence_length) |
| """ |
| sequences, embeddings, labels = zip(*batch) |
|
|
| |
| embeddings = torch.stack(embeddings, dim=0) |
| labels = torch.stack(labels, dim=0) |
|
|
| |
| sequences = list(sequences) |
| |
| return sequences, embeddings, labels |
| |
| class DisorderDataset(Dataset): |
| def __init__(self, csv_file_path, cached_embeddings_path=None, max_length=4405): |
| super(DisorderDataset, self).__init__() |
| self.dataset = pd.read_csv(csv_file_path) |
| self.cached_embeddings_path = cached_embeddings_path |
| |
| self.embeddings = self.__retrieve_embeddings__() |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __retrieve_embeddings__(self): |
| try: |
| with open(self.cached_embeddings_path,"rb") as f: |
| |
| embeddings = pickle.load(f) |
| except: |
| raise Exception("Error: failed to load embeddings") |
| |
| |
| seqs = self.dataset['Sequence'].tolist() |
| embeddings = {k:v for k,v in embeddings.items() if k in seqs} |
| return embeddings |
| |
| def __getitem__(self, idx): |
| sequence = self.dataset.iloc[idx]['Sequence'] |
| embedding = self.embeddings[sequence] |
| embedding = torch.tensor(embedding, dtype=torch.float32) |
| |
| |
| label_str = self.dataset.iloc[idx]['Label'] |
| |
| labels = list(map(int, label_str)) |
| labels = torch.tensor(labels, dtype=torch.float) |
| assert len(labels)==len(sequence) |
| |
| return sequence, embedding, labels |
| |
| def get_dataloader(data_path, cached_embeddings_path, max_length=4405, batch_size=1, shuffle=True): |
| """ |
| Creates a DataLoader for the dataset. |
| Args: |
| data_path (str): Path to the CSV file (train, val, or test). |
| batch_size (int): Batch size. |
| shuffle (bool): Whether to shuffle the data. |
| tokenizer (Tokenizer): tokenizer object for data tokenization |
| Returns: |
| DataLoader: DataLoader object. |
| """ |
| dataset = DisorderDataset(data_path, cached_embeddings_path=cached_embeddings_path, max_length=max_length) |
| return DataLoader(dataset, batch_size=batch_size, collate_fn=custom_collate_fn, shuffle=shuffle) |
|
|
| def check_dataloaders(train_loader, test_loader, max_length=512, checkpoint_dir=''): |
| log_update(f'\nBuilt train and test dataloders') |
| log_update(f"\tNumber of sequences in the Training DataLoader: {len(train_loader.dataset)}") |
| log_update(f"\tNumber of sequences in the Testing DataLoader: {len(test_loader.dataset)}") |
| dataloader_overlaps = check_dataloader_overlap(train_loader, test_loader) |
| if len(dataloader_overlaps)==0: log_update("\tDataloaders are clean (no overlaps)") |
| else: log_update(f"\tWARNING! sequence overlap found: {','.join(dataloader_overlaps)}") |
| |
| |
| if not(os.path.exists(f'{checkpoint_dir}/batch_diversity')): |
| os.mkdir(f'{checkpoint_dir}/batch_diversity') |
| |
| max_length_violators = [] |
| for name, dataloader in {'train':train_loader, 'test':test_loader}.items(): |
| max_length_followed = check_max_length(dataloader, max_length) |
| if max_length_followed == False: |
| max_length_violators.append(name) |
| |
| if len(max_length_violators)==0: log_update(f"\tDataloaders follow the max length limit set by user: {max_length}") |
| else: log_update(f"\tWARNING! these loaders have sequences longer than max length={max_length}: {','.join(max_length_violators)}") |
| |
| def check_dataloader_overlap(train_loader, test_loader): |
| train_seqs = set() |
| test_seqs = set() |
| for batch_idx, (sequences, _, _) in enumerate(train_loader): |
| train_seqs.add(sequences[0]) |
| for batch_idx, (sequences, _, _) in enumerate(test_loader): |
| test_seqs.add(sequences[0]) |
| |
| return train_seqs.intersection(test_seqs) |
| |
| def check_max_length(dataloader, max_length): |
| for batch_idx, (sequences, _, _) in enumerate(dataloader): |
| if len(sequences[0]) > max_length: |
| return False |
| |
| return True |
|
|