Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| class PositionalEncodings(nn.Module): | |
| def __init__(self, num_embeddings, period_range = None): | |
| if period_range is None: | |
| period_range = [2,1000] | |
| super(PositionalEncodings, self).__init__() | |
| self.num_embeddings = num_embeddings | |
| self.period_range = period_range | |
| def forward(self, E_idx): | |
| N_nodes = E_idx.size(1) | |
| ii = torch.arange(N_nodes, dtype=torch.float32, device = E_idx.device).view((1, -1, 1)) | |
| d = (E_idx.float() - ii).unsqueeze(-1) | |
| # Original Transformer frequencies | |
| frequency = torch.exp(torch.arange(0, self.num_embeddings, 2, dtype=torch.float32, device = E_idx.device) * -(np.log(10000.0) / self.num_embeddings)) | |
| angles = d * frequency.view((1,1,1,-1)) | |
| return torch.cat((torch.cos(angles), torch.sin(angles)), -1) | |
| class ProteinFeatures(nn.Module): | |
| def __init__(self, edge_features, node_features, num_positional_embeddings=16, num_rbf=16, top_k=30, features_type='full', augment_eps=0., dropout=0.1): | |
| super(ProteinFeatures, self).__init__() | |
| """Extract Protein Features""" | |
| self.edge_features = edge_features | |
| self.node_features = node_features | |
| self.top_k = top_k | |
| self.augment_eps = augment_eps | |
| self.num_rbf = num_rbf | |
| self.num_positional_embeddings = num_positional_embeddings | |
| ## Feature types ## | |
| self.features_type = features_type | |
| self.feature_dimensions = { | |
| 'coarse': (3, num_positional_embeddings + num_rbf + 7), | |
| 'full': (6, num_positional_embeddings + num_rbf + 7), | |
| 'dist': (6, num_positional_embeddings + num_rbf), | |
| 'hbonds': (3, 2 * num_positional_embeddings)} | |
| ## Positional encoding ## | |
| self.embeddings = PositionalEncodings(num_positional_embeddings) | |
| self.dropout = nn.Dropout(dropout) | |
| ## Normalization and embedding ## | |
| node_in, edge_in = self.feature_dimensions[features_type] | |
| self.node_embedding = nn.Linear(node_in, node_features, bias=True) | |
| self.edge_embedding = nn.Linear(edge_in, edge_features, bias=True) | |
| self.norm_nodes = Normalize(node_features) | |
| self.norm_edges = Normalize(edge_features) | |
| def _dist(self, X, mask, eps=1E-6): | |
| """ Pairwise Euclidean Distance """ | |
| mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2) | |
| dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2) | |
| D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps) | |
| D_max, _ = torch.max(D, -1, keepdim=True) | |
| D_adjust = D + (1. - mask_2D) * (D_max+1) | |
| D_neighbors, E_idx = torch.topk(D_adjust, min(self.top_k, D_adjust.shape[-1]), dim=-1, largest=False) | |
| mask_neighbors = gather_edges(mask_2D.unsqueeze(-1), E_idx) | |
| return D_neighbors, E_idx, mask_neighbors | |
| def _rbf(self, D): | |
| """ Distance Radial Basis Function """ | |
| D_min, D_max, D_count = 0., 20., self.num_rbf | |
| D_mu = torch.linspace(D_min, D_max, D_count, device=D.device) | |
| D_mu = D_mu.view([1,1,1,-1]) | |
| D_sigma = (D_max - D_min) / D_count | |
| D_expand = torch.unsqueeze(D, -1) | |
| return torch.exp(-((D_expand - D_mu) / D_sigma)**2) # return RBF | |
| def _quaternions(self, R): | |
| """ Convert a batch of 3D rotations [R] to quaternions [Q] """ | |
| diag = torch.diagonal(R, dim1=-2, dim2=-1) | |
| Rxx, Ryy, Rzz = diag.unbind(-1) | |
| magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([ | |
| Rxx - Ryy - Rzz, | |
| - Rxx + Ryy - Rzz, | |
| - Rxx - Ryy + Rzz | |
| ], -1))) | |
| _R = lambda i,j: R[:,:,:,i,j] | |
| signs = torch.sign(torch.stack([ | |
| _R(2,1) - _R(1,2), | |
| _R(0,2) - _R(2,0), | |
| _R(1,0) - _R(0,1) | |
| ], -1)) | |
| xyz = signs * magnitudes | |
| w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2. | |
| Q = torch.cat((xyz, w), -1) | |
| Q = F.normalize(Q, dim=-1) | |
| return Q | |
| def _contacts(self, D_neighbors, mask_neighbors, cutoff=8): | |
| """ Contacts """ | |
| D_neighbors = D_neighbors.unsqueeze(-1) | |
| return mask_neighbors * (D_neighbors < cutoff).type(torch.float32) # return neighbor_C | |
| def _hbonds(self, X, E_idx, mask_neighbors, eps=1E-3): | |
| """ Hydrogen bonds and contact map """ | |
| X_atoms = dict(zip(['N', 'CA', 'C', 'O'], torch.unbind(X, 2))) | |
| # Virtual hydrogens | |
| X_atoms['C_prev'] = F.pad(X_atoms['C'][:,1:,:], (0,0,0,1), 'constant', 0) | |
| X_atoms['H'] = X_atoms['N'] + F.normalize( | |
| F.normalize(X_atoms['N'] - X_atoms['C_prev'], -1) | |
| + F.normalize(X_atoms['N'] - X_atoms['CA'], -1) | |
| , -1) | |
| def _distance(X_a, X_b): | |
| return torch.norm(X_a[:,None,:,:] - X_b[:,:,None,:], dim=-1) | |
| def _inv_distance(X_a, X_b): | |
| return 1. / (_distance(X_a, X_b) + eps) | |
| U = (0.084 * 332) * ( | |
| _inv_distance(X_atoms['O'], X_atoms['N']) | |
| + _inv_distance(X_atoms['C'], X_atoms['H']) | |
| - _inv_distance(X_atoms['O'], X_atoms['H']) | |
| - _inv_distance(X_atoms['C'], X_atoms['N']) | |
| ) | |
| HB = (U < -0.5).type(torch.float32) | |
| neighbor_HB = mask_neighbors * gather_edges(HB.unsqueeze(-1), E_idx) | |
| return neighbor_HB | |
| def _orientations_coarse(self, X, E_idx, eps=1e-6): | |
| # Pair features | |
| # Shifted slices of unit vectors | |
| dX = X[:,1:,:] - X[:,:-1,:] | |
| U = F.normalize(dX, dim=-1) | |
| u_2 = U[:,:-2,:] | |
| u_1 = U[:,1:-1,:] | |
| u_0 = U[:,2:,:] | |
| # Backbone normals | |
| n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1) | |
| n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1) | |
| # Bond angle calculation | |
| cosA = -(u_1 * u_0).sum(-1) | |
| cosA = torch.clamp(cosA, -1+eps, 1-eps) | |
| A = torch.acos(cosA) | |
| # Angle between normals | |
| cosD = (n_2 * n_1).sum(-1) | |
| cosD = torch.clamp(cosD, -1+eps, 1-eps) | |
| D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) | |
| # Backbone features | |
| AD_features = torch.stack((torch.cos(A), torch.sin(A) * torch.cos(D), torch.sin(A) * torch.sin(D)), 2) | |
| AD_features = F.pad(AD_features, (0,0,1,2), 'constant', 0) | |
| # Build relative orientations | |
| o_1 = F.normalize(u_2 - u_1, dim=-1) | |
| O = torch.stack((o_1, n_2, torch.cross(o_1, n_2)), 2) | |
| O = O.view(list(O.shape[:2]) + [9]) | |
| O = F.pad(O, (0,0,1,2), 'constant', 0) | |
| O_neighbors = gather_nodes(O, E_idx) | |
| X_neighbors = gather_nodes(X, E_idx) | |
| # Re-view as rotation matrices | |
| O = O.view(list(O.shape[:2]) + [3,3]) | |
| O_neighbors = O_neighbors.view(list(O_neighbors.shape[:3]) + [3,3]) | |
| # Rotate into local reference frames | |
| dX = X_neighbors - X.unsqueeze(-2) | |
| dU = torch.matmul(O.unsqueeze(2), dX.unsqueeze(-1)).squeeze(-1) | |
| dU = F.normalize(dU, dim=-1) | |
| R = torch.matmul(O.unsqueeze(2).transpose(-1,-2), O_neighbors) | |
| Q = self._quaternions(R) | |
| # Orientation features | |
| O_features = torch.cat((dU,Q), dim=-1) | |
| return AD_features, O_features | |
| def _dihedrals(self, X, eps=1e-7): | |
| # First 3 coordinates are N, CA, C | |
| X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3) | |
| # Shifted slices of unit vectors | |
| dX = X[:,1:,:] - X[:,:-1,:] | |
| U = F.normalize(dX, dim=-1) | |
| u_2 = U[:,:-2,:] | |
| u_1 = U[:,1:-1,:] | |
| u_0 = U[:,2:,:] | |
| # Backbone normals | |
| n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1) | |
| n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1) | |
| # Angle between normals | |
| cosD = (n_2 * n_1).sum(-1) | |
| cosD = torch.clamp(cosD, -1+eps, 1-eps) | |
| D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) | |
| # This scheme will remove phi[0], psi[-1], omega[-1] | |
| D = F.pad(D, (1,2), 'constant', 0) | |
| D = D.view((D.size(0), int(D.size(1)/3), 3)) | |
| return torch.cat((torch.cos(D), torch.sin(D)), 2) # return D_features | |
| def forward(self, X, L, mask): | |
| """ Featurize coordinates as an attributed graph """ | |
| # Data augmentation | |
| if self.training and self.augment_eps > 0: | |
| X = X + self.augment_eps * torch.randn_like(X) | |
| # Build k-Nearest Neighbors graph | |
| X_ca = X[:,:,1,:] # [32, 483, 3] | |
| D_neighbors, E_idx, mask_neighbors = self._dist(X_ca, mask) # [32, 483, 30], [32, 483, 30], [32, 483, 30, 1] | |
| # Pairwise features | |
| AD_features, O_features = self._orientations_coarse(X_ca, E_idx) # [32, 483, 3], [32, 483, 30, 7] | |
| RBF = self._rbf(D_neighbors) # [32, 483, 30, 16] | |
| # Pairwise embeddings | |
| E_positional = self.embeddings(E_idx) # [32, 483, 30, 16] | |
| if self.features_type == 'coarse': | |
| # Coarse backbone features | |
| V = AD_features | |
| E = torch.cat((E_positional, RBF, O_features), -1) | |
| elif self.features_type == 'hbonds': | |
| # Hydrogen bonds and contacts | |
| neighbor_HB = self._hbonds(X, E_idx, mask_neighbors) | |
| neighbor_C = self._contacts(D_neighbors, E_idx, mask_neighbors) | |
| # Dropout | |
| neighbor_C = self.dropout(neighbor_C) | |
| neighbor_HB = self.dropout(neighbor_HB) | |
| # Pack | |
| V = mask.unsqueeze(-1) * torch.ones_like(AD_features) | |
| neighbor_C = neighbor_C.expand(-1,-1,-1, int(self.num_positional_embeddings / 2)) | |
| neighbor_HB = neighbor_HB.expand(-1,-1,-1, int(self.num_positional_embeddings / 2)) | |
| E = torch.cat((E_positional, neighbor_C, neighbor_HB), -1) | |
| elif self.features_type == 'full': | |
| # Full backbone angles | |
| V = self._dihedrals(X) # [32, 483, 6] | |
| E = torch.cat((E_positional, RBF, O_features), -1) # [32, 483, 30, 39] | |
| elif self.features_type == 'dist': | |
| # Full backbone angles | |
| V = self._dihedrals(X) | |
| E = torch.cat((E_positional, RBF), -1) | |
| # Embed the nodes | |
| V = self.node_embedding(V) # [32, 483, 6] --> [32, 483, 128] | |
| V = self.norm_nodes(V) # [32, 483, 128] --> [32, 483, 128] | |
| E = self.edge_embedding(E) # [32, 483, 30, 39] --> [32, 483, 30, 128] | |
| E = self.norm_edges(E) # [32, 483, 30, 128] --> [32, 483, 30, 128] | |
| return V, E, E_idx | |
| def gather_edges(edges, neighbor_idx): | |
| # Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C] | |
| neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1)) | |
| return torch.gather(edges, 2, neighbors) # return edge_features | |
| def gather_nodes(nodes, neighbor_idx): | |
| # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C] | |
| # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C] | |
| neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1)) | |
| neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2)) # [32, 14460, 1] | |
| # Gather and re-pack | |
| neighbor_features = torch.gather(nodes, 1, neighbors_flat) # [32, 14460, 1] | |
| neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1]) # [32, 482, 30, 1] | |
| return neighbor_features | |
| def gather_nodes_t(nodes, neighbor_idx): | |
| # Features [B,N,C] at Neighbor index [B,K] => Neighbor features[B,K,C] | |
| idx_flat = neighbor_idx.unsqueeze(-1).expand(-1, -1, nodes.size(2)) | |
| return torch.gather(nodes, 1, idx_flat) # return node features | |
| def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx): | |
| h_nodes = gather_nodes(h_nodes, E_idx) | |
| return torch.cat([h_neighbors, h_nodes], -1) | |
| class TransformerLayer(nn.Module): | |
| def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.1): | |
| super(TransformerLayer, self).__init__() | |
| self.num_heads = num_heads | |
| self.num_hidden = num_hidden | |
| self.num_in = num_in | |
| self.dropout = nn.Dropout(dropout) | |
| self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) | |
| self.attention = NeighborAttention(num_hidden, num_in, num_heads) | |
| self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4) | |
| def forward(self, h_V, h_E, mask_V=None, mask_attend=None): # h_V: [32, 482, 128], h_E: [32, 482, 30, 256], mask_V: [32, 482], mask_attend: [32, 482, 30] | |
| """ Parallel computation of full transformer layer """ | |
| # Self-attention | |
| dh = self.attention(h_V, h_E, mask_attend) | |
| h_V = self.norm[0](h_V + self.dropout(dh)) | |
| # Position-wise feedforward | |
| dh = self.dense(h_V) | |
| h_V = self.norm[1](h_V + self.dropout(dh)) | |
| if mask_V is not None: | |
| mask_V = mask_V.unsqueeze(-1) | |
| h_V = mask_V * h_V | |
| return h_V | |
| def step(self, t, h_V, h_E, mask_V=None, mask_attend=None): | |
| """ Sequential computation of step t of a transformer layer """ | |
| # Self-attention | |
| h_V_t = h_V[:,t,:] | |
| dh_t = self.attention.step(t, h_V, h_E, mask_attend) | |
| h_V_t = self.norm[0](h_V_t + self.dropout(dh_t)) | |
| # Position-wise feedforward | |
| dh_t = self.dense(h_V_t) | |
| h_V_t = self.norm[1](h_V_t + self.dropout(dh_t)) | |
| if mask_V is not None: | |
| mask_V_t = mask_V[:,t].unsqueeze(-1) | |
| h_V_t = mask_V_t * h_V_t | |
| return h_V_t | |
| class MPNNLayer(nn.Module): | |
| def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30): | |
| super(MPNNLayer, self).__init__() | |
| self.num_hidden = num_hidden | |
| self.num_in = num_in | |
| self.scale = scale | |
| self.dropout = nn.Dropout(dropout) | |
| self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) | |
| self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True) | |
| self.W2 = nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.W3 = nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4) | |
| def forward(self, h_V, h_E, mask_V=None, mask_attend=None): | |
| """ Parallel computation of full transformer layer """ | |
| # Concatenate h_V_i to h_E_ij | |
| h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_E.size(-2),-1) | |
| h_EV = torch.cat([h_V_expand, h_E], -1) | |
| h_message = self.W3(F.relu(self.W2(F.relu(self.W1(h_EV))))) | |
| if mask_attend is not None: | |
| h_message = mask_attend.unsqueeze(-1) * h_message | |
| dh = torch.sum(h_message, -2) / self.scale | |
| h_V = self.norm[0](h_V + self.dropout(dh)) | |
| # Position-wise feedforward | |
| dh = self.dense(h_V) | |
| h_V = self.norm[1](h_V + self.dropout(dh)) | |
| if mask_V is not None: | |
| mask_V = mask_V.unsqueeze(-1) | |
| h_V = mask_V * h_V | |
| return h_V | |
| class Normalize(nn.Module): | |
| def __init__(self, features, epsilon=1e-6): | |
| super(Normalize, self).__init__() | |
| self.gain = nn.Parameter(torch.ones(features)) | |
| self.bias = nn.Parameter(torch.zeros(features)) | |
| self.epsilon = epsilon | |
| def forward(self, x, dim=-1): | |
| mu = x.mean(dim, keepdim=True) | |
| sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon) | |
| gain = self.gain | |
| bias = self.bias | |
| # Reshape | |
| if dim != -1: | |
| shape = [1] * len(mu.size()) | |
| shape[dim] = self.gain.size()[0] | |
| gain = gain.view(shape) | |
| bias = bias.view(shape) | |
| return gain * (x - mu) / (sigma + self.epsilon) + bias | |
| class PositionWiseFeedForward(nn.Module): | |
| def __init__(self, num_hidden, num_ff): | |
| super(PositionWiseFeedForward, self).__init__() | |
| self.W_in = nn.Linear(num_hidden, num_ff, bias=True) | |
| self.W_out = nn.Linear(num_ff, num_hidden, bias=True) | |
| def forward(self, h_V): | |
| h = F.relu(self.W_in(h_V)) | |
| h = self.W_out(h) | |
| return h | |
| class NeighborAttention(nn.Module): | |
| def __init__(self, num_hidden, num_in, num_heads=4): | |
| super(NeighborAttention, self).__init__() | |
| self.num_heads = num_heads | |
| self.num_hidden = num_hidden | |
| # Self-attention layers: {queries, keys, values, output} | |
| self.W_Q = nn.Linear(num_hidden, num_hidden, bias=False) | |
| self.W_K = nn.Linear(num_in, num_hidden, bias=False) | |
| self.W_V = nn.Linear(num_in, num_hidden, bias=False) | |
| self.W_O = nn.Linear(num_hidden, num_hidden, bias=False) | |
| return | |
| def _masked_softmax(self, attend_logits, mask_attend, dim=-1): | |
| """ Numerically stable masked softmax """ | |
| negative_inf = np.finfo(np.float32).min | |
| attend_logits = torch.where(mask_attend > 0, attend_logits, torch.tensor(negative_inf, device=attend_logits.device)) | |
| attend = F.softmax(attend_logits, dim) | |
| attend = mask_attend * attend | |
| return attend | |
| def forward(self, h_V, h_E, mask_attend=None): | |
| """ Self-attention, graph-structured O(Nk) | |
| Args: | |
| h_V: Node features [N_batch, N_nodes, N_hidden] | |
| h_E: Neighbor features [N_batch, N_nodes, K, 3*N_hidden] | |
| mask_attend: Mask for attention [N_batch, N_nodes, K] | |
| Returns: | |
| h_V: Node update | |
| """ | |
| # Queries, Keys, Values | |
| n_batch, n_nodes, n_neighbors = h_E.shape[:3] | |
| n_heads = self.num_heads | |
| d = int(self.num_hidden / n_heads) | |
| Q = self.W_Q(h_V).view([n_batch, n_nodes, 1, n_heads, 1, d]) | |
| K = self.W_K(h_E).view([n_batch, n_nodes, n_neighbors, n_heads, d, 1]) | |
| V = self.W_V(h_E).view([n_batch, n_nodes, n_neighbors, n_heads, d]) | |
| # Attention with scaled inner product | |
| # n_neighbors这个维度提供attention权重,该权重可以视为邻居点和中心点做点积而得到 | |
| attend_logits = torch.matmul(Q, K).view([n_batch, n_nodes, n_neighbors, n_heads]).transpose(-2,-1) | |
| attend_logits = attend_logits / np.sqrt(d) # [N_batch, N_nodes, n_heads, K] | |
| if mask_attend is not None: | |
| # Masked softmax | |
| mask = mask_attend.unsqueeze(2).expand(-1,-1,n_heads,-1) # [N_batch, N_nodes, n_heads, K] | |
| attend = self._masked_softmax(attend_logits, mask) | |
| else: | |
| attend = F.softmax(attend_logits, -1) | |
| # Attentive reduction | |
| h_V_update = torch.matmul(attend.unsqueeze(-2), V.transpose(2,3)) # [32, 482, 4, 1, 30], [32, 482, 4, 30, 32] --> [32, 482, 4, 1, 32] 相当于信息汇聚操作 | |
| h_V_update = h_V_update.view([n_batch, n_nodes, self.num_hidden]) | |
| h_V_update = self.W_O(h_V_update) | |
| return h_V_update | |
| def step(self, t, h_V, h_E, E_idx, mask_attend=None): | |
| """ Self-attention for a specific time step t | |
| Args: | |
| h_V: Node features [N_batch, N_nodes, N_hidden] | |
| h_E: Neighbor features [N_batch, N_nodes, K, N_in] | |
| E_idx: Neighbor indices [N_batch, N_nodes, K] | |
| mask_attend: Mask for attention [N_batch, N_nodes, K] | |
| Returns: | |
| h_V_t: Node update | |
| """ | |
| # Dimensions | |
| n_batch, n_nodes, n_neighbors = h_E.shape[:3] | |
| n_heads = self.num_heads | |
| d = self.num_hidden / n_heads | |
| # Per time-step tensors | |
| h_V_t = h_V[:,t,:] | |
| h_E_t = h_E[:,t,:,:] | |
| E_idx_t = E_idx[:,t,:] | |
| # Single time-step | |
| h_V_neighbors_t = gather_nodes_t(h_V, E_idx_t) | |
| E_t = torch.cat([h_E_t, h_V_neighbors_t], -1) | |
| # Queries, Keys, Values | |
| Q = self.W_Q(h_V_t).view([n_batch, 1, n_heads, 1, d]) | |
| K = self.W_K(E_t).view([n_batch, n_neighbors, n_heads, d, 1]) | |
| V = self.W_V(E_t).view([n_batch, n_neighbors, n_heads, d]) | |
| # Attention with scaled inner product | |
| attend_logits = torch.matmul(Q, K).view([n_batch, n_neighbors, n_heads]).transpose(-2,-1) | |
| attend_logits = attend_logits / np.sqrt(d) | |
| if mask_attend is not None: | |
| # Masked softmax | |
| # [N_batch, K] -=> [N_batch, N_heads, K] | |
| mask_t = mask_attend[:,t,:].unsqueeze(1).expand(-1,n_heads,-1) | |
| attend = self._masked_softmax(attend_logits, mask_t) | |
| else: | |
| attend = F.softmax(attend_logits / np.sqrt(d), -1) | |
| # Attentive reduction | |
| h_V_t_update = torch.matmul(attend.unsqueeze(-2), V.transpose(1,2)) | |
| return h_V_t_update | |
| class Struct2Seq(nn.Module): | |
| def __init__(self, num_letters, node_features, edge_features, | |
| hidden_dim, num_encoder_layers=3, num_decoder_layers=3, | |
| vocab=33, k_neighbors=30, protein_features='full', augment_eps=0., | |
| dropout=0.1, forward_attention_decoder=True, use_mpnn=False): | |
| """ Graph labeling network """ | |
| super(Struct2Seq, self).__init__() | |
| # Hyperparameters | |
| self.node_features = node_features | |
| self.edge_features = edge_features | |
| self.hidden_dim = hidden_dim | |
| # Embedding layers | |
| self.W_v = nn.Linear(node_features, hidden_dim, bias=True) | |
| self.W_e = nn.Linear(edge_features, hidden_dim, bias=True) | |
| self.W_s = nn.Embedding(vocab, hidden_dim) | |
| layer = MPNNLayer if use_mpnn else TransformerLayer | |
| # Encoder layers | |
| self.encoder_layers = nn.ModuleList([ | |
| layer(hidden_dim, hidden_dim*2, dropout=dropout) | |
| for _ in range(num_encoder_layers) | |
| ]) | |
| # Decoder layers | |
| self.forward_attention_decoder = forward_attention_decoder | |
| self.decoder_layers = nn.ModuleList([ | |
| layer(hidden_dim, hidden_dim*3, dropout=dropout) | |
| for _ in range(num_decoder_layers) | |
| ]) | |
| self.W_out = nn.Linear(hidden_dim, num_letters, bias=True) | |
| # Initialization | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| def _autoregressive_mask(self, E_idx): | |
| N_nodes = E_idx.size(1) | |
| ii = torch.arange(N_nodes, device=E_idx.device) | |
| ii = ii.view((1, -1, 1)) | |
| mask = E_idx < ii | |
| mask = mask.type(torch.float32) | |
| return mask | |
| def forward_sequential(self, X, S, L, mask=None): | |
| """ Compute the transformer layer sequentially, for purposes of debugging """ | |
| if self.args.augment_eps>0: | |
| X = X + self.args.augment_eps * torch.randn_like(X) | |
| # Prepare node and edge embeddings | |
| V, E, E_idx = self.features(X, L, mask) | |
| h_V = self.W_v(V) | |
| h_E = self.W_e(E) | |
| # Encoder is unmasked self-attention | |
| mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) | |
| mask_attend = mask.unsqueeze(-1) * mask_attend | |
| for layer in self.encoder_layers: | |
| h_EV = cat_neighbors_nodes(h_V, h_E, E_idx) | |
| h_V = layer(h_V, h_EV, mask_V=mask, mask_attend=mask_attend) | |
| # Decoder alternates masked self-attention | |
| mask_attend = self._autoregressive_mask(E_idx).unsqueeze(-1) | |
| mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) | |
| mask_bw = mask_1D * mask_attend | |
| mask_fw = mask_1D * (1. - mask_attend) | |
| N_batch, N_nodes = X.size(0), X.size(1) | |
| log_probs = torch.zeros((N_batch, N_nodes, 20)) | |
| h_S = torch.zeros_like(h_V) | |
| h_V_stack = [h_V] + [torch.zeros_like(h_V) for _ in range(len(self.decoder_layers))] | |
| for t in range(N_nodes): | |
| # Hidden layers | |
| E_idx_t = E_idx[:,t:t+1,:] | |
| h_E_t = h_E[:,t:t+1,:,:] | |
| h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t) | |
| # Stale relational features for future states | |
| h_ESV_encoder_t = mask_fw[:,t:t+1,:,:] * cat_neighbors_nodes(h_V, h_ES_t, E_idx_t) | |
| for l, layer in enumerate(self.decoder_layers): | |
| # Updated relational features for future states | |
| h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t) | |
| h_V_t = h_V_stack[l][:,t:t+1,:] | |
| h_ESV_t = mask_bw[:,t:t+1,:,:] * h_ESV_decoder_t + h_ESV_encoder_t | |
| h_V_stack[l+1][:,t,:] = layer( | |
| h_V_t, h_ESV_t, mask_V=mask[:,t:t+1] | |
| ).squeeze(1) | |
| # Sampling step | |
| h_V_t = h_V_stack[-1][:,t,:] | |
| logits = self.W_out(h_V_t) | |
| log_probs[:,t,:] = F.log_softmax(logits, dim=-1) | |
| # Update | |
| h_S[:,t,:] = self.W_s(S[:,t]) | |
| return log_probs |