Spaces:
Running on Zero
Running on Zero
| from __future__ import print_function | |
| import itertools | |
| import sys | |
| import numpy as np | |
| import torch | |
| class ProteinMPNN(torch.nn.Module): | |
| def __init__( | |
| self, | |
| num_letters=21, | |
| node_features=128, | |
| edge_features=128, | |
| hidden_dim=128, | |
| num_encoder_layers=3, | |
| num_decoder_layers=3, | |
| vocab=21, | |
| k_neighbors=48, | |
| augment_eps=0.0, | |
| dropout=0.0, | |
| device=None, | |
| atom_context_num=0, | |
| model_type="protein_mpnn", | |
| ligand_mpnn_use_side_chain_context=False, | |
| ): | |
| super(ProteinMPNN, self).__init__() | |
| self.model_type = model_type | |
| self.node_features = node_features | |
| self.edge_features = edge_features | |
| self.hidden_dim = hidden_dim | |
| if self.model_type == "ligand_mpnn": | |
| self.features = ProteinFeaturesLigand( | |
| node_features, | |
| edge_features, | |
| top_k=k_neighbors, | |
| augment_eps=augment_eps, | |
| device=device, | |
| atom_context_num=atom_context_num, | |
| use_side_chains=ligand_mpnn_use_side_chain_context, | |
| ) | |
| self.W_v = torch.nn.Linear(node_features, hidden_dim, bias=True) | |
| self.W_c = torch.nn.Linear(hidden_dim, hidden_dim, bias=True) | |
| self.W_nodes_y = torch.nn.Linear(hidden_dim, hidden_dim, bias=True) | |
| self.W_edges_y = torch.nn.Linear(hidden_dim, hidden_dim, bias=True) | |
| self.V_C = torch.nn.Linear(hidden_dim, hidden_dim, bias=False) | |
| self.V_C_norm = torch.nn.LayerNorm(hidden_dim) | |
| self.context_encoder_layers = torch.nn.ModuleList( | |
| [ | |
| DecLayer(hidden_dim, hidden_dim * 2, dropout=dropout) | |
| for _ in range(2) | |
| ] | |
| ) | |
| self.y_context_encoder_layers = torch.nn.ModuleList( | |
| [DecLayerJ(hidden_dim, hidden_dim, dropout=dropout) for _ in range(2)] | |
| ) | |
| elif self.model_type == "protein_mpnn" or self.model_type == "soluble_mpnn": | |
| self.features = ProteinFeatures( | |
| node_features, edge_features, top_k=k_neighbors, augment_eps=augment_eps | |
| ) | |
| elif ( | |
| self.model_type == "per_residue_label_membrane_mpnn" | |
| or self.model_type == "global_label_membrane_mpnn" | |
| ): | |
| self.W_v = torch.nn.Linear(node_features, hidden_dim, bias=True) | |
| self.features = ProteinFeaturesMembrane( | |
| node_features, | |
| edge_features, | |
| top_k=k_neighbors, | |
| augment_eps=augment_eps, | |
| num_classes=3, | |
| ) | |
| else: | |
| print("Choose --model_type flag from currently available models") | |
| sys.exit() | |
| self.W_e = torch.nn.Linear(edge_features, hidden_dim, bias=True) | |
| self.W_s = torch.nn.Embedding(vocab, hidden_dim) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| # Encoder layers | |
| self.encoder_layers = torch.nn.ModuleList( | |
| [ | |
| EncLayer(hidden_dim, hidden_dim * 2, dropout=dropout) | |
| for _ in range(num_encoder_layers) | |
| ] | |
| ) | |
| # Decoder layers | |
| self.decoder_layers = torch.nn.ModuleList( | |
| [ | |
| DecLayer(hidden_dim, hidden_dim * 3, dropout=dropout) | |
| for _ in range(num_decoder_layers) | |
| ] | |
| ) | |
| self.W_out = torch.nn.Linear(hidden_dim, num_letters, bias=True) | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| torch.nn.init.xavier_uniform_(p) | |
| def encode(self, feature_dict): | |
| # xyz_37 = feature_dict["xyz_37"] #[B,L,37,3] - xyz coordinates for all atoms if needed | |
| # xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords | |
| # Y = feature_dict["Y"] #[B,L,num_context_atoms,3] - for ligandMPNN coords | |
| # Y_t = feature_dict["Y_t"] #[B,L,num_context_atoms] - element type | |
| # Y_m = feature_dict["Y_m"] #[B,L,num_context_atoms] - mask | |
| # X = feature_dict["X"] #[B,L,4,3] - backbone xyz coordinates for N,CA,C,O | |
| S_true = feature_dict[ | |
| "S" | |
| ] # [B,L] - integer protein sequence encoded using "restype_STRtoINT" | |
| # R_idx = feature_dict["R_idx"] #[B,L] - primary sequence residue index | |
| mask = feature_dict[ | |
| "mask" | |
| ] # [B,L] - mask for missing regions - should be removed! all ones most of the time | |
| # chain_labels = feature_dict["chain_labels"] #[B,L] - integer labels for chain letters | |
| B, L = S_true.shape | |
| device = S_true.device | |
| if self.model_type == "ligand_mpnn": | |
| V, E, E_idx, Y_nodes, Y_edges, Y_m = self.features(feature_dict) | |
| h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device) | |
| h_E = self.W_e(E) | |
| h_E_context = self.W_v(V) | |
| mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) | |
| mask_attend = mask.unsqueeze(-1) * mask_attend | |
| for layer in self.encoder_layers: | |
| h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) | |
| h_V_C = self.W_c(h_V) | |
| Y_m_edges = Y_m[:, :, :, None] * Y_m[:, :, None, :] | |
| Y_nodes = self.W_nodes_y(Y_nodes) | |
| Y_edges = self.W_edges_y(Y_edges) | |
| for i in range(len(self.context_encoder_layers)): | |
| Y_nodes = self.y_context_encoder_layers[i]( | |
| Y_nodes, Y_edges, Y_m, Y_m_edges | |
| ) | |
| h_E_context_cat = torch.cat([h_E_context, Y_nodes], -1) | |
| h_V_C = self.context_encoder_layers[i]( | |
| h_V_C, h_E_context_cat, mask, Y_m | |
| ) | |
| h_V_C = self.V_C(h_V_C) | |
| h_V = h_V + self.V_C_norm(self.dropout(h_V_C)) | |
| elif self.model_type == "protein_mpnn" or self.model_type == "soluble_mpnn": | |
| E, E_idx = self.features(feature_dict) | |
| h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device) | |
| h_E = self.W_e(E) | |
| mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) | |
| mask_attend = mask.unsqueeze(-1) * mask_attend | |
| for layer in self.encoder_layers: | |
| h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) | |
| elif ( | |
| self.model_type == "per_residue_label_membrane_mpnn" | |
| or self.model_type == "global_label_membrane_mpnn" | |
| ): | |
| V, E, E_idx = self.features(feature_dict) | |
| h_V = self.W_v(V) | |
| h_E = self.W_e(E) | |
| mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) | |
| mask_attend = mask.unsqueeze(-1) * mask_attend | |
| for layer in self.encoder_layers: | |
| h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend) | |
| return h_V, h_E, E_idx | |
| def sample(self, feature_dict): | |
| # xyz_37 = feature_dict["xyz_37"] #[B,L,37,3] - xyz coordinates for all atoms if needed | |
| # xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords | |
| # Y = feature_dict["Y"] #[B,L,num_context_atoms,3] - for ligandMPNN coords | |
| # Y_t = feature_dict["Y_t"] #[B,L,num_context_atoms] - element type | |
| # Y_m = feature_dict["Y_m"] #[B,L,num_context_atoms] - mask | |
| # X = feature_dict["X"] #[B,L,4,3] - backbone xyz coordinates for N,CA,C,O | |
| B_decoder = feature_dict["batch_size"] | |
| S_true = feature_dict[ | |
| "S" | |
| ] # [B,L] - integer proitein sequence encoded using "restype_STRtoINT" | |
| # R_idx = feature_dict["R_idx"] #[B,L] - primary sequence residue index | |
| mask = feature_dict[ | |
| "mask" | |
| ] # [B,L] - mask for missing regions - should be removed! all ones most of the time | |
| chain_mask = feature_dict[ | |
| "chain_mask" | |
| ] # [B,L] - mask for which residues need to be fixed; 0.0 - fixed; 1.0 - will be designed | |
| bias = feature_dict["bias"] # [B,L,21] - amino acid bias per position | |
| # chain_labels = feature_dict["chain_labels"] #[B,L] - integer labels for chain letters | |
| randn = feature_dict[ | |
| "randn" | |
| ] # [B,L] - random numbers for decoding order; only the first entry is used since decoding within a batch needs to match for symmetry | |
| temperature = feature_dict[ | |
| "temperature" | |
| ] # float - sampling temperature; prob = softmax(logits/temperature) | |
| symmetry_list_of_lists = feature_dict[ | |
| "symmetry_residues" | |
| ] # [[0, 1, 14], [10,11,14,15], [20, 21]] #indices to select X over length - L | |
| symmetry_weights_list_of_lists = feature_dict[ | |
| "symmetry_weights" | |
| ] # [[1.0, 1.0, 1.0], [-2.0,1.1,0.2,1.1], [2.3, 1.1]] | |
| B, L = S_true.shape | |
| device = S_true.device | |
| h_V, h_E, E_idx = self.encode(feature_dict) | |
| chain_mask = mask * chain_mask # update chain_M to include missing regions | |
| decoding_order = torch.argsort( | |
| (chain_mask + 0.0001) * (torch.abs(randn)) | |
| ) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] | |
| if len(symmetry_list_of_lists[0]) == 0 and len(symmetry_list_of_lists) == 1: | |
| E_idx = E_idx.repeat(B_decoder, 1, 1) | |
| permutation_matrix_reverse = torch.nn.functional.one_hot( | |
| decoding_order, num_classes=L | |
| ).float() | |
| order_mask_backward = torch.einsum( | |
| "ij, biq, bjp->bqp", | |
| (1 - torch.triu(torch.ones(L, L, device=device))), | |
| permutation_matrix_reverse, | |
| permutation_matrix_reverse, | |
| ) | |
| mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) | |
| mask_1D = mask.view([B, L, 1, 1]) | |
| mask_bw = mask_1D * mask_attend | |
| mask_fw = mask_1D * (1.0 - mask_attend) | |
| # repeat for decoding | |
| S_true = S_true.repeat(B_decoder, 1) | |
| h_V = h_V.repeat(B_decoder, 1, 1) | |
| h_E = h_E.repeat(B_decoder, 1, 1, 1) | |
| chain_mask = chain_mask.repeat(B_decoder, 1) | |
| mask = mask.repeat(B_decoder, 1) | |
| bias = bias.repeat(B_decoder, 1, 1) | |
| all_probs = torch.zeros( | |
| (B_decoder, L, 20), device=device, dtype=torch.float32 | |
| ) | |
| all_log_probs = torch.zeros( | |
| (B_decoder, L, 21), device=device, dtype=torch.float32 | |
| ) | |
| h_S = torch.zeros_like(h_V, device=device) | |
| S = 20 * torch.ones((B_decoder, L), dtype=torch.int64, device=device) | |
| h_V_stack = [h_V] + [ | |
| torch.zeros_like(h_V, device=device) | |
| for _ in range(len(self.decoder_layers)) | |
| ] | |
| h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) | |
| h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) | |
| h_EXV_encoder_fw = mask_fw * h_EXV_encoder | |
| for t_ in range(L): | |
| t = decoding_order[:, t_] # [B] | |
| chain_mask_t = torch.gather(chain_mask, 1, t[:, None])[:, 0] # [B] | |
| mask_t = torch.gather(mask, 1, t[:, None])[:, 0] # [B] | |
| bias_t = torch.gather(bias, 1, t[:, None, None].repeat(1, 1, 21))[ | |
| :, 0, : | |
| ] # [B,21] | |
| E_idx_t = torch.gather( | |
| E_idx, 1, t[:, None, None].repeat(1, 1, E_idx.shape[-1]) | |
| ) | |
| h_E_t = torch.gather( | |
| h_E, | |
| 1, | |
| t[:, None, None, None].repeat(1, 1, h_E.shape[-2], h_E.shape[-1]), | |
| ) | |
| h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t) | |
| h_EXV_encoder_t = torch.gather( | |
| h_EXV_encoder_fw, | |
| 1, | |
| t[:, None, None, None].repeat( | |
| 1, 1, h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1] | |
| ), | |
| ) | |
| mask_bw_t = torch.gather( | |
| mask_bw, | |
| 1, | |
| t[:, None, None, None].repeat( | |
| 1, 1, mask_bw.shape[-2], mask_bw.shape[-1] | |
| ), | |
| ) | |
| for l, layer in enumerate(self.decoder_layers): | |
| h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t) | |
| h_V_t = torch.gather( | |
| h_V_stack[l], | |
| 1, | |
| t[:, None, None].repeat(1, 1, h_V_stack[l].shape[-1]), | |
| ) | |
| h_ESV_t = mask_bw_t * h_ESV_decoder_t + h_EXV_encoder_t | |
| h_V_stack[l + 1].scatter_( | |
| 1, | |
| t[:, None, None].repeat(1, 1, h_V.shape[-1]), | |
| layer(h_V_t, h_ESV_t, mask_V=mask_t), | |
| ) | |
| h_V_t = torch.gather( | |
| h_V_stack[-1], | |
| 1, | |
| t[:, None, None].repeat(1, 1, h_V_stack[-1].shape[-1]), | |
| )[:, 0] | |
| logits = self.W_out(h_V_t) # [B,21] | |
| log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # [B,21] | |
| probs = torch.nn.functional.softmax( | |
| (logits + bias_t) / temperature, dim=-1 | |
| ) # [B,21] | |
| probs_sample = probs[:, :20] / torch.sum( | |
| probs[:, :20], dim=-1, keepdim=True | |
| ) # hard omit X #[B,20] | |
| S_t = torch.multinomial(probs_sample, 1)[:, 0] # [B] | |
| all_probs.scatter_( | |
| 1, | |
| t[:, None, None].repeat(1, 1, 20), | |
| (chain_mask_t[:, None, None] * probs_sample[:, None, :]).float(), | |
| ) | |
| all_log_probs.scatter_( | |
| 1, | |
| t[:, None, None].repeat(1, 1, 21), | |
| (chain_mask_t[:, None, None] * log_probs[:, None, :]).float(), | |
| ) | |
| S_true_t = torch.gather(S_true, 1, t[:, None])[:, 0] | |
| S_t = (S_t * chain_mask_t + S_true_t * (1.0 - chain_mask_t)).long() | |
| h_S.scatter_( | |
| 1, | |
| t[:, None, None].repeat(1, 1, h_S.shape[-1]), | |
| self.W_s(S_t)[:, None, :], | |
| ) | |
| S.scatter_(1, t[:, None], S_t[:, None]) | |
| output_dict = { | |
| "S": S, | |
| "sampling_probs": all_probs, | |
| "log_probs": all_log_probs, | |
| "decoding_order": decoding_order, | |
| } | |
| else: | |
| # weights for symmetric design | |
| symmetry_weights = torch.ones([L], device=device, dtype=torch.float32) | |
| for i1, item_list in enumerate(symmetry_list_of_lists): | |
| for i2, item in enumerate(item_list): | |
| symmetry_weights[item] = symmetry_weights_list_of_lists[i1][i2] | |
| new_decoding_order = [] | |
| for t_dec in list(decoding_order[0,].cpu().data.numpy()): | |
| if t_dec not in list(itertools.chain(*new_decoding_order)): | |
| list_a = [item for item in symmetry_list_of_lists if t_dec in item] | |
| if list_a: | |
| new_decoding_order.append(list_a[0]) | |
| else: | |
| new_decoding_order.append([t_dec]) | |
| decoding_order = torch.tensor( | |
| list(itertools.chain(*new_decoding_order)), device=device | |
| )[None,].repeat(B, 1) | |
| permutation_matrix_reverse = torch.nn.functional.one_hot( | |
| decoding_order, num_classes=L | |
| ).float() | |
| order_mask_backward = torch.einsum( | |
| "ij, biq, bjp->bqp", | |
| (1 - torch.triu(torch.ones(L, L, device=device))), | |
| permutation_matrix_reverse, | |
| permutation_matrix_reverse, | |
| ) | |
| mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) | |
| mask_1D = mask.view([B, L, 1, 1]) | |
| mask_bw = mask_1D * mask_attend | |
| mask_fw = mask_1D * (1.0 - mask_attend) | |
| # repeat for decoding | |
| S_true = S_true.repeat(B_decoder, 1) | |
| h_V = h_V.repeat(B_decoder, 1, 1) | |
| h_E = h_E.repeat(B_decoder, 1, 1, 1) | |
| E_idx = E_idx.repeat(B_decoder, 1, 1) | |
| mask_fw = mask_fw.repeat(B_decoder, 1, 1, 1) | |
| mask_bw = mask_bw.repeat(B_decoder, 1, 1, 1) | |
| chain_mask = chain_mask.repeat(B_decoder, 1) | |
| mask = mask.repeat(B_decoder, 1) | |
| bias = bias.repeat(B_decoder, 1, 1) | |
| all_probs = torch.zeros( | |
| (B_decoder, L, 20), device=device, dtype=torch.float32 | |
| ) | |
| all_log_probs = torch.zeros( | |
| (B_decoder, L, 21), device=device, dtype=torch.float32 | |
| ) | |
| h_S = torch.zeros_like(h_V, device=device) | |
| S = 20 * torch.ones((B_decoder, L), dtype=torch.int64, device=device) | |
| h_V_stack = [h_V] + [ | |
| torch.zeros_like(h_V, device=device) | |
| for _ in range(len(self.decoder_layers)) | |
| ] | |
| h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) | |
| h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) | |
| h_EXV_encoder_fw = mask_fw * h_EXV_encoder | |
| for t_list in new_decoding_order: | |
| total_logits = 0.0 | |
| for t in t_list: | |
| chain_mask_t = chain_mask[:, t] # [B] | |
| mask_t = mask[:, t] # [B] | |
| bias_t = bias[:, t] # [B, 21] | |
| E_idx_t = E_idx[:, t : t + 1] | |
| h_E_t = h_E[:, t : t + 1] | |
| h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t) | |
| h_EXV_encoder_t = h_EXV_encoder_fw[:, t : t + 1] | |
| for l, layer in enumerate(self.decoder_layers): | |
| h_ESV_decoder_t = cat_neighbors_nodes( | |
| h_V_stack[l], h_ES_t, E_idx_t | |
| ) | |
| h_V_t = h_V_stack[l][:, t : t + 1] | |
| h_ESV_t = ( | |
| mask_bw[:, t : t + 1] * h_ESV_decoder_t + h_EXV_encoder_t | |
| ) | |
| h_V_stack[l + 1][:, t : t + 1, :] = layer( | |
| h_V_t, h_ESV_t, mask_V=mask_t[:, None] | |
| ) | |
| h_V_t = h_V_stack[-1][:, t] | |
| logits = self.W_out(h_V_t) # [B,21] | |
| log_probs = torch.nn.functional.log_softmax( | |
| logits, dim=-1 | |
| ) # [B,21] | |
| all_log_probs[:, t] = ( | |
| chain_mask_t[:, None] * log_probs | |
| ).float() # [B,21] | |
| total_logits += symmetry_weights[t] * logits | |
| probs = torch.nn.functional.softmax( | |
| (total_logits + bias_t) / temperature, dim=-1 | |
| ) # [B,21] | |
| probs_sample = probs[:, :20] / torch.sum( | |
| probs[:, :20], dim=-1, keepdim=True | |
| ) # hard omit X #[B,20] | |
| S_t = torch.multinomial(probs_sample, 1)[:, 0] # [B] | |
| for t in t_list: | |
| chain_mask_t = chain_mask[:, t] # [B] | |
| all_probs[:, t] = ( | |
| chain_mask_t[:, None] * probs_sample | |
| ).float() # [B,20] | |
| S_true_t = S_true[:, t] # [B] | |
| S_t = (S_t * chain_mask_t + S_true_t * (1.0 - chain_mask_t)).long() | |
| h_S[:, t] = self.W_s(S_t) | |
| S[:, t] = S_t | |
| output_dict = { | |
| "S": S, | |
| "sampling_probs": all_probs, | |
| "log_probs": all_log_probs, | |
| "decoding_order": decoding_order.repeat(B_decoder, 1), | |
| } | |
| return output_dict | |
| def single_aa_score(self, feature_dict, use_sequence: bool): | |
| """ | |
| feature_dict - input features | |
| use_sequence - False using backbone info only | |
| """ | |
| B_decoder = feature_dict["batch_size"] | |
| S_true_enc = feature_dict[ | |
| "S" | |
| ] | |
| mask_enc = feature_dict[ | |
| "mask" | |
| ] | |
| chain_mask_enc = feature_dict[ | |
| "chain_mask" | |
| ] | |
| randn = feature_dict[ | |
| "randn" | |
| ] | |
| B, L = S_true_enc.shape | |
| device = S_true_enc.device | |
| h_V_enc, h_E_enc, E_idx_enc = self.encode(feature_dict) | |
| log_probs_out = torch.zeros([B_decoder, L, 21], device=device).float() | |
| logits_out = torch.zeros([B_decoder, L, 21], device=device).float() | |
| decoding_order_out = torch.zeros([B_decoder, L, L], device=device).float() | |
| for idx in range(L): | |
| h_V = torch.clone(h_V_enc) | |
| E_idx = torch.clone(E_idx_enc) | |
| mask = torch.clone(mask_enc) | |
| S_true = torch.clone(S_true_enc) | |
| if not use_sequence: | |
| order_mask = torch.zeros(chain_mask_enc.shape[1], device=device).float() | |
| order_mask[idx] = 1. | |
| else: | |
| order_mask = torch.ones(chain_mask_enc.shape[1], device=device).float() | |
| order_mask[idx] = 0. | |
| decoding_order = torch.argsort( | |
| (order_mask + 0.0001) * (torch.abs(randn)) | |
| ) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] | |
| E_idx = E_idx.repeat(B_decoder, 1, 1) | |
| permutation_matrix_reverse = torch.nn.functional.one_hot( | |
| decoding_order, num_classes=L | |
| ).float() | |
| order_mask_backward = torch.einsum( | |
| "ij, biq, bjp->bqp", | |
| (1 - torch.triu(torch.ones(L, L, device=device))), | |
| permutation_matrix_reverse, | |
| permutation_matrix_reverse, | |
| ) | |
| mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) | |
| mask_1D = mask.view([B, L, 1, 1]) | |
| mask_bw = mask_1D * mask_attend | |
| mask_fw = mask_1D * (1.0 - mask_attend) | |
| S_true = S_true.repeat(B_decoder, 1) | |
| h_V = h_V.repeat(B_decoder, 1, 1) | |
| h_E = h_E_enc.repeat(B_decoder, 1, 1, 1) | |
| mask = mask.repeat(B_decoder, 1) | |
| h_S = self.W_s(S_true) | |
| h_ES = cat_neighbors_nodes(h_S, h_E, E_idx) | |
| # Build encoder embeddings | |
| h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) | |
| h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) | |
| h_EXV_encoder_fw = mask_fw * h_EXV_encoder | |
| for layer in self.decoder_layers: | |
| # Masked positions attend to encoder information, unmasked see. | |
| h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx) | |
| h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw | |
| h_V = layer(h_V, h_ESV, mask) | |
| logits = self.W_out(h_V) | |
| log_probs = torch.nn.functional.log_softmax(logits, dim=-1) | |
| log_probs_out[:,idx,:] = log_probs[:,idx,:] | |
| logits_out[:,idx,:] = logits[:,idx,:] | |
| decoding_order_out[:,idx,:] = decoding_order | |
| output_dict = { | |
| "S": S_true, | |
| "log_probs": log_probs_out, | |
| "logits": logits_out, | |
| "decoding_order": decoding_order_out, | |
| } | |
| return output_dict | |
| def score(self, feature_dict, use_sequence: bool): | |
| B_decoder = feature_dict["batch_size"] | |
| S_true = feature_dict[ | |
| "S" | |
| ] | |
| mask = feature_dict[ | |
| "mask" | |
| ] | |
| chain_mask = feature_dict[ | |
| "chain_mask" | |
| ] | |
| randn = feature_dict[ | |
| "randn" | |
| ] | |
| symmetry_list_of_lists = feature_dict[ | |
| "symmetry_residues" | |
| ] | |
| B, L = S_true.shape | |
| device = S_true.device | |
| h_V, h_E, E_idx = self.encode(feature_dict) | |
| chain_mask = mask * chain_mask # update chain_M to include missing regions | |
| decoding_order = torch.argsort( | |
| (chain_mask + 0.0001) * (torch.abs(randn)) | |
| ) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0] | |
| if len(symmetry_list_of_lists[0]) == 0 and len(symmetry_list_of_lists) == 1: | |
| E_idx = E_idx.repeat(B_decoder, 1, 1) | |
| permutation_matrix_reverse = torch.nn.functional.one_hot( | |
| decoding_order, num_classes=L | |
| ).float() | |
| order_mask_backward = torch.einsum( | |
| "ij, biq, bjp->bqp", | |
| (1 - torch.triu(torch.ones(L, L, device=device))), | |
| permutation_matrix_reverse, | |
| permutation_matrix_reverse, | |
| ) | |
| mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) | |
| mask_1D = mask.view([B, L, 1, 1]) | |
| mask_bw = mask_1D * mask_attend | |
| mask_fw = mask_1D * (1.0 - mask_attend) | |
| else: | |
| new_decoding_order = [] | |
| for t_dec in list(decoding_order[0,].cpu().data.numpy()): | |
| if t_dec not in list(itertools.chain(*new_decoding_order)): | |
| list_a = [item for item in symmetry_list_of_lists if t_dec in item] | |
| if list_a: | |
| new_decoding_order.append(list_a[0]) | |
| else: | |
| new_decoding_order.append([t_dec]) | |
| decoding_order = torch.tensor( | |
| list(itertools.chain(*new_decoding_order)), device=device | |
| )[None,].repeat(B, 1) | |
| permutation_matrix_reverse = torch.nn.functional.one_hot( | |
| decoding_order, num_classes=L | |
| ).float() | |
| order_mask_backward = torch.einsum( | |
| "ij, biq, bjp->bqp", | |
| (1 - torch.triu(torch.ones(L, L, device=device))), | |
| permutation_matrix_reverse, | |
| permutation_matrix_reverse, | |
| ) | |
| mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) | |
| mask_1D = mask.view([B, L, 1, 1]) | |
| mask_bw = mask_1D * mask_attend | |
| mask_fw = mask_1D * (1.0 - mask_attend) | |
| E_idx = E_idx.repeat(B_decoder, 1, 1) | |
| mask_fw = mask_fw.repeat(B_decoder, 1, 1, 1) | |
| mask_bw = mask_bw.repeat(B_decoder, 1, 1, 1) | |
| decoding_order = decoding_order.repeat(B_decoder, 1) | |
| S_true = S_true.repeat(B_decoder, 1) | |
| h_V = h_V.repeat(B_decoder, 1, 1) | |
| h_E = h_E.repeat(B_decoder, 1, 1, 1) | |
| mask = mask.repeat(B_decoder, 1) | |
| h_S = self.W_s(S_true) | |
| h_ES = cat_neighbors_nodes(h_S, h_E, E_idx) | |
| # Build encoder embeddings | |
| h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx) | |
| h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx) | |
| h_EXV_encoder_fw = mask_fw * h_EXV_encoder | |
| if not use_sequence: | |
| for layer in self.decoder_layers: | |
| h_V = layer(h_V, h_EXV_encoder_fw, mask) | |
| else: | |
| for layer in self.decoder_layers: | |
| # Masked positions attend to encoder information, unmasked see. | |
| h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx) | |
| h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw | |
| h_V = layer(h_V, h_ESV, mask) | |
| logits = self.W_out(h_V) | |
| log_probs = torch.nn.functional.log_softmax(logits, dim=-1) | |
| output_dict = { | |
| "S": S_true, | |
| "log_probs": log_probs, | |
| "logits": logits, | |
| "decoding_order": decoding_order, | |
| } | |
| return output_dict | |
| class ProteinFeaturesLigand(torch.nn.Module): | |
| def __init__( | |
| self, | |
| edge_features, | |
| node_features, | |
| num_positional_embeddings=16, | |
| num_rbf=16, | |
| top_k=30, | |
| augment_eps=0.0, | |
| device=None, | |
| atom_context_num=16, | |
| use_side_chains=False, | |
| ): | |
| """Extract protein features""" | |
| super(ProteinFeaturesLigand, self).__init__() | |
| self.use_side_chains = use_side_chains | |
| self.edge_features = edge_features | |
| self.node_features = node_features | |
| self.top_k = top_k | |
| self.augment_eps = augment_eps | |
| self.num_rbf = num_rbf | |
| self.num_positional_embeddings = num_positional_embeddings | |
| self.embeddings = PositionalEncodings(num_positional_embeddings) | |
| edge_in = num_positional_embeddings + num_rbf * 25 | |
| self.edge_embedding = torch.nn.Linear(edge_in, edge_features, bias=False) | |
| self.norm_edges = torch.nn.LayerNorm(edge_features) | |
| self.node_project_down = torch.nn.Linear( | |
| 5 * num_rbf + 64 + 4, node_features, bias=True | |
| ) | |
| self.norm_nodes = torch.nn.LayerNorm(node_features) | |
| self.type_linear = torch.nn.Linear(147, 64) | |
| self.y_nodes = torch.nn.Linear(147, node_features, bias=False) | |
| self.y_edges = torch.nn.Linear(num_rbf, node_features, bias=False) | |
| self.norm_y_edges = torch.nn.LayerNorm(node_features) | |
| self.norm_y_nodes = torch.nn.LayerNorm(node_features) | |
| self.atom_context_num = atom_context_num | |
| # the last 32 atoms in the 37 atom representation | |
| self.side_chain_atom_types = torch.tensor( | |
| [ | |
| 6, | |
| 6, | |
| 6, | |
| 8, | |
| 8, | |
| 16, | |
| 6, | |
| 6, | |
| 6, | |
| 7, | |
| 7, | |
| 8, | |
| 8, | |
| 16, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 7, | |
| 7, | |
| 7, | |
| 8, | |
| 8, | |
| 6, | |
| 7, | |
| 7, | |
| 8, | |
| 6, | |
| 6, | |
| 6, | |
| 7, | |
| 8, | |
| ], | |
| device=device, | |
| ) | |
| self.periodic_table_features = torch.tensor( | |
| [ | |
| [ | |
| 0, | |
| 1, | |
| 2, | |
| 3, | |
| 4, | |
| 5, | |
| 6, | |
| 7, | |
| 8, | |
| 9, | |
| 10, | |
| 11, | |
| 12, | |
| 13, | |
| 14, | |
| 15, | |
| 16, | |
| 17, | |
| 18, | |
| 19, | |
| 20, | |
| 21, | |
| 22, | |
| 23, | |
| 24, | |
| 25, | |
| 26, | |
| 27, | |
| 28, | |
| 29, | |
| 30, | |
| 31, | |
| 32, | |
| 33, | |
| 34, | |
| 35, | |
| 36, | |
| 37, | |
| 38, | |
| 39, | |
| 40, | |
| 41, | |
| 42, | |
| 43, | |
| 44, | |
| 45, | |
| 46, | |
| 47, | |
| 48, | |
| 49, | |
| 50, | |
| 51, | |
| 52, | |
| 53, | |
| 54, | |
| 55, | |
| 56, | |
| 57, | |
| 58, | |
| 59, | |
| 60, | |
| 61, | |
| 62, | |
| 63, | |
| 64, | |
| 65, | |
| 66, | |
| 67, | |
| 68, | |
| 69, | |
| 70, | |
| 71, | |
| 72, | |
| 73, | |
| 74, | |
| 75, | |
| 76, | |
| 77, | |
| 78, | |
| 79, | |
| 80, | |
| 81, | |
| 82, | |
| 83, | |
| 84, | |
| 85, | |
| 86, | |
| 87, | |
| 88, | |
| 89, | |
| 90, | |
| 91, | |
| 92, | |
| 93, | |
| 94, | |
| 95, | |
| 96, | |
| 97, | |
| 98, | |
| 99, | |
| 100, | |
| 101, | |
| 102, | |
| 103, | |
| 104, | |
| 105, | |
| 106, | |
| 107, | |
| 108, | |
| 109, | |
| 110, | |
| 111, | |
| 112, | |
| 113, | |
| 114, | |
| 115, | |
| 116, | |
| 117, | |
| 118, | |
| ], | |
| [ | |
| 0, | |
| 1, | |
| 18, | |
| 1, | |
| 2, | |
| 13, | |
| 14, | |
| 15, | |
| 16, | |
| 17, | |
| 18, | |
| 1, | |
| 2, | |
| 13, | |
| 14, | |
| 15, | |
| 16, | |
| 17, | |
| 18, | |
| 1, | |
| 2, | |
| 3, | |
| 4, | |
| 5, | |
| 6, | |
| 7, | |
| 8, | |
| 9, | |
| 10, | |
| 11, | |
| 12, | |
| 13, | |
| 14, | |
| 15, | |
| 16, | |
| 17, | |
| 18, | |
| 1, | |
| 2, | |
| 3, | |
| 4, | |
| 5, | |
| 6, | |
| 7, | |
| 8, | |
| 9, | |
| 10, | |
| 11, | |
| 12, | |
| 13, | |
| 14, | |
| 15, | |
| 16, | |
| 17, | |
| 18, | |
| 1, | |
| 2, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 4, | |
| 5, | |
| 6, | |
| 7, | |
| 8, | |
| 9, | |
| 10, | |
| 11, | |
| 12, | |
| 13, | |
| 14, | |
| 15, | |
| 16, | |
| 17, | |
| 18, | |
| 1, | |
| 2, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 4, | |
| 5, | |
| 6, | |
| 7, | |
| 8, | |
| 9, | |
| 10, | |
| 11, | |
| 12, | |
| 13, | |
| 14, | |
| 15, | |
| 16, | |
| 17, | |
| 18, | |
| ], | |
| [ | |
| 0, | |
| 1, | |
| 1, | |
| 2, | |
| 2, | |
| 2, | |
| 2, | |
| 2, | |
| 2, | |
| 2, | |
| 2, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 3, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 4, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 5, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 6, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| 7, | |
| ], | |
| ], | |
| dtype=torch.long, | |
| device=device, | |
| ) | |
| def _make_angle_features(self, A, B, C, Y): | |
| v1 = A - B | |
| v2 = C - B | |
| e1 = torch.nn.functional.normalize(v1, dim=-1) | |
| e1_v2_dot = torch.einsum("bli, bli -> bl", e1, v2)[..., None] | |
| u2 = v2 - e1 * e1_v2_dot | |
| e2 = torch.nn.functional.normalize(u2, dim=-1) | |
| e3 = torch.cross(e1, e2, dim=-1) | |
| R_residue = torch.cat( | |
| (e1[:, :, :, None], e2[:, :, :, None], e3[:, :, :, None]), dim=-1 | |
| ) | |
| local_vectors = torch.einsum( | |
| "blqp, blyq -> blyp", R_residue, Y - B[:, :, None, :] | |
| ) | |
| rxy = torch.sqrt(local_vectors[..., 0] ** 2 + local_vectors[..., 1] ** 2 + 1e-8) | |
| f1 = local_vectors[..., 0] / rxy | |
| f2 = local_vectors[..., 1] / rxy | |
| rxyz = torch.norm(local_vectors, dim=-1) + 1e-8 | |
| f3 = rxy / rxyz | |
| f4 = local_vectors[..., 2] / rxyz | |
| f = torch.cat([f1[..., None], f2[..., None], f3[..., None], f4[..., None]], -1) | |
| return f | |
| def _dist(self, X, mask, eps=1e-6): | |
| mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2) | |
| dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2) | |
| D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps) | |
| D_max, _ = torch.max(D, -1, keepdim=True) | |
| D_adjust = D + (1.0 - mask_2D) * D_max | |
| D_neighbors, E_idx = torch.topk( | |
| D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False | |
| ) | |
| return D_neighbors, E_idx | |
| def _rbf(self, D): | |
| device = D.device | |
| D_min, D_max, D_count = 2.0, 22.0, self.num_rbf | |
| D_mu = torch.linspace(D_min, D_max, D_count, device=device) | |
| D_mu = D_mu.view([1, 1, 1, -1]) | |
| D_sigma = (D_max - D_min) / D_count | |
| D_expand = torch.unsqueeze(D, -1) | |
| RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2)) | |
| return RBF | |
| def _get_rbf(self, A, B, E_idx): | |
| D_A_B = torch.sqrt( | |
| torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6 | |
| ) # [B, L, L] | |
| D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[ | |
| :, :, :, 0 | |
| ] # [B,L,K] | |
| RBF_A_B = self._rbf(D_A_B_neighbors) | |
| return RBF_A_B | |
| def forward(self, input_features): | |
| Y = input_features["Y"] | |
| Y_m = input_features["Y_m"] | |
| Y_t = input_features["Y_t"] | |
| X = input_features["X"] | |
| mask = input_features["mask"] | |
| R_idx = input_features["R_idx"] | |
| chain_labels = input_features["chain_labels"] | |
| if self.augment_eps > 0: | |
| X = X + self.augment_eps * torch.randn_like(X) | |
| Y = Y + self.augment_eps * torch.randn_like(Y) | |
| B, L, _, _ = X.shape | |
| Ca = X[:, :, 1, :] | |
| N = X[:, :, 0, :] | |
| C = X[:, :, 2, :] | |
| O = X[:, :, 3, :] | |
| b = Ca - N | |
| c = C - Ca | |
| a = torch.cross(b, c, dim=-1) | |
| Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca # shift from CA | |
| D_neighbors, E_idx = self._dist(Ca, mask) | |
| RBF_all = [] | |
| RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca | |
| RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N | |
| RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C | |
| RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O | |
| RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb | |
| RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N | |
| RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C | |
| RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O | |
| RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb | |
| RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C | |
| RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O | |
| RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb | |
| RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C | |
| RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O | |
| RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C | |
| RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca | |
| RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca | |
| RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca | |
| RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca | |
| RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N | |
| RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N | |
| RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N | |
| RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb | |
| RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb | |
| RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O | |
| RBF_all = torch.cat(tuple(RBF_all), dim=-1) | |
| offset = R_idx[:, :, None] - R_idx[:, None, :] | |
| offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K] | |
| d_chains = ( | |
| (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0 | |
| ).long() # find self vs non-self interaction | |
| E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0] | |
| E_positional = self.embeddings(offset.long(), E_chains) | |
| E = torch.cat((E_positional, RBF_all), -1) | |
| E = self.edge_embedding(E) | |
| E = self.norm_edges(E) | |
| if self.use_side_chains: | |
| xyz_37 = input_features["xyz_37"] | |
| xyz_37_m = input_features["xyz_37_m"] | |
| E_idx_sub = E_idx[:, :, :16] # [B, L, 15] | |
| mask_residues = input_features["chain_mask"] | |
| xyz_37_m = xyz_37_m * (1 - mask_residues[:, :, None]) | |
| R_m = gather_nodes(xyz_37_m[:, :, 5:], E_idx_sub) | |
| X_sidechain = xyz_37[:, :, 5:, :].view(B, L, -1) | |
| R = gather_nodes(X_sidechain, E_idx_sub).view( | |
| B, L, E_idx_sub.shape[2], -1, 3 | |
| ) | |
| R_t = self.side_chain_atom_types[None, None, None, :].repeat( | |
| B, L, E_idx_sub.shape[2], 1 | |
| ) | |
| # Side chain atom context | |
| R = R.view(B, L, -1, 3) # coordinates | |
| R_m = R_m.view(B, L, -1) # mask | |
| R_t = R_t.view(B, L, -1) # atom types | |
| # Ligand atom context | |
| Y = torch.cat((R, Y), 2) # [B, L, atoms, 3] | |
| Y_m = torch.cat((R_m, Y_m), 2) # [B, L, atoms] | |
| Y_t = torch.cat((R_t, Y_t), 2) # [B, L, atoms] | |
| Cb_Y_distances = torch.sum((Cb[:, :, None, :] - Y) ** 2, -1) | |
| mask_Y = mask[:, :, None] * Y_m | |
| Cb_Y_distances_adjusted = Cb_Y_distances * mask_Y + (1.0 - mask_Y) * 10000.0 | |
| _, E_idx_Y = torch.topk( | |
| Cb_Y_distances_adjusted, self.atom_context_num, dim=-1, largest=False | |
| ) | |
| Y = torch.gather(Y, 2, E_idx_Y[:, :, :, None].repeat(1, 1, 1, 3)) | |
| Y_t = torch.gather(Y_t, 2, E_idx_Y) | |
| Y_m = torch.gather(Y_m, 2, E_idx_Y) | |
| Y_t = Y_t.long() | |
| Y_t_g = self.periodic_table_features[1][Y_t] # group; 19 categories including 0 | |
| Y_t_p = self.periodic_table_features[2][Y_t] # period; 8 categories including 0 | |
| Y_t_g_1hot_ = torch.nn.functional.one_hot(Y_t_g, 19) # [B, L, M, 19] | |
| Y_t_p_1hot_ = torch.nn.functional.one_hot(Y_t_p, 8) # [B, L, M, 8] | |
| Y_t_1hot_ = torch.nn.functional.one_hot(Y_t, 120) # [B, L, M, 120] | |
| Y_t_1hot_ = torch.cat( | |
| [Y_t_1hot_, Y_t_g_1hot_, Y_t_p_1hot_], -1 | |
| ) # [B, L, M, 147] | |
| Y_t_1hot = self.type_linear(Y_t_1hot_.float()) | |
| D_N_Y = self._rbf( | |
| torch.sqrt(torch.sum((N[:, :, None, :] - Y) ** 2, -1) + 1e-6) | |
| ) # [B, L, M, num_bins] | |
| D_Ca_Y = self._rbf( | |
| torch.sqrt(torch.sum((Ca[:, :, None, :] - Y) ** 2, -1) + 1e-6) | |
| ) | |
| D_C_Y = self._rbf(torch.sqrt(torch.sum((C[:, :, None, :] - Y) ** 2, -1) + 1e-6)) | |
| D_O_Y = self._rbf(torch.sqrt(torch.sum((O[:, :, None, :] - Y) ** 2, -1) + 1e-6)) | |
| D_Cb_Y = self._rbf( | |
| torch.sqrt(torch.sum((Cb[:, :, None, :] - Y) ** 2, -1) + 1e-6) | |
| ) | |
| f_angles = self._make_angle_features(N, Ca, C, Y) # [B, L, M, 4] | |
| D_all = torch.cat( | |
| (D_N_Y, D_Ca_Y, D_C_Y, D_O_Y, D_Cb_Y, Y_t_1hot, f_angles), dim=-1 | |
| ) # [B,L,M,5*num_bins+5] | |
| V = self.node_project_down(D_all) # [B, L, M, node_features] | |
| V = self.norm_nodes(V) | |
| Y_edges = self._rbf( | |
| torch.sqrt( | |
| torch.sum((Y[:, :, :, None, :] - Y[:, :, None, :, :]) ** 2, -1) + 1e-6 | |
| ) | |
| ) # [B, L, M, M, num_bins] | |
| Y_edges = self.y_edges(Y_edges) | |
| Y_nodes = self.y_nodes(Y_t_1hot_.float()) | |
| Y_edges = self.norm_y_edges(Y_edges) | |
| Y_nodes = self.norm_y_nodes(Y_nodes) | |
| return V, E, E_idx, Y_nodes, Y_edges, Y_m | |
| class ProteinFeatures(torch.nn.Module): | |
| def __init__( | |
| self, | |
| edge_features, | |
| node_features, | |
| num_positional_embeddings=16, | |
| num_rbf=16, | |
| top_k=48, | |
| augment_eps=0.0, | |
| ): | |
| """Extract protein features""" | |
| super(ProteinFeatures, self).__init__() | |
| self.edge_features = edge_features | |
| self.node_features = node_features | |
| self.top_k = top_k | |
| self.augment_eps = augment_eps | |
| self.num_rbf = num_rbf | |
| self.num_positional_embeddings = num_positional_embeddings | |
| self.embeddings = PositionalEncodings(num_positional_embeddings) | |
| edge_in = num_positional_embeddings + num_rbf * 25 | |
| self.edge_embedding = torch.nn.Linear(edge_in, edge_features, bias=False) | |
| self.norm_edges = torch.nn.LayerNorm(edge_features) | |
| def _dist(self, X, mask, eps=1e-6): | |
| mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2) | |
| dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2) | |
| D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps) | |
| D_max, _ = torch.max(D, -1, keepdim=True) | |
| D_adjust = D + (1.0 - mask_2D) * D_max | |
| D_neighbors, E_idx = torch.topk( | |
| D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False | |
| ) | |
| return D_neighbors, E_idx | |
| def _rbf(self, D): | |
| device = D.device | |
| D_min, D_max, D_count = 2.0, 22.0, self.num_rbf | |
| D_mu = torch.linspace(D_min, D_max, D_count, device=device) | |
| D_mu = D_mu.view([1, 1, 1, -1]) | |
| D_sigma = (D_max - D_min) / D_count | |
| D_expand = torch.unsqueeze(D, -1) | |
| RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2)) | |
| return RBF | |
| def _get_rbf(self, A, B, E_idx): | |
| D_A_B = torch.sqrt( | |
| torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6 | |
| ) # [B, L, L] | |
| D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[ | |
| :, :, :, 0 | |
| ] # [B,L,K] | |
| RBF_A_B = self._rbf(D_A_B_neighbors) | |
| return RBF_A_B | |
| def forward(self, input_features): | |
| X = input_features["X"] | |
| mask = input_features["mask"] | |
| R_idx = input_features["R_idx"] | |
| chain_labels = input_features["chain_labels"] | |
| if self.augment_eps > 0: | |
| X = X + self.augment_eps * torch.randn_like(X) | |
| b = X[:, :, 1, :] - X[:, :, 0, :] | |
| c = X[:, :, 2, :] - X[:, :, 1, :] | |
| a = torch.cross(b, c, dim=-1) | |
| Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + X[:, :, 1, :] | |
| Ca = X[:, :, 1, :] | |
| N = X[:, :, 0, :] | |
| C = X[:, :, 2, :] | |
| O = X[:, :, 3, :] | |
| D_neighbors, E_idx = self._dist(Ca, mask) | |
| RBF_all = [] | |
| RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca | |
| RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N | |
| RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C | |
| RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O | |
| RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb | |
| RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N | |
| RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C | |
| RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O | |
| RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb | |
| RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C | |
| RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O | |
| RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb | |
| RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C | |
| RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O | |
| RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C | |
| RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca | |
| RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca | |
| RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca | |
| RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca | |
| RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N | |
| RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N | |
| RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N | |
| RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb | |
| RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb | |
| RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O | |
| RBF_all = torch.cat(tuple(RBF_all), dim=-1) | |
| offset = R_idx[:, :, None] - R_idx[:, None, :] | |
| offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K] | |
| d_chains = ( | |
| (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0 | |
| ).long() # find self vs non-self interaction | |
| E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0] | |
| E_positional = self.embeddings(offset.long(), E_chains) | |
| E = torch.cat((E_positional, RBF_all), -1) | |
| E = self.edge_embedding(E) | |
| E = self.norm_edges(E) | |
| return E, E_idx | |
| class ProteinFeaturesMembrane(torch.nn.Module): | |
| def __init__( | |
| self, | |
| edge_features, | |
| node_features, | |
| num_positional_embeddings=16, | |
| num_rbf=16, | |
| top_k=48, | |
| augment_eps=0.0, | |
| num_classes=3, | |
| ): | |
| """Extract protein features""" | |
| super(ProteinFeaturesMembrane, self).__init__() | |
| self.edge_features = edge_features | |
| self.node_features = node_features | |
| self.top_k = top_k | |
| self.augment_eps = augment_eps | |
| self.num_rbf = num_rbf | |
| self.num_positional_embeddings = num_positional_embeddings | |
| self.num_classes = num_classes | |
| self.embeddings = PositionalEncodings(num_positional_embeddings) | |
| edge_in = num_positional_embeddings + num_rbf * 25 | |
| self.edge_embedding = torch.nn.Linear(edge_in, edge_features, bias=False) | |
| self.norm_edges = torch.nn.LayerNorm(edge_features) | |
| self.node_embedding = torch.nn.Linear( | |
| self.num_classes, node_features, bias=False | |
| ) | |
| self.norm_nodes = torch.nn.LayerNorm(node_features) | |
| def _dist(self, X, mask, eps=1e-6): | |
| mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2) | |
| dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2) | |
| D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps) | |
| D_max, _ = torch.max(D, -1, keepdim=True) | |
| D_adjust = D + (1.0 - mask_2D) * D_max | |
| D_neighbors, E_idx = torch.topk( | |
| D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False | |
| ) | |
| return D_neighbors, E_idx | |
| def _rbf(self, D): | |
| device = D.device | |
| D_min, D_max, D_count = 2.0, 22.0, self.num_rbf | |
| D_mu = torch.linspace(D_min, D_max, D_count, device=device) | |
| D_mu = D_mu.view([1, 1, 1, -1]) | |
| D_sigma = (D_max - D_min) / D_count | |
| D_expand = torch.unsqueeze(D, -1) | |
| RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2)) | |
| return RBF | |
| def _get_rbf(self, A, B, E_idx): | |
| D_A_B = torch.sqrt( | |
| torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6 | |
| ) # [B, L, L] | |
| D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[ | |
| :, :, :, 0 | |
| ] # [B,L,K] | |
| RBF_A_B = self._rbf(D_A_B_neighbors) | |
| return RBF_A_B | |
| def forward(self, input_features): | |
| X = input_features["X"] | |
| mask = input_features["mask"] | |
| R_idx = input_features["R_idx"] | |
| chain_labels = input_features["chain_labels"] | |
| membrane_per_residue_labels = input_features["membrane_per_residue_labels"] | |
| if self.augment_eps > 0: | |
| X = X + self.augment_eps * torch.randn_like(X) | |
| b = X[:, :, 1, :] - X[:, :, 0, :] | |
| c = X[:, :, 2, :] - X[:, :, 1, :] | |
| a = torch.cross(b, c, dim=-1) | |
| Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + X[:, :, 1, :] | |
| Ca = X[:, :, 1, :] | |
| N = X[:, :, 0, :] | |
| C = X[:, :, 2, :] | |
| O = X[:, :, 3, :] | |
| D_neighbors, E_idx = self._dist(Ca, mask) | |
| RBF_all = [] | |
| RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca | |
| RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N | |
| RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C | |
| RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O | |
| RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb | |
| RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N | |
| RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C | |
| RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O | |
| RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb | |
| RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C | |
| RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O | |
| RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb | |
| RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C | |
| RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O | |
| RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C | |
| RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca | |
| RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca | |
| RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca | |
| RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca | |
| RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N | |
| RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N | |
| RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N | |
| RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb | |
| RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb | |
| RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O | |
| RBF_all = torch.cat(tuple(RBF_all), dim=-1) | |
| offset = R_idx[:, :, None] - R_idx[:, None, :] | |
| offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K] | |
| d_chains = ( | |
| (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0 | |
| ).long() # find self vs non-self interaction | |
| E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0] | |
| E_positional = self.embeddings(offset.long(), E_chains) | |
| E = torch.cat((E_positional, RBF_all), -1) | |
| E = self.edge_embedding(E) | |
| E = self.norm_edges(E) | |
| C_1hot = torch.nn.functional.one_hot( | |
| membrane_per_residue_labels, self.num_classes | |
| ).float() | |
| V = self.node_embedding(C_1hot) | |
| V = self.norm_nodes(V) | |
| return V, E, E_idx | |
| class DecLayerJ(torch.nn.Module): | |
| def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30): | |
| super(DecLayerJ, self).__init__() | |
| self.num_hidden = num_hidden | |
| self.num_in = num_in | |
| self.scale = scale | |
| self.dropout1 = torch.nn.Dropout(dropout) | |
| self.dropout2 = torch.nn.Dropout(dropout) | |
| self.norm1 = torch.nn.LayerNorm(num_hidden) | |
| self.norm2 = torch.nn.LayerNorm(num_hidden) | |
| self.W1 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True) | |
| self.W2 = torch.nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.W3 = torch.nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.act = torch.nn.GELU() | |
| self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4) | |
| def forward(self, h_V, h_E, mask_V=None, mask_attend=None): | |
| """Parallel computation of full transformer layer""" | |
| # Concatenate h_V_i to h_E_ij | |
| h_V_expand = h_V.unsqueeze(-2).expand( | |
| -1, -1, -1, h_E.size(-2), -1 | |
| ) # the only difference | |
| h_EV = torch.cat([h_V_expand, h_E], -1) | |
| h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV))))) | |
| if mask_attend is not None: | |
| h_message = mask_attend.unsqueeze(-1) * h_message | |
| dh = torch.sum(h_message, -2) / self.scale | |
| h_V = self.norm1(h_V + self.dropout1(dh)) | |
| # Position-wise feedforward | |
| dh = self.dense(h_V) | |
| h_V = self.norm2(h_V + self.dropout2(dh)) | |
| if mask_V is not None: | |
| mask_V = mask_V.unsqueeze(-1) | |
| h_V = mask_V * h_V | |
| return h_V | |
| class PositionWiseFeedForward(torch.nn.Module): | |
| def __init__(self, num_hidden, num_ff): | |
| super(PositionWiseFeedForward, self).__init__() | |
| self.W_in = torch.nn.Linear(num_hidden, num_ff, bias=True) | |
| self.W_out = torch.nn.Linear(num_ff, num_hidden, bias=True) | |
| self.act = torch.nn.GELU() | |
| def forward(self, h_V): | |
| h = self.act(self.W_in(h_V)) | |
| h = self.W_out(h) | |
| return h | |
| class PositionalEncodings(torch.nn.Module): | |
| def __init__(self, num_embeddings, max_relative_feature=32): | |
| super(PositionalEncodings, self).__init__() | |
| self.num_embeddings = num_embeddings | |
| self.max_relative_feature = max_relative_feature | |
| self.linear = torch.nn.Linear(2 * max_relative_feature + 1 + 1, num_embeddings) | |
| def forward(self, offset, mask): | |
| d = torch.clip( | |
| offset + self.max_relative_feature, 0, 2 * self.max_relative_feature | |
| ) * mask + (1 - mask) * (2 * self.max_relative_feature + 1) | |
| d_onehot = torch.nn.functional.one_hot(d, 2 * self.max_relative_feature + 1 + 1) | |
| E = self.linear(d_onehot.float()) | |
| return E | |
| class DecLayer(torch.nn.Module): | |
| def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30): | |
| super(DecLayer, self).__init__() | |
| self.num_hidden = num_hidden | |
| self.num_in = num_in | |
| self.scale = scale | |
| self.dropout1 = torch.nn.Dropout(dropout) | |
| self.dropout2 = torch.nn.Dropout(dropout) | |
| self.norm1 = torch.nn.LayerNorm(num_hidden) | |
| self.norm2 = torch.nn.LayerNorm(num_hidden) | |
| self.W1 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True) | |
| self.W2 = torch.nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.W3 = torch.nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.act = torch.nn.GELU() | |
| self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4) | |
| def forward(self, h_V, h_E, mask_V=None, mask_attend=None): | |
| """Parallel computation of full transformer layer""" | |
| # Concatenate h_V_i to h_E_ij | |
| h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_E.size(-2), -1) | |
| h_EV = torch.cat([h_V_expand, h_E], -1) | |
| h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV))))) | |
| if mask_attend is not None: | |
| h_message = mask_attend.unsqueeze(-1) * h_message | |
| dh = torch.sum(h_message, -2) / self.scale | |
| h_V = self.norm1(h_V + self.dropout1(dh)) | |
| # Position-wise feedforward | |
| dh = self.dense(h_V) | |
| h_V = self.norm2(h_V + self.dropout2(dh)) | |
| if mask_V is not None: | |
| mask_V = mask_V.unsqueeze(-1) | |
| h_V = mask_V * h_V | |
| return h_V | |
| class EncLayer(torch.nn.Module): | |
| def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30): | |
| super(EncLayer, self).__init__() | |
| self.num_hidden = num_hidden | |
| self.num_in = num_in | |
| self.scale = scale | |
| self.dropout1 = torch.nn.Dropout(dropout) | |
| self.dropout2 = torch.nn.Dropout(dropout) | |
| self.dropout3 = torch.nn.Dropout(dropout) | |
| self.norm1 = torch.nn.LayerNorm(num_hidden) | |
| self.norm2 = torch.nn.LayerNorm(num_hidden) | |
| self.norm3 = torch.nn.LayerNorm(num_hidden) | |
| self.W1 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True) | |
| self.W2 = torch.nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.W3 = torch.nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.W11 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True) | |
| self.W12 = torch.nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.W13 = torch.nn.Linear(num_hidden, num_hidden, bias=True) | |
| self.act = torch.nn.GELU() | |
| self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4) | |
| def forward(self, h_V, h_E, E_idx, mask_V=None, mask_attend=None): | |
| """Parallel computation of full transformer layer""" | |
| h_EV = cat_neighbors_nodes(h_V, h_E, E_idx) | |
| h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1) | |
| h_EV = torch.cat([h_V_expand, h_EV], -1) | |
| h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV))))) | |
| if mask_attend is not None: | |
| h_message = mask_attend.unsqueeze(-1) * h_message | |
| dh = torch.sum(h_message, -2) / self.scale | |
| h_V = self.norm1(h_V + self.dropout1(dh)) | |
| dh = self.dense(h_V) | |
| h_V = self.norm2(h_V + self.dropout2(dh)) | |
| if mask_V is not None: | |
| mask_V = mask_V.unsqueeze(-1) | |
| h_V = mask_V * h_V | |
| h_EV = cat_neighbors_nodes(h_V, h_E, E_idx) | |
| h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1) | |
| h_EV = torch.cat([h_V_expand, h_EV], -1) | |
| h_message = self.W13(self.act(self.W12(self.act(self.W11(h_EV))))) | |
| h_E = self.norm3(h_E + self.dropout3(h_message)) | |
| return h_V, h_E | |
| # The following gather functions | |
| def gather_edges(edges, neighbor_idx): | |
| # Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C] | |
| neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1)) | |
| edge_features = torch.gather(edges, 2, neighbors) | |
| return edge_features | |
| def gather_nodes(nodes, neighbor_idx): | |
| # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C] | |
| # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C] | |
| neighbors_flat = neighbor_idx.reshape((neighbor_idx.shape[0], -1)) | |
| neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2)) | |
| # Gather and re-pack | |
| neighbor_features = torch.gather(nodes, 1, neighbors_flat) | |
| neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1]) | |
| return neighbor_features | |
| def gather_nodes_t(nodes, neighbor_idx): | |
| # Features [B,N,C] at Neighbor index [B,K] => Neighbor features[B,K,C] | |
| idx_flat = neighbor_idx.unsqueeze(-1).expand(-1, -1, nodes.size(2)) | |
| neighbor_features = torch.gather(nodes, 1, idx_flat) | |
| return neighbor_features | |
| def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx): | |
| h_nodes = gather_nodes(h_nodes, E_idx) | |
| h_nn = torch.cat([h_neighbors, h_nodes], -1) | |
| return h_nn | |