Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import itertools | |
| import numpy as np | |
| from src.modules.graphtrans_module import gather_nodes, cat_neighbors_nodes | |
| from src.modules.proteinmpnn_module import EncLayer, DecLayer, ProteinFeatures | |
| class ProteinMPNN_Model(nn.Module): | |
| def __init__(self, args, **kwargs): | |
| """ Graph labeling network """ | |
| super(ProteinMPNN_Model, self).__init__() | |
| # Hyperparameters | |
| self.node_features = args.hidden | |
| self.edge_features = args.hidden | |
| self.hidden_dim = args.hidden | |
| self.k_neighbors = args.k_neighbors | |
| self.augment_eps = args.augment_eps | |
| self.vocab = args.vocab | |
| self.num_letters = args.num_letters | |
| self.num_encoder_layers = args.num_encoder_layers | |
| self.num_decoder_layers = args.num_decoder_layers | |
| self.dropout = args.dropout | |
| self.proteinmpnn_type = args.proteinmpnn_type | |
| if args.proteinmpnn_type == 1: | |
| self.augment_eps = 0.02 | |
| self.init_flex_features = args.init_flex_features | |
| self.use_dynamics = args.use_dynamics | |
| # Featurization layers | |
| self.features = ProteinFeatures(self.node_features, self.edge_features, top_k=self.k_neighbors, augment_eps=self.augment_eps, proteinmpnn_type=self.proteinmpnn_type) | |
| self.W_e = nn.Linear(self.edge_features, self.hidden_dim, bias=True) | |
| self.W_s = nn.Embedding(self.vocab, self.hidden_dim) | |
| # import pdb; pdb.set_trace() #TODO check the path is correctly read from the config | |
| self.init_pmpnn_weights = args.use_pmpnn_checkpoint | |
| self.pmpnn_init_weights_path = None if not self.init_pmpnn_weights else args.starting_checkpoint_path | |
| # Encoder layers | |
| self.encoder_layers = nn.ModuleList([ | |
| EncLayer(self.hidden_dim, self.hidden_dim*2, dropout=self.dropout, proteinmpnn_type=self.proteinmpnn_type) | |
| for _ in range(self.num_encoder_layers) | |
| ]) | |
| # Decoder layers | |
| self.decoder_layers = nn.ModuleList([ | |
| DecLayer(self.hidden_dim, self.hidden_dim*3, dropout=self.dropout) | |
| for _ in range(self.num_decoder_layers) | |
| ]) | |
| self.W_out = nn.Linear(self.hidden_dim, self.num_letters, bias=True) | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| self._init_params() | |
| # self.gt_flex_cache = {} | |
| def _init_params(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def _autoregressive_mask(self, E_idx): | |
| N_nodes = E_idx.size(1) | |
| ii = torch.arange(N_nodes) | |
| ii = ii.view((1, -1, 1)).to(E_idx.device) | |
| mask = E_idx - ii < 0 | |
| mask = mask.type(torch.float32) | |
| return mask | |
| def _get_features(self, batch): | |
| return batch | |
| def forward(self, batch, use_input_decoding_order=False, decoding_order=None): | |
| """ Graph-conditioned sequence model """ | |
| X, S, score, mask, lengths, chain_M, chain_M_pos, residue_idx, chain_encoding_all = batch['X'], batch['S'], batch['score'], batch['mask'], batch['lengths'], batch['chain_M'], batch['chain_M_pos'], batch['residue_idx'], batch['chain_encoding_all'] | |
| # import pdb; pdb.set_trace() | |
| randn = torch.randn(chain_M.shape, device=X.device) | |
| chain_M = chain_M*chain_M_pos | |
| device = X.device | |
| # Prepare node and edge embeddings | |
| E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all) | |
| if self.init_flex_features: #and self.use_dynamics: | |
| # gt_seq = batch['S'] | |
| # anm_input = batch['enm_vals'] | |
| gt_flex = batch['gt_flex'] | |
| # trail_idcs = torch.argmax((batch['S'] == 0).int(), dim=1) | |
| # trail_idcs[trail_idcs == 0] = batch['S'].shape[1] | |
| # cache_keys = list(batch['title']) | |
| # # Check if all cache_keys are in self.gt_flex_cache | |
| # all_keys_in_cache = all(cache_key in self.gt_flex_cache for cache_key in cache_keys) | |
| # #TODO: check the keys!!! | |
| # if not all_keys_in_cache: | |
| # gt_flex = self.flex_model(None, anm_input, trail_idcs, attention_mask=batch['mask'], sampled_pmpnn_sequence=gt_seq, alphabet='pmpnn')['predicted_flex'][:,:-1,0] | |
| # for key, val in zip(cache_keys, gt_flex): | |
| # self.gt_flex_cache[key] = val | |
| # else: | |
| # retrieved_gt_flexs = [] | |
| # for key in cache_keys: | |
| # _gt_flex = self.gt_flex_cache[key] | |
| # retrieved_gt_flexs.append(_gt_flex) | |
| # gt_flex = torch.cat(retrieved_gt_flexs, dim=0) | |
| h_V = gt_flex.unsqueeze(-1).expand(-1, -1, E.shape[-1]).clone() | |
| h_V = torch.nan_to_num(h_V, nan=0.0) | |
| else: | |
| h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device) | |
| h_E = self.W_e(E) | |
| # Encoder is unmasked self-attention | |
| mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) | |
| mask_attend = mask.unsqueeze(-1) * mask_attend | |
| for layer in self.encoder_layers: | |
| h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) | |
| # Concatenate sequence embeddings for autoregressive decoder | |
| h_S = self.W_s(S) | |
| h_ES = cat_neighbors_nodes(h_S, h_E, E_idx) | |
| # Build encoder embeddings | |
| h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) | |
| h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) | |
| if self.proteinmpnn_type == 4: | |
| mask_attend = self._autoregressive_mask(E_idx).unsqueeze(-1) | |
| else: | |
| chain_M = chain_M*mask #update chain_M to include missing regions | |
| if not use_input_decoding_order: | |
| decoding_order = torch.argsort((chain_M+0.0001)*(torch.abs(randn))) # [8, 901] | |
| # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] | |
| mask_size = E_idx.shape[1] | |
| permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float() # [8, 901, 901] | |
| order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse) | |
| mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) | |
| mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) | |
| mask_bw = mask_1D * mask_attend | |
| mask_fw = mask_1D * (1. - mask_attend) | |
| h_EXV_encoder_fw = mask_fw * h_EXV_encoder | |
| for layer in self.decoder_layers: | |
| # Masked positions attend to encoder information, unmasked see. | |
| h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx) | |
| h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw | |
| h_V = layer(h_V, h_ESV, mask) | |
| logits = self.W_out(h_V) | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| return {'log_probs':log_probs, 'logits':logits} | |
| def sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0, omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None, bias_by_res=None): | |
| device = X.device | |
| # Prepare node and edge embeddings | |
| E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all) | |
| h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device) | |
| h_E = self.W_e(E) | |
| # Encoder is unmasked self-attention | |
| mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) | |
| mask_attend = mask.unsqueeze(-1) * mask_attend | |
| for layer in self.encoder_layers: | |
| h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) | |
| # Decoder uses masked self-attention | |
| chain_mask = chain_mask*chain_M_pos*mask #update chain_M to include missing regions | |
| decoding_order = torch.argsort((chain_mask+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] | |
| mask_size = E_idx.shape[1] | |
| permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float() | |
| order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse) | |
| mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) | |
| mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) | |
| mask_bw = mask_1D * mask_attend | |
| mask_fw = mask_1D * (1. - mask_attend) | |
| N_batch, N_nodes = X.size(0), X.size(1) | |
| log_probs = torch.zeros((N_batch, N_nodes, 33), device=device) | |
| all_probs = torch.zeros((N_batch, N_nodes, 33), device=device, dtype=torch.float32) | |
| h_S = torch.zeros_like(h_V, device=device) | |
| S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=device) | |
| h_V_stack = [h_V] + [torch.zeros_like(h_V, device=device) for _ in range(len(self.decoder_layers))] | |
| # constant = torch.tensor(omit_AAs_np, device=device) | |
| # constant_bias = torch.tensor(bias_AAs_np, device=device) | |
| #chain_mask_combined = chain_mask*chain_M_pos | |
| omit_AA_mask_flag = omit_AA_mask != None | |
| h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) | |
| h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) | |
| h_EXV_encoder_fw = mask_fw * h_EXV_encoder | |
| for t_ in range(N_nodes): | |
| t = decoding_order[:,t_] #[B] | |
| chain_mask_gathered = torch.gather(chain_mask, 1, t[:,None]) #[B] | |
| mask_gathered = torch.gather(mask, 1, t[:,None]) #[B] | |
| # bias_by_res_gathered = torch.gather(bias_by_res, 1, t[:,None,None].repeat(1,1,21))[:,0,:] #[B, 21] | |
| if (mask_gathered==0).all(): #for padded or missing regions only | |
| S_t = torch.gather(S_true, 1, t[:,None]) | |
| else: | |
| # Hidden layers | |
| E_idx_t = torch.gather(E_idx, 1, t[:,None,None].repeat(1,1,E_idx.shape[-1])) | |
| h_E_t = torch.gather(h_E, 1, t[:,None,None,None].repeat(1,1,h_E.shape[-2], h_E.shape[-1])) | |
| h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t) | |
| h_EXV_encoder_t = torch.gather(h_EXV_encoder_fw, 1, t[:,None,None,None].repeat(1,1,h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1])) | |
| mask_t = torch.gather(mask, 1, t[:,None]) | |
| for l, layer in enumerate(self.decoder_layers): | |
| # Updated relational features for future states | |
| h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t) | |
| h_V_t = torch.gather(h_V_stack[l], 1, t[:,None,None].repeat(1,1,h_V_stack[l].shape[-1])) | |
| h_ESV_t = torch.gather(mask_bw, 1, t[:,None,None,None].repeat(1,1,mask_bw.shape[-2], mask_bw.shape[-1])) * h_ESV_decoder_t + h_EXV_encoder_t | |
| h_V_stack[l+1].scatter_(1, t[:,None,None].repeat(1,1,h_V.shape[-1]), layer(h_V_t, h_ESV_t, mask_V=mask_t)) | |
| # Sampling step | |
| h_V_t = torch.gather(h_V_stack[-1], 1, t[:,None,None].repeat(1,1,h_V_stack[-1].shape[-1]))[:,0] | |
| logits = self.W_out(h_V_t) / temperature | |
| # probs = F.softmax(logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature, dim=-1) | |
| probs = F.softmax(logits, dim=-1) | |
| if pssm_bias_flag: | |
| pssm_coef_gathered = torch.gather(pssm_coef, 1, t[:,None])[:,0] | |
| pssm_bias_gathered = torch.gather(pssm_bias, 1, t[:,None,None].repeat(1,1,pssm_bias.shape[-1]))[:,0] | |
| probs = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered | |
| if pssm_log_odds_flag: | |
| pssm_log_odds_mask_gathered = torch.gather(pssm_log_odds_mask, 1, t[:,None, None].repeat(1,1,pssm_log_odds_mask.shape[-1]))[:,0] #[B, 21] | |
| probs_masked = probs*pssm_log_odds_mask_gathered | |
| probs_masked += probs * 0.001 | |
| probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21] | |
| if omit_AA_mask_flag: | |
| omit_AA_mask_gathered = torch.gather(omit_AA_mask, 1, t[:,None, None].repeat(1,1,omit_AA_mask.shape[-1]))[:,0] #[B, 21] | |
| probs_masked = probs*(1.0-omit_AA_mask_gathered) | |
| probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21] | |
| # S_t = torch.multinomial(probs, 1) | |
| S_t = probs.argmax(dim=-1, keepdim=True) | |
| all_probs.scatter_(1, t[:,None,None].repeat(1,1,33), (chain_mask_gathered[:,:,None,]*probs[:,None,:]).float()) | |
| S_true_gathered = torch.gather(S_true, 1, t[:,None]) | |
| S_t = (S_t*chain_mask_gathered+S_true_gathered*(1.0-chain_mask_gathered)).long() | |
| temp1 = self.W_s(S_t) | |
| h_S.scatter_(1, t[:,None,None].repeat(1,1,temp1.shape[-1]), temp1) | |
| S.scatter_(1, t[:,None], S_t) | |
| output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order} | |
| return output_dict | |
| def tied_sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0, omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None, tied_pos=None, tied_beta=None, bias_by_res=None): | |
| device = X.device | |
| # Prepare node and edge embeddings | |
| E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all) | |
| h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device) | |
| h_E = self.W_e(E) | |
| # Encoder is unmasked self-attention | |
| mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) | |
| mask_attend = mask.unsqueeze(-1) * mask_attend | |
| for layer in self.encoder_layers: | |
| h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) | |
| # Decoder uses masked self-attention | |
| chain_mask = chain_mask*chain_M_pos*mask #update chain_M to include missing regions | |
| decoding_order = torch.argsort((chain_mask+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] | |
| new_decoding_order = [] | |
| for t_dec in list(decoding_order[0,].cpu().data.numpy()): | |
| if t_dec not in list(itertools.chain(*new_decoding_order)): | |
| list_a = [item for item in tied_pos if t_dec in item] | |
| if list_a: | |
| new_decoding_order.append(list_a[0]) | |
| else: | |
| new_decoding_order.append([t_dec]) | |
| decoding_order = torch.tensor(list(itertools.chain(*new_decoding_order)), device=device)[None,].repeat(X.shape[0],1) | |
| mask_size = E_idx.shape[1] | |
| permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float() | |
| order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse) | |
| mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) | |
| mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) | |
| mask_bw = mask_1D * mask_attend | |
| mask_fw = mask_1D * (1. - mask_attend) | |
| N_batch, N_nodes = X.size(0), X.size(1) | |
| log_probs = torch.zeros((N_batch, N_nodes, 21), device=device) | |
| all_probs = torch.zeros((N_batch, N_nodes, 21), device=device, dtype=torch.float32) | |
| h_S = torch.zeros_like(h_V, device=device) | |
| S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=device) | |
| h_V_stack = [h_V] + [torch.zeros_like(h_V, device=device) for _ in range(len(self.decoder_layers))] | |
| constant = torch.tensor(omit_AAs_np, device=device) | |
| constant_bias = torch.tensor(bias_AAs_np, device=device) | |
| omit_AA_mask_flag = omit_AA_mask != None | |
| h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) | |
| h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) | |
| h_EXV_encoder_fw = mask_fw * h_EXV_encoder | |
| for t_list in new_decoding_order: | |
| logits = 0.0 | |
| logit_list = [] | |
| done_flag = False | |
| for t in t_list: | |
| if (chain_mask[:,t]==0).all(): | |
| S_t = S_true[:,t] | |
| for t in t_list: | |
| h_S[:,t,:] = self.W_s(S_t) | |
| S[:,t] = S_t | |
| done_flag = True | |
| break | |
| else: | |
| E_idx_t = E_idx[:,t:t+1,:] | |
| h_E_t = h_E[:,t:t+1,:,:] | |
| h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t) | |
| h_EXV_encoder_t = h_EXV_encoder_fw[:,t:t+1,:,:] | |
| mask_t = mask[:,t:t+1] | |
| for l, layer in enumerate(self.decoder_layers): | |
| h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t) | |
| h_V_t = h_V_stack[l][:,t:t+1,:] | |
| h_ESV_t = mask_bw[:,t:t+1,:,:] * h_ESV_decoder_t + h_EXV_encoder_t | |
| h_V_stack[l+1][:,t,:] = layer(h_V_t, h_ESV_t, mask_V=mask_t).squeeze(1) | |
| h_V_t = h_V_stack[-1][:,t,:] | |
| logit_list.append((self.W_out(h_V_t) / temperature)/len(t_list)) | |
| logits += tied_beta[t]*(self.W_out(h_V_t) / temperature)/len(t_list) | |
| if done_flag: | |
| pass | |
| else: | |
| bias_by_res_gathered = bias_by_res[:,t,:] #[B, 21] | |
| probs = F.softmax(logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature, dim=-1) | |
| if pssm_bias_flag: | |
| pssm_coef_gathered = pssm_coef[:,t] | |
| pssm_bias_gathered = pssm_bias[:,t] | |
| probs = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered | |
| if pssm_log_odds_flag: | |
| pssm_log_odds_mask_gathered = pssm_log_odds_mask[:,t] | |
| probs_masked = probs*pssm_log_odds_mask_gathered | |
| probs_masked += probs * 0.001 | |
| probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21] | |
| if omit_AA_mask_flag: | |
| omit_AA_mask_gathered = omit_AA_mask[:,t] | |
| probs_masked = probs*(1.0-omit_AA_mask_gathered) | |
| probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21] | |
| S_t_repeat = torch.multinomial(probs, 1).squeeze(-1) | |
| for t in t_list: | |
| h_S[:,t,:] = self.W_s(S_t_repeat) | |
| S[:,t] = S_t_repeat | |
| all_probs[:,t,:] = probs.float() | |
| output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order} | |
| return output_dict | |
| def conditional_probs(self, X, S, mask, chain_M, residue_idx, chain_encoding_all, randn, backbone_only=False): | |
| """ Graph-conditioned sequence model """ | |
| device=X.device | |
| # Prepare node and edge embeddings | |
| E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all) | |
| h_V_enc = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device) | |
| h_E = self.W_e(E) | |
| # Encoder is unmasked self-attention | |
| mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) | |
| mask_attend = mask.unsqueeze(-1) * mask_attend | |
| for layer in self.encoder_layers: | |
| h_V_enc, h_E = layer(h_V_enc, h_E, E_idx, mask, mask_attend) | |
| # Concatenate sequence embeddings for autoregressive decoder | |
| h_S = self.W_s(S) | |
| h_ES = cat_neighbors_nodes(h_S, h_E, E_idx) | |
| # Build encoder embeddings | |
| h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) | |
| h_EXV_encoder = cat_neighbors_nodes(h_V_enc, h_EX_encoder, E_idx) | |
| chain_M = chain_M*mask #update chain_M to include missing regions | |
| chain_M_np = chain_M.cpu().numpy() | |
| idx_to_loop = np.argwhere(chain_M_np[0,:]==1)[:,0] | |
| log_conditional_probs = torch.zeros([X.shape[0], chain_M.shape[1], 21], device=device).float() | |
| for idx in idx_to_loop: | |
| h_V = torch.clone(h_V_enc) | |
| order_mask = torch.zeros(chain_M.shape[1], device=device).float() | |
| if backbone_only: | |
| order_mask = torch.ones(chain_M.shape[1], device=device).float() | |
| order_mask[idx] = 0. | |
| else: | |
| order_mask = torch.zeros(chain_M.shape[1], device=device).float() | |
| order_mask[idx] = 1. | |
| decoding_order = torch.argsort((order_mask[None,]+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] | |
| mask_size = E_idx.shape[1] | |
| permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float() | |
| order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse) | |
| mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) | |
| mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) | |
| mask_bw = mask_1D * mask_attend | |
| mask_fw = mask_1D * (1. - mask_attend) | |
| h_EXV_encoder_fw = mask_fw * h_EXV_encoder | |
| for layer in self.decoder_layers: | |
| # Masked positions attend to encoder information, unmasked see. | |
| h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx) | |
| h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw | |
| h_V = layer(h_V, h_ESV, mask) | |
| logits = self.W_out(h_V) | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| log_conditional_probs[:,idx,:] = log_probs[:,idx,:] | |
| return log_conditional_probs | |
| def unconditional_probs(self, X, mask, residue_idx, chain_encoding_all): | |
| """ Graph-conditioned sequence model """ | |
| device=X.device | |
| # Prepare node and edge embeddings | |
| E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all) | |
| h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device) | |
| h_E = self.W_e(E) | |
| # Encoder is unmasked self-attention | |
| mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) | |
| mask_attend = mask.unsqueeze(-1) * mask_attend | |
| for layer in self.encoder_layers: | |
| h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) | |
| # Build encoder embeddings | |
| h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_V), h_E, E_idx) | |
| h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) | |
| order_mask_backward = torch.zeros([X.shape[0], X.shape[1], X.shape[1]], device=device) | |
| mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) | |
| mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) | |
| mask_bw = mask_1D * mask_attend | |
| mask_fw = mask_1D * (1. - mask_attend) | |
| h_EXV_encoder_fw = mask_fw * h_EXV_encoder | |
| for layer in self.decoder_layers: | |
| h_V = layer(h_V, h_EXV_encoder_fw, mask) | |
| logits = self.W_out(h_V) | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| return log_probs |