Spaces:
Running
on
Zero
Running
on
Zero
| import time | |
| import torch | |
| import math | |
| import torch.nn as nn | |
| from src.tools import gather_nodes, _dihedrals, _get_rbf, _get_dist, _rbf, _orientations_coarse_gl_tuple | |
| import numpy as np | |
| from src.modules.pifold_module import * | |
| from transformers import AutoTokenizer | |
| pair_lst = ['N-N', 'C-C', 'O-O', 'Cb-Cb', 'Ca-N', 'Ca-C', 'Ca-O', 'Ca-Cb', 'N-C', 'N-O', 'N-Cb', 'Cb-C', 'Cb-O', 'O-C', 'N-Ca', 'C-Ca', 'O-Ca', 'Cb-Ca', 'C-N', 'O-N', 'Cb-N', 'C-Cb', 'O-Cb', 'C-O'] | |
| class PiFold_Model(nn.Module): | |
| def __init__(self, args, **kwargs): | |
| """ Graph labeling network """ | |
| super(PiFold_Model, self).__init__() | |
| self.args = args | |
| self.augment_eps = args.augment_eps | |
| node_features = args.node_features | |
| edge_features = args.edge_features | |
| hidden_dim = args.hidden_dim | |
| dropout = args.dropout | |
| num_encoder_layers = args.num_encoder_layers | |
| self.top_k = args.k_neighbors | |
| self.num_rbf = 16 | |
| self.num_positional_embeddings = 16 | |
| self.dihedral_type = args.dihedral_type | |
| self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/") | |
| alphabet = [one for one in 'ACDEFGHIKLMNPQRSTVWYX'] | |
| self.token_mask = torch.tensor([(one in alphabet) for one in self.tokenizer._token_to_id.keys()]) | |
| # self.full_atom_dis = args.full_atom_dis | |
| # node_in = 12 | |
| # node_in = node_in + 9 + 576 # node_in + 9 + 576 | |
| prior_matrix = [ | |
| [-0.58273431, 0.56802827, -0.54067466], | |
| [0.0 , 0.83867057, -0.54463904], | |
| [0.01984028, -0.78380804, -0.54183614], | |
| ] | |
| # prior_matrix = torch.rand(self.args.virtual_num,3) | |
| self.virtual_atoms = nn.Parameter(torch.tensor(prior_matrix)[:self.args.virtual_num,:]) | |
| # num_va = self.virtual_atoms.shape[0] | |
| # edge_in = (15 + 9 * num_va + (num_va - 1) * num_va) * 16 + 16 + 7 | |
| node_in = 0 | |
| if self.args.node_dist: | |
| pair_num = 6 | |
| if self.args.virtual_num>0: | |
| pair_num += self.args.virtual_num*(self.args.virtual_num-1) | |
| node_in += pair_num*self.num_rbf | |
| if self.args.node_angle: | |
| node_in += 12 | |
| if self.args.node_direct: | |
| node_in += 9 | |
| edge_in = 0 | |
| if self.args.edge_dist: | |
| pair_num = 0 | |
| if self.args.Ca_Ca: | |
| pair_num += 1 | |
| if self.args.Ca_C: | |
| pair_num += 2 | |
| if self.args.Ca_N: | |
| pair_num += 2 | |
| if self.args.Ca_O: | |
| pair_num += 2 | |
| if self.args.C_C: | |
| pair_num += 1 | |
| if self.args.C_N: | |
| pair_num += 2 | |
| if self.args.C_O: | |
| pair_num += 2 | |
| if self.args.N_N: | |
| pair_num += 1 | |
| if self.args.N_O: | |
| pair_num += 2 | |
| if self.args.O_O: | |
| pair_num += 1 | |
| if self.args.virtual_num>0: | |
| pair_num += self.args.virtual_num | |
| pair_num += self.args.virtual_num*(self.args.virtual_num-1) | |
| edge_in += pair_num*self.num_rbf | |
| if self.args.edge_angle: | |
| edge_in += 4 | |
| if self.args.edge_direct: | |
| edge_in += 12 | |
| if self.args.use_gvp_feat: | |
| node_in = 12 | |
| edge_in = 48-16 | |
| edge_in += 16+16 # position encoding, chain encoding | |
| self.node_embedding = nn.Linear(node_in, node_features, bias=True) | |
| self.edge_embedding = nn.Linear(edge_in, edge_features, bias=True) | |
| self.norm_nodes = nn.BatchNorm1d(node_features) | |
| self.norm_edges = nn.BatchNorm1d(edge_features) | |
| self.W_v = nn.Sequential( | |
| nn.Linear(node_features, hidden_dim, bias=True), | |
| nn.LeakyReLU(), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.Linear(hidden_dim, hidden_dim, bias=True), | |
| nn.LeakyReLU(), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.Linear(hidden_dim, hidden_dim, bias=True) | |
| ) | |
| self.W_e = nn.Linear(edge_features, hidden_dim, bias=True) | |
| self.W_f = nn.Linear(edge_features, hidden_dim, bias=True) | |
| self.encoder = StructureEncoder(hidden_dim, num_encoder_layers, dropout, args.updating_edges, args.att_output_mlp, args.node_output_mlp, args.node_net, args.edge_net, args.node_context, args.edge_context) | |
| # self.decoder = CNNDecoder(hidden_dim, hidden_dim, args.num_decoder_layers1, args.kernel_size1, args.act_type, args.glu) | |
| # self.decoder2 = CNNDecoder2(hidden_dim, hidden_dim, args.num_decoder_layers2, args.kernel_size2, args.act_type, args.glu) | |
| self.decoder = MLPDecoder(hidden_dim, hidden_dim, args.num_decoder_layers1, args.kernel_size1, args.act_type, args.glu, vocab=len(self.tokenizer._token_to_id)) | |
| # self.chain_embed = nn.Embedding(2,16) | |
| self._init_params() | |
| self.encode_t = 0 | |
| self.decode_t = 0 | |
| def forward(self, batch): | |
| h_V, h_P, P_idx, batch_id = batch['_V'], batch['_E'], batch['E_idx'], batch['batch_id'] | |
| t1 = time.time() | |
| h_V = self.W_v(self.norm_nodes(self.node_embedding(h_V))) | |
| h_P = self.W_e(self.norm_edges(self.edge_embedding(h_P))) | |
| h_V, h_P = self.encoder(h_V, h_P, P_idx, batch_id) | |
| t2 = time.time() | |
| log_probs, logits = self.decoder(h_V, batch_id, self.token_mask) | |
| # log_probs, logits = self.decoder2(h_V, logits, batch_id) | |
| t3 = time.time() | |
| self.encode_t += t2-t1 | |
| self.decode_t += t3-t2 | |
| # return log_probs, log_probs0 | |
| return {'log_probs': log_probs} | |
| def _init_params(self): | |
| for name, p in self.named_parameters(): | |
| if name == 'virtual_atoms': | |
| continue | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def _full_dist(self, X, mask, top_k=30, eps=1E-6): | |
| mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2) | |
| dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2) | |
| D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps) | |
| D_max, _ = torch.max(D, -1, keepdim=True) | |
| D_adjust = D + (1. - mask_2D) * (D_max+1) | |
| D_neighbors, E_idx = torch.topk(D_adjust, min(top_k, D_adjust.shape[-1]), dim=-1, largest=False) | |
| return D_neighbors, E_idx | |
| def _get_features(self, batch): | |
| S, score, X, mask, chain_mask, chain_encoding = batch['S'], batch['score'], batch['X'], batch['mask'], batch['chain_mask'], batch['chain_encoding'] | |
| device = X.device | |
| mask_bool = (mask==1) | |
| B, N, _,_ = X.shape | |
| X_ca = X[:,:,1,:] | |
| D_neighbors, E_idx = self._full_dist(X_ca, mask, self.top_k) | |
| mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) | |
| mask_attend = (mask.unsqueeze(-1) * mask_attend) == 1 | |
| edge_mask_select = lambda x: torch.masked_select(x, mask_attend.unsqueeze(-1)).reshape(-1,x.shape[-1]) | |
| node_mask_select = lambda x: torch.masked_select(x, mask_bool.unsqueeze(-1)).reshape(-1, x.shape[-1]) | |
| # sequence | |
| S = torch.masked_select(S, mask_bool) | |
| if score is not None: | |
| score = torch.masked_select(score, mask_bool) | |
| chain_mask = torch.masked_select(chain_mask, mask_bool) | |
| chain_encoding = torch.masked_select(chain_encoding, mask_bool) | |
| # angle & direction | |
| V_angles = _dihedrals(X, self.dihedral_type) | |
| V_angles = node_mask_select(V_angles) | |
| V_direct, E_direct, E_angles = _orientations_coarse_gl_tuple(X, E_idx) | |
| V_direct = node_mask_select(V_direct) | |
| E_direct = edge_mask_select(E_direct) | |
| E_angles = edge_mask_select(E_angles) | |
| # distance | |
| atom_N = X[:,:,0,:] | |
| atom_Ca = X[:,:,1,:] | |
| atom_C = X[:,:,2,:] | |
| atom_O = X[:,:,3,:] | |
| b = atom_Ca - atom_N | |
| c = atom_C - atom_Ca | |
| a = torch.cross(b, c, dim=-1) | |
| if self.args.virtual_num>0: | |
| virtual_atoms = self.virtual_atoms / torch.norm(self.virtual_atoms, dim=1, keepdim=True) | |
| for i in range(self.virtual_atoms.shape[0]): | |
| vars()['atom_v' + str(i)] = virtual_atoms[i][0] * a \ | |
| + virtual_atoms[i][1] * b \ | |
| + virtual_atoms[i][2] * c \ | |
| + 1 * atom_Ca | |
| node_list = ['Ca-N', 'Ca-C', 'Ca-O', 'N-C', 'N-O', 'O-C'] | |
| node_dist = [] | |
| for pair in node_list: | |
| atom1, atom2 = pair.split('-') | |
| node_dist.append( node_mask_select(_get_rbf(vars()['atom_' + atom1], vars()['atom_' + atom2], None, self.num_rbf).squeeze())) | |
| if self.args.virtual_num>0: | |
| for i in range(self.virtual_atoms.shape[0]): | |
| for j in range(0, i): | |
| node_dist.append(node_mask_select(_get_rbf(vars()['atom_v' + str(i)], vars()['atom_v' + str(j)], None, self.num_rbf).squeeze())) | |
| node_dist.append(node_mask_select(_get_rbf(vars()['atom_v' + str(j)], vars()['atom_v' + str(i)], None, self.num_rbf).squeeze())) | |
| V_dist = torch.cat(tuple(node_dist), dim=-1).squeeze() | |
| pair_lst = [] | |
| if self.args.Ca_Ca: | |
| pair_lst.append('Ca-Ca') | |
| if self.args.Ca_C: | |
| pair_lst.append('Ca-C') | |
| pair_lst.append('C-Ca') | |
| if self.args.Ca_N: | |
| pair_lst.append('Ca-N') | |
| pair_lst.append('N-Ca') | |
| if self.args.Ca_O: | |
| pair_lst.append('Ca-O') | |
| pair_lst.append('O-Ca') | |
| if self.args.C_C: | |
| pair_lst.append('C-C') | |
| if self.args.C_N: | |
| pair_lst.append('C-N') | |
| pair_lst.append('N-C') | |
| if self.args.C_O: | |
| pair_lst.append('C-O') | |
| pair_lst.append('O-C') | |
| if self.args.N_N: | |
| pair_lst.append('N-N') | |
| if self.args.N_O: | |
| pair_lst.append('N-O') | |
| pair_lst.append('O-N') | |
| if self.args.O_O: | |
| pair_lst.append('O-O') | |
| edge_dist = [] #Ca-Ca | |
| for pair in pair_lst: | |
| atom1, atom2 = pair.split('-') | |
| rbf = _get_rbf(vars()['atom_' + atom1], vars()['atom_' + atom2], E_idx, self.num_rbf) | |
| edge_dist.append(edge_mask_select(rbf)) | |
| if self.args.virtual_num>0: | |
| for i in range(self.virtual_atoms.shape[0]): | |
| edge_dist.append(edge_mask_select(_get_rbf(vars()['atom_v' + str(i)], vars()['atom_v' + str(i)], E_idx, self.num_rbf))) | |
| for j in range(0, i): | |
| edge_dist.append(edge_mask_select(_get_rbf(vars()['atom_v' + str(i)], vars()['atom_v' + str(j)], E_idx, self.num_rbf))) | |
| edge_dist.append(edge_mask_select(_get_rbf(vars()['atom_v' + str(j)], vars()['atom_v' + str(i)], E_idx, self.num_rbf))) | |
| E_dist = torch.cat(tuple(edge_dist), dim=-1) | |
| h_V = [] | |
| if self.args.node_dist: | |
| h_V.append(V_dist) | |
| if self.args.node_angle: | |
| h_V.append(V_angles) | |
| if self.args.node_direct: | |
| h_V.append(V_direct) | |
| h_E = [] | |
| if self.args.edge_dist: | |
| h_E.append(E_dist) | |
| if self.args.edge_angle: | |
| h_E.append(E_angles) | |
| if self.args.edge_direct: | |
| h_E.append(E_direct) | |
| _V = torch.cat(h_V, dim=-1) | |
| _E = torch.cat(h_E, dim=-1) | |
| # edge index | |
| shift = mask.sum(dim=1).cumsum(dim=0) - mask.sum(dim=1) | |
| src = shift.view(B,1,1) + E_idx | |
| src = torch.masked_select(src, mask_attend).view(1,-1) | |
| dst = shift.view(B,1,1) + torch.arange(0, N, device=src.device).view(1,-1,1).expand_as(mask_attend) | |
| dst = torch.masked_select(dst, mask_attend).view(1,-1) | |
| E_idx = torch.cat((dst, src), dim=0).long() | |
| pos_embed = self._positional_embeddings(E_idx, 16) | |
| _E = torch.cat([_E, pos_embed], dim=-1) | |
| d_chains = ((chain_encoding[dst.long()] - chain_encoding[src.long()])==0).long().reshape(-1) | |
| chain_embed = self._idx_embeddings(d_chains) | |
| _E = torch.cat([_E, chain_embed], dim=-1) | |
| # 3D point | |
| sparse_idx = mask.nonzero() # index of non-zero values | |
| X = X[sparse_idx[:,0], sparse_idx[:,1], :, :] | |
| batch_id = sparse_idx[:,0] | |
| mask = torch.masked_select(mask, mask_bool) | |
| batch.update({'X':X, | |
| 'S':S, | |
| 'score':score, | |
| '_V':_V, | |
| '_E':_E, | |
| 'E_idx':E_idx, | |
| 'batch_id': batch_id, | |
| 'mask': mask, | |
| 'chain_mask': chain_mask, | |
| 'chain_encoding': chain_encoding}) | |
| return batch | |
| def _positional_embeddings(self, E_idx, | |
| num_embeddings=None): | |
| # From https://github.com/jingraham/neurips19-graph-protein-design | |
| num_embeddings = num_embeddings or self.num_positional_embeddings | |
| d = E_idx[0]-E_idx[1] | |
| frequency = torch.exp( | |
| torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=E_idx.device) | |
| * -(np.log(10000.0) / num_embeddings) | |
| ) | |
| angles = d[:,None] * frequency[None,:] | |
| E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) | |
| return E | |
| def _idx_embeddings(self, d, | |
| num_embeddings=None): | |
| # From https://github.com/jingraham/neurips19-graph-protein-design | |
| num_embeddings = num_embeddings or self.num_positional_embeddings | |
| frequency = torch.exp( | |
| torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=d.device) | |
| * -(np.log(10000.0) / num_embeddings) | |
| ) | |
| angles = d[:,None] * frequency[None,:] | |
| E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) | |
| return E | |