import torch import numpy as np import itertools import torch.nn.functional as F import math import torch_geometric # import torch_cluster from collections.abc import Mapping, Sequence from torch_geometric.data import Data, Batch from torch.utils.data.dataloader import default_collate from transformers import AutoTokenizer import pdb tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir='./cache_dir/') # mask token: 32 def _normalize(tensor, dim=-1): ''' Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. ''' return torch.nan_to_num( torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'): ''' From https://github.com/jingraham/neurips19-graph-protein-design Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1. That is, if `D` has shape [...dims], then the returned tensor will have shape [...dims, D_count]. ''' D_mu = torch.linspace(D_min, D_max, D_count, device=device) D_mu = D_mu.view([1, -1]) D_sigma = (D_max - D_min) / D_count D_expand = torch.unsqueeze(D, -1) RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2) return RBF def shuffle_subset(n, p): n_shuffle = np.random.binomial(n, p) ix = np.arange(n) ix_subset = np.random.choice(ix, size=n_shuffle, replace=False) ix_subset_shuffled = np.copy(ix_subset) np.random.shuffle(ix_subset_shuffled) ix[ix_subset] = ix_subset_shuffled return ix def featurize_AF(batch, shuffle_fraction=0.): """ Pack and pad batch into torch tensors """ alphabet = 'ACDEFGHIKLMNPQRSTVWY' B = len(batch) lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) L_max = max([len(b['seq']) for b in batch]) X = np.zeros([B, L_max, 4, 3]) S = np.zeros([B, L_max], dtype=np.int32) score = np.zeros([B, L_max]) # Build the batch for i, b in enumerate(batch): x = np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1) # [#atom, 4, 3] l = len(b['seq']) x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) # [#atom, 4, 3] X[i,:,:,:] = x_pad # Convert to labels indices = np.asarray([alphabet.index(a) for a in b['seq']], dtype=np.int32) if shuffle_fraction > 0.: idx_shuffle = shuffle_subset(l, shuffle_fraction) S[i, :l] = indices[idx_shuffle] score[i,:l] = b['score'][idx_shuffle] else: S[i, :l] = indices score[i,:l] = b['score'] mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask numbers = np.sum(mask, axis=1).astype(np.int) S_new = np.zeros_like(S) score_new = np.zeros_like(score) X_new = np.zeros_like(X)+np.nan for i, n in enumerate(numbers): X_new[i,:n,::] = X[i][mask[i]==1] S_new[i,:n] = S[i][mask[i]==1] score_new[i,:n] = score[i][mask[i]==1] X = X_new S = S_new score = score_new isnan = np.isnan(X) mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) X[isnan] = 0. # Conversion S = torch.from_numpy(S).to(dtype=torch.long) score = torch.from_numpy(score).float() X = torch.from_numpy(X).to(dtype=torch.float32) mask = torch.from_numpy(mask).to(dtype=torch.float32) return X, S, score, mask, lengths def featurize_GTrans(batch): """ Pack and pad batch into torch tensors """ # alphabet = 'ACDEFGHIKLMNPQRSTVWYX' batch = [one for one in batch if one is not None] B = len(batch) if B==0: return None lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) L_max = max([len(b['seq']) for b in batch]) X = np.zeros([B, L_max, 4, 3]) S = np.zeros([B, L_max], dtype=np.int32) score = np.ones([B, L_max]) * 100.0 chain_mask = np.zeros([B, L_max])-1 # 1:需要被预测的掩码部分 0:可见部分 chain_encoding = np.zeros([B, L_max])-1 # Build the batch for i, b in enumerate(batch): x = np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1) # [#atom, 4, 3] l = len(b['seq']) x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) # [#atom, 4, 3] X[i,:,:,:] = x_pad # Convert to labels indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False)) # indices = np.asarray([alphabet.index(a) for a in b['seq']], dtype=np.int32) S[i, :l] = indices chain_mask[i,:l] = b['chain_mask'] chain_encoding[i,:l] = b['chain_encoding'] mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask numbers = np.sum(mask, axis=1).astype(np.int32) S_new = np.zeros_like(S) X_new = np.zeros_like(X)+np.nan for i, n in enumerate(numbers): X_new[i,:n,::] = X[i][mask[i]==1] S_new[i,:n] = S[i][mask[i]==1] X = X_new S = S_new isnan = np.isnan(X) mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) X[isnan] = 0. # Conversion S = torch.from_numpy(S).to(dtype=torch.long) score = torch.from_numpy(score).float() X = torch.from_numpy(X).to(dtype=torch.float32) mask = torch.from_numpy(mask).to(dtype=torch.float32) lengths = torch.from_numpy(lengths) chain_mask = torch.from_numpy(chain_mask) chain_encoding = torch.from_numpy(chain_encoding) return {"title": [b['title'] for b in batch], "X":X, "S":S, "score": score, "mask":mask, "lengths":lengths, "chain_mask":chain_mask, "chain_encoding":chain_encoding} class featurize_GVP: def __init__(self, num_positional_embeddings=16, top_k=30, num_rbf=16): self.top_k = top_k self.num_rbf = num_rbf self.num_positional_embeddings = num_positional_embeddings # self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9, # 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, # 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, # 'N': 2, 'Y': 18, 'M': 12} # self.num_to_letter = {v:k for k, v in self.letter_to_num.items()} def featurize(self, batch): data_all = [] for b in batch: if b is None: continue coords = torch.tensor(np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1)) seq = torch.tensor(np.array(tokenizer.encode(b['seq'], add_special_tokens=False))) mask = torch.isfinite(coords.sum(dim=(1,2))) coords[~mask] = np.inf X_ca = coords[:, 1].float() edge_index = torch_geometric.nn.knn_graph(X_ca, k=self.top_k)#torch_cluster.knn_graph(X_ca, k=self.top_k) pos_embeddings = self._positional_embeddings(edge_index) # [E, 16] E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]] # [E, 3] rbf = _rbf(E_vectors.norm(dim=-1), D_count=self.num_rbf) # [E, 16] dihedrals = self._dihedrals(coords) # [n,6] orientations = self._orientations(X_ca) # [n,2,3] sidechains = self._sidechains(coords) # [n,3] node_s = dihedrals.float() # [n,6] node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2).float() # [n, 3, 3] edge_s = torch.cat([rbf, pos_embeddings], dim=-1).float() # [E, 32] edge_v = _normalize(E_vectors).unsqueeze(-2).float() # [E, 1, 3] node_s, node_v, edge_s, edge_v = map(torch.nan_to_num,(node_s, node_v, edge_s, edge_v)) data = torch_geometric.data.Data(x=X_ca, seq=seq, node_s=node_s, node_v=node_v, edge_s=edge_s, edge_v=edge_v, edge_index=edge_index, mask=mask) data_all.append(data) return data_all def _positional_embeddings(self, edge_index, num_embeddings=None, period_range=[2, 1000]): # From https://github.com/jingraham/neurips19-graph-protein-design num_embeddings = num_embeddings or self.num_positional_embeddings d = edge_index[0] - edge_index[1] frequency = torch.exp( torch.arange(0, num_embeddings, 2, dtype=torch.float32) * -(np.log(10000.0) / num_embeddings) ) angles = d.unsqueeze(-1) * frequency E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) return E def _dihedrals(self, X, eps=1e-7): # From https://github.com/jingraham/neurips19-graph-protein-design X = torch.reshape(X[:, :3], [3*X.shape[0], 3]) dX = X[1:] - X[:-1] U = _normalize(dX, dim=-1) u_2 = U[:-2] u_1 = U[1:-1] u_0 = U[2:] # Backbone normals n_2 = _normalize(torch.cross(u_2, u_1), dim=-1) n_1 = _normalize(torch.cross(u_1, u_0), dim=-1) # Angle between normals cosD = torch.sum(n_2 * n_1, -1) cosD = torch.clamp(cosD, -1 + eps, 1 - eps) D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD) # This scheme will remove phi[0], psi[-1], omega[-1] D = F.pad(D, [1, 2]) D = torch.reshape(D, [-1, 3]) # Lift angle representations to the circle D_features = torch.cat([torch.cos(D), torch.sin(D)], 1) return D_features def _orientations(self, X): forward = _normalize(X[1:] - X[:-1]) backward = _normalize(X[:-1] - X[1:]) forward = F.pad(forward, [0, 0, 0, 1]) backward = F.pad(backward, [0, 0, 1, 0]) return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2) def _sidechains(self, X): n, origin, c = X[:, 0], X[:, 1], X[:, 2] c, n = _normalize(c - origin), _normalize(n - origin) bisector = _normalize(c + n) perp = _normalize(torch.cross(c, n)) vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) return vec def collate(self, batch): batch = self.featurize(batch) if (batch is None) or (len(batch)==0): return None elem = batch[0] if isinstance(elem, Data): return Batch.from_data_list(batch) elif isinstance(elem, torch.Tensor): return default_collate(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, str): return batch elif isinstance(elem, Mapping): return {key: self.collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): return type(elem)(*(self.collate(s) for s in zip(*batch))) elif isinstance(elem, Sequence) and not isinstance(elem, str): return [self.collate(s) for s in zip(*batch)] raise TypeError('DataLoader found invalid type: {}'.format(type(elem))) def featurize_ProteinMPNN(batch, is_testing=False, chain_dict=None, fixed_position_dict=None, omit_AA_dict=None, tied_positions_dict=None, pssm_dict=None, bias_by_res_dict=None): """ Pack and pad batch into torch tensors """ batch = [one for one in batch if one is not None] # print('______________________________________________________') # print('______________________________________________________') # print('______________________________________________________') # print('______________________________________________________') # print(batch[0].keys()) USING_DYNAMICS = True if ('norm_bfactors' in batch[0].keys()) or ('gt_flex' in batch[0].keys()) or ('enm_vals' in batch[0].keys()) or ('original_gt_flex' in batch[0].keys()) or ('eng_mask' in batch[0].keys()) else False alphabet = 'ACDEFGHIKLMNPQRSTVWYX' B = len(batch) if B==0: return None lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) #sum of chain seq lengths L_max = max([len(b['seq']) for b in batch]) X = np.zeros([B, L_max, 4, 3]) residue_idx = -100*np.ones([B, L_max], dtype=np.int32) chain_M = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted pssm_coef_all = np.zeros([B, L_max], dtype=np.float32) #1.0 for the bits that need to be predicted pssm_bias_all = np.zeros([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted pssm_log_odds_all = 10000.0*np.ones([B, L_max, 21], dtype=np.float32) #1.0 for the bits that need to be predicted chain_M_pos = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted bias_by_res_all = np.zeros([B, L_max, 21], dtype=np.float32) chain_encoding_all = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted S = np.zeros([B, L_max], dtype=np.int32) score = np.zeros([B, L_max]) omit_AA_mask = np.zeros([B, L_max, len(alphabet)], dtype=np.int32) # Build the batch letter_list_list = [] visible_list_list = [] masked_list_list = [] masked_chain_length_list_list = [] tied_pos_list_of_lists_list = [] # shuffle all chains before the main loop if USING_DYNAMICS: if ('norm_bfactors' in batch[0].keys()): b_factors = np.zeros([B, L_max]) if ('gt_flex' in batch[0].keys()): gt_flex = np.zeros([B, L_max]) if ('enm_vals' in batch[0].keys()): enm_vals = np.zeros([B, L_max]) if ('original_gt_flex' in batch[0].keys()): original_gt_flex = np.zeros([B, L_max]) if ('eng_mask' in batch[0].keys()): eng_mask = np.zeros([B, L_max]) for i, b in enumerate(batch): if chain_dict != None: masked_chains, visible_chains = chain_dict[b['name']] #masked_chains a list of chain letters to predict [A, D, F] else: # masked_chains = [item[-1:] for item in list(b) if item[:10]=='seq_chain_'] masked_chains = [''] visible_chains = [] # num_chains = b['num_of_chains'] all_chains = masked_chains + visible_chains #random.shuffle(all_chains) for i, b in enumerate(batch): mask_dict = {} a = 0 x_chain_list = [] chain_mask_list = [] chain_seq_list = [] chain_encoding_list = [] c = 1 letter_list = [] global_idx_start_list = [0] visible_list = [] masked_list = [] masked_chain_length_list = [] fixed_position_mask_list = [] omit_AA_mask_list = [] pssm_coef_list = [] pssm_bias_list = [] pssm_log_odds_list = [] bias_by_res_list = [] if USING_DYNAMICS: if ('norm_bfactors' in batch[0].keys()): b_factors_list = [] if ('gt_flex' in batch[0].keys()): gt_flex_list = [] if ('enm_vals' in batch[0].keys()): enm_vals_list = [] if ('original_gt_flex' in batch[0].keys()): original_gt_flex_list = [] if ('eng_mask' in batch[0].keys()): eng_mask_list = [] l0 = 0 l1 = 0 for step, letter in enumerate(all_chains): if letter in visible_chains: letter_list.append(letter) visible_list.append(letter) chain_seq = b[f'seq_chain_{letter}'] chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq]) chain_length = len(chain_seq) global_idx_start_list.append(global_idx_start_list[-1]+chain_length) chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary chain_mask = np.zeros(chain_length) #0.0 for visible chains x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3] x_chain_list.append(x_chain) chain_mask_list.append(chain_mask) chain_seq_list.append(chain_seq) chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0])) l1 += chain_length residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1) l0 += chain_length c+=1 fixed_position_mask = np.ones(chain_length) fixed_position_mask_list.append(fixed_position_mask) omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32) omit_AA_mask_list.append(omit_AA_mask_temp) pssm_coef = np.zeros(chain_length) pssm_bias = np.zeros([chain_length, 21]) pssm_log_odds = 10000.0*np.ones([chain_length, 21]) pssm_coef_list.append(pssm_coef) pssm_bias_list.append(pssm_bias) pssm_log_odds_list.append(pssm_log_odds) bias_by_res_list.append(np.zeros([chain_length, 21])) if letter in masked_chains: masked_list.append(letter) letter_list.append(letter) if USING_DYNAMICS: if ('norm_bfactors' in batch[0].keys()): chain_b_factors = b['norm_bfactors'] b_factors_list.append(chain_b_factors) if ('gt_flex' in batch[0].keys()): chain_gt_flex = b['gt_flex'] gt_flex_list.append(chain_gt_flex) if ('enm_vals' in batch[0].keys()): chain_enm_vals = b['enm_vals'] enm_vals_list.append(chain_enm_vals) if ('original_gt_flex' in batch[0].keys()): chain_original_gt_flex = b['original_gt_flex'] original_gt_flex_list.append(chain_original_gt_flex) if ('eng_mask' in batch[0].keys()): chain_eng_mask = b['eng_mask'] eng_mask_list.append(chain_eng_mask) # chain_seq = b[f'seq_chain_{letter}'] chain_seq = b[f'seq{letter}'] chain_seq = ''.join([a if a!='-' else 'X' for a in chain_seq]) chain_length = len(chain_seq) global_idx_start_list.append(global_idx_start_list[-1]+chain_length) masked_chain_length_list.append(chain_length) # chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary chain_coords = b chain_mask = np.ones(chain_length) #1.0 for masked # x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3] x_chain = np.stack([chain_coords[c] for c in [f'N', f'CA', f'C', f'O']], 1) #[chain_lenght,4,3] x_chain_list.append(x_chain) chain_mask_list.append(chain_mask) chain_seq_list.append(chain_seq) chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0])) l1 += chain_length residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1) l0 += chain_length c+=1 fixed_position_mask = np.ones(chain_length) if fixed_position_dict!=None: fixed_pos_list = fixed_position_dict[b['name']][letter] if fixed_pos_list: fixed_position_mask[np.array(fixed_pos_list)-1] = 0.0 fixed_position_mask_list.append(fixed_position_mask) omit_AA_mask_temp = np.zeros([chain_length, len(alphabet)], np.int32) if omit_AA_dict!=None: for item in omit_AA_dict[b['name']][letter]: idx_AA = np.array(item[0])-1 AA_idx = np.array([np.argwhere(np.array(list(alphabet))== AA)[0][0] for AA in item[1]]).repeat(idx_AA.shape[0]) idx_ = np.array([[a, b] for a in idx_AA for b in AA_idx]) omit_AA_mask_temp[idx_[:,0], idx_[:,1]] = 1 omit_AA_mask_list.append(omit_AA_mask_temp) pssm_coef = np.zeros(chain_length) pssm_bias = np.zeros([chain_length, 21]) pssm_log_odds = 10000.0*np.ones([chain_length, 21]) if pssm_dict: if pssm_dict[b['name']][letter]: pssm_coef = pssm_dict[b['name']][letter]['pssm_coef'] pssm_bias = pssm_dict[b['name']][letter]['pssm_bias'] pssm_log_odds = pssm_dict[b['name']][letter]['pssm_log_odds'] pssm_coef_list.append(pssm_coef) pssm_bias_list.append(pssm_bias) pssm_log_odds_list.append(pssm_log_odds) if bias_by_res_dict: bias_by_res_list.append(bias_by_res_dict[b['name']][letter]) else: bias_by_res_list.append(np.zeros([chain_length, 21])) letter_list_np = np.array(letter_list) tied_pos_list_of_lists = [] tied_beta = np.ones(L_max) if tied_positions_dict!=None: tied_pos_list = tied_positions_dict[b['name']] if tied_pos_list: set_chains_tied = set(list(itertools.chain(*[list(item) for item in tied_pos_list]))) for tied_item in tied_pos_list: one_list = [] for k, v in tied_item.items(): start_idx = global_idx_start_list[np.argwhere(letter_list_np == k)[0][0]] if isinstance(v[0], list): for v_count in range(len(v[0])): one_list.append(start_idx+v[0][v_count]-1)#make 0 to be the first tied_beta[start_idx+v[0][v_count]-1] = v[1][v_count] else: for v_ in v: one_list.append(start_idx+v_-1)#make 0 to be the first tied_pos_list_of_lists.append(one_list) tied_pos_list_of_lists_list.append(tied_pos_list_of_lists) x = np.concatenate(x_chain_list,0) #[L, 4, 3] if USING_DYNAMICS: if ('norm_bfactors' in batch[0].keys()): bf = np.concatenate(b_factors_list,0) #[L,] if ('gt_flex' in batch[0].keys()): gt = np.concatenate(gt_flex_list,0) #[L,] if ('enm_vals' in batch[0].keys()): enm = np.concatenate(enm_vals_list,0) if ('original_gt_flex' in batch[0].keys()): orig_gt = np.concatenate(original_gt_flex_list,0) if ('eng_mask' in batch[0].keys()): eng = np.concatenate(eng_mask_list,0) all_sequence = "".join(chain_seq_list) m = np.concatenate(chain_mask_list,0) #[L,], 1.0 for places that need to be predicted chain_encoding = np.concatenate(chain_encoding_list,0) m_pos = np.concatenate(fixed_position_mask_list,0) #[L,], 1.0 for places that need to be predicted pssm_coef_ = np.concatenate(pssm_coef_list,0) #[L,], 1.0 for places that need to be predicted pssm_bias_ = np.concatenate(pssm_bias_list,0) #[L,], 1.0 for places that need to be predicted pssm_log_odds_ = np.concatenate(pssm_log_odds_list,0) #[L,], 1.0 for places that need to be predicted bias_by_res_ = np.concatenate(bias_by_res_list, 0) #[L,21], 0.0 for places where AA frequencies don't need to be tweaked l = len(all_sequence) x_pad = np.pad(x, [[0, L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) if USING_DYNAMICS: if ('norm_bfactors' in batch[0].keys()): bf_pad = np.pad(bf, [[0, L_max-l]], 'constant', constant_values=(np.nan, )) if ('gt_flex' in batch[0].keys()): gt_pad = np.pad(gt, [[0, L_max-l]], 'constant', constant_values=(np.nan, )) if ('enm_vals' in batch[0].keys()): enm_pad = np.pad(enm, [[0, L_max-l]], 'constant', constant_values=(np.nan, )) if ('original_gt_flex' in batch[0].keys()): orig_gt_pad = np.pad(orig_gt, [[0, L_max-l]], 'constant', constant_values=(0, )) if ('eng_mask' in batch[0].keys()): eng_pad = np.pad(eng, [[0, L_max-l]], 'constant', constant_values=(0, )) X[i,:,:,:] = x_pad if USING_DYNAMICS: if ('norm_bfactors' in batch[0].keys()): b_factors[i, :] = bf_pad if ('gt_flex' in batch[0].keys()): gt_flex[i, :] = gt_pad[:-1] if ('enm_vals' in batch[0].keys()): enm_vals[i, :] = enm_pad if ('original_gt_flex' in batch[0].keys()): original_gt_flex[i, :] = orig_gt_pad[:-1] if ('eng_mask' in batch[0].keys()): eng_mask[i, :] = eng_pad[:-1] if 'score' in b.keys(): score[i, :l] = b['score'] else: score[i, :l] = 100.0 m_pad = np.pad(m, [[0, L_max-l]], 'constant', constant_values=(0.0, )) m_pos_pad = np.pad(m_pos, [[0,L_max-l]], 'constant', constant_values=(0.0, )) omit_AA_mask_pad = np.pad(np.concatenate(omit_AA_mask_list,0), [[0,L_max-l], [0, 0]], 'constant', constant_values=(0.0, )) chain_M[i,:] = m_pad chain_M_pos[i,:] = m_pos_pad omit_AA_mask[i,] = omit_AA_mask_pad chain_encoding_pad = np.pad(chain_encoding, [[0,L_max-l]], 'constant', constant_values=(0.0, )) chain_encoding_all[i,:] = chain_encoding_pad pssm_coef_pad = np.pad(pssm_coef_, [[0, L_max-l]], 'constant', constant_values=(0.0, )) pssm_bias_pad = np.pad(pssm_bias_, [[0, L_max-l], [0,0]], 'constant', constant_values=(0.0, )) pssm_log_odds_pad = np.pad(pssm_log_odds_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, )) pssm_coef_all[i,:] = pssm_coef_pad pssm_bias_all[i,:] = pssm_bias_pad pssm_log_odds_all[i,:] = pssm_log_odds_pad bias_by_res_pad = np.pad(bias_by_res_, [[0,L_max-l], [0,0]], 'constant', constant_values=(0.0, )) bias_by_res_all[i,:] = bias_by_res_pad # Convert to labels indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False)) S[i, :l] = indices letter_list_list.append(letter_list) visible_list_list.append(visible_list) masked_list_list.append(masked_list) masked_chain_length_list_list.append(masked_chain_length_list) isnan = np.isnan(X) mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) X[isnan] = 0. # Conversion pssm_coef_all = torch.from_numpy(pssm_coef_all).to(dtype=torch.float32) pssm_bias_all = torch.from_numpy(pssm_bias_all).to(dtype=torch.float32) pssm_log_odds_all = torch.from_numpy(pssm_log_odds_all).to(dtype=torch.float32) tied_beta = torch.from_numpy(tied_beta).to(dtype=torch.float32) jumps = ((residue_idx[:,1:]-residue_idx[:,:-1])==1).astype(np.float32) bias_by_res_all = torch.from_numpy(bias_by_res_all).to(dtype=torch.float32) phi_mask = np.pad(jumps, [[0,0],[1,0]]) psi_mask = np.pad(jumps, [[0,0],[0,1]]) omega_mask = np.pad(jumps, [[0,0],[0,1]]) dihedral_mask = np.concatenate([phi_mask[:,:,None], psi_mask[:,:,None], omega_mask[:,:,None]], -1) #[B,L,3] dihedral_mask = torch.from_numpy(dihedral_mask).to(dtype=torch.float32) residue_idx = torch.from_numpy(residue_idx).to(dtype=torch.long) S = torch.from_numpy(S).to(dtype=torch.long) X = torch.from_numpy(X).to(dtype=torch.float32) if USING_DYNAMICS: if ('norm_bfactors' in batch[0].keys()): b_factors = torch.from_numpy(b_factors).to(dtype=torch.float32) if ('gt_flex' in batch[0].keys()): gt_flex = torch.from_numpy(gt_flex).to(dtype=torch.float32) if ('enm_vals' in batch[0].keys()): enm_vals = torch.from_numpy(enm_vals).to(dtype=torch.float32) if ('original_gt_flex' in batch[0].keys()): original_gt_flex = torch.from_numpy(original_gt_flex).to(dtype=torch.float32) if ('eng_mask' in batch[0].keys()): eng_mask = torch.from_numpy(eng_mask).to(dtype=torch.float32) score = torch.from_numpy(score).float() mask = torch.from_numpy(mask).to(dtype=torch.float32) chain_M = torch.from_numpy(chain_M).to(dtype=torch.float32) chain_M_pos = torch.from_numpy(chain_M_pos).to(dtype=torch.float32) omit_AA_mask = torch.from_numpy(omit_AA_mask).to(dtype=torch.float32) chain_encoding_all = torch.from_numpy(chain_encoding_all).to(dtype=torch.long) if is_testing is False: retVal = {"title": [b['title'] for b in batch], "X":X, "S":S, "score": score, "mask":mask, "lengths":lengths, "chain_M":chain_M, "chain_M_pos":chain_M_pos, "residue_idx":residue_idx, "chain_encoding_all":chain_encoding_all} if USING_DYNAMICS: if ('norm_bfactors' in batch[0].keys()): retVal['norm_bfactors'] = b_factors if ('gt_flex' in batch[0].keys()): retVal['gt_flex'] = gt_flex if ('enm_vals' in batch[0].keys()): retVal['enm_vals'] = enm_vals if ('original_gt_flex' in batch[0].keys()): retVal['original_gt_flex'] = original_gt_flex if ('eng_mask' in batch[0].keys()): retVal['eng_mask'] = eng_mask return retVal else: retVal = {"title": [b['title'] for b in batch], "X":X, "S":S, "score": score, "mask":mask, "lengths":lengths, "chain_M":chain_M, "chain_M_pos":chain_M_pos, "residue_idx":residue_idx, "chain_encoding_all":chain_encoding_all} if USING_DYNAMICS: if ('norm_bfactors' in batch[0].keys()): retVal['norm_bfactors'] = b_factors if ('gt_flex' in batch[0].keys()): retVal['gt_flex'] = gt_flex if ('enm_vals' in batch[0].keys()): retVal['enm_vals'] = enm_vals if ('original_gt_flex' in batch[0].keys()): retVal['original_gt_flex'] = original_gt_flex if ('eng_mask' in batch[0].keys()): retVal['eng_mask'] = eng_mask return retVal def featurize_Inversefolding(batch, shuffle_fraction=0.): """ Pack and pad batch into torch tensors """ alphabet = 'ACDEFGHIKLMNPQRSTVWY' B = len(batch) lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) L_max = max([len(b['seq']) for b in batch]) X = np.zeros([B, L_max, 3, 3]) S = np.zeros([B, L_max], dtype=np.int32) score = np.ones([B, L_max]) * 100.0 chain_mask = np.zeros([B, L_max])-1 # 1:需要被预测的掩码部分 0:可见部分 chain_encoding = np.zeros([B, L_max])-1 # Build the batch for i, b in enumerate(batch): x = np.stack([b[c] for c in ['N', 'CA', 'C']], 1) # [#atom, 4, 3] l = len(b['seq']) x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) # [#atom, 3, 3] X[i,:,:,:] = x_pad # Convert to labels indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False)) if shuffle_fraction > 0.: idx_shuffle = shuffle_subset(l, shuffle_fraction) S[i, :l] = indices[idx_shuffle] else: S[i, :l] = indices chain_mask[i,:l] = b['chain_mask'] chain_encoding[i,:l] = b['chain_encoding'] mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) # atom mask numbers = np.sum(mask, axis=1).astype(np.int) S_new = np.zeros_like(S) X_new = np.zeros_like(X)+np.nan for i, n in enumerate(numbers): X_new[i,:n,::] = X[i][mask[i]==1] S_new[i,:n] = S[i][mask[i]==1] X = X_new S = S_new isnan = np.isnan(X) mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32) X[isnan] = 0. # Conversion S = torch.from_numpy(S).to(dtype=torch.long) score = torch.from_numpy(score).float() X = torch.from_numpy(X).to(dtype=torch.float32) mask = torch.from_numpy(mask).to(dtype=torch.float32) chain_mask = torch.from_numpy(chain_mask) chain_encoding = torch.from_numpy(chain_encoding) return {"title": [b['title'] for b in batch], "X":X, "S":S, "score": score, "mask":mask, "lengths":lengths, "chain_mask":chain_mask, "chain_encoding":chain_encoding}