| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import math |
| from typing import Dict, List, Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
|
|
| from esm.modules import SinusoidalPositionalEmbedding |
| from .features import GVPInputFeaturizer, DihedralFeatures |
| from .gvp_encoder import GVPEncoder |
| from .transformer_layer import TransformerEncoderLayer |
| from .util import nan_to_num, get_rotation_frames, rotate, rbf |
|
|
|
|
| class GVPTransformerEncoder(nn.Module): |
| """ |
| Transformer encoder consisting of *args.encoder.layers* layers. Each layer |
| is a :class:`TransformerEncoderLayer`. |
| |
| Args: |
| args (argparse.Namespace): parsed command-line arguments |
| dictionary (~fairseq.data.Dictionary): encoding dictionary |
| embed_tokens (torch.nn.Embedding): input embedding |
| """ |
|
|
| def __init__(self, args, dictionary, embed_tokens): |
| super().__init__() |
| self.args = args |
| self.dictionary = dictionary |
|
|
| self.dropout_module = nn.Dropout(args.dropout) |
|
|
| embed_dim = embed_tokens.embedding_dim |
| self.padding_idx = embed_tokens.padding_idx |
|
|
| self.embed_tokens = embed_tokens |
| self.embed_scale = math.sqrt(embed_dim) |
| self.embed_positions = SinusoidalPositionalEmbedding( |
| embed_dim, |
| self.padding_idx, |
| ) |
| self.embed_gvp_input_features = nn.Linear(15, embed_dim) |
| self.embed_confidence = nn.Linear(16, embed_dim) |
| self.embed_dihedrals = DihedralFeatures(embed_dim) |
|
|
| gvp_args = argparse.Namespace() |
| for k, v in vars(args).items(): |
| if k.startswith("gvp_"): |
| setattr(gvp_args, k[4:], v) |
| self.gvp_encoder = GVPEncoder(gvp_args) |
| gvp_out_dim = gvp_args.node_hidden_dim_scalar + (3 * |
| gvp_args.node_hidden_dim_vector) |
| self.embed_gvp_output = nn.Linear(gvp_out_dim, embed_dim) |
|
|
| self.layers = nn.ModuleList([]) |
| self.layers.extend( |
| [self.build_encoder_layer(args) for i in range(args.encoder_layers)] |
| ) |
| self.num_layers = len(self.layers) |
| self.layer_norm = nn.LayerNorm(embed_dim) |
|
|
| def build_encoder_layer(self, args): |
| return TransformerEncoderLayer(args) |
|
|
| def forward_embedding(self, coords, padding_mask, confidence): |
| """ |
| Args: |
| coords: N, CA, C backbone coordinates in shape length x 3 (atoms) x 3 |
| padding_mask: boolean Tensor (true for padding) of shape length |
| confidence: confidence scores between 0 and 1 of shape length |
| """ |
| components = dict() |
| coord_mask = torch.all(torch.all(torch.isfinite(coords), dim=-1), dim=-1) |
| coords = nan_to_num(coords) |
| mask_tokens = ( |
| padding_mask * self.dictionary.padding_idx + |
| ~padding_mask * self.dictionary.get_idx("<mask>") |
| ) |
| components["tokens"] = self.embed_tokens(mask_tokens) * self.embed_scale |
| components["diherals"] = self.embed_dihedrals(coords) |
|
|
| |
| gvp_out_scalars, gvp_out_vectors = self.gvp_encoder(coords, |
| coord_mask, padding_mask, confidence) |
| R = get_rotation_frames(coords) |
| |
| gvp_out_features = torch.cat([ |
| gvp_out_scalars, |
| rotate(gvp_out_vectors, R.transpose(-2, -1)).flatten(-2, -1), |
| ], dim=-1) |
| components["gvp_out"] = self.embed_gvp_output(gvp_out_features) |
|
|
| components["confidence"] = self.embed_confidence( |
| rbf(confidence, 0., 1.)) |
|
|
| |
| |
| scalar_features, vector_features = GVPInputFeaturizer.get_node_features( |
| coords, coord_mask, with_coord_mask=False) |
| features = torch.cat([ |
| scalar_features, |
| rotate(vector_features, R.transpose(-2, -1)).flatten(-2, -1), |
| ], dim=-1) |
| components["gvp_input_features"] = self.embed_gvp_input_features(features) |
|
|
| embed = sum(components.values()) |
| |
| |
|
|
| x = embed |
| x = x + self.embed_positions(mask_tokens) |
| x = self.dropout_module(x) |
| return x, components |
|
|
| def forward( |
| self, |
| coords, |
| encoder_padding_mask, |
| confidence, |
| return_all_hiddens: bool = False, |
| ): |
| """ |
| Args: |
| coords (Tensor): backbone coordinates |
| shape batch_size x num_residues x num_atoms (3 for N, CA, C) x 3 |
| encoder_padding_mask (ByteTensor): the positions of |
| padding elements of shape `(batch_size x num_residues)` |
| confidence (Tensor): the confidence score of shape (batch_size x |
| num_residues). The value is between 0. and 1. for each residue |
| coordinate, or -1. if no coordinate is given |
| return_all_hiddens (bool, optional): also return all of the |
| intermediate hidden states (default: False). |
| |
| Returns: |
| dict: |
| - **encoder_out** (Tensor): the last encoder layer's output of |
| shape `(num_residues, batch_size, embed_dim)` |
| - **encoder_padding_mask** (ByteTensor): the positions of |
| padding elements of shape `(batch_size, num_residues)` |
| - **encoder_embedding** (Tensor): the (scaled) embedding lookup |
| of shape `(batch_size, num_residues, embed_dim)` |
| - **encoder_states** (List[Tensor]): all intermediate |
| hidden states of shape `(num_residues, batch_size, embed_dim)`. |
| Only populated if *return_all_hiddens* is True. |
| """ |
| x, encoder_embedding = self.forward_embedding(coords, |
| encoder_padding_mask, confidence) |
| |
| x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) |
|
|
| |
| x = x.transpose(0, 1) |
|
|
| encoder_states = [] |
|
|
| if return_all_hiddens: |
| encoder_states.append(x) |
|
|
| |
| for layer in self.layers: |
| x = layer( |
| x, encoder_padding_mask=encoder_padding_mask |
| ) |
| if return_all_hiddens: |
| assert encoder_states is not None |
| encoder_states.append(x) |
|
|
| if self.layer_norm is not None: |
| x = self.layer_norm(x) |
|
|
| return { |
| "encoder_out": [x], |
| "encoder_padding_mask": [encoder_padding_mask], |
| "encoder_embedding": [encoder_embedding], |
| "encoder_states": encoder_states, |
| } |
|
|