Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| import random | |
| import os.path as osp | |
| import torch | |
| import torch.utils.data as data | |
| import pdb | |
| from .cath_dataset import CATHDataset | |
| from .alphafold_dataset import AlphaFoldDataset | |
| from .ts_dataset import TSDataset | |
| from .casp_dataset import CASPDataset | |
| from .mpnn_dataset import MPNNDataset | |
| from .featurizer import (featurize_AF, featurize_GTrans, featurize_GVP, | |
| featurize_ProteinMPNN, featurize_Inversefolding) | |
| from .fast_dataloader import DataLoaderX | |
| class GTransDataLoader(torch.utils.data.DataLoader): | |
| def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, | |
| collate_fn=None, **kwargs): | |
| super(GTransDataLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn,**kwargs) | |
| self.featurizer = collate_fn | |
| class BatchSampler(data.Sampler): | |
| ''' | |
| From https://github.com/jingraham/neurips19-graph-protein-design. | |
| A `torch.utils.data.Sampler` which samples batches according to a | |
| maximum number of graph nodes. | |
| :param node_counts: array of node counts in the dataset to sample from | |
| :param max_nodes: the maximum number of nodes in any batch, | |
| including batches of a single element | |
| :param shuffle: if `True`, batches in shuffled order | |
| ''' | |
| def __init__(self, node_counts, max_nodes=3000, shuffle=True): | |
| self.node_counts = node_counts | |
| self.idx = [i for i in range(len(node_counts)) | |
| if node_counts[i] <= max_nodes] | |
| self.shuffle = shuffle | |
| self.max_nodes = max_nodes | |
| self._form_batches() | |
| def _form_batches(self): | |
| self.batches = [] | |
| if self.shuffle: random.shuffle(self.idx) | |
| idx = self.idx | |
| while idx: | |
| batch = [] | |
| n_nodes = 0 | |
| while idx and n_nodes + self.node_counts[idx[0]] <= self.max_nodes: | |
| next_idx, idx = idx[0], idx[1:] | |
| n_nodes += self.node_counts[next_idx] | |
| batch.append(next_idx) | |
| self.batches.append(batch) | |
| def __len__(self): | |
| if not self.batches: self._form_batches() | |
| return len(self.batches) | |
| def __iter__(self): | |
| if not self.batches: | |
| self._form_batches() | |
| for batch in self.batches: | |
| yield batch | |
| class GVPDataLoader(torch.utils.data.DataLoader): | |
| def __init__(self, dataset, num_workers=0, | |
| featurizer=None, max_nodes=3000, **kwargs): | |
| super(GVPDataLoader, self).__init__(dataset, | |
| batch_sampler = BatchSampler(node_counts = [ len(data['seq']) for data in dataset], max_nodes=max_nodes), | |
| num_workers = num_workers, | |
| collate_fn = featurizer.collate, | |
| **kwargs) | |
| self.featurizer = featurizer | |
| def load_data(data_name, method, batch_size, data_root, pdb_path, split_csv, max_nodes=3000, num_workers=8, removeTS=0, test_casp=False, **kwargs): | |
| if data_name == 'CATH4.2' or data_name == 'TS': | |
| cath_set = CATHDataset(osp.join(data_root, 'cath4.2'), mode='train', test_name='All', removeTS=removeTS) | |
| train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3) | |
| valid_set.change_mode('valid') | |
| test_set.change_mode('test') | |
| if data_name == 'TS': | |
| test_set = TSDataset(osp.join(data_root, 'ts')) | |
| collate_fn = featurize_GTrans | |
| elif data_name == 'CATH4.3': | |
| cath_set = CATHDataset(osp.join(data_root, 'cath4.3'), mode='train', test_name='All', removeTS=removeTS, version=4.3) | |
| train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3) | |
| valid_set.change_mode('valid') | |
| test_set.change_mode('test') | |
| collate_fn = featurize_GTrans | |
| elif data_name == 'AlphaFold': | |
| af_set = AlphaFoldDataset(osp.join(data_root, 'af2db'), upid=upid, mode='train', limit_length=limit_length, joint_data=joint_data) | |
| train_set, valid_set, test_set = map(lambda x: copy.copy(x), [af_set] * 3) | |
| valid_set.change_mode('valid') | |
| test_set.change_mode('test') | |
| collate_fn = featurize_AF | |
| elif data_name=='MPNN': | |
| train_set = MPNNDataset(mode='train') | |
| valid_set = MPNNDataset(mode='valid') | |
| test_set = MPNNDataset(mode='test') | |
| collate_fn = featurize_GTrans | |
| elif data_name == 'S350': | |
| cath_set = CATHDataset(osp.join(data_root, 's350'), mode='train', test_name='All', removeTS=removeTS, version=4.3) | |
| train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3) | |
| valid_set.change_mode('train') | |
| test_set.change_mode('train') | |
| collate_fn = featurize_GTrans | |
| elif data_name == 'Protherm': | |
| cath_set = CATHDataset(osp.join(data_root, 'protherm'), mode='train', test_name='All', removeTS=removeTS, version=4.3) | |
| train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3) | |
| valid_set.change_mode('valid') | |
| test_set.change_mode('test') | |
| collate_fn = featurize_GTrans | |
| if test_casp: | |
| test_set = CASPDataset(osp.join(data_root, 'casp15')) | |
| if method in ['AlphaDesign', 'PiFold', 'KWDesign', 'GraphTrans', 'StructGNN']: | |
| pass | |
| elif method == 'GVP': | |
| featurizer = featurize_GVP() | |
| collate_fn = featurizer.collate | |
| elif method == 'ProteinMPNN': | |
| collate_fn = featurize_ProteinMPNN | |
| elif method == 'ESMIF': | |
| collate_fn = featurize_Inversefolding | |
| # train_set.data = train_set.data[:100] | |
| # valid_set.data = valid_set.data[:100] | |
| # test_set.data = test_set.data[:100] | |
| pdb.set_trace() | |
| train_loader = DataLoaderX(local_rank=0, dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8) | |
| valid_loader = DataLoaderX(local_rank=0,dataset=valid_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8) | |
| test_loader = DataLoaderX(local_rank=0,dataset=test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8) | |
| return train_loader, valid_loader, test_loader | |
| def make_cath_loader(test_set, method, batch_size, max_nodes=3000, num_workers=8): | |
| if method in ['pifold','adesign', 'graphtrans', 'structgnn', 'gca']: | |
| collate_fn = featurize_GTrans | |
| test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) | |
| elif method == 'gvp': | |
| featurizer = featurize_GVP() | |
| test_loader = GVPDataLoader(test_set, num_workers=num_workers, featurizer=featurizer, max_nodes=max_nodes) | |
| elif method == 'proteinmpnn': | |
| collate_fn = featurize_ProteinMPNN | |
| test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) | |
| elif method == 'esmif': | |
| collate_fn = featurize_Inversefolding | |
| test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) | |
| return test_loader | |