import os import json import numpy as np from tqdm import tqdm import pandas as pd import torch.utils.data as data from Bio.PDB import PDBParser import torch import random import csv from dateutil import parser from .fast_dataloader import DataLoaderX from torch.utils.data import DataLoader import time from joblib import Parallel, delayed, cpu_count from tqdm import tqdm def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, desc=None, **kwargs): """ Parallel map using joblib. Parameters ---------- pickleable_fn : callable Function to map over data. data : iterable Data over which we want to parallelize the function call. n_jobs : int, optional The maximum number of concurrently running jobs. By default, it is one less than the number of CPUs. verbose: int, optional The verbosity level. If nonzero, the function prints the progress messages. The frequency of the messages increases with the verbosity level. If above 10, it reports all iterations. If above 50, it sends the output to stdout. kwargs Additional arguments for :attr:`pickleable_fn`. Returns ------- list The i-th element of the list corresponds to the output of applying :attr:`pickleable_fn` to :attr:`data[i]`. """ if n_jobs is None: n_jobs = cpu_count() - 1 results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)( delayed(pickleable_fn)(*d, **kwargs) for i, d in tqdm(enumerate(data),desc=desc) ) def build_training_clusters(params, debug): val_ids = set([int(l) for l in open(params['VAL']).readlines()]) test_ids = set([int(l) for l in open(params['TEST']).readlines()]) if debug: val_ids = [] test_ids = [] # read & clean list.csv with open(params['LIST'], 'r') as f: reader = csv.reader(f) next(reader) rows = [[r[0],r[3],int(r[4])] for r in reader if float(r[2])<=params['RESCUT'] and parser.parse(r[1])<=parser.parse(params['DATCUT'])] # compile training and validation sets train = {} valid = {} test = {} if debug: rows = rows[:20] for r in rows: if r[2] in val_ids: if r[2] in valid.keys(): valid[r[2]].append(r[:2]) else: valid[r[2]] = [r[:2]] elif r[2] in test_ids: if r[2] in test.keys(): test[r[2]].append(r[:2]) else: test[r[2]] = [r[:2]] else: if r[2] in train.keys(): train[r[2]].append(r[:2]) else: train[r[2]] = [r[:2]] if debug: valid=train return train, valid, test def loader_pdb(item,params): pdbid,chid = item[0].split('_') PREFIX = "%s/pdb/%s/%s"%(params['DIR'],pdbid[1:3],pdbid) # load metadata if not os.path.isfile(PREFIX+".pt"): return {'seq': np.zeros(5)} meta = torch.load(PREFIX+".pt") asmb_ids = meta['asmb_ids'] asmb_chains = meta['asmb_chains'] chids = np.array(meta['chains']) # find candidate assemblies which contain chid chain asmb_candidates = set([a for a,b in zip(asmb_ids,asmb_chains) if chid in b.split(',')]) # if the chains is missing is missing from all the assemblies # then return this chain alone if len(asmb_candidates)<1: chain = torch.load("%s_%s.pt"%(PREFIX,chid)) L = len(chain['seq']) return {'seq' : chain['seq'], 'xyz' : chain['xyz'], 'idx' : torch.zeros(L).int(), 'masked' : torch.Tensor([0]).int(), 'label' : item[0]} # randomly pick one assembly from candidates asmb_i = random.sample(list(asmb_candidates), 1) # indices of selected transforms idx = np.where(np.array(asmb_ids)==asmb_i)[0] # load relevant chains chains = {c:torch.load("%s_%s.pt"%(PREFIX,c)) for i in idx for c in asmb_chains[i] if c in meta['chains']} # generate assembly asmb = {} for k in idx: # pick k-th xform xform = meta['asmb_xform%d'%k] u = xform[:,:3,:3] r = xform[:,:3,3] # select chains which k-th xform should be applied to s1 = set(meta['chains']) s2 = set(asmb_chains[k].split(',')) chains_k = s1&s2 # transform selected chains for c in chains_k: try: xyz = chains[c]['xyz'] xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:,None,None,:] asmb.update({(c,k,i):xyz_i for i,xyz_i in enumerate(xyz_ru)}) except KeyError: return {'seq': np.zeros(5)} # select chains which share considerable similarity to chid seqid = meta['tm'][chids==chid][0,:,1] homo = set([ch_j for seqid_j,ch_j in zip(seqid,chids) if seqid_j>params['HOMO']]) # stack all chains in the assembly together seq,xyz,idx,masked = "",[],[],[] seq_list = [] for counter,(k,v) in enumerate(asmb.items()): seq += chains[k[0]]['seq'] seq_list.append(chains[k[0]]['seq']) xyz.append(v) idx.append(torch.full((v.shape[0],),counter)) if k[0] in homo: masked.append(counter) return {'seq' : seq, 'xyz' : torch.cat(xyz,dim=0), 'idx' : torch.cat(idx,dim=0), 'masked' : torch.Tensor(masked).int(), 'label' : item[0]} def get_pdbs(data, max_length=10000, num_units=1000000): init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z'] extra_alphabet = [str(item) for item in list(np.arange(300))] chain_alphabet = init_alphabet + extra_alphabet c = 0 c1 = 0 data = {k:v for k,v in data.items()} c1 += 1 if 'label' in list(data): my_dict = {} s = 0 concat_seq = '' concat_N = [] concat_CA = [] concat_C = [] concat_O = [] concat_mask = [] coords_dict = {} mask_list = [] visible_list = [] if len(list(np.unique(data['idx']))) < 352: for idx in list(np.unique(data['idx'])): letter = chain_alphabet[idx] res = np.argwhere(data['idx']==idx) initial_sequence= "".join(list(np.array(list(data['seq']))[res][0,])) if initial_sequence[-6:] == "HHHHHH": res = res[:,:-6] if initial_sequence[0:6] == "HHHHHH": res = res[:,6:] if initial_sequence[-7:-1] == "HHHHHH": res = res[:,:-7] if initial_sequence[-8:-2] == "HHHHHH": res = res[:,:-8] if initial_sequence[-9:-3] == "HHHHHH": res = res[:,:-9] if initial_sequence[-10:-4] == "HHHHHH": res = res[:,:-10] if initial_sequence[1:7] == "HHHHHH": res = res[:,7:] if initial_sequence[2:8] == "HHHHHH": res = res[:,8:] if initial_sequence[3:9] == "HHHHHH": res = res[:,9:] if initial_sequence[4:10] == "HHHHHH": res = res[:,10:] if res.shape[1] < 4: pass else: my_dict['seq_chain_'+letter]= "".join(list(np.array(list(data['seq']))[res][0,])) concat_seq += my_dict['seq_chain_'+letter] if idx in data['masked']: mask_list.append(letter) else: visible_list.append(letter) coords_dict_chain = {} all_atoms = np.array(data['xyz'][res,])[0,] #[L, 14, 3] coords_dict_chain['N_chain_'+letter]=all_atoms[:,0,:].tolist() coords_dict_chain['CA_chain_'+letter]=all_atoms[:,1,:].tolist() coords_dict_chain['C_chain_'+letter]=all_atoms[:,2,:].tolist() coords_dict_chain['O_chain_'+letter]=all_atoms[:,3,:].tolist() my_dict['coords_chain_'+letter]=coords_dict_chain my_dict['name']= data['label'] my_dict['masked_list']= mask_list my_dict['visible_list']= visible_list my_dict['num_of_chains'] = len(mask_list) + len(visible_list) my_dict['seq'] = concat_seq if len(concat_seq) <= max_length: return my_dict return None def safe_iter(ID, split_dict, params, alphabet_set, max_length=1000): sel_idx = np.random.randint(0, len(split_dict[ID])) out = loader_pdb(split_dict[ID][sel_idx], params) entry = get_pdbs(out) if entry is None: return None seq = entry['seq'] bad_chars = set([s for s in seq]).difference(alphabet_set) if len(bad_chars) != 0: return None if len(entry['seq']) > max_length: return None masked_chains = entry['masked_list'] visible_chains = entry['visible_list'] all_chains = masked_chains + visible_chains visible_temp_dict = {} masked_temp_dict = {} for step, letter in enumerate(all_chains): chain_seq = entry[f'seq_chain_{letter}'] if letter in visible_chains: visible_temp_dict[letter] = chain_seq elif letter in masked_chains: masked_temp_dict[letter] = chain_seq for km, vm in masked_temp_dict.items(): for kv, vv in visible_temp_dict.items(): if vm == vv: if kv not in masked_chains: masked_chains.append(kv) if kv in visible_chains: visible_chains.remove(kv) all_chains = masked_chains + visible_chains random.shuffle(all_chains) x_chain_list = [] chain_mask_list = [] chain_seq_list = [] chain_encoding_list = [] c = 1 for step, letter in enumerate(all_chains): if letter in visible_chains: chain_seq = entry[f'seq_chain_{letter}'] chain_length = len(chain_seq) chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary chain_mask = np.zeros(chain_length) #0.0 for visible chains x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_length,4,3] x_chain_list.append(x_chain) chain_mask_list.append(chain_mask) chain_seq_list.append(chain_seq) chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0])) c+=1 elif letter in masked_chains: chain_seq = entry[f'seq_chain_{letter}'] chain_length = len(chain_seq) chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary chain_mask = np.ones(chain_length) #0.0 for visible chains x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3] x_chain_list.append(x_chain) chain_mask_list.append(chain_mask) chain_seq_list.append(chain_seq) chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0])) c+=1 chain_mask_all = torch.from_numpy(np.concatenate(chain_mask_list)) chain_encoding_all = torch.from_numpy(np.concatenate(chain_encoding_list)) x_chain_all = torch.from_numpy(np.concatenate(x_chain_list)) data = { "title":entry['name'], "seq":''.join(chain_seq_list), #len(seq)=n "chain_mask":chain_mask_all, "chain_encoding":chain_encoding_all, "CA":x_chain_all[:,1], # [n,3] "C":x_chain_all[:,2], "O":x_chain_all[:,3], "N":x_chain_all[:,0]} # [n,] return data class MPNNDataset(data.Dataset): def __init__(self, data_path='/gaozhangyang/drug_dataset/proteinmpnn_data/pdb_2021aug02', rescut=3.5, split='train'): self.data_path = data_path self.rescut = rescut self.params = { "LIST" : f"{self.data_path}/list.csv", "VAL" : f"{self.data_path}/valid_clusters.txt", "TEST" : f"{self.data_path}/test_clusters.txt", "DIR" : f"{self.data_path}", "DATCUT" : "2030-Jan-01", "RESCUT" : self.rescut, #resolution cutoff for PDBs "HOMO" : 0.70 #min seq.id. to detect homo chains } if not os.path.exists("/gaozhangyang/experiments/OpenCPD/data/mpnn_data/split.pt"): train, valid, test = build_training_clusters(self.params, False) split = {"train": train, "valid":valid, "test":test} torch.save(split, "/gaozhangyang/experiments/OpenCPD/data/mpnn_data/split.pt") else: split = torch.load("/gaozhangyang/experiments/OpenCPD/data/mpnn_data/split.pt") self.split_dict = split[mode] alphabet='ACDEFGHIKLMNPQRSTVWYX' self.alphabet_set = set([a for a in alphabet]) self.IDs = list(self.split_dict.keys()) # self.data = self.preprocess() def cache_split(self,): train, valid, test = build_training_clusters(self.params, False) return {"train": train, "valid":valid, "test":test} @classmethod def safe_iter(self, ID, split_dict, params, alphabet_set, max_length=1000): # sel_idx = np.random.randint(0, len(split_dict[ID])) sel_idx = 0 out = loader_pdb(split_dict[ID][sel_idx], params) entry = get_pdbs(out) if entry is None: return None seq = entry['seq'] bad_chars = set([s for s in seq]).difference(alphabet_set) if len(bad_chars) != 0: return None if len(entry['seq']) > max_length: return None masked_chains = entry['masked_list'] visible_chains = entry['visible_list'] all_chains = masked_chains + visible_chains visible_temp_dict = {} masked_temp_dict = {} for step, letter in enumerate(all_chains): chain_seq = entry[f'seq_chain_{letter}'] if letter in visible_chains: visible_temp_dict[letter] = chain_seq elif letter in masked_chains: masked_temp_dict[letter] = chain_seq for km, vm in masked_temp_dict.items(): for kv, vv in visible_temp_dict.items(): if vm == vv: if kv not in masked_chains: masked_chains.append(kv) if kv in visible_chains: visible_chains.remove(kv) all_chains = masked_chains + visible_chains random.shuffle(all_chains) x_chain_list = [] chain_mask_list = [] chain_seq_list = [] chain_encoding_list = [] c = 1 for step, letter in enumerate(all_chains): if letter in visible_chains: chain_seq = entry[f'seq_chain_{letter}'] chain_length = len(chain_seq) chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary chain_mask = np.zeros(chain_length) #0.0 for visible chains x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_length,4,3] x_chain_list.append(x_chain) chain_mask_list.append(chain_mask) chain_seq_list.append(chain_seq) chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0])) c+=1 elif letter in masked_chains: chain_seq = entry[f'seq_chain_{letter}'] chain_length = len(chain_seq) chain_coords = entry[f'coords_chain_{letter}'] #this is a dictionary chain_mask = np.ones(chain_length) #0.0 for visible chains x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3] x_chain_list.append(x_chain) chain_mask_list.append(chain_mask) chain_seq_list.append(chain_seq) chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0])) c+=1 chain_mask_all = np.concatenate(chain_mask_list) chain_encoding_all = np.concatenate(chain_encoding_list) x_chain_all = np.concatenate(x_chain_list) data = { "title":entry['name']+str(int(chain_mask_all.sum())), "seq":''.join(chain_seq_list), #len(seq)=n "chain_mask":chain_mask_all, "chain_encoding":chain_encoding_all, "CA":x_chain_all[:,1], # [n,3] "C":x_chain_all[:,2], "O":x_chain_all[:,3], "N":x_chain_all[:,0]} # [n,] return data def preprocess(self): data = pmap_multi(self.safe_iter, [(ID,) for ID in self.IDs], split_dict=self.split_dict, params=self.params, alphabet_set=self.alphabet_set) return data def __len__(self): # return len(self.data) return len(self.IDs) def __getitem__(self, index): ID = self.IDs[index] out = self.safe_iter(ID, split_dict=self.split_dict, params=self.params, alphabet_set=self.alphabet_set) return out def collate_fn(batch): return batch if __name__ == "__main__": MPNNDataset = MPNNDataset() loader = DataLoaderX(local_rank=0, dataset = MPNNDataset, collate_fn=collate_fn, batch_size=4) # loader = DataLoader(dataset = MPNNDataset, collate_fn=collate_fn, batch_size=4, prefetch_factor=4, num_workers=4) for batch in tqdm(loader): for one in batch: if one is not None: for key, val in one.items(): if type(val) == torch.Tensor: result = val.to('cuda:0') time.sleep(2) print()