Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,408 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import copy
import random
import os.path as osp
import torch
import torch.utils.data as data
import pdb
from .cath_dataset import CATHDataset
from .alphafold_dataset import AlphaFoldDataset
from .ts_dataset import TSDataset
from .casp_dataset import CASPDataset
from .mpnn_dataset import MPNNDataset
from .featurizer import (featurize_AF, featurize_GTrans, featurize_GVP,
featurize_ProteinMPNN, featurize_Inversefolding)
from .fast_dataloader import DataLoaderX
class GTransDataLoader(torch.utils.data.DataLoader):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0,
collate_fn=None, **kwargs):
super(GTransDataLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn,**kwargs)
self.featurizer = collate_fn
class BatchSampler(data.Sampler):
'''
From https://github.com/jingraham/neurips19-graph-protein-design.
A `torch.utils.data.Sampler` which samples batches according to a
maximum number of graph nodes.
:param node_counts: array of node counts in the dataset to sample from
:param max_nodes: the maximum number of nodes in any batch,
including batches of a single element
:param shuffle: if `True`, batches in shuffled order
'''
def __init__(self, node_counts, max_nodes=3000, shuffle=True):
self.node_counts = node_counts
self.idx = [i for i in range(len(node_counts))
if node_counts[i] <= max_nodes]
self.shuffle = shuffle
self.max_nodes = max_nodes
self._form_batches()
def _form_batches(self):
self.batches = []
if self.shuffle: random.shuffle(self.idx)
idx = self.idx
while idx:
batch = []
n_nodes = 0
while idx and n_nodes + self.node_counts[idx[0]] <= self.max_nodes:
next_idx, idx = idx[0], idx[1:]
n_nodes += self.node_counts[next_idx]
batch.append(next_idx)
self.batches.append(batch)
def __len__(self):
if not self.batches: self._form_batches()
return len(self.batches)
def __iter__(self):
if not self.batches:
self._form_batches()
for batch in self.batches:
yield batch
class GVPDataLoader(torch.utils.data.DataLoader):
def __init__(self, dataset, num_workers=0,
featurizer=None, max_nodes=3000, **kwargs):
super(GVPDataLoader, self).__init__(dataset,
batch_sampler = BatchSampler(node_counts = [ len(data['seq']) for data in dataset], max_nodes=max_nodes),
num_workers = num_workers,
collate_fn = featurizer.collate,
**kwargs)
self.featurizer = featurizer
def load_data(data_name, method, batch_size, data_root, pdb_path, split_csv, max_nodes=3000, num_workers=8, removeTS=0, test_casp=False, **kwargs):
if data_name == 'CATH4.2' or data_name == 'TS':
cath_set = CATHDataset(osp.join(data_root, 'cath4.2'), mode='train', test_name='All', removeTS=removeTS)
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
valid_set.change_mode('valid')
test_set.change_mode('test')
if data_name == 'TS':
test_set = TSDataset(osp.join(data_root, 'ts'))
collate_fn = featurize_GTrans
elif data_name == 'CATH4.3':
cath_set = CATHDataset(osp.join(data_root, 'cath4.3'), mode='train', test_name='All', removeTS=removeTS, version=4.3)
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
valid_set.change_mode('valid')
test_set.change_mode('test')
collate_fn = featurize_GTrans
elif data_name == 'AlphaFold':
af_set = AlphaFoldDataset(osp.join(data_root, 'af2db'), upid=upid, mode='train', limit_length=limit_length, joint_data=joint_data)
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [af_set] * 3)
valid_set.change_mode('valid')
test_set.change_mode('test')
collate_fn = featurize_AF
elif data_name=='MPNN':
train_set = MPNNDataset(mode='train')
valid_set = MPNNDataset(mode='valid')
test_set = MPNNDataset(mode='test')
collate_fn = featurize_GTrans
elif data_name == 'S350':
cath_set = CATHDataset(osp.join(data_root, 's350'), mode='train', test_name='All', removeTS=removeTS, version=4.3)
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
valid_set.change_mode('train')
test_set.change_mode('train')
collate_fn = featurize_GTrans
elif data_name == 'Protherm':
cath_set = CATHDataset(osp.join(data_root, 'protherm'), mode='train', test_name='All', removeTS=removeTS, version=4.3)
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
valid_set.change_mode('valid')
test_set.change_mode('test')
collate_fn = featurize_GTrans
if test_casp:
test_set = CASPDataset(osp.join(data_root, 'casp15'))
if method in ['AlphaDesign', 'PiFold', 'KWDesign', 'GraphTrans', 'StructGNN']:
pass
elif method == 'GVP':
featurizer = featurize_GVP()
collate_fn = featurizer.collate
elif method == 'ProteinMPNN':
collate_fn = featurize_ProteinMPNN
elif method == 'ESMIF':
collate_fn = featurize_Inversefolding
# train_set.data = train_set.data[:100]
# valid_set.data = valid_set.data[:100]
# test_set.data = test_set.data[:100]
pdb.set_trace()
train_loader = DataLoaderX(local_rank=0, dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8)
valid_loader = DataLoaderX(local_rank=0,dataset=valid_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8)
test_loader = DataLoaderX(local_rank=0,dataset=test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8)
return train_loader, valid_loader, test_loader
def make_cath_loader(test_set, method, batch_size, max_nodes=3000, num_workers=8):
if method in ['pifold','adesign', 'graphtrans', 'structgnn', 'gca']:
collate_fn = featurize_GTrans
test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
elif method == 'gvp':
featurizer = featurize_GVP()
test_loader = GVPDataLoader(test_set, num_workers=num_workers, featurizer=featurizer, max_nodes=max_nodes)
elif method == 'proteinmpnn':
collate_fn = featurize_ProteinMPNN
test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
elif method == 'esmif':
collate_fn = featurize_Inversefolding
test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
return test_loader
|