Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| import itertools | |
| import torch.nn.functional as F | |
| import math | |
| import torch_geometric | |
| # import torch_cluster | |
| from collections.abc import Mapping, Sequence | |
| from torch_geometric.data import Data, Batch | |
| from torch.utils.data.dataloader import default_collate | |
| from transformers import AutoTokenizer | |
| import pdb | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir='./cache_dir/') # mask token: 32 | |
| def _normalize(tensor, dim=-1): | |
| ''' | |
| Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. | |
| ''' | |
| return torch.nan_to_num( | |
| torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) | |
| def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'): | |
| ''' | |
| From https://github.com/jingraham/neurips19-graph-protein-design | |
| Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1. | |
| That is, if `D` has shape [...dims], then the returned tensor will have | |
| shape [...dims, D_count]. | |
| ''' | |
| D_mu = torch.linspace(D_min, D_max, D_count, device=device) | |
| D_mu = D_mu.view([1, -1]) | |
| D_sigma = (D_max - D_min) / D_count | |
| D_expand = torch.unsqueeze(D, -1) | |
| RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2) | |
| return RBF | |
| def shuffle_subset(n, p): | |
| n_shuffle = np.random.binomial(n, p) | |
| ix = np.arange(n) | |
| ix_subset = np.random.choice(ix, size=n_shuffle, replace=False) | |
| ix_subset_shuffled = np.copy(ix_subset) | |
| np.random.shuffle(ix_subset_shuffled) | |
| ix[ix_subset] = ix_subset_shuffled | |
| return ix | |
| def featurize_AF(batch, shuffle_fraction=0.): | |
| """ Pack and pad batch into torch tensors """ | |
| alphabet = 'ACDEFGHIKLMNPQRSTVWY' | |
| B = len(batch) | |
| lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) | |
| L_max = max([len(b['seq']) for b in batch]) | |
| X = np.zeros([B, L_max, 4, 3]) | |
| S = np.zeros([B, L_max], dtype=np.int32) | |
| score = np.zeros([B, L_max]) | |
| # Build the batch | |
| for i, b in enumerate(batch): | |
| x = np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1) # [#atom, 4, 3] | |
| l = len(b['seq']) | |
| x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) # [#atom, 4, 3] | |
| X[i,:,:,:] = x_pad | |
| # Convert to labels | |
| indices = np.asarray([alphabet.index(a) for a in b['seq']], dtype=np.int32) | |
| if shuffle_fraction > 0.: | |
| idx_shuffle = shuffle_subset(l, shuffle_fraction) | |
| S[i, :l] = indices[idx_shuffle] | |
| score[i,:l] = b['score'][idx_shuffle] | |
| else: | |
| S[i, :l] = indices | |
| score[i,:l] = b['score'] | |
| mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask | |
| numbers = np.sum(mask, axis=1).astype(np.int) | |
| S_new = np.zeros_like(S) | |
| score_new = np.zeros_like(score) | |
| X_new = np.zeros_like(X)+np.nan | |
| for i, n in enumerate(numbers): | |
| X_new[i,:n,::] = X[i][mask[i]==1] | |
| S_new[i,:n] = S[i][mask[i]==1] | |
| score_new[i,:n] = score[i][mask[i]==1] | |
| X = X_new | |
| S = S_new | |
| score = score_new | |
| isnan = np.isnan(X) | |
| mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) | |
| X[isnan] = 0. | |
| # Conversion | |
| S = torch.from_numpy(S).to(dtype=torch.long) | |
| score = torch.from_numpy(score).float() | |
| X = torch.from_numpy(X).to(dtype=torch.float32) | |
| mask = torch.from_numpy(mask).to(dtype=torch.float32) | |
| return X, S, score, mask, lengths | |
| def featurize_GTrans(batch): | |
| """ Pack and pad batch into torch tensors """ | |
| # alphabet = 'ACDEFGHIKLMNPQRSTVWYX' | |
| batch = [one for one in batch if one is not None] | |
| B = len(batch) | |
| if B==0: | |
| return None | |
| lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) | |
| L_max = max([len(b['seq']) for b in batch]) | |
| X = np.zeros([B, L_max, 4, 3]) | |
| S = np.zeros([B, L_max], dtype=np.int32) | |
| score = np.ones([B, L_max]) * 100.0 | |
| chain_mask = np.zeros([B, L_max])-1 # 1:需要被预测的掩码部分 0:可见部分 | |
| chain_encoding = np.zeros([B, L_max])-1 | |
| # Build the batch | |
| for i, b in enumerate(batch): | |
| x = np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1) # [#atom, 4, 3] | |
| l = len(b['seq']) | |
| x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) # [#atom, 4, 3] | |
| X[i,:,:,:] = x_pad | |
| # Convert to labels | |
| indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False)) | |
| # indices = np.asarray([alphabet.index(a) for a in b['seq']], dtype=np.int32) | |
| S[i, :l] = indices | |
| chain_mask[i,:l] = b['chain_mask'] | |
| chain_encoding[i,:l] = b['chain_encoding'] | |
| mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask | |
| numbers = np.sum(mask, axis=1).astype(np.int32) | |
| S_new = np.zeros_like(S) | |
| X_new = np.zeros_like(X)+np.nan | |
| for i, n in enumerate(numbers): | |
| X_new[i,:n,::] = X[i][mask[i]==1] | |
| S_new[i,:n] = S[i][mask[i]==1] | |
| X = X_new | |
| S = S_new | |
| isnan = np.isnan(X) | |
| mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) | |
| X[isnan] = 0. | |
| # Conversion | |
| S = torch.from_numpy(S).to(dtype=torch.long) | |
| score = torch.from_numpy(score).float() | |
| X = torch.from_numpy(X).to(dtype=torch.float32) | |
| mask = torch.from_numpy(mask).to(dtype=torch.float32) | |
| lengths = torch.from_numpy(lengths) | |
| chain_mask = torch.from_numpy(chain_mask) | |
| chain_encoding = torch.from_numpy(chain_encoding) | |
| return {"title": [b['title'] for b in batch], | |
| "X":X, | |
| "S":S, | |
| "score": score, | |
| "mask":mask, | |
| "lengths":lengths, | |
| "chain_mask":chain_mask, | |
| "chain_encoding":chain_encoding} | |
| class featurize_GVP: | |
| def __init__(self, num_positional_embeddings=16, top_k=30, num_rbf=16): | |
| self.top_k = top_k | |
| self.num_rbf = num_rbf | |
| self.num_positional_embeddings = num_positional_embeddings | |
| # self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9, | |
| # 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, | |
| # 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, | |
| # 'N': 2, 'Y': 18, 'M': 12} | |
| # self.num_to_letter = {v:k for k, v in self.letter_to_num.items()} | |
| def featurize(self, batch): | |
| data_all = [] | |
| for b in batch: | |
| if b is None: | |
| continue | |
| coords = torch.tensor(np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1)) | |
| seq = torch.tensor(np.array(tokenizer.encode(b['seq'], add_special_tokens=False))) | |
| mask = torch.isfinite(coords.sum(dim=(1,2))) | |
| coords[~mask] = np.inf | |
| X_ca = coords[:, 1].float() | |
| edge_index = torch_geometric.nn.knn_graph(X_ca, k=self.top_k)#torch_cluster.knn_graph(X_ca, k=self.top_k) | |
| pos_embeddings = self._positional_embeddings(edge_index) # [E, 16] | |
| E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]] # [E, 3] | |
| rbf = _rbf(E_vectors.norm(dim=-1), D_count=self.num_rbf) # [E, 16] | |
| dihedrals = self._dihedrals(coords) # [n,6] | |
| orientations = self._orientations(X_ca) # [n,2,3] | |
| sidechains = self._sidechains(coords) # [n,3] | |
| node_s = dihedrals.float() # [n,6] | |
| node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2).float() # [n, 3, 3] | |
| edge_s = torch.cat([rbf, pos_embeddings], dim=-1).float() # [E, 32] | |
| edge_v = _normalize(E_vectors).unsqueeze(-2).float() # [E, 1, 3] | |
| node_s, node_v, edge_s, edge_v = map(torch.nan_to_num,(node_s, node_v, edge_s, edge_v)) | |
| data = torch_geometric.data.Data(x=X_ca, seq=seq, | |
| node_s=node_s, node_v=node_v, | |
| edge_s=edge_s, edge_v=edge_v, | |
| edge_index=edge_index, mask=mask) | |
| data_all.append(data) | |
| return data_all | |
| def _positional_embeddings(self, edge_index, | |
| num_embeddings=None, | |
| period_range=[2, 1000]): | |
| # From https://github.com/jingraham/neurips19-graph-protein-design | |
| num_embeddings = num_embeddings or self.num_positional_embeddings | |
| d = edge_index[0] - edge_index[1] | |
| frequency = torch.exp( | |
| torch.arange(0, num_embeddings, 2, dtype=torch.float32) | |
| * -(np.log(10000.0) / num_embeddings) | |
| ) | |
| angles = d.unsqueeze(-1) * frequency | |
| E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) | |
| return E | |
| def _dihedrals(self, X, eps=1e-7): | |
| # From https://github.com/jingraham/neurips19-graph-protein-design | |
| X = torch.reshape(X[:, :3], [3*X.shape[0], 3]) | |
| dX = X[1:] - X[:-1] | |
| U = _normalize(dX, dim=-1) | |
| u_2 = U[:-2] | |
| u_1 = U[1:-1] | |
| u_0 = U[2:] | |
| # Backbone normals | |
| n_2 = _normalize(torch.cross(u_2, u_1), dim=-1) | |
| n_1 = _normalize(torch.cross(u_1, u_0), dim=-1) | |
| # Angle between normals | |
| cosD = torch.sum(n_2 * n_1, -1) | |
| cosD = torch.clamp(cosD, -1 + eps, 1 - eps) | |
| D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD) | |
| # This scheme will remove phi[0], psi[-1], omega[-1] | |
| D = F.pad(D, [1, 2]) | |
| D = torch.reshape(D, [-1, 3]) | |
| # Lift angle representations to the circle | |
| D_features = torch.cat([torch.cos(D), torch.sin(D)], 1) | |
| return D_features | |
| def _orientations(self, X): | |
| forward = _normalize(X[1:] - X[:-1]) | |
| backward = _normalize(X[:-1] - X[1:]) | |
| forward = F.pad(forward, [0, 0, 0, 1]) | |
| backward = F.pad(backward, [0, 0, 1, 0]) | |
| return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2) | |
| def _sidechains(self, X): | |
| n, origin, c = X[:, 0], X[:, 1], X[:, 2] | |
| c, n = _normalize(c - origin), _normalize(n - origin) | |
| bisector = _normalize(c + n) | |
| perp = _normalize(torch.cross(c, n)) | |
| vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) | |
| return vec | |
| def collate(self, batch): | |
| batch = self.featurize(batch) | |
| if (batch is None) or (len(batch)==0): | |
| return None | |
| elem = batch[0] | |
| if isinstance(elem, Data): | |
| return Batch.from_data_list(batch) | |
| elif isinstance(elem, torch.Tensor): | |
| return default_collate(batch) | |
| elif isinstance(elem, float): | |
| return torch.tensor(batch, dtype=torch.float) | |
| elif isinstance(elem, int): | |
| return torch.tensor(batch) | |
| elif isinstance(elem, str): | |
| return batch | |
| elif isinstance(elem, Mapping): | |
| return {key: self.collate([d[key] for d in batch]) for key in elem} | |
| elif isinstance(elem, tuple) and hasattr(elem, '_fields'): | |
| return type(elem)(*(self.collate(s) for s in zip(*batch))) | |
| elif isinstance(elem, Sequence) and not isinstance(elem, str): | |
| return [self.collate(s) for s in zip(*batch)] | |
| raise TypeError('DataLoader found invalid type: {}'.format(type(elem))) | |
| def featurize_ProteinMPNN(batch, is_testing=False, chain_dict=None, fixed_position_dict=None, omit_AA_dict=None, tied_positions_dict=None, pssm_dict=None, bias_by_res_dict=None): | |
| """ Pack and pad batch into torch tensors """ | |
| batch = [one for one in batch if one is not None] | |
| # print('______________________________________________________') | |
| # print('______________________________________________________') | |
| # print('______________________________________________________') | |
| # print('______________________________________________________') | |
| # print(batch[0].keys()) | |
| USING_DYNAMICS = True if ('norm_bfactors' in batch[0].keys()) or ('gt_flex' in batch[0].keys()) or ('enm_vals' in batch[0].keys()) or ('original_gt_flex' in batch[0].keys()) or ('eng_mask' in batch[0].keys()) else False | |
| alphabet = 'ACDEFGHIKLMNPQRSTVWYX' | |
| B = len(batch) | |
| if B==0: | |
| return None | |
| lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) #sum of chain seq lengths | |
| L_max = max([len(b['seq']) for b in batch]) | |
| X = np.zeros([B, L_max, 4, 3]) | |
| residue_idx = -100*np.ones([B, L_max], dtype=np.int32) | |
| chain_M = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted | |
| pssm_coef_all = np.zeros([B, L_max], dtype=np.float32) #1.0 for the bits that need to be predicted | |
| pssm_bias_all = np.zeros([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted | |
| pssm_log_odds_all = 10000.0*np.ones([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted | |
| chain_M_pos = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted | |
| bias_by_res_all = np.zeros([B, L_max, 21], dtype=np.float32) | |
| chain_encoding_all = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted | |
| S = np.zeros([B, L_max], dtype=np.int32) | |
| score = np.zeros([B, L_max]) | |
| omit_AA_mask = np.zeros([B, L_max, len(alphabet)], dtype=np.int32) | |
| # Build the batch | |
| letter_list_list = [] | |
| visible_list_list = [] | |
| masked_list_list = [] | |
| masked_chain_length_list_list = [] | |
| tied_pos_list_of_lists_list = [] | |
| # shuffle all chains before the main loop | |
| if USING_DYNAMICS: | |
| if ('norm_bfactors' in batch[0].keys()): | |
| b_factors = np.zeros([B, L_max]) | |
| if ('gt_flex' in batch[0].keys()): | |
| gt_flex = np.zeros([B, L_max]) | |
| if ('enm_vals' in batch[0].keys()): | |
| enm_vals = np.zeros([B, L_max]) | |
| if ('original_gt_flex' in batch[0].keys()): | |
| original_gt_flex = np.zeros([B, L_max]) | |
| if ('eng_mask' in batch[0].keys()): | |
| eng_mask = np.zeros([B, L_max]) | |
| for i, b in enumerate(batch): | |
| if chain_dict != None: | |
| masked_chains, visible_chains = chain_dict[b['name']] #masked_chains a list of chain letters to predict [A, D, F] | |
| else: | |
| # masked_chains = [item[-1:] for item in list(b) if item[:10]=='seq_chain_'] | |
| masked_chains = [''] | |
| visible_chains = [] | |
| # num_chains = b['num_of_chains'] | |
| all_chains = masked_chains + visible_chains | |
| #random.shuffle(all_chains) | |
| for i, b in enumerate(batch): | |
| mask_dict = {} | |
| a = 0 | |
| x_chain_list = [] | |
| chain_mask_list = [] | |
| chain_seq_list = [] | |
| chain_encoding_list = [] | |
| c = 1 | |
| letter_list = [] | |
| global_idx_start_list = [0] | |
| visible_list = [] | |
| masked_list = [] | |
| masked_chain_length_list = [] | |
| fixed_position_mask_list = [] | |
| omit_AA_mask_list = [] | |
| pssm_coef_list = [] | |
| pssm_bias_list = [] | |
| pssm_log_odds_list = [] | |
| bias_by_res_list = [] | |
| if USING_DYNAMICS: | |
| if ('norm_bfactors' in batch[0].keys()): | |
| b_factors_list = [] | |
| if ('gt_flex' in batch[0].keys()): | |
| gt_flex_list = [] | |
| if ('enm_vals' in batch[0].keys()): | |
| enm_vals_list = [] | |
| if ('original_gt_flex' in batch[0].keys()): | |
| original_gt_flex_list = [] | |
| if ('eng_mask' in batch[0].keys()): | |
| eng_mask_list = [] | |
| l0 = 0 | |
| l1 = 0 | |
| for step, letter in enumerate(all_chains): | |
| if letter in visible_chains: | |
| letter_list.append(letter) | |
| visible_list.append(letter) | |
| chain_seq = b[f'seq_chain_{letter}'] | |
| chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq]) | |
| chain_length = len(chain_seq) | |
| global_idx_start_list.append(global_idx_start_list[-1]+chain_length) | |
| chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary | |
| chain_mask = np.zeros(chain_length) #0.0 for visible chains | |
| x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3] | |
| x_chain_list.append(x_chain) | |
| chain_mask_list.append(chain_mask) | |
| chain_seq_list.append(chain_seq) | |
| chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0])) | |
| l1 += chain_length | |
| residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1) | |
| l0 += chain_length | |
| c+=1 | |
| fixed_position_mask = np.ones(chain_length) | |
| fixed_position_mask_list.append(fixed_position_mask) | |
| omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32) | |
| omit_AA_mask_list.append(omit_AA_mask_temp) | |
| pssm_coef = np.zeros(chain_length) | |
| pssm_bias = np.zeros([chain_length, 21]) | |
| pssm_log_odds = 10000.0*np.ones([chain_length, 21]) | |
| pssm_coef_list.append(pssm_coef) | |
| pssm_bias_list.append(pssm_bias) | |
| pssm_log_odds_list.append(pssm_log_odds) | |
| bias_by_res_list.append(np.zeros([chain_length, 21])) | |
| if letter in masked_chains: | |
| masked_list.append(letter) | |
| letter_list.append(letter) | |
| if USING_DYNAMICS: | |
| if ('norm_bfactors' in batch[0].keys()): | |
| chain_b_factors = b['norm_bfactors'] | |
| b_factors_list.append(chain_b_factors) | |
| if ('gt_flex' in batch[0].keys()): | |
| chain_gt_flex = b['gt_flex'] | |
| gt_flex_list.append(chain_gt_flex) | |
| if ('enm_vals' in batch[0].keys()): | |
| chain_enm_vals = b['enm_vals'] | |
| enm_vals_list.append(chain_enm_vals) | |
| if ('original_gt_flex' in batch[0].keys()): | |
| chain_original_gt_flex = b['original_gt_flex'] | |
| original_gt_flex_list.append(chain_original_gt_flex) | |
| if ('eng_mask' in batch[0].keys()): | |
| chain_eng_mask = b['eng_mask'] | |
| eng_mask_list.append(chain_eng_mask) | |
| # chain_seq = b[f'seq_chain_{letter}'] | |
| chain_seq = b[f'seq{letter}'] | |
| chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq]) | |
| chain_length = len(chain_seq) | |
| global_idx_start_list.append(global_idx_start_list[-1]+chain_length) | |
| masked_chain_length_list.append(chain_length) | |
| # chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary | |
| chain_coords = b | |
| chain_mask = np.ones(chain_length) #1.0 for masked | |
| # x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3] | |
| x_chain = np.stack([chain_coords[c] for c in [f'N', f'CA', f'C', f'O']], 1) #[chain_lenght,4,3] | |
| x_chain_list.append(x_chain) | |
| chain_mask_list.append(chain_mask) | |
| chain_seq_list.append(chain_seq) | |
| chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0])) | |
| l1 += chain_length | |
| residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1) | |
| l0 += chain_length | |
| c+=1 | |
| fixed_position_mask = np.ones(chain_length) | |
| if fixed_position_dict!=None: | |
| fixed_pos_list = fixed_position_dict[b['name']][letter] | |
| if fixed_pos_list: | |
| fixed_position_mask[np.array(fixed_pos_list)-1] = 0.0 | |
| fixed_position_mask_list.append(fixed_position_mask) | |
| omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32) | |
| if omit_AA_dict!=None: | |
| for item in omit_AA_dict[b['name']][letter]: | |
| idx_AA = np.array(item[0])-1 | |
| AA_idx = np.array([np.argwhere(np.array(list(alphabet))== AA)[0][0] for AA in item[1]]).repeat(idx_AA.shape[0]) | |
| idx_ = np.array([[a, b] for a in idx_AA for b in AA_idx]) | |
| omit_AA_mask_temp[idx_[:,0], idx_[:,1]] = 1 | |
| omit_AA_mask_list.append(omit_AA_mask_temp) | |
| pssm_coef = np.zeros(chain_length) | |
| pssm_bias = np.zeros([chain_length, 21]) | |
| pssm_log_odds = 10000.0*np.ones([chain_length, 21]) | |
| if pssm_dict: | |
| if pssm_dict[b['name']][letter]: | |
| pssm_coef = pssm_dict[b['name']][letter]['pssm_coef'] | |
| pssm_bias = pssm_dict[b['name']][letter]['pssm_bias'] | |
| pssm_log_odds = pssm_dict[b['name']][letter]['pssm_log_odds'] | |
| pssm_coef_list.append(pssm_coef) | |
| pssm_bias_list.append(pssm_bias) | |
| pssm_log_odds_list.append(pssm_log_odds) | |
| if bias_by_res_dict: | |
| bias_by_res_list.append(bias_by_res_dict[b['name']][letter]) | |
| else: | |
| bias_by_res_list.append(np.zeros([chain_length, 21])) | |
| letter_list_np = np.array(letter_list) | |
| tied_pos_list_of_lists = [] | |
| tied_beta = np.ones(L_max) | |
| if tied_positions_dict!=None: | |
| tied_pos_list = tied_positions_dict[b['name']] | |
| if tied_pos_list: | |
| set_chains_tied = set(list(itertools.chain(*[list(item) for item in tied_pos_list]))) | |
| for tied_item in tied_pos_list: | |
| one_list = [] | |
| for k, v in tied_item.items(): | |
| start_idx = global_idx_start_list[np.argwhere(letter_list_np == k)[0][0]] | |
| if isinstance(v[0], list): | |
| for v_count in range(len(v[0])): | |
| one_list.append(start_idx+v[0][v_count]-1)#make 0 to be the first | |
| tied_beta[start_idx+v[0][v_count]-1] = v[1][v_count] | |
| else: | |
| for v_ in v: | |
| one_list.append(start_idx+v_-1)#make 0 to be the first | |
| tied_pos_list_of_lists.append(one_list) | |
| tied_pos_list_of_lists_list.append(tied_pos_list_of_lists) | |
| x = np.concatenate(x_chain_list,0) #[L, 4, 3] | |
| if USING_DYNAMICS: | |
| if ('norm_bfactors' in batch[0].keys()): | |
| bf = np.concatenate(b_factors_list,0) #[L,] | |
| if ('gt_flex' in batch[0].keys()): | |
| gt = np.concatenate(gt_flex_list,0) #[L,] | |
| if ('enm_vals' in batch[0].keys()): | |
| enm = np.concatenate(enm_vals_list,0) | |
| if ('original_gt_flex' in batch[0].keys()): | |
| orig_gt = np.concatenate(original_gt_flex_list,0) | |
| if ('eng_mask' in batch[0].keys()): | |
| eng = np.concatenate(eng_mask_list,0) | |
| all_sequence = "".join(chain_seq_list) | |
| m = np.concatenate(chain_mask_list,0) #[L,], 1.0 for places that need to be predicted | |
| chain_encoding = np.concatenate(chain_encoding_list,0) | |
| m_pos = np.concatenate(fixed_position_mask_list,0) #[L,], 1.0 for places that need to be predicted | |
| pssm_coef_ = np.concatenate(pssm_coef_list,0) #[L,], 1.0 for places that need to be predicted | |
| pssm_bias_ = np.concatenate(pssm_bias_list,0) #[L,], 1.0 for places that need to be predicted | |
| pssm_log_odds_ = np.concatenate(pssm_log_odds_list,0) #[L,], 1.0 for places that need to be predicted | |
| bias_by_res_ = np.concatenate(bias_by_res_list, 0) #[L,21], 0.0 for places where AA frequencies don't need to be tweaked | |
| l = len(all_sequence) | |
| x_pad = np.pad(x, [[0, L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) | |
| if USING_DYNAMICS: | |
| if ('norm_bfactors' in batch[0].keys()): | |
| bf_pad = np.pad(bf, [[0, L_max-l]], 'constant', constant_values=(np.nan, )) | |
| if ('gt_flex' in batch[0].keys()): | |
| gt_pad = np.pad(gt, [[0, L_max-l]], 'constant', constant_values=(np.nan, )) | |
| if ('enm_vals' in batch[0].keys()): | |
| enm_pad = np.pad(enm, [[0, L_max-l]], 'constant', constant_values=(np.nan, )) | |
| if ('original_gt_flex' in batch[0].keys()): | |
| orig_gt_pad = np.pad(orig_gt, [[0, L_max-l]], 'constant', constant_values=(0, )) | |
| if ('eng_mask' in batch[0].keys()): | |
| eng_pad = np.pad(eng, [[0, L_max-l]], 'constant', constant_values=(0, )) | |
| X[i,:,:,:] = x_pad | |
| if USING_DYNAMICS: | |
| if ('norm_bfactors' in batch[0].keys()): | |
| b_factors[i, :] = bf_pad | |
| if ('gt_flex' in batch[0].keys()): | |
| gt_flex[i, :] = gt_pad[:-1] | |
| if ('enm_vals' in batch[0].keys()): | |
| enm_vals[i, :] = enm_pad | |
| if ('original_gt_flex' in batch[0].keys()): | |
| original_gt_flex[i, :] = orig_gt_pad[:-1] | |
| if ('eng_mask' in batch[0].keys()): | |
| eng_mask[i, :] = eng_pad[:-1] | |
| if 'score' in b.keys(): | |
| score[i, :l] = b['score'] | |
| else: | |
| score[i, :l] = 100.0 | |
| m_pad = np.pad(m, [[0, L_max-l]], 'constant', constant_values=(0.0, )) | |
| m_pos_pad = np.pad(m_pos, [[0,L_max-l]], 'constant', constant_values=(0.0, )) | |
| omit_AA_mask_pad = np.pad(np.concatenate(omit_AA_mask_list,0), [[0,L_max-l], [0, 0]], 'constant', constant_values=(0.0, )) | |
| chain_M[i,:] = m_pad | |
| chain_M_pos[i,:] = m_pos_pad | |
| omit_AA_mask[i,] = omit_AA_mask_pad | |
| chain_encoding_pad = np.pad(chain_encoding, [[0,L_max-l]], 'constant', constant_values=(0.0, )) | |
| chain_encoding_all[i,:] = chain_encoding_pad | |
| pssm_coef_pad = np.pad(pssm_coef_, [[0, L_max-l]], 'constant', constant_values=(0.0, )) | |
| pssm_bias_pad = np.pad(pssm_bias_, [[0, L_max-l], [0,0]], 'constant', constant_values=(0.0, )) | |
| pssm_log_odds_pad = np.pad(pssm_log_odds_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, )) | |
| pssm_coef_all[i,:] = pssm_coef_pad | |
| pssm_bias_all[i,:] = pssm_bias_pad | |
| pssm_log_odds_all[i,:] = pssm_log_odds_pad | |
| bias_by_res_pad = np.pad(bias_by_res_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, )) | |
| bias_by_res_all[i,:] = bias_by_res_pad | |
| # Convert to labels | |
| indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False)) | |
| S[i, :l] = indices | |
| letter_list_list.append(letter_list) | |
| visible_list_list.append(visible_list) | |
| masked_list_list.append(masked_list) | |
| masked_chain_length_list_list.append(masked_chain_length_list) | |
| isnan = np.isnan(X) | |
| mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) | |
| X[isnan] = 0. | |
| # Conversion | |
| pssm_coef_all = torch.from_numpy(pssm_coef_all).to(dtype=torch.float32) | |
| pssm_bias_all = torch.from_numpy(pssm_bias_all).to(dtype=torch.float32) | |
| pssm_log_odds_all = torch.from_numpy(pssm_log_odds_all).to(dtype=torch.float32) | |
| tied_beta = torch.from_numpy(tied_beta).to(dtype=torch.float32) | |
| jumps = ((residue_idx[:,1:]-residue_idx[:,:-1])==1).astype(np.float32) | |
| bias_by_res_all = torch.from_numpy(bias_by_res_all).to(dtype=torch.float32) | |
| phi_mask = np.pad(jumps, [[0,0],[1,0]]) | |
| psi_mask = np.pad(jumps, [[0,0],[0,1]]) | |
| omega_mask = np.pad(jumps, [[0,0],[0,1]]) | |
| dihedral_mask = np.concatenate([phi_mask[:,:,None], psi_mask[:,:,None], omega_mask[:,:,None]], -1) #[B,L,3] | |
| dihedral_mask = torch.from_numpy(dihedral_mask).to(dtype=torch.float32) | |
| residue_idx = torch.from_numpy(residue_idx).to(dtype=torch.long) | |
| S = torch.from_numpy(S).to(dtype=torch.long) | |
| X = torch.from_numpy(X).to(dtype=torch.float32) | |
| if USING_DYNAMICS: | |
| if ('norm_bfactors' in batch[0].keys()): | |
| b_factors = torch.from_numpy(b_factors).to(dtype=torch.float32) | |
| if ('gt_flex' in batch[0].keys()): | |
| gt_flex = torch.from_numpy(gt_flex).to(dtype=torch.float32) | |
| if ('enm_vals' in batch[0].keys()): | |
| enm_vals = torch.from_numpy(enm_vals).to(dtype=torch.float32) | |
| if ('original_gt_flex' in batch[0].keys()): | |
| original_gt_flex = torch.from_numpy(original_gt_flex).to(dtype=torch.float32) | |
| if ('eng_mask' in batch[0].keys()): | |
| eng_mask = torch.from_numpy(eng_mask).to(dtype=torch.float32) | |
| score = torch.from_numpy(score).float() | |
| mask = torch.from_numpy(mask).to(dtype=torch.float32) | |
| chain_M = torch.from_numpy(chain_M).to(dtype=torch.float32) | |
| chain_M_pos = torch.from_numpy(chain_M_pos).to(dtype=torch.float32) | |
| omit_AA_mask = torch.from_numpy(omit_AA_mask).to(dtype=torch.float32) | |
| chain_encoding_all = torch.from_numpy(chain_encoding_all).to(dtype=torch.long) | |
| if is_testing is False: | |
| retVal = {"title": [b['title'] for b in batch], | |
| "X":X, | |
| "S":S, | |
| "score": score, | |
| "mask":mask, | |
| "lengths":lengths, | |
| "chain_M":chain_M, | |
| "chain_M_pos":chain_M_pos, | |
| "residue_idx":residue_idx, | |
| "chain_encoding_all":chain_encoding_all} | |
| if USING_DYNAMICS: | |
| if ('norm_bfactors' in batch[0].keys()): | |
| retVal['norm_bfactors'] = b_factors | |
| if ('gt_flex' in batch[0].keys()): | |
| retVal['gt_flex'] = gt_flex | |
| if ('enm_vals' in batch[0].keys()): | |
| retVal['enm_vals'] = enm_vals | |
| if ('original_gt_flex' in batch[0].keys()): | |
| retVal['original_gt_flex'] = original_gt_flex | |
| if ('eng_mask' in batch[0].keys()): | |
| retVal['eng_mask'] = eng_mask | |
| return retVal | |
| else: | |
| retVal = {"title": [b['title'] for b in batch], | |
| "X":X, | |
| "S":S, | |
| "score": score, | |
| "mask":mask, | |
| "lengths":lengths, | |
| "chain_M":chain_M, | |
| "chain_M_pos":chain_M_pos, | |
| "residue_idx":residue_idx, | |
| "chain_encoding_all":chain_encoding_all} | |
| if USING_DYNAMICS: | |
| if ('norm_bfactors' in batch[0].keys()): | |
| retVal['norm_bfactors'] = b_factors | |
| if ('gt_flex' in batch[0].keys()): | |
| retVal['gt_flex'] = gt_flex | |
| if ('enm_vals' in batch[0].keys()): | |
| retVal['enm_vals'] = enm_vals | |
| if ('original_gt_flex' in batch[0].keys()): | |
| retVal['original_gt_flex'] = original_gt_flex | |
| if ('eng_mask' in batch[0].keys()): | |
| retVal['eng_mask'] = eng_mask | |
| return retVal | |
| def featurize_Inversefolding(batch, shuffle_fraction=0.): | |
| """ Pack and pad batch into torch tensors """ | |
| alphabet = 'ACDEFGHIKLMNPQRSTVWY' | |
| B = len(batch) | |
| lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) | |
| L_max = max([len(b['seq']) for b in batch]) | |
| X = np.zeros([B, L_max, 3, 3]) | |
| S = np.zeros([B, L_max], dtype=np.int32) | |
| score = np.ones([B, L_max]) * 100.0 | |
| chain_mask = np.zeros([B, L_max])-1 # 1:需要被预测的掩码部分 0:可见部分 | |
| chain_encoding = np.zeros([B, L_max])-1 | |
| # Build the batch | |
| for i, b in enumerate(batch): | |
| x = np.stack([b[c] for c in ['N', 'CA', 'C']], 1) # [#atom, 4, 3] | |
| l = len(b['seq']) | |
| x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) # [#atom, 3, 3] | |
| X[i,:,:,:] = x_pad | |
| # Convert to labels | |
| indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False)) | |
| if shuffle_fraction > 0.: | |
| idx_shuffle = shuffle_subset(l, shuffle_fraction) | |
| S[i, :l] = indices[idx_shuffle] | |
| else: | |
| S[i, :l] = indices | |
| chain_mask[i,:l] = b['chain_mask'] | |
| chain_encoding[i,:l] = b['chain_encoding'] | |
| mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask | |
| numbers = np.sum(mask, axis=1).astype(np.int) | |
| S_new = np.zeros_like(S) | |
| X_new = np.zeros_like(X)+np.nan | |
| for i, n in enumerate(numbers): | |
| X_new[i,:n,::] = X[i][mask[i]==1] | |
| S_new[i,:n] = S[i][mask[i]==1] | |
| X = X_new | |
| S = S_new | |
| isnan = np.isnan(X) | |
| mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) | |
| X[isnan] = 0. | |
| # Conversion | |
| S = torch.from_numpy(S).to(dtype=torch.long) | |
| score = torch.from_numpy(score).float() | |
| X = torch.from_numpy(X).to(dtype=torch.float32) | |
| mask = torch.from_numpy(mask).to(dtype=torch.float32) | |
| chain_mask = torch.from_numpy(chain_mask) | |
| chain_encoding = torch.from_numpy(chain_encoding) | |
| return {"title": [b['title'] for b in batch], | |
| "X":X, | |
| "S":S, | |
| "score": score, | |
| "mask":mask, | |
| "lengths":lengths, | |
| "chain_mask":chain_mask, | |
| "chain_encoding":chain_encoding} |