Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import time | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from src.modules.esmif_module import GVPTransformerEncoder, TransformerDecoder, CoordBatchConverter, Alphabet | |
| from transformers import AutoTokenizer | |
| class GVPTransformerModel(nn.Module): | |
| """ | |
| GVP-Transformer inverse folding model. | |
| Architecture: Geometric GVP-GNN as initial layers, followed by | |
| sequence-to-sequence Transformer encoder and decoder. | |
| """ | |
| def __init__(self, args): | |
| super().__init__() | |
| # alphabet = Alphabet.from_architecture() | |
| alphabet = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="./cache_dir/") | |
| encoder_embed_tokens = self.build_embedding( | |
| args, alphabet, args.encoder_embed_dim, | |
| ) | |
| decoder_embed_tokens = self.build_embedding( | |
| args, alphabet, args.decoder_embed_dim, | |
| ) | |
| encoder = self.build_encoder(args, alphabet, encoder_embed_tokens) | |
| decoder = self.build_decoder(args, alphabet, decoder_embed_tokens) | |
| self.args = args | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| def build_encoder(cls, args, src_dict, embed_tokens): | |
| encoder = GVPTransformerEncoder(args, src_dict, embed_tokens) | |
| return encoder | |
| def build_decoder(cls, args, tgt_dict, embed_tokens): | |
| decoder = TransformerDecoder( | |
| args, | |
| tgt_dict, | |
| embed_tokens, | |
| ) | |
| return decoder | |
| def build_embedding(cls, args, dictionary, embed_dim): | |
| num_embeddings = len(dictionary) | |
| padding_idx = dictionary.pad_token_id | |
| emb = nn.Embedding(num_embeddings, embed_dim, padding_idx) | |
| nn.init.normal_(emb.weight, mean=0, std=embed_dim ** -0.5) | |
| nn.init.constant_(emb.weight[padding_idx], 0) | |
| return emb | |
| def forward( | |
| self, | |
| coords, | |
| padding_mask, | |
| confidence, | |
| prev_output_tokens, | |
| return_all_hiddens: bool = False, | |
| features_only: bool = False, | |
| ): | |
| encoder_out = self.encoder(coords, padding_mask, confidence, | |
| return_all_hiddens=return_all_hiddens) | |
| logits, extra = self.decoder( | |
| prev_output_tokens, | |
| encoder_out=encoder_out, | |
| features_only=features_only, | |
| return_all_hiddens=return_all_hiddens, | |
| ) | |
| return logits, extra | |
| def sample(self, batch_coords, padding_mask, partial_seq=None, temperature=1.0, confidence=None): | |
| """ | |
| Samples sequences based on multinomial sampling (no beam search). | |
| Args: | |
| coords: L x 3 x 3 list representing one backbone | |
| partial_seq: Optional, partial sequence with mask tokens if part of | |
| the sequence is known | |
| temperature: sampling temperature, use low temperature for higher | |
| sequence recovery and high temperature for higher diversity | |
| confidence: optional length L list of confidence scores for coordinates | |
| """ | |
| L = batch_coords.shape[1] | |
| sampled_tokens = torch.zeros((1, 1+L), device=batch_coords.device).long() | |
| # Save incremental states for faster sampling | |
| incremental_state = dict() | |
| # Run encoder only once | |
| t1 = time.time() | |
| encoder_out = self.encoder(batch_coords, padding_mask, confidence) | |
| t2 = time.time() | |
| # Decode one token at a time | |
| for i in range(1, L+1): | |
| logits, _ = self.decoder( | |
| sampled_tokens[:, :i], | |
| encoder_out, | |
| incremental_state=incremental_state, | |
| ) | |
| logits = logits[0].transpose(0, 1) | |
| logits /= temperature | |
| probs = F.softmax(logits, dim=-1) | |
| # sampled_tokens[:, i] = torch.multinomial(probs, 1).squeeze(-1) | |
| sampled_tokens[:, i] = probs.argmax(dim=-1) | |
| sampled_seq = sampled_tokens[0, 1:] | |
| t3 = time.time() | |
| # Convert back to string via lookup | |
| return sampled_seq |