flexpert / Flexpert-Design /src /models /proteinmpnn_model.py
Honzus24's picture
initial commit
7968cb0
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
from src.modules.graphtrans_module import gather_nodes, cat_neighbors_nodes
from src.modules.proteinmpnn_module import EncLayer, DecLayer, ProteinFeatures
class ProteinMPNN_Model(nn.Module):
def __init__(self, args, **kwargs):
""" Graph labeling network """
super(ProteinMPNN_Model, self).__init__()
# Hyperparameters
self.node_features = args.hidden
self.edge_features = args.hidden
self.hidden_dim = args.hidden
self.k_neighbors = args.k_neighbors
self.augment_eps = args.augment_eps
self.vocab = args.vocab
self.num_letters = args.num_letters
self.num_encoder_layers = args.num_encoder_layers
self.num_decoder_layers = args.num_decoder_layers
self.dropout = args.dropout
self.proteinmpnn_type = args.proteinmpnn_type
if args.proteinmpnn_type == 1:
self.augment_eps = 0.02
self.init_flex_features = args.init_flex_features
self.use_dynamics = args.use_dynamics
# Featurization layers
self.features = ProteinFeatures(self.node_features, self.edge_features, top_k=self.k_neighbors, augment_eps=self.augment_eps, proteinmpnn_type=self.proteinmpnn_type)
self.W_e = nn.Linear(self.edge_features, self.hidden_dim, bias=True)
self.W_s = nn.Embedding(self.vocab, self.hidden_dim)
# import pdb; pdb.set_trace() #TODO check the path is correctly read from the config
self.init_pmpnn_weights = args.use_pmpnn_checkpoint
self.pmpnn_init_weights_path = None if not self.init_pmpnn_weights else args.starting_checkpoint_path
# Encoder layers
self.encoder_layers = nn.ModuleList([
EncLayer(self.hidden_dim, self.hidden_dim*2, dropout=self.dropout, proteinmpnn_type=self.proteinmpnn_type)
for _ in range(self.num_encoder_layers)
])
# Decoder layers
self.decoder_layers = nn.ModuleList([
DecLayer(self.hidden_dim, self.hidden_dim*3, dropout=self.dropout)
for _ in range(self.num_decoder_layers)
])
self.W_out = nn.Linear(self.hidden_dim, self.num_letters, bias=True)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
self._init_params()
# self.gt_flex_cache = {}
def _init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def _autoregressive_mask(self, E_idx):
N_nodes = E_idx.size(1)
ii = torch.arange(N_nodes)
ii = ii.view((1, -1, 1)).to(E_idx.device)
mask = E_idx - ii < 0
mask = mask.type(torch.float32)
return mask
def _get_features(self, batch):
return batch
def forward(self, batch, use_input_decoding_order=False, decoding_order=None):
""" Graph-conditioned sequence model """
X, S, score, mask, lengths, chain_M, chain_M_pos, residue_idx, chain_encoding_all = batch['X'], batch['S'], batch['score'], batch['mask'], batch['lengths'], batch['chain_M'], batch['chain_M_pos'], batch['residue_idx'], batch['chain_encoding_all']
# import pdb; pdb.set_trace()
randn = torch.randn(chain_M.shape, device=X.device)
chain_M = chain_M*chain_M_pos
device = X.device
# Prepare node and edge embeddings
E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
if self.init_flex_features: #and self.use_dynamics:
# gt_seq = batch['S']
# anm_input = batch['enm_vals']
gt_flex = batch['gt_flex']
# trail_idcs = torch.argmax((batch['S'] == 0).int(), dim=1)
# trail_idcs[trail_idcs == 0] = batch['S'].shape[1]
# cache_keys = list(batch['title'])
# # Check if all cache_keys are in self.gt_flex_cache
# all_keys_in_cache = all(cache_key in self.gt_flex_cache for cache_key in cache_keys)
# #TODO: check the keys!!!
# if not all_keys_in_cache:
# gt_flex = self.flex_model(None, anm_input, trail_idcs, attention_mask=batch['mask'], sampled_pmpnn_sequence=gt_seq, alphabet='pmpnn')['predicted_flex'][:,:-1,0]
# for key, val in zip(cache_keys, gt_flex):
# self.gt_flex_cache[key] = val
# else:
# retrieved_gt_flexs = []
# for key in cache_keys:
# _gt_flex = self.gt_flex_cache[key]
# retrieved_gt_flexs.append(_gt_flex)
# gt_flex = torch.cat(retrieved_gt_flexs, dim=0)
h_V = gt_flex.unsqueeze(-1).expand(-1, -1, E.shape[-1]).clone()
h_V = torch.nan_to_num(h_V, nan=0.0)
else:
h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
h_E = self.W_e(E)
# Encoder is unmasked self-attention
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
mask_attend = mask.unsqueeze(-1) * mask_attend
for layer in self.encoder_layers:
h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
# Concatenate sequence embeddings for autoregressive decoder
h_S = self.W_s(S)
h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
# Build encoder embeddings
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
if self.proteinmpnn_type == 4:
mask_attend = self._autoregressive_mask(E_idx).unsqueeze(-1)
else:
chain_M = chain_M*mask #update chain_M to include missing regions
if not use_input_decoding_order:
decoding_order = torch.argsort((chain_M+0.0001)*(torch.abs(randn))) # [8, 901]
# [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
mask_size = E_idx.shape[1]
permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float() # [8, 901, 901]
order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1. - mask_attend)
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
for layer in self.decoder_layers:
# Masked positions attend to encoder information, unmasked see.
h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
h_V = layer(h_V, h_ESV, mask)
logits = self.W_out(h_V)
log_probs = F.log_softmax(logits, dim=-1)
return {'log_probs':log_probs, 'logits':logits}
def sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0, omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None, bias_by_res=None):
device = X.device
# Prepare node and edge embeddings
E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device)
h_E = self.W_e(E)
# Encoder is unmasked self-attention
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
mask_attend = mask.unsqueeze(-1) * mask_attend
for layer in self.encoder_layers:
h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
# Decoder uses masked self-attention
chain_mask = chain_mask*chain_M_pos*mask #update chain_M to include missing regions
decoding_order = torch.argsort((chain_mask+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
mask_size = E_idx.shape[1]
permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1. - mask_attend)
N_batch, N_nodes = X.size(0), X.size(1)
log_probs = torch.zeros((N_batch, N_nodes, 33), device=device)
all_probs = torch.zeros((N_batch, N_nodes, 33), device=device, dtype=torch.float32)
h_S = torch.zeros_like(h_V, device=device)
S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=device)
h_V_stack = [h_V] + [torch.zeros_like(h_V, device=device) for _ in range(len(self.decoder_layers))]
# constant = torch.tensor(omit_AAs_np, device=device)
# constant_bias = torch.tensor(bias_AAs_np, device=device)
#chain_mask_combined = chain_mask*chain_M_pos
omit_AA_mask_flag = omit_AA_mask != None
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
for t_ in range(N_nodes):
t = decoding_order[:,t_] #[B]
chain_mask_gathered = torch.gather(chain_mask, 1, t[:,None]) #[B]
mask_gathered = torch.gather(mask, 1, t[:,None]) #[B]
# bias_by_res_gathered = torch.gather(bias_by_res, 1, t[:,None,None].repeat(1,1,21))[:,0,:] #[B, 21]
if (mask_gathered==0).all(): #for padded or missing regions only
S_t = torch.gather(S_true, 1, t[:,None])
else:
# Hidden layers
E_idx_t = torch.gather(E_idx, 1, t[:,None,None].repeat(1,1,E_idx.shape[-1]))
h_E_t = torch.gather(h_E, 1, t[:,None,None,None].repeat(1,1,h_E.shape[-2], h_E.shape[-1]))
h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
h_EXV_encoder_t = torch.gather(h_EXV_encoder_fw, 1, t[:,None,None,None].repeat(1,1,h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1]))
mask_t = torch.gather(mask, 1, t[:,None])
for l, layer in enumerate(self.decoder_layers):
# Updated relational features for future states
h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
h_V_t = torch.gather(h_V_stack[l], 1, t[:,None,None].repeat(1,1,h_V_stack[l].shape[-1]))
h_ESV_t = torch.gather(mask_bw, 1, t[:,None,None,None].repeat(1,1,mask_bw.shape[-2], mask_bw.shape[-1])) * h_ESV_decoder_t + h_EXV_encoder_t
h_V_stack[l+1].scatter_(1, t[:,None,None].repeat(1,1,h_V.shape[-1]), layer(h_V_t, h_ESV_t, mask_V=mask_t))
# Sampling step
h_V_t = torch.gather(h_V_stack[-1], 1, t[:,None,None].repeat(1,1,h_V_stack[-1].shape[-1]))[:,0]
logits = self.W_out(h_V_t) / temperature
# probs = F.softmax(logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature, dim=-1)
probs = F.softmax(logits, dim=-1)
if pssm_bias_flag:
pssm_coef_gathered = torch.gather(pssm_coef, 1, t[:,None])[:,0]
pssm_bias_gathered = torch.gather(pssm_bias, 1, t[:,None,None].repeat(1,1,pssm_bias.shape[-1]))[:,0]
probs = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered
if pssm_log_odds_flag:
pssm_log_odds_mask_gathered = torch.gather(pssm_log_odds_mask, 1, t[:,None, None].repeat(1,1,pssm_log_odds_mask.shape[-1]))[:,0] #[B, 21]
probs_masked = probs*pssm_log_odds_mask_gathered
probs_masked += probs * 0.001
probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
if omit_AA_mask_flag:
omit_AA_mask_gathered = torch.gather(omit_AA_mask, 1, t[:,None, None].repeat(1,1,omit_AA_mask.shape[-1]))[:,0] #[B, 21]
probs_masked = probs*(1.0-omit_AA_mask_gathered)
probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
# S_t = torch.multinomial(probs, 1)
S_t = probs.argmax(dim=-1, keepdim=True)
all_probs.scatter_(1, t[:,None,None].repeat(1,1,33), (chain_mask_gathered[:,:,None,]*probs[:,None,:]).float())
S_true_gathered = torch.gather(S_true, 1, t[:,None])
S_t = (S_t*chain_mask_gathered+S_true_gathered*(1.0-chain_mask_gathered)).long()
temp1 = self.W_s(S_t)
h_S.scatter_(1, t[:,None,None].repeat(1,1,temp1.shape[-1]), temp1)
S.scatter_(1, t[:,None], S_t)
output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order}
return output_dict
def tied_sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0, omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None, tied_pos=None, tied_beta=None, bias_by_res=None):
device = X.device
# Prepare node and edge embeddings
E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device)
h_E = self.W_e(E)
# Encoder is unmasked self-attention
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
mask_attend = mask.unsqueeze(-1) * mask_attend
for layer in self.encoder_layers:
h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
# Decoder uses masked self-attention
chain_mask = chain_mask*chain_M_pos*mask #update chain_M to include missing regions
decoding_order = torch.argsort((chain_mask+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
new_decoding_order = []
for t_dec in list(decoding_order[0,].cpu().data.numpy()):
if t_dec not in list(itertools.chain(*new_decoding_order)):
list_a = [item for item in tied_pos if t_dec in item]
if list_a:
new_decoding_order.append(list_a[0])
else:
new_decoding_order.append([t_dec])
decoding_order = torch.tensor(list(itertools.chain(*new_decoding_order)), device=device)[None,].repeat(X.shape[0],1)
mask_size = E_idx.shape[1]
permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1. - mask_attend)
N_batch, N_nodes = X.size(0), X.size(1)
log_probs = torch.zeros((N_batch, N_nodes, 21), device=device)
all_probs = torch.zeros((N_batch, N_nodes, 21), device=device, dtype=torch.float32)
h_S = torch.zeros_like(h_V, device=device)
S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=device)
h_V_stack = [h_V] + [torch.zeros_like(h_V, device=device) for _ in range(len(self.decoder_layers))]
constant = torch.tensor(omit_AAs_np, device=device)
constant_bias = torch.tensor(bias_AAs_np, device=device)
omit_AA_mask_flag = omit_AA_mask != None
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
for t_list in new_decoding_order:
logits = 0.0
logit_list = []
done_flag = False
for t in t_list:
if (chain_mask[:,t]==0).all():
S_t = S_true[:,t]
for t in t_list:
h_S[:,t,:] = self.W_s(S_t)
S[:,t] = S_t
done_flag = True
break
else:
E_idx_t = E_idx[:,t:t+1,:]
h_E_t = h_E[:,t:t+1,:,:]
h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
h_EXV_encoder_t = h_EXV_encoder_fw[:,t:t+1,:,:]
mask_t = mask[:,t:t+1]
for l, layer in enumerate(self.decoder_layers):
h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
h_V_t = h_V_stack[l][:,t:t+1,:]
h_ESV_t = mask_bw[:,t:t+1,:,:] * h_ESV_decoder_t + h_EXV_encoder_t
h_V_stack[l+1][:,t,:] = layer(h_V_t, h_ESV_t, mask_V=mask_t).squeeze(1)
h_V_t = h_V_stack[-1][:,t,:]
logit_list.append((self.W_out(h_V_t) / temperature)/len(t_list))
logits += tied_beta[t]*(self.W_out(h_V_t) / temperature)/len(t_list)
if done_flag:
pass
else:
bias_by_res_gathered = bias_by_res[:,t,:] #[B, 21]
probs = F.softmax(logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature, dim=-1)
if pssm_bias_flag:
pssm_coef_gathered = pssm_coef[:,t]
pssm_bias_gathered = pssm_bias[:,t]
probs = (1-pssm_multi*pssm_coef_gathered[:,None])*probs + pssm_multi*pssm_coef_gathered[:,None]*pssm_bias_gathered
if pssm_log_odds_flag:
pssm_log_odds_mask_gathered = pssm_log_odds_mask[:,t]
probs_masked = probs*pssm_log_odds_mask_gathered
probs_masked += probs * 0.001
probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
if omit_AA_mask_flag:
omit_AA_mask_gathered = omit_AA_mask[:,t]
probs_masked = probs*(1.0-omit_AA_mask_gathered)
probs = probs_masked/torch.sum(probs_masked, dim=-1, keepdim=True) #[B, 21]
S_t_repeat = torch.multinomial(probs, 1).squeeze(-1)
for t in t_list:
h_S[:,t,:] = self.W_s(S_t_repeat)
S[:,t] = S_t_repeat
all_probs[:,t,:] = probs.float()
output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order}
return output_dict
def conditional_probs(self, X, S, mask, chain_M, residue_idx, chain_encoding_all, randn, backbone_only=False):
""" Graph-conditioned sequence model """
device=X.device
# Prepare node and edge embeddings
E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
h_V_enc = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
h_E = self.W_e(E)
# Encoder is unmasked self-attention
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
mask_attend = mask.unsqueeze(-1) * mask_attend
for layer in self.encoder_layers:
h_V_enc, h_E = layer(h_V_enc, h_E, E_idx, mask, mask_attend)
# Concatenate sequence embeddings for autoregressive decoder
h_S = self.W_s(S)
h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
# Build encoder embeddings
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
h_EXV_encoder = cat_neighbors_nodes(h_V_enc, h_EX_encoder, E_idx)
chain_M = chain_M*mask #update chain_M to include missing regions
chain_M_np = chain_M.cpu().numpy()
idx_to_loop = np.argwhere(chain_M_np[0,:]==1)[:,0]
log_conditional_probs = torch.zeros([X.shape[0], chain_M.shape[1], 21], device=device).float()
for idx in idx_to_loop:
h_V = torch.clone(h_V_enc)
order_mask = torch.zeros(chain_M.shape[1], device=device).float()
if backbone_only:
order_mask = torch.ones(chain_M.shape[1], device=device).float()
order_mask[idx] = 0.
else:
order_mask = torch.zeros(chain_M.shape[1], device=device).float()
order_mask[idx] = 1.
decoding_order = torch.argsort((order_mask[None,]+0.0001)*(torch.abs(randn))) #[numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
mask_size = E_idx.shape[1]
permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1. - mask_attend)
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
for layer in self.decoder_layers:
# Masked positions attend to encoder information, unmasked see.
h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
h_V = layer(h_V, h_ESV, mask)
logits = self.W_out(h_V)
log_probs = F.log_softmax(logits, dim=-1)
log_conditional_probs[:,idx,:] = log_probs[:,idx,:]
return log_conditional_probs
def unconditional_probs(self, X, mask, residue_idx, chain_encoding_all):
""" Graph-conditioned sequence model """
device=X.device
# Prepare node and edge embeddings
E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
h_E = self.W_e(E)
# Encoder is unmasked self-attention
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
mask_attend = mask.unsqueeze(-1) * mask_attend
for layer in self.encoder_layers:
h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
# Build encoder embeddings
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_V), h_E, E_idx)
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
order_mask_backward = torch.zeros([X.shape[0], X.shape[1], X.shape[1]], device=device)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1. - mask_attend)
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
for layer in self.decoder_layers:
h_V = layer(h_V, h_EXV_encoder_fw, mask)
logits = self.W_out(h_V)
log_probs = F.log_softmax(logits, dim=-1)
return log_probs