Honzus24's picture
initial commit
7968cb0
import copy
import random
import os.path as osp
import torch
import torch.utils.data as data
import pdb
from .cath_dataset import CATHDataset
from .alphafold_dataset import AlphaFoldDataset
from .ts_dataset import TSDataset
from .casp_dataset import CASPDataset
from .mpnn_dataset import MPNNDataset
from .featurizer import (featurize_AF, featurize_GTrans, featurize_GVP,
featurize_ProteinMPNN, featurize_Inversefolding)
from .fast_dataloader import DataLoaderX
class GTransDataLoader(torch.utils.data.DataLoader):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0,
collate_fn=None, **kwargs):
super(GTransDataLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn,**kwargs)
self.featurizer = collate_fn
class BatchSampler(data.Sampler):
'''
From https://github.com/jingraham/neurips19-graph-protein-design.
A `torch.utils.data.Sampler` which samples batches according to a
maximum number of graph nodes.
:param node_counts: array of node counts in the dataset to sample from
:param max_nodes: the maximum number of nodes in any batch,
including batches of a single element
:param shuffle: if `True`, batches in shuffled order
'''
def __init__(self, node_counts, max_nodes=3000, shuffle=True):
self.node_counts = node_counts
self.idx = [i for i in range(len(node_counts))
if node_counts[i] <= max_nodes]
self.shuffle = shuffle
self.max_nodes = max_nodes
self._form_batches()
def _form_batches(self):
self.batches = []
if self.shuffle: random.shuffle(self.idx)
idx = self.idx
while idx:
batch = []
n_nodes = 0
while idx and n_nodes + self.node_counts[idx[0]] <= self.max_nodes:
next_idx, idx = idx[0], idx[1:]
n_nodes += self.node_counts[next_idx]
batch.append(next_idx)
self.batches.append(batch)
def __len__(self):
if not self.batches: self._form_batches()
return len(self.batches)
def __iter__(self):
if not self.batches:
self._form_batches()
for batch in self.batches:
yield batch
class GVPDataLoader(torch.utils.data.DataLoader):
def __init__(self, dataset, num_workers=0,
featurizer=None, max_nodes=3000, **kwargs):
super(GVPDataLoader, self).__init__(dataset,
batch_sampler = BatchSampler(node_counts = [ len(data['seq']) for data in dataset], max_nodes=max_nodes),
num_workers = num_workers,
collate_fn = featurizer.collate,
**kwargs)
self.featurizer = featurizer
def load_data(data_name, method, batch_size, data_root, pdb_path, split_csv, max_nodes=3000, num_workers=8, removeTS=0, test_casp=False, **kwargs):
if data_name == 'CATH4.2' or data_name == 'TS':
cath_set = CATHDataset(osp.join(data_root, 'cath4.2'), mode='train', test_name='All', removeTS=removeTS)
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
valid_set.change_mode('valid')
test_set.change_mode('test')
if data_name == 'TS':
test_set = TSDataset(osp.join(data_root, 'ts'))
collate_fn = featurize_GTrans
elif data_name == 'CATH4.3':
cath_set = CATHDataset(osp.join(data_root, 'cath4.3'), mode='train', test_name='All', removeTS=removeTS, version=4.3)
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
valid_set.change_mode('valid')
test_set.change_mode('test')
collate_fn = featurize_GTrans
elif data_name == 'AlphaFold':
af_set = AlphaFoldDataset(osp.join(data_root, 'af2db'), upid=upid, mode='train', limit_length=limit_length, joint_data=joint_data)
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [af_set] * 3)
valid_set.change_mode('valid')
test_set.change_mode('test')
collate_fn = featurize_AF
elif data_name=='MPNN':
train_set = MPNNDataset(mode='train')
valid_set = MPNNDataset(mode='valid')
test_set = MPNNDataset(mode='test')
collate_fn = featurize_GTrans
elif data_name == 'S350':
cath_set = CATHDataset(osp.join(data_root, 's350'), mode='train', test_name='All', removeTS=removeTS, version=4.3)
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
valid_set.change_mode('train')
test_set.change_mode('train')
collate_fn = featurize_GTrans
elif data_name == 'Protherm':
cath_set = CATHDataset(osp.join(data_root, 'protherm'), mode='train', test_name='All', removeTS=removeTS, version=4.3)
train_set, valid_set, test_set = map(lambda x: copy.copy(x), [cath_set] * 3)
valid_set.change_mode('valid')
test_set.change_mode('test')
collate_fn = featurize_GTrans
if test_casp:
test_set = CASPDataset(osp.join(data_root, 'casp15'))
if method in ['AlphaDesign', 'PiFold', 'KWDesign', 'GraphTrans', 'StructGNN']:
pass
elif method == 'GVP':
featurizer = featurize_GVP()
collate_fn = featurizer.collate
elif method == 'ProteinMPNN':
collate_fn = featurize_ProteinMPNN
elif method == 'ESMIF':
collate_fn = featurize_Inversefolding
# train_set.data = train_set.data[:100]
# valid_set.data = valid_set.data[:100]
# test_set.data = test_set.data[:100]
pdb.set_trace()
train_loader = DataLoaderX(local_rank=0, dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8)
valid_loader = DataLoaderX(local_rank=0,dataset=valid_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8)
test_loader = DataLoaderX(local_rank=0,dataset=test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, prefetch_factor=8)
return train_loader, valid_loader, test_loader
def make_cath_loader(test_set, method, batch_size, max_nodes=3000, num_workers=8):
if method in ['pifold','adesign', 'graphtrans', 'structgnn', 'gca']:
collate_fn = featurize_GTrans
test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
elif method == 'gvp':
featurizer = featurize_GVP()
test_loader = GVPDataLoader(test_set, num_workers=num_workers, featurizer=featurizer, max_nodes=max_nodes)
elif method == 'proteinmpnn':
collate_fn = featurize_ProteinMPNN
test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
elif method == 'esmif':
collate_fn = featurize_Inversefolding
test_loader = GTransDataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
return test_loader