# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch.nn as nn from fairseq.models import FairseqEncoder from fairseq.modules import LayerNorm, TransformerEncoderLayer class TransformerEncoderNoEmb(FairseqEncoder): """Transformer encoder without token embeddings.""" def __init__(self, args): super().__init__(None) self.layers = nn.ModuleList( [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] ) if args.encoder_normalize_before: self.layer_norm = LayerNorm(args.encoder_embed_dim) else: self.layer_norm = None def forward(self, x, encoder_padding_mask, return_all_hiddens=False): encoder_states = [] for layer in self.layers: x = layer(x, encoder_padding_mask) if return_all_hiddens: encoder_states.append(x) if self.layer_norm is not None: x = self.layer_norm(x) return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask is not None and encoder_padding_mask.any() else [], # B x T "encoder_embedding": [], # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": [], "src_lengths": [], } def reorder_encoder_out(self, encoder_out, new_order): new_encoder_out = ( [] if len(encoder_out["encoder_out"]) == 0 else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] ) new_encoder_padding_mask = ( [] if len(encoder_out["encoder_padding_mask"]) == 0 else [ x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"] ] ) new_encoder_embedding = ( [] if len(encoder_out["encoder_embedding"]) == 0 else [ x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] ] ) encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, # B x T "encoder_embedding": new_encoder_embedding, # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": [], # B x T "src_lengths": [], # B x 1 }