Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import DataLoader | |
| import csv | |
| from dateutil import parser | |
| import numpy as np | |
| import time | |
| import random | |
| import os | |
| class StructureDataset(): | |
| def __init__(self, pdb_dict_list, verbose=True, truncate=None, max_length=100, | |
| alphabet='ACDEFGHIKLMNPQRSTVWYX'): | |
| alphabet_set = set([a for a in alphabet]) | |
| discard_count = { | |
| 'bad_chars': 0, | |
| 'too_long': 0, | |
| 'bad_seq_length': 0 | |
| } | |
| self.data = [] | |
| start = time.time() | |
| for i, entry in enumerate(pdb_dict_list): | |
| seq = entry['seq'] | |
| name = entry['name'] | |
| bad_chars = set([s for s in seq]).difference(alphabet_set) | |
| if len(bad_chars) == 0: | |
| if len(entry['seq']) <= max_length: | |
| self.data.append(entry) | |
| else: | |
| discard_count['too_long'] += 1 | |
| else: | |
| #print(name, bad_chars, entry['seq']) | |
| discard_count['bad_chars'] += 1 | |
| # Truncate early | |
| if truncate is not None and len(self.data) == truncate: | |
| return | |
| if verbose and (i + 1) % 1000 == 0: | |
| elapsed = time.time() - start | |
| #print('{} entries ({} loaded) in {:.1f} s'.format(len(self.data), i+1, elapsed)) | |
| #print('Discarded', discard_count) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] | |
| class StructureLoader(): | |
| def __init__(self, dataset, batch_size=100, shuffle=True, | |
| collate_fn=lambda x:x, drop_last=False): | |
| self.dataset = dataset | |
| self.size = len(dataset) | |
| self.lengths = [len(dataset[i]['seq']) for i in range(self.size)] | |
| self.batch_size = batch_size | |
| sorted_ix = np.argsort(self.lengths) | |
| # Cluster into batches of similar sizes | |
| clusters, batch = [], [] | |
| batch_max = 0 | |
| for ix in sorted_ix: | |
| size = self.lengths[ix] | |
| if size * (len(batch) + 1) <= self.batch_size: | |
| batch.append(ix) | |
| batch_max = size | |
| else: | |
| clusters.append(batch) | |
| batch, batch_max = [], 0 | |
| if len(batch) > 0: | |
| clusters.append(batch) | |
| self.clusters = clusters | |
| def __len__(self): | |
| return len(self.clusters) | |
| def __iter__(self): | |
| np.random.shuffle(self.clusters) | |
| for b_idx in self.clusters: | |
| batch = [self.dataset[i] for i in b_idx] | |
| yield batch | |
| def worker_init_fn(worker_id): | |
| np.random.seed() | |
| class NoamOpt: | |
| "Optim wrapper that implements rate." | |
| def __init__(self, model_size, factor, warmup, optimizer, step): | |
| self.optimizer = optimizer | |
| self._step = step | |
| self.warmup = warmup | |
| self.factor = factor | |
| self.model_size = model_size | |
| self._rate = 0 | |
| def param_groups(self): | |
| """Return param_groups.""" | |
| return self.optimizer.param_groups | |
| def step(self): | |
| "Update parameters and rate" | |
| self._step += 1 | |
| rate = self.rate() | |
| for p in self.optimizer.param_groups: | |
| p['lr'] = rate | |
| self._rate = rate | |
| self.optimizer.step() | |
| def rate(self, step = None): | |
| "Implement `lrate` above" | |
| if step is None: | |
| step = self._step | |
| return self.factor * \ | |
| (self.model_size ** (-0.5) * | |
| min(step ** (-0.5), step * self.warmup ** (-1.5))) | |
| def zero_grad(self): | |
| self.optimizer.zero_grad() | |
| def get_std_opt(parameters, d_model, step): | |
| return NoamOpt( | |
| d_model, 2, 4000, torch.optim.Adam(parameters, lr=0, betas=(0.9, 0.98), eps=1e-9), step | |
| ) | |
| def get_pdbs(data_loader, repeat=1, max_length=10000, num_units=1000000): | |
| init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z'] | |
| extra_alphabet = [str(item) for item in list(np.arange(300))] | |
| chain_alphabet = init_alphabet + extra_alphabet | |
| c = 0 | |
| c1 = 0 | |
| pdb_dict_list = [] | |
| t0 = time.time() | |
| for _ in range(repeat): | |
| for step,t in enumerate(data_loader): | |
| t = {k:v[0] for k,v in t.items()} | |
| c1 += 1 | |
| if 'label' in list(t): | |
| my_dict = {} | |
| s = 0 | |
| concat_seq = '' | |
| concat_N = [] | |
| concat_CA = [] | |
| concat_C = [] | |
| concat_O = [] | |
| concat_mask = [] | |
| coords_dict = {} | |
| mask_list = [] | |
| visible_list = [] | |
| if len(list(np.unique(t['idx']))) < 352: | |
| for idx in list(np.unique(t['idx'])): | |
| letter = chain_alphabet[idx] | |
| res = np.argwhere(t['idx']==idx) | |
| initial_sequence= "".join(list(np.array(list(t['seq']))[res][0,])) | |
| if initial_sequence[-6:] == "HHHHHH": | |
| res = res[:,:-6] | |
| if initial_sequence[0:6] == "HHHHHH": | |
| res = res[:,6:] | |
| if initial_sequence[-7:-1] == "HHHHHH": | |
| res = res[:,:-7] | |
| if initial_sequence[-8:-2] == "HHHHHH": | |
| res = res[:,:-8] | |
| if initial_sequence[-9:-3] == "HHHHHH": | |
| res = res[:,:-9] | |
| if initial_sequence[-10:-4] == "HHHHHH": | |
| res = res[:,:-10] | |
| if initial_sequence[1:7] == "HHHHHH": | |
| res = res[:,7:] | |
| if initial_sequence[2:8] == "HHHHHH": | |
| res = res[:,8:] | |
| if initial_sequence[3:9] == "HHHHHH": | |
| res = res[:,9:] | |
| if initial_sequence[4:10] == "HHHHHH": | |
| res = res[:,10:] | |
| if res.shape[1] < 4: | |
| pass | |
| else: | |
| my_dict['seq_chain_'+letter]= "".join(list(np.array(list(t['seq']))[res][0,])) | |
| concat_seq += my_dict['seq_chain_'+letter] | |
| if idx in t['masked']: | |
| mask_list.append(letter) | |
| else: | |
| visible_list.append(letter) | |
| coords_dict_chain = {} | |
| all_atoms = np.array(t['xyz'][res,])[0,] #[L, 14, 3] | |
| coords_dict_chain['N_chain_'+letter]=all_atoms[:,0,:].tolist() | |
| coords_dict_chain['CA_chain_'+letter]=all_atoms[:,1,:].tolist() | |
| coords_dict_chain['C_chain_'+letter]=all_atoms[:,2,:].tolist() | |
| coords_dict_chain['O_chain_'+letter]=all_atoms[:,3,:].tolist() | |
| my_dict['coords_chain_'+letter]=coords_dict_chain | |
| my_dict['name']= t['label'] | |
| my_dict['masked_list']= mask_list | |
| my_dict['visible_list']= visible_list | |
| my_dict['num_of_chains'] = len(mask_list) + len(visible_list) | |
| my_dict['seq'] = concat_seq | |
| if len(concat_seq) <= max_length: | |
| pdb_dict_list.append(my_dict) | |
| if len(pdb_dict_list) >= num_units: | |
| break | |
| return pdb_dict_list | |
| class PDB_dataset(torch.utils.data.Dataset): | |
| def __init__(self, IDs, loader, train_dict, params): | |
| self.IDs = IDs | |
| self.train_dict = train_dict | |
| self.loader = loader | |
| self.params = params | |
| def __len__(self): | |
| return len(self.IDs) | |
| def __getitem__(self, index): | |
| ID = self.IDs[index] | |
| sel_idx = np.random.randint(0, len(self.train_dict[ID])) | |
| out = self.loader(self.train_dict[ID][sel_idx], self.params) | |
| return out | |
| def loader_pdb(item,params): | |
| pdbid,chid = item[0].split('_') | |
| PREFIX = "%s/pdb/%s/%s"%(params['DIR'],pdbid[1:3],pdbid) | |
| # load metadata | |
| if not os.path.isfile(PREFIX+".pt"): | |
| return {'seq': np.zeros(5)} | |
| meta = torch.load(PREFIX+".pt") | |
| asmb_ids = meta['asmb_ids'] | |
| asmb_chains = meta['asmb_chains'] | |
| chids = np.array(meta['chains']) | |
| # find candidate assemblies which contain chid chain | |
| asmb_candidates = set([a for a,b in zip(asmb_ids,asmb_chains) | |
| if chid in b.split(',')]) | |
| # if the chains is missing is missing from all the assemblies | |
| # then return this chain alone | |
| if len(asmb_candidates)<1: | |
| chain = torch.load("%s_%s.pt"%(PREFIX,chid)) | |
| L = len(chain['seq']) | |
| return {'seq' : chain['seq'], | |
| 'xyz' : chain['xyz'], | |
| 'idx' : torch.zeros(L).int(), | |
| 'masked' : torch.Tensor([0]).int(), | |
| 'label' : item[0]} | |
| # randomly pick one assembly from candidates | |
| asmb_i = random.sample(list(asmb_candidates), 1) | |
| # indices of selected transforms | |
| idx = np.where(np.array(asmb_ids)==asmb_i)[0] | |
| # load relevant chains | |
| chains = {c:torch.load("%s_%s.pt"%(PREFIX,c)) | |
| for i in idx for c in asmb_chains[i] | |
| if c in meta['chains']} | |
| # generate assembly | |
| asmb = {} | |
| for k in idx: | |
| # pick k-th xform | |
| xform = meta['asmb_xform%d'%k] | |
| u = xform[:,:3,:3] | |
| r = xform[:,:3,3] | |
| # select chains which k-th xform should be applied to | |
| s1 = set(meta['chains']) | |
| s2 = set(asmb_chains[k].split(',')) | |
| chains_k = s1&s2 | |
| # transform selected chains | |
| for c in chains_k: | |
| try: | |
| xyz = chains[c]['xyz'] | |
| xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:,None,None,:] | |
| asmb.update({(c,k,i):xyz_i for i,xyz_i in enumerate(xyz_ru)}) | |
| except KeyError: | |
| return {'seq': np.zeros(5)} | |
| # select chains which share considerable similarity to chid | |
| seqid = meta['tm'][chids==chid][0,:,1] | |
| homo = set([ch_j for seqid_j,ch_j in zip(seqid,chids) | |
| if seqid_j>params['HOMO']]) | |
| # stack all chains in the assembly together | |
| seq,xyz,idx,masked = "",[],[],[] | |
| seq_list = [] | |
| for counter,(k,v) in enumerate(asmb.items()): | |
| seq += chains[k[0]]['seq'] | |
| seq_list.append(chains[k[0]]['seq']) | |
| xyz.append(v) | |
| idx.append(torch.full((v.shape[0],),counter)) | |
| if k[0] in homo: | |
| masked.append(counter) | |
| return {'seq' : seq, | |
| 'xyz' : torch.cat(xyz,dim=0), | |
| 'idx' : torch.cat(idx,dim=0), | |
| 'masked' : torch.Tensor(masked).int(), | |
| 'label' : item[0]} | |
| def build_training_clusters(params, debug): | |
| val_ids = set([int(l) for l in open(params['VAL']).readlines()]) | |
| test_ids = set([int(l) for l in open(params['TEST']).readlines()]) | |
| if debug: | |
| val_ids = [] | |
| test_ids = [] | |
| # read & clean list.csv | |
| with open(params['LIST'], 'r') as f: | |
| reader = csv.reader(f) | |
| next(reader) | |
| rows = [[r[0],r[3],int(r[4])] for r in reader | |
| if float(r[2])<=params['RESCUT'] and | |
| parser.parse(r[1])<=parser.parse(params['DATCUT'])] | |
| # compile training and validation sets | |
| train = {} | |
| valid = {} | |
| test = {} | |
| if debug: | |
| rows = rows[:20] | |
| for r in rows: | |
| if r[2] in val_ids: | |
| if r[2] in valid.keys(): | |
| valid[r[2]].append(r[:2]) | |
| else: | |
| valid[r[2]] = [r[:2]] | |
| elif r[2] in test_ids: | |
| if r[2] in test.keys(): | |
| test[r[2]].append(r[:2]) | |
| else: | |
| test[r[2]] = [r[:2]] | |
| else: | |
| if r[2] in train.keys(): | |
| train[r[2]].append(r[:2]) | |
| else: | |
| train[r[2]] = [r[:2]] | |
| if debug: | |
| valid=train | |
| return train, valid, test | |