import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class PositionalEncodings(nn.Module): def __init__(self, num_embeddings, period_range = None): if period_range is None: period_range = [2,1000] super(PositionalEncodings, self).__init__() self.num_embeddings = num_embeddings self.period_range = period_range def forward(self, E_idx): N_nodes = E_idx.size(1) ii = torch.arange(N_nodes, dtype=torch.float32, device = E_idx.device).view((1, -1, 1)) d = (E_idx.float() - ii).unsqueeze(-1) # Original Transformer frequencies frequency = torch.exp(torch.arange(0, self.num_embeddings, 2, dtype=torch.float32, device = E_idx.device) * -(np.log(10000.0) / self.num_embeddings)) angles = d * frequency.view((1,1,1,-1)) return torch.cat((torch.cos(angles), torch.sin(angles)), -1) class ProteinFeatures(nn.Module): def __init__(self, edge_features, node_features, num_positional_embeddings=16, num_rbf=16, top_k=30, features_type='full', augment_eps=0., dropout=0.1): super(ProteinFeatures, self).__init__() """Extract Protein Features""" self.edge_features = edge_features self.node_features = node_features self.top_k = top_k self.augment_eps = augment_eps self.num_rbf = num_rbf self.num_positional_embeddings = num_positional_embeddings ## Feature types ## self.features_type = features_type self.feature_dimensions = { 'coarse': (3, num_positional_embeddings + num_rbf + 7), 'full': (6, num_positional_embeddings + num_rbf + 7), 'dist': (6, num_positional_embeddings + num_rbf), 'hbonds': (3, 2 * num_positional_embeddings)} ## Positional encoding ## self.embeddings = PositionalEncodings(num_positional_embeddings) self.dropout = nn.Dropout(dropout) ## Normalization and embedding ## node_in, edge_in = self.feature_dimensions[features_type] self.node_embedding = nn.Linear(node_in, node_features, bias=True) self.edge_embedding = nn.Linear(edge_in, edge_features, bias=True) self.norm_nodes = Normalize(node_features) self.norm_edges = Normalize(edge_features) def _dist(self, X, mask, eps=1E-6): """ Pairwise Euclidean Distance """ mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2) dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2) D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps) D_max, _ = torch.max(D, -1, keepdim=True) D_adjust = D + (1. - mask_2D) * (D_max+1) D_neighbors, E_idx = torch.topk(D_adjust, min(self.top_k, D_adjust.shape[-1]), dim=-1, largest=False) mask_neighbors = gather_edges(mask_2D.unsqueeze(-1), E_idx) return D_neighbors, E_idx, mask_neighbors def _rbf(self, D): """ Distance Radial Basis Function """ D_min, D_max, D_count = 0., 20., self.num_rbf D_mu = torch.linspace(D_min, D_max, D_count, device=D.device) D_mu = D_mu.view([1,1,1,-1]) D_sigma = (D_max - D_min) / D_count D_expand = torch.unsqueeze(D, -1) return torch.exp(-((D_expand - D_mu) / D_sigma)**2) # return RBF def _quaternions(self, R): """ Convert a batch of 3D rotations [R] to quaternions [Q] """ diag = torch.diagonal(R, dim1=-2, dim2=-1) Rxx, Ryy, Rzz = diag.unbind(-1) magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([ Rxx - Ryy - Rzz, - Rxx + Ryy - Rzz, - Rxx - Ryy + Rzz ], -1))) _R = lambda i,j: R[:,:,:,i,j] signs = torch.sign(torch.stack([ _R(2,1) - _R(1,2), _R(0,2) - _R(2,0), _R(1,0) - _R(0,1) ], -1)) xyz = signs * magnitudes w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2. Q = torch.cat((xyz, w), -1) Q = F.normalize(Q, dim=-1) return Q def _contacts(self, D_neighbors, mask_neighbors, cutoff=8): """ Contacts """ D_neighbors = D_neighbors.unsqueeze(-1) return mask_neighbors * (D_neighbors < cutoff).type(torch.float32) # return neighbor_C def _hbonds(self, X, E_idx, mask_neighbors, eps=1E-3): """ Hydrogen bonds and contact map """ X_atoms = dict(zip(['N', 'CA', 'C', 'O'], torch.unbind(X, 2))) # Virtual hydrogens X_atoms['C_prev'] = F.pad(X_atoms['C'][:,1:,:], (0,0,0,1), 'constant', 0) X_atoms['H'] = X_atoms['N'] + F.normalize( F.normalize(X_atoms['N'] - X_atoms['C_prev'], -1) + F.normalize(X_atoms['N'] - X_atoms['CA'], -1) , -1) def _distance(X_a, X_b): return torch.norm(X_a[:,None,:,:] - X_b[:,:,None,:], dim=-1) def _inv_distance(X_a, X_b): return 1. / (_distance(X_a, X_b) + eps) U = (0.084 * 332) * ( _inv_distance(X_atoms['O'], X_atoms['N']) + _inv_distance(X_atoms['C'], X_atoms['H']) - _inv_distance(X_atoms['O'], X_atoms['H']) - _inv_distance(X_atoms['C'], X_atoms['N']) ) HB = (U < -0.5).type(torch.float32) neighbor_HB = mask_neighbors * gather_edges(HB.unsqueeze(-1), E_idx) return neighbor_HB def _orientations_coarse(self, X, E_idx, eps=1e-6): # Pair features # Shifted slices of unit vectors dX = X[:,1:,:] - X[:,:-1,:] U = F.normalize(dX, dim=-1) u_2 = U[:,:-2,:] u_1 = U[:,1:-1,:] u_0 = U[:,2:,:] # Backbone normals n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1) n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1) # Bond angle calculation cosA = -(u_1 * u_0).sum(-1) cosA = torch.clamp(cosA, -1+eps, 1-eps) A = torch.acos(cosA) # Angle between normals cosD = (n_2 * n_1).sum(-1) cosD = torch.clamp(cosD, -1+eps, 1-eps) D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) # Backbone features AD_features = torch.stack((torch.cos(A), torch.sin(A) * torch.cos(D), torch.sin(A) * torch.sin(D)), 2) AD_features = F.pad(AD_features, (0,0,1,2), 'constant', 0) # Build relative orientations o_1 = F.normalize(u_2 - u_1, dim=-1) O = torch.stack((o_1, n_2, torch.cross(o_1, n_2)), 2) O = O.view(list(O.shape[:2]) + [9]) O = F.pad(O, (0,0,1,2), 'constant', 0) O_neighbors = gather_nodes(O, E_idx) X_neighbors = gather_nodes(X, E_idx) # Re-view as rotation matrices O = O.view(list(O.shape[:2]) + [3,3]) O_neighbors = O_neighbors.view(list(O_neighbors.shape[:3]) + [3,3]) # Rotate into local reference frames dX = X_neighbors - X.unsqueeze(-2) dU = torch.matmul(O.unsqueeze(2), dX.unsqueeze(-1)).squeeze(-1) dU = F.normalize(dU, dim=-1) R = torch.matmul(O.unsqueeze(2).transpose(-1,-2), O_neighbors) Q = self._quaternions(R) # Orientation features O_features = torch.cat((dU,Q), dim=-1) return AD_features, O_features def _dihedrals(self, X, eps=1e-7): # First 3 coordinates are N, CA, C X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3) # Shifted slices of unit vectors dX = X[:,1:,:] - X[:,:-1,:] U = F.normalize(dX, dim=-1) u_2 = U[:,:-2,:] u_1 = U[:,1:-1,:] u_0 = U[:,2:,:] # Backbone normals n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1) n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1) # Angle between normals cosD = (n_2 * n_1).sum(-1) cosD = torch.clamp(cosD, -1+eps, 1-eps) D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) # This scheme will remove phi[0], psi[-1], omega[-1] D = F.pad(D, (1,2), 'constant', 0) D = D.view((D.size(0), int(D.size(1)/3), 3)) return torch.cat((torch.cos(D), torch.sin(D)), 2) # return D_features def forward(self, X, L, mask): """ Featurize coordinates as an attributed graph """ # Data augmentation if self.training and self.augment_eps > 0: X = X + self.augment_eps * torch.randn_like(X) # Build k-Nearest Neighbors graph X_ca = X[:,:,1,:] # [32, 483, 3] D_neighbors, E_idx, mask_neighbors = self._dist(X_ca, mask) # [32, 483, 30], [32, 483, 30], [32, 483, 30, 1] # Pairwise features AD_features, O_features = self._orientations_coarse(X_ca, E_idx) # [32, 483, 3], [32, 483, 30, 7] RBF = self._rbf(D_neighbors) # [32, 483, 30, 16] # Pairwise embeddings E_positional = self.embeddings(E_idx) # [32, 483, 30, 16] if self.features_type == 'coarse': # Coarse backbone features V = AD_features E = torch.cat((E_positional, RBF, O_features), -1) elif self.features_type == 'hbonds': # Hydrogen bonds and contacts neighbor_HB = self._hbonds(X, E_idx, mask_neighbors) neighbor_C = self._contacts(D_neighbors, E_idx, mask_neighbors) # Dropout neighbor_C = self.dropout(neighbor_C) neighbor_HB = self.dropout(neighbor_HB) # Pack V = mask.unsqueeze(-1) * torch.ones_like(AD_features) neighbor_C = neighbor_C.expand(-1,-1,-1, int(self.num_positional_embeddings / 2)) neighbor_HB = neighbor_HB.expand(-1,-1,-1, int(self.num_positional_embeddings / 2)) E = torch.cat((E_positional, neighbor_C, neighbor_HB), -1) elif self.features_type == 'full': # Full backbone angles V = self._dihedrals(X) # [32, 483, 6] E = torch.cat((E_positional, RBF, O_features), -1) # [32, 483, 30, 39] elif self.features_type == 'dist': # Full backbone angles V = self._dihedrals(X) E = torch.cat((E_positional, RBF), -1) # Embed the nodes V = self.node_embedding(V) # [32, 483, 6] --> [32, 483, 128] V = self.norm_nodes(V) # [32, 483, 128] --> [32, 483, 128] E = self.edge_embedding(E) # [32, 483, 30, 39] --> [32, 483, 30, 128] E = self.norm_edges(E) # [32, 483, 30, 128] --> [32, 483, 30, 128] return V, E, E_idx def gather_edges(edges, neighbor_idx): # Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C] neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1)) return torch.gather(edges, 2, neighbors) # return edge_features def gather_nodes(nodes, neighbor_idx): # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C] # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C] neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1)) neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2)) # [32, 14460, 1] # Gather and re-pack neighbor_features = torch.gather(nodes, 1, neighbors_flat) # [32, 14460, 1] neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1]) # [32, 482, 30, 1] return neighbor_features def gather_nodes_t(nodes, neighbor_idx): # Features [B,N,C] at Neighbor index [B,K] => Neighbor features[B,K,C] idx_flat = neighbor_idx.unsqueeze(-1).expand(-1, -1, nodes.size(2)) return torch.gather(nodes, 1, idx_flat) # return node features def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx): h_nodes = gather_nodes(h_nodes, E_idx) return torch.cat([h_neighbors, h_nodes], -1) class TransformerLayer(nn.Module): def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.1): super(TransformerLayer, self).__init__() self.num_heads = num_heads self.num_hidden = num_hidden self.num_in = num_in self.dropout = nn.Dropout(dropout) self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) self.attention = NeighborAttention(num_hidden, num_in, num_heads) self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4) def forward(self, h_V, h_E, mask_V=None, mask_attend=None): # h_V: [32, 482, 128], h_E: [32, 482, 30, 256], mask_V: [32, 482], mask_attend: [32, 482, 30] """ Parallel computation of full transformer layer """ # Self-attention dh = self.attention(h_V, h_E, mask_attend) h_V = self.norm[0](h_V + self.dropout(dh)) # Position-wise feedforward dh = self.dense(h_V) h_V = self.norm[1](h_V + self.dropout(dh)) if mask_V is not None: mask_V = mask_V.unsqueeze(-1) h_V = mask_V * h_V return h_V def step(self, t, h_V, h_E, mask_V=None, mask_attend=None): """ Sequential computation of step t of a transformer layer """ # Self-attention h_V_t = h_V[:,t,:] dh_t = self.attention.step(t, h_V, h_E, mask_attend) h_V_t = self.norm[0](h_V_t + self.dropout(dh_t)) # Position-wise feedforward dh_t = self.dense(h_V_t) h_V_t = self.norm[1](h_V_t + self.dropout(dh_t)) if mask_V is not None: mask_V_t = mask_V[:,t].unsqueeze(-1) h_V_t = mask_V_t * h_V_t return h_V_t class MPNNLayer(nn.Module): def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30): super(MPNNLayer, self).__init__() self.num_hidden = num_hidden self.num_in = num_in self.scale = scale self.dropout = nn.Dropout(dropout) self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True) self.W2 = nn.Linear(num_hidden, num_hidden, bias=True) self.W3 = nn.Linear(num_hidden, num_hidden, bias=True) self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4) def forward(self, h_V, h_E, mask_V=None, mask_attend=None): """ Parallel computation of full transformer layer """ # Concatenate h_V_i to h_E_ij h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_E.size(-2),-1) h_EV = torch.cat([h_V_expand, h_E], -1) h_message = self.W3(F.relu(self.W2(F.relu(self.W1(h_EV))))) if mask_attend is not None: h_message = mask_attend.unsqueeze(-1) * h_message dh = torch.sum(h_message, -2) / self.scale h_V = self.norm[0](h_V + self.dropout(dh)) # Position-wise feedforward dh = self.dense(h_V) h_V = self.norm[1](h_V + self.dropout(dh)) if mask_V is not None: mask_V = mask_V.unsqueeze(-1) h_V = mask_V * h_V return h_V class Normalize(nn.Module): def __init__(self, features, epsilon=1e-6): super(Normalize, self).__init__() self.gain = nn.Parameter(torch.ones(features)) self.bias = nn.Parameter(torch.zeros(features)) self.epsilon = epsilon def forward(self, x, dim=-1): mu = x.mean(dim, keepdim=True) sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon) gain = self.gain bias = self.bias # Reshape if dim != -1: shape = [1] * len(mu.size()) shape[dim] = self.gain.size()[0] gain = gain.view(shape) bias = bias.view(shape) return gain * (x - mu) / (sigma + self.epsilon) + bias class PositionWiseFeedForward(nn.Module): def __init__(self, num_hidden, num_ff): super(PositionWiseFeedForward, self).__init__() self.W_in = nn.Linear(num_hidden, num_ff, bias=True) self.W_out = nn.Linear(num_ff, num_hidden, bias=True) def forward(self, h_V): h = F.relu(self.W_in(h_V)) h = self.W_out(h) return h class NeighborAttention(nn.Module): def __init__(self, num_hidden, num_in, num_heads=4): super(NeighborAttention, self).__init__() self.num_heads = num_heads self.num_hidden = num_hidden # Self-attention layers: {queries, keys, values, output} self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False) self.W_K = nn.Linear(num_in, num_hidden, bias=False) self.W_V = nn.Linear(num_in, num_hidden, bias=False) self.W_O = nn.Linear(num_hidden, num_hidden, bias=False) return def _masked_softmax(self, attend_logits, mask_attend, dim=-1): """ Numerically stable masked softmax """ negative_inf = np.finfo(np.float32).min attend_logits = torch.where(mask_attend > 0, attend_logits, torch.tensor(negative_inf, device=attend_logits.device)) attend = F.softmax(attend_logits, dim) attend = mask_attend * attend return attend def forward(self, h_V, h_E, mask_attend=None): """ Self-attention, graph-structured O(Nk) Args: h_V: Node features [N_batch, N_nodes, N_hidden] h_E: Neighbor features [N_batch, N_nodes, K, 3*N_hidden] mask_attend: Mask for attention [N_batch, N_nodes, K] Returns: h_V: Node update """ # Queries, Keys, Values n_batch, n_nodes, n_neighbors = h_E.shape[:3] n_heads = self.num_heads d = int(self.num_hidden / n_heads) Q = self.W_Q(h_V).view([n_batch, n_nodes, 1, n_heads, 1, d]) K = self.W_K(h_E).view([n_batch, n_nodes, n_neighbors, n_heads, d, 1]) V = self.W_V(h_E).view([n_batch, n_nodes, n_neighbors, n_heads, d]) # Attention with scaled inner product # n_neighbors这个维度提供attention权重,该权重可以视为邻居点和中心点做点积而得到 attend_logits = torch.matmul(Q, K).view([n_batch, n_nodes, n_neighbors, n_heads]).transpose(-2,-1) attend_logits = attend_logits / np.sqrt(d) # [N_batch, N_nodes, n_heads, K] if mask_attend is not None: # Masked softmax mask = mask_attend.unsqueeze(2).expand(-1,-1,n_heads,-1) # [N_batch, N_nodes, n_heads, K] attend = self._masked_softmax(attend_logits, mask) else: attend = F.softmax(attend_logits, -1) # Attentive reduction h_V_update = torch.matmul(attend.unsqueeze(-2), V.transpose(2,3)) # [32, 482, 4, 1, 30], [32, 482, 4, 30, 32] --> [32, 482, 4, 1, 32] 相当于信息汇聚操作 h_V_update = h_V_update.view([n_batch, n_nodes, self.num_hidden]) h_V_update = self.W_O(h_V_update) return h_V_update def step(self, t, h_V, h_E, E_idx, mask_attend=None): """ Self-attention for a specific time step t Args: h_V: Node features [N_batch, N_nodes, N_hidden] h_E: Neighbor features [N_batch, N_nodes, K, N_in] E_idx: Neighbor indices [N_batch, N_nodes, K] mask_attend: Mask for attention [N_batch, N_nodes, K] Returns: h_V_t: Node update """ # Dimensions n_batch, n_nodes, n_neighbors = h_E.shape[:3] n_heads = self.num_heads d = self.num_hidden / n_heads # Per time-step tensors h_V_t = h_V[:,t,:] h_E_t = h_E[:,t,:,:] E_idx_t = E_idx[:,t,:] # Single time-step h_V_neighbors_t = gather_nodes_t(h_V, E_idx_t) E_t = torch.cat([h_E_t, h_V_neighbors_t], -1) # Queries, Keys, Values Q = self.W_Q(h_V_t).view([n_batch, 1, n_heads, 1, d]) K = self.W_K(E_t).view([n_batch, n_neighbors, n_heads, d, 1]) V = self.W_V(E_t).view([n_batch, n_neighbors, n_heads, d]) # Attention with scaled inner product attend_logits = torch.matmul(Q, K).view([n_batch, n_neighbors, n_heads]).transpose(-2,-1) attend_logits = attend_logits / np.sqrt(d) if mask_attend is not None: # Masked softmax # [N_batch, K] -=> [N_batch, N_heads, K] mask_t = mask_attend[:,t,:].unsqueeze(1).expand(-1,n_heads,-1) attend = self._masked_softmax(attend_logits, mask_t) else: attend = F.softmax(attend_logits / np.sqrt(d), -1) # Attentive reduction h_V_t_update = torch.matmul(attend.unsqueeze(-2), V.transpose(1,2)) return h_V_t_update class Struct2Seq(nn.Module): def __init__(self, num_letters, node_features, edge_features, hidden_dim, num_encoder_layers=3, num_decoder_layers=3, vocab=33, k_neighbors=30, protein_features='full', augment_eps=0., dropout=0.1, forward_attention_decoder=True, use_mpnn=False): """ Graph labeling network """ super(Struct2Seq, self).__init__() # Hyperparameters self.node_features = node_features self.edge_features = edge_features self.hidden_dim = hidden_dim # Embedding layers self.W_v = nn.Linear(node_features, hidden_dim, bias=True) self.W_e = nn.Linear(edge_features, hidden_dim, bias=True) self.W_s = nn.Embedding(vocab, hidden_dim) layer = MPNNLayer if use_mpnn else TransformerLayer # Encoder layers self.encoder_layers = nn.ModuleList([ layer(hidden_dim, hidden_dim*2, dropout=dropout) for _ in range(num_encoder_layers) ]) # Decoder layers self.forward_attention_decoder = forward_attention_decoder self.decoder_layers = nn.ModuleList([ layer(hidden_dim, hidden_dim*3, dropout=dropout) for _ in range(num_decoder_layers) ]) self.W_out = nn.Linear(hidden_dim, num_letters, bias=True) # Initialization for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def _autoregressive_mask(self, E_idx): N_nodes = E_idx.size(1) ii = torch.arange(N_nodes, device=E_idx.device) ii = ii.view((1, -1, 1)) mask = E_idx < ii mask = mask.type(torch.float32) return mask def forward_sequential(self, X, S, L, mask=None): """ Compute the transformer layer sequentially, for purposes of debugging """ if self.args.augment_eps>0: X = X + self.args.augment_eps * torch.randn_like(X) # Prepare node and edge embeddings V, E, E_idx = self.features(X, L, mask) h_V = self.W_v(V) h_E = self.W_e(E) # Encoder is unmasked self-attention mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) mask_attend = mask.unsqueeze(-1) * mask_attend for layer in self.encoder_layers: h_EV = cat_neighbors_nodes(h_V, h_E, E_idx) h_V = layer(h_V, h_EV, mask_V=mask, mask_attend=mask_attend) # Decoder alternates masked self-attention mask_attend = self._autoregressive_mask(E_idx).unsqueeze(-1) mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) mask_bw = mask_1D * mask_attend mask_fw = mask_1D * (1. - mask_attend) N_batch, N_nodes = X.size(0), X.size(1) log_probs = torch.zeros((N_batch, N_nodes, 20)) h_S = torch.zeros_like(h_V) h_V_stack = [h_V] + [torch.zeros_like(h_V) for _ in range(len(self.decoder_layers))] for t in range(N_nodes): # Hidden layers E_idx_t = E_idx[:,t:t+1,:] h_E_t = h_E[:,t:t+1,:,:] h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t) # Stale relational features for future states h_ESV_encoder_t = mask_fw[:,t:t+1,:,:] * cat_neighbors_nodes(h_V, h_ES_t, E_idx_t) for l, layer in enumerate(self.decoder_layers): # Updated relational features for future states h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t) h_V_t = h_V_stack[l][:,t:t+1,:] h_ESV_t = mask_bw[:,t:t+1,:,:] * h_ESV_decoder_t + h_ESV_encoder_t h_V_stack[l+1][:,t,:] = layer( h_V_t, h_ESV_t, mask_V=mask[:,t:t+1] ).squeeze(1) # Sampling step h_V_t = h_V_stack[-1][:,t,:] logits = self.W_out(h_V_t) log_probs[:,t,:] = F.log_softmax(logits, dim=-1) # Update h_S[:,t,:] = self.W_s(S[:,t]) return log_probs