Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch.nn as nn | |
| from fairseq.models import FairseqEncoder | |
| from fairseq.modules import LayerNorm, TransformerEncoderLayer | |
| class TransformerEncoderNoEmb(FairseqEncoder): | |
| """Transformer encoder without token embeddings.""" | |
| def __init__(self, args): | |
| super().__init__(None) | |
| self.layers = nn.ModuleList( | |
| [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] | |
| ) | |
| if args.encoder_normalize_before: | |
| self.layer_norm = LayerNorm(args.encoder_embed_dim) | |
| else: | |
| self.layer_norm = None | |
| def forward(self, x, encoder_padding_mask, return_all_hiddens=False): | |
| encoder_states = [] | |
| for layer in self.layers: | |
| x = layer(x, encoder_padding_mask) | |
| if return_all_hiddens: | |
| encoder_states.append(x) | |
| if self.layer_norm is not None: | |
| x = self.layer_norm(x) | |
| return { | |
| "encoder_out": [x], # T x B x C | |
| "encoder_padding_mask": [encoder_padding_mask] | |
| if encoder_padding_mask is not None and encoder_padding_mask.any() | |
| else [], # B x T | |
| "encoder_embedding": [], # B x T x C | |
| "encoder_states": encoder_states, # List[T x B x C] | |
| "src_tokens": [], | |
| "src_lengths": [], | |
| } | |
| def reorder_encoder_out(self, encoder_out, new_order): | |
| new_encoder_out = ( | |
| [] | |
| if len(encoder_out["encoder_out"]) == 0 | |
| else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] | |
| ) | |
| new_encoder_padding_mask = ( | |
| [] | |
| if len(encoder_out["encoder_padding_mask"]) == 0 | |
| else [ | |
| x.index_select(0, new_order) | |
| for x in encoder_out["encoder_padding_mask"] | |
| ] | |
| ) | |
| new_encoder_embedding = ( | |
| [] | |
| if len(encoder_out["encoder_embedding"]) == 0 | |
| else [ | |
| x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] | |
| ] | |
| ) | |
| encoder_states = encoder_out["encoder_states"] | |
| if len(encoder_states) > 0: | |
| for idx, state in enumerate(encoder_states): | |
| encoder_states[idx] = state.index_select(1, new_order) | |
| return { | |
| "encoder_out": new_encoder_out, # T x B x C | |
| "encoder_padding_mask": new_encoder_padding_mask, # B x T | |
| "encoder_embedding": new_encoder_embedding, # B x T x C | |
| "encoder_states": encoder_states, # List[T x B x C] | |
| "src_tokens": [], # B x T | |
| "src_lengths": [], # B x 1 | |
| } | |