Honzus24's picture
initial commit
7968cb0
import os
import json
import numpy as np
from tqdm import tqdm
import pandas as pd
import torch.utils.data as data
from Bio.PDB import PDBParser
import torch
import random
import csv
from dateutil import parser
from .fast_dataloader import DataLoaderX
from torch.utils.data import DataLoader
import time
from joblib import Parallel, delayed, cpu_count
from tqdm import tqdm
def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, desc=None, **kwargs):
"""
Parallel map using joblib.
Parameters
----------
pickleable_fn : callable
Function to map over data.
data : iterable
Data over which we want to parallelize the function call.
n_jobs : int, optional
The maximum number of concurrently running jobs. By default, it is one less than
the number of CPUs.
verbose: int, optional
The verbosity level. If nonzero, the function prints the progress messages.
The frequency of the messages increases with the verbosity level. If above 10,
it reports all iterations. If above 50, it sends the output to stdout.
kwargs
Additional arguments for :attr:`pickleable_fn`.
Returns
-------
list
The i-th element of the list corresponds to the output of applying
:attr:`pickleable_fn` to :attr:`data[i]`.
"""
if n_jobs is None:
n_jobs = cpu_count() - 1
results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)(
delayed(pickleable_fn)(*d, **kwargs) for i, d in tqdm(enumerate(data),desc=desc)
)
def build_training_clusters(params, debug):
val_ids = set([int(l) for l in open(params['VAL']).readlines()])
test_ids = set([int(l) for l in open(params['TEST']).readlines()])
if debug:
val_ids = []
test_ids = []
# read & clean list.csv
with open(params['LIST'], 'r') as f:
reader = csv.reader(f)
next(reader)
rows = [[r[0],r[3],int(r[4])] for r in reader
if float(r[2])<=params['RESCUT'] and
parser.parse(r[1])<=parser.parse(params['DATCUT'])]
# compile training and validation sets
train = {}
valid = {}
test = {}
if debug:
rows = rows[:20]
for r in rows:
if r[2] in val_ids:
if r[2] in valid.keys():
valid[r[2]].append(r[:2])
else:
valid[r[2]] = [r[:2]]
elif r[2] in test_ids:
if r[2] in test.keys():
test[r[2]].append(r[:2])
else:
test[r[2]] = [r[:2]]
else:
if r[2] in train.keys():
train[r[2]].append(r[:2])
else:
train[r[2]] = [r[:2]]
if debug:
valid=train
return train, valid, test
def loader_pdb(item,params):
pdbid,chid = item[0].split('_')
PREFIX = "%s/pdb/%s/%s"%(params['DIR'],pdbid[1:3],pdbid)
# load metadata
if not os.path.isfile(PREFIX+".pt"):
return {'seq': np.zeros(5)}
meta = torch.load(PREFIX+".pt")
asmb_ids = meta['asmb_ids']
asmb_chains = meta['asmb_chains']
chids = np.array(meta['chains'])
# find candidate assemblies which contain chid chain
asmb_candidates = set([a for a,b in zip(asmb_ids,asmb_chains)
if chid in b.split(',')])
# if the chains is missing is missing from all the assemblies
# then return this chain alone
if len(asmb_candidates)<1:
chain = torch.load("%s_%s.pt"%(PREFIX,chid))
L = len(chain['seq'])
return {'seq' : chain['seq'],
'xyz' : chain['xyz'],
'idx' : torch.zeros(L).int(),
'masked' : torch.Tensor([0]).int(),
'label' : item[0]}
# randomly pick one assembly from candidates
asmb_i = random.sample(list(asmb_candidates), 1)
# indices of selected transforms
idx = np.where(np.array(asmb_ids)==asmb_i)[0]
# load relevant chains
chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))
for i in idx for c in asmb_chains[i]
if c in meta['chains']}
# generate assembly
asmb = {}
for k in idx:
# pick k-th xform
xform = meta['asmb_xform%d'%k]
u = xform[:,:3,:3]
r = xform[:,:3,3]
# select chains which k-th xform should be applied to
s1 = set(meta['chains'])
s2 = set(asmb_chains[k].split(','))
chains_k = s1&s2
# transform selected chains
for c in chains_k:
try:
xyz = chains[c]['xyz']
xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:,None,None,:]
asmb.update({(c,k,i):xyz_i for i,xyz_i in enumerate(xyz_ru)})
except KeyError:
return {'seq': np.zeros(5)}
# select chains which share considerable similarity to chid
seqid = meta['tm'][chids==chid][0,:,1]
homo = set([ch_j for seqid_j,ch_j in zip(seqid,chids)
if seqid_j>params['HOMO']])
# stack all chains in the assembly together
seq,xyz,idx,masked = "",[],[],[]
seq_list = []
for counter,(k,v) in enumerate(asmb.items()):
seq += chains[k[0]]['seq']
seq_list.append(chains[k[0]]['seq'])
xyz.append(v)
idx.append(torch.full((v.shape[0],),counter))
if k[0] in homo:
masked.append(counter)
return {'seq' : seq,
'xyz' : torch.cat(xyz,dim=0),
'idx' : torch.cat(idx,dim=0),
'masked' : torch.Tensor(masked).int(),
'label' : item[0]}
def get_pdbs(data, max_length=10000, num_units=1000000):
init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
extra_alphabet = [str(item) for item in list(np.arange(300))]
chain_alphabet = init_alphabet + extra_alphabet
c = 0
c1 = 0
data = {k:v for k,v in data.items()}
c1 += 1
if 'label' in list(data):
my_dict = {}
s = 0
concat_seq = ''
concat_N = []
concat_CA = []
concat_C = []
concat_O = []
concat_mask = []
coords_dict = {}
mask_list = []
visible_list = []
if len(list(np.unique(data['idx']))) < 352:
for idx in list(np.unique(data['idx'])):
letter = chain_alphabet[idx]
res = np.argwhere(data['idx']==idx)
initial_sequence= "".join(list(np.array(list(data['seq']))[res][0,]))
if initial_sequence[-6:] == "HHHHHH":
res = res[:,:-6]
if initial_sequence[0:6] == "HHHHHH":
res = res[:,6:]
if initial_sequence[-7:-1] == "HHHHHH":
res = res[:,:-7]
if initial_sequence[-8:-2] == "HHHHHH":
res = res[:,:-8]
if initial_sequence[-9:-3] == "HHHHHH":
res = res[:,:-9]
if initial_sequence[-10:-4] == "HHHHHH":
res = res[:,:-10]
if initial_sequence[1:7] == "HHHHHH":
res = res[:,7:]
if initial_sequence[2:8] == "HHHHHH":
res = res[:,8:]
if initial_sequence[3:9] == "HHHHHH":
res = res[:,9:]
if initial_sequence[4:10] == "HHHHHH":
res = res[:,10:]
if res.shape[1] < 4:
pass
else:
my_dict['seq_chain_'+letter]= "".join(list(np.array(list(data['seq']))[res][0,]))
concat_seq += my_dict['seq_chain_'+letter]
if idx in data['masked']:
mask_list.append(letter)
else:
visible_list.append(letter)
coords_dict_chain = {}
all_atoms = np.array(data['xyz'][res,])[0,] #[L, 14, 3]
coords_dict_chain['N_chain_'+letter]=all_atoms[:,0,:].tolist()
coords_dict_chain['CA_chain_'+letter]=all_atoms[:,1,:].tolist()
coords_dict_chain['C_chain_'+letter]=all_atoms[:,2,:].tolist()
coords_dict_chain['O_chain_'+letter]=all_atoms[:,3,:].tolist()
my_dict['coords_chain_'+letter]=coords_dict_chain
my_dict['name']= data['label']
my_dict['masked_list']= mask_list
my_dict['visible_list']= visible_list
my_dict['num_of_chains'] = len(mask_list) + len(visible_list)
my_dict['seq'] = concat_seq
if len(concat_seq) <= max_length:
return my_dict
return None
def safe_iter(ID, split_dict, params, alphabet_set, max_length=1000):
sel_idx = np.random.randint(0, len(split_dict[ID]))
out = loader_pdb(split_dict[ID][sel_idx], params)
entry = get_pdbs(out)
if entry is None:
return None
seq = entry['seq']
bad_chars = set([s for s in seq]).difference(alphabet_set)
if len(bad_chars) != 0:
return None
if len(entry['seq']) > max_length:
return None
masked_chains = entry['masked_list']
visible_chains = entry['visible_list']
all_chains = masked_chains + visible_chains
visible_temp_dict = {}
masked_temp_dict = {}
for step, letter in enumerate(all_chains):
chain_seq = entry[f'seq_chain_{letter}']
if letter in visible_chains:
visible_temp_dict[letter] = chain_seq
elif letter in masked_chains:
masked_temp_dict[letter] = chain_seq
for km, vm in masked_temp_dict.items():
for kv, vv in visible_temp_dict.items():
if vm == vv:
if kv not in masked_chains:
masked_chains.append(kv)
if kv in visible_chains:
visible_chains.remove(kv)
all_chains = masked_chains + visible_chains
random.shuffle(all_chains)
x_chain_list = []
chain_mask_list = []
chain_seq_list = []
chain_encoding_list = []
c = 1
for step, letter in enumerate(all_chains):
if letter in visible_chains:
chain_seq = entry[f'seq_chain_{letter}']
chain_length = len(chain_seq)
chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary
chain_mask = np.zeros(chain_length) #0.0 for visible chains
x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_length,4,3]
x_chain_list.append(x_chain)
chain_mask_list.append(chain_mask)
chain_seq_list.append(chain_seq)
chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
c+=1
elif letter in masked_chains:
chain_seq = entry[f'seq_chain_{letter}']
chain_length = len(chain_seq)
chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary
chain_mask = np.ones(chain_length) #0.0 for visible chains
x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
x_chain_list.append(x_chain)
chain_mask_list.append(chain_mask)
chain_seq_list.append(chain_seq)
chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
c+=1
chain_mask_all = torch.from_numpy(np.concatenate(chain_mask_list))
chain_encoding_all = torch.from_numpy(np.concatenate(chain_encoding_list))
x_chain_all = torch.from_numpy(np.concatenate(x_chain_list))
data = {
"title":entry['name'],
"seq":''.join(chain_seq_list), #len(seq)=n
"chain_mask":chain_mask_all,
"chain_encoding":chain_encoding_all,
"CA":x_chain_all[:,1], # [n,3]
"C":x_chain_all[:,2],
"O":x_chain_all[:,3],
"N":x_chain_all[:,0]} # [n,]
return data
class MPNNDataset(data.Dataset):
def __init__(self, data_path='/gaozhangyang/drug_dataset/proteinmpnn_data/pdb_2021aug02', rescut=3.5, split='train'):
self.data_path = data_path
self.rescut = rescut
self.params = {
"LIST" : f"{self.data_path}/list.csv",
"VAL" : f"{self.data_path}/valid_clusters.txt",
"TEST" : f"{self.data_path}/test_clusters.txt",
"DIR" : f"{self.data_path}",
"DATCUT" : "2030-Jan-01",
"RESCUT" : self.rescut, #resolution cutoff for PDBs
"HOMO" : 0.70 #min seq.id. to detect homo chains
}
if not os.path.exists("/gaozhangyang/experiments/OpenCPD/data/mpnn_data/split.pt"):
train, valid, test = build_training_clusters(self.params, False)
split = {"train": train, "valid":valid, "test":test}
torch.save(split, "/gaozhangyang/experiments/OpenCPD/data/mpnn_data/split.pt")
else:
split = torch.load("/gaozhangyang/experiments/OpenCPD/data/mpnn_data/split.pt")
self.split_dict = split[mode]
alphabet='ACDEFGHIKLMNPQRSTVWYX'
self.alphabet_set = set([a for a in alphabet])
self.IDs = list(self.split_dict.keys())
# self.data = self.preprocess()
def cache_split(self,):
train, valid, test = build_training_clusters(self.params, False)
return {"train": train, "valid":valid, "test":test}
@classmethod
def safe_iter(self, ID, split_dict, params, alphabet_set, max_length=1000):
# sel_idx = np.random.randint(0, len(split_dict[ID]))
sel_idx = 0
out = loader_pdb(split_dict[ID][sel_idx], params)
entry = get_pdbs(out)
if entry is None:
return None
seq = entry['seq']
bad_chars = set([s for s in seq]).difference(alphabet_set)
if len(bad_chars) != 0:
return None
if len(entry['seq']) > max_length:
return None
masked_chains = entry['masked_list']
visible_chains = entry['visible_list']
all_chains = masked_chains + visible_chains
visible_temp_dict = {}
masked_temp_dict = {}
for step, letter in enumerate(all_chains):
chain_seq = entry[f'seq_chain_{letter}']
if letter in visible_chains:
visible_temp_dict[letter] = chain_seq
elif letter in masked_chains:
masked_temp_dict[letter] = chain_seq
for km, vm in masked_temp_dict.items():
for kv, vv in visible_temp_dict.items():
if vm == vv:
if kv not in masked_chains:
masked_chains.append(kv)
if kv in visible_chains:
visible_chains.remove(kv)
all_chains = masked_chains + visible_chains
random.shuffle(all_chains)
x_chain_list = []
chain_mask_list = []
chain_seq_list = []
chain_encoding_list = []
c = 1
for step, letter in enumerate(all_chains):
if letter in visible_chains:
chain_seq = entry[f'seq_chain_{letter}']
chain_length = len(chain_seq)
chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary
chain_mask = np.zeros(chain_length) #0.0 for visible chains
x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_length,4,3]
x_chain_list.append(x_chain)
chain_mask_list.append(chain_mask)
chain_seq_list.append(chain_seq)
chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
c+=1
elif letter in masked_chains:
chain_seq = entry[f'seq_chain_{letter}']
chain_length = len(chain_seq)
chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary
chain_mask = np.ones(chain_length) #0.0 for visible chains
x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
x_chain_list.append(x_chain)
chain_mask_list.append(chain_mask)
chain_seq_list.append(chain_seq)
chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
c+=1
chain_mask_all = np.concatenate(chain_mask_list)
chain_encoding_all = np.concatenate(chain_encoding_list)
x_chain_all = np.concatenate(x_chain_list)
data = {
"title":entry['name']+str(int(chain_mask_all.sum())),
"seq":''.join(chain_seq_list), #len(seq)=n
"chain_mask":chain_mask_all,
"chain_encoding":chain_encoding_all,
"CA":x_chain_all[:,1], # [n,3]
"C":x_chain_all[:,2],
"O":x_chain_all[:,3],
"N":x_chain_all[:,0]} # [n,]
return data
def preprocess(self):
data = pmap_multi(self.safe_iter, [(ID,) for ID in self.IDs], split_dict=self.split_dict, params=self.params, alphabet_set=self.alphabet_set)
return data
def __len__(self):
# return len(self.data)
return len(self.IDs)
def __getitem__(self, index):
ID = self.IDs[index]
out = self.safe_iter(ID, split_dict=self.split_dict, params=self.params, alphabet_set=self.alphabet_set)
return out
def collate_fn(batch):
return batch
if __name__ == "__main__":
MPNNDataset = MPNNDataset()
loader = DataLoaderX(local_rank=0, dataset = MPNNDataset, collate_fn=collate_fn, batch_size=4)
# loader = DataLoader(dataset = MPNNDataset, collate_fn=collate_fn, batch_size=4, prefetch_factor=4, num_workers=4)
for batch in tqdm(loader):
for one in batch:
if one is not None:
for key, val in one.items():
if type(val) == torch.Tensor:
result = val.to('cuda:0')
time.sleep(2)
print()