Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from fairseq import utils | |
| from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import ( | |
| Embedding, | |
| TransformerDecoderEmbedding, | |
| TransformerDecoderLayer, | |
| TransformerDecoderOutputLayer, | |
| TransformerEncoderEmbedding, | |
| TransformerEncoderLayer, | |
| TransformerEncoderLayerNorm, | |
| ) | |
| from fairseq.models import ( | |
| BaseFairseqModel, | |
| FairseqDecoder, | |
| FairseqEncoder, | |
| register_model, | |
| register_model_architecture, | |
| ) | |
| from fairseq.models.fairseq_encoder import EncoderOut | |
| from fairseq.models.transformer import ( | |
| base_architecture, | |
| transformer_iwslt_de_en, | |
| transformer_wmt_en_de_big, | |
| ) | |
| from fairseq.modules import SinusoidalPositionalEmbedding | |
| logger = logging.getLogger(__name__) | |
| DEFAULT_MAX_SOURCE_POSITIONS = 1024 | |
| DEFAULT_MAX_TARGET_POSITIONS = 1024 | |
| TORCH_PIPE = False | |
| RPC_INIT = False | |
| def import_pipe(): | |
| global TORCH_PIPE | |
| global RPC_INIT | |
| try: | |
| from torch.distributed.pipeline.sync import Pipe # noqa | |
| global Pipe | |
| from torch.distributed.pipeline.sync.utils import partition_model | |
| global partition_model | |
| from torch.distributed import rpc | |
| import tempfile | |
| TORCH_PIPE = True | |
| # Initialize single process RPC agent since TORCH_PIPE requires | |
| # RRef. RRef depends on RPC being initialized and as a result we initialize | |
| # RPC with a single node. | |
| tmpfile = tempfile.NamedTemporaryFile() | |
| if not RPC_INIT: | |
| rpc.init_rpc( | |
| name="worker", | |
| rank=0, | |
| world_size=1, | |
| rpc_backend_options=rpc.TensorPipeRpcBackendOptions( | |
| init_method="file://{}".format(tmpfile.name), | |
| ), | |
| ) | |
| RPC_INIT = True | |
| logger.info("Using torch pipe") | |
| except ImportError: | |
| try: | |
| from fairscale.nn import Pipe # noqa | |
| logger.info("Using fairscale pipe") | |
| except ImportError: | |
| raise ImportError("Please install fairscale with: pip install fairscale") | |
| class PipelineParallelTransformerModel(BaseFairseqModel): | |
| def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): | |
| import_pipe() | |
| super().__init__() | |
| assert isinstance(encoder, FairseqEncoder) | |
| assert isinstance(decoder, FairseqDecoder) | |
| encoder_module_list = ( | |
| [encoder.embedding_layer] | |
| + list(encoder.encoder_layers) | |
| + [encoder.final_layer_norm] | |
| ) | |
| self.num_encoder_modules = len(encoder_module_list) | |
| decoder_module_list = ( | |
| [decoder.embedding_layer] | |
| + list(decoder.decoder_layers) | |
| + [decoder.decoder_output_layer] | |
| ) | |
| self.num_decoder_modules = len(decoder_module_list) | |
| module_list = encoder_module_list + decoder_module_list | |
| self.devices = devices | |
| if TORCH_PIPE: | |
| self.model = Pipe( | |
| partition_model(nn.Sequential(*module_list), balance, devices), | |
| chunks=chunks, | |
| checkpoint=checkpoint, | |
| ) | |
| else: | |
| self.model = Pipe( | |
| nn.Sequential(*module_list), | |
| balance=balance, | |
| devices=devices, | |
| chunks=chunks, | |
| checkpoint=checkpoint, | |
| ) | |
| self.encoder_max_positions = self.max_positions_helper( | |
| encoder.embedding_layer, "max_source_positions" | |
| ) | |
| self.decoder_max_positions = self.max_positions_helper( | |
| decoder.embedding_layer, "max_target_positions" | |
| ) | |
| self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None) | |
| # Note: To be populated during inference | |
| self.encoder = None | |
| self.decoder = None | |
| def forward(self, src_tokens, src_lengths, prev_output_tokens): | |
| if self.training: | |
| input_lst = [src_tokens, src_lengths, prev_output_tokens] | |
| input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst) | |
| if TORCH_PIPE: | |
| return self.model(input).local_value() | |
| else: | |
| return self.model(input) | |
| else: | |
| assert self.encoder is not None and self.decoder is not None, ( | |
| "encoder and decoder need to be initialized by " | |
| + "calling the `prepare_for_inference_()` method" | |
| ) | |
| encoder_output_tuple = self.encoder(input) | |
| return self.decoder(encoder_output_tuple) | |
| def prepare_for_inference_(self, cfg): | |
| if self.encoder is not None and self.decoder is not None: | |
| logger.info("Encoder and Decoder already initialized") | |
| return | |
| encoder_module_list = [] | |
| decoder_module_list = [] | |
| module_count = 0 | |
| for partition in self.model.partitions: | |
| for module in partition: | |
| if module_count < self.num_encoder_modules: | |
| encoder_module_list.append(module) | |
| else: | |
| decoder_module_list.append(module) | |
| module_count += 1 | |
| self.model = None | |
| self.encoder = TransformerEncoder( | |
| cfg.distributed_training, None, None, encoder_module_list | |
| ) | |
| self.decoder = TransformerDecoder( | |
| cfg.distributed_training, | |
| None, | |
| None, | |
| decoder_module_list=decoder_module_list, | |
| ) | |
| def add_args(parser): | |
| """Add model-specific arguments to the parser.""" | |
| # fmt: off | |
| parser.add_argument('--activation-fn', | |
| choices=utils.get_available_activation_fns(), | |
| help='activation function to use') | |
| parser.add_argument('--dropout', type=float, metavar='D', | |
| help='dropout probability') | |
| parser.add_argument('--attention-dropout', type=float, metavar='D', | |
| help='dropout probability for attention weights') | |
| parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', | |
| help='dropout probability after activation in FFN.') | |
| parser.add_argument('--encoder-embed-path', type=str, metavar='STR', | |
| help='path to pre-trained encoder embedding') | |
| parser.add_argument('--encoder-embed-dim', type=int, metavar='N', | |
| help='encoder embedding dimension') | |
| parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', | |
| help='encoder embedding dimension for FFN') | |
| parser.add_argument('--encoder-layers', type=int, metavar='N', | |
| help='num encoder layers') | |
| parser.add_argument('--encoder-attention-heads', type=int, metavar='N', | |
| help='num encoder attention heads') | |
| parser.add_argument('--encoder-normalize-before', action='store_true', | |
| help='apply layernorm before each encoder block') | |
| parser.add_argument('--encoder-learned-pos', action='store_true', | |
| help='use learned positional embeddings in the encoder') | |
| parser.add_argument('--decoder-embed-path', type=str, metavar='STR', | |
| help='path to pre-trained decoder embedding') | |
| parser.add_argument('--decoder-embed-dim', type=int, metavar='N', | |
| help='decoder embedding dimension') | |
| parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', | |
| help='decoder embedding dimension for FFN') | |
| parser.add_argument('--decoder-layers', type=int, metavar='N', | |
| help='num decoder layers') | |
| parser.add_argument('--decoder-attention-heads', type=int, metavar='N', | |
| help='num decoder attention heads') | |
| parser.add_argument('--decoder-learned-pos', action='store_true', | |
| help='use learned positional embeddings in the decoder') | |
| parser.add_argument('--decoder-normalize-before', action='store_true', | |
| help='apply layernorm before each decoder block') | |
| parser.add_argument('--share-decoder-input-output-embed', action='store_true', | |
| help='share decoder input and output embeddings') | |
| parser.add_argument('--share-all-embeddings', action='store_true', | |
| help='share encoder, decoder and output embeddings' | |
| ' (requires shared dictionary and embed dim)') | |
| parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', | |
| help='if set, disables positional embeddings (outside self attention)') | |
| parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', | |
| help='comma separated list of adaptive softmax cutoff points. ' | |
| 'Must be used with adaptive_loss criterion'), | |
| parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', | |
| help='sets adaptive softmax dropout for the tail projections') | |
| parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1, | |
| help='Number of embedding layer chunks (enables more even distribution' | |
| 'of optimizer states across data parallel nodes' | |
| 'when using optimizer state sharding and' | |
| 'a big embedding vocabulary)') | |
| # fmt: on | |
| def build_model_base(cls, args, task): | |
| """Build a new model instance.""" | |
| # make sure all arguments are present in older models | |
| base_architecture(args) | |
| if not hasattr(args, "max_source_positions"): | |
| args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS | |
| if not hasattr(args, "max_target_positions"): | |
| args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS | |
| src_dict, tgt_dict = task.source_dictionary, task.target_dictionary | |
| def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1): | |
| assert embed_dim % num_embed_chunks == 0, ( | |
| f"Number of embedding chunks = {num_embed_chunks} should be " | |
| + f"divisible by the embedding dimension = {embed_dim}" | |
| ) | |
| assert path is None or num_embed_chunks == 1, ( | |
| "Loading embedding from a path with number of embedding chunks > 1" | |
| + " is not yet supported" | |
| ) | |
| num_embeddings = len(dictionary) | |
| padding_idx = dictionary.pad() | |
| # if provided, load from preloaded dictionaries | |
| if path: | |
| emb = Embedding(num_embeddings, embed_dim, padding_idx) | |
| embed_dict = utils.parse_embedding(path) | |
| utils.load_embedding(embed_dict, dictionary, emb) | |
| else: | |
| embed_chunk_dim = embed_dim // num_embed_chunks | |
| emb = nn.ModuleList() | |
| for i in range(num_embed_chunks): | |
| emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx)) | |
| return emb | |
| num_embed_chunks = args.num_embedding_chunks | |
| if args.share_all_embeddings: | |
| if src_dict != tgt_dict: | |
| raise ValueError("--share-all-embeddings requires a joined dictionary") | |
| if args.encoder_embed_dim != args.decoder_embed_dim: | |
| raise ValueError( | |
| "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" | |
| ) | |
| if args.decoder_embed_path and ( | |
| args.decoder_embed_path != args.encoder_embed_path | |
| ): | |
| raise ValueError( | |
| "--share-all-embeddings not compatible with --decoder-embed-path" | |
| ) | |
| encoder_embed_tokens = build_embedding( | |
| src_dict, | |
| args.encoder_embed_dim, | |
| args.encoder_embed_path, | |
| num_embed_chunks, | |
| ) | |
| decoder_embed_tokens = encoder_embed_tokens | |
| args.share_decoder_input_output_embed = True | |
| else: | |
| assert args.share_decoder_input_output_embed or num_embed_chunks == 1, ( | |
| "Not sharing decoder I/O embeddings is not yet supported with number of " | |
| + "embedding chunks > 1" | |
| ) | |
| encoder_embed_tokens = build_embedding( | |
| src_dict, | |
| args.encoder_embed_dim, | |
| args.encoder_embed_path, | |
| num_embed_chunks, | |
| ) | |
| decoder_embed_tokens = build_embedding( | |
| tgt_dict, | |
| args.decoder_embed_dim, | |
| args.decoder_embed_path, | |
| num_embed_chunks, | |
| ) | |
| encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) | |
| decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) | |
| return (encoder, decoder) | |
| def build_encoder(cls, args, src_dict, embed_tokens): | |
| return TransformerEncoder(args, src_dict, embed_tokens) | |
| def build_decoder(cls, args, tgt_dict, embed_tokens): | |
| return TransformerDecoder(args, tgt_dict, embed_tokens) | |
| def build_model(cls, args, task): | |
| encoder, decoder = cls.build_model_base(args, task) | |
| return PipelineParallelTransformerModel( | |
| encoder=encoder, | |
| decoder=decoder, | |
| balance=utils.eval_str_list(args.pipeline_balance, type=int), | |
| devices=utils.eval_str_list(args.pipeline_devices, type=int), | |
| chunks=args.pipeline_chunks, | |
| checkpoint=args.pipeline_checkpoint, | |
| ) | |
| def output_layer(self, features, **kwargs): | |
| """Project features to the default output size (typically vocabulary size).""" | |
| return self.decoder.output_layer(features, **kwargs) | |
| def max_positions(self): | |
| """Maximum length supported by the model.""" | |
| return (self.encoder_max_positions, self.decoder_max_positions) | |
| def max_positions_helper( | |
| self, embedding_layer, max_positions_field="max_source_positions" | |
| ): | |
| """Maximum input length supported by the encoder or decoder.""" | |
| if embedding_layer.embed_positions is None: | |
| return getattr(embedding_layer, max_positions_field) | |
| return min( | |
| getattr(embedding_layer, max_positions_field), | |
| embedding_layer.embed_positions.max_positions, | |
| ) | |
| def get_normalized_probs(self, net_output, log_probs, sample=None): | |
| """Get normalized probabilities (or log probs) from a net's output.""" | |
| if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: | |
| if sample is not None: | |
| assert "target" in sample | |
| target = sample["target"] | |
| else: | |
| target = None | |
| out = self.adaptive_softmax.get_log_prob(net_output, target=target) | |
| return out.exp_() if not log_probs else out | |
| # A Pipe() module returns a tuple of tensors as the output. | |
| # In this case, the tuple has one element - the output tensor of logits | |
| logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0] | |
| if log_probs: | |
| return utils.log_softmax(logits, dim=-1, onnx_trace=False) | |
| else: | |
| return utils.softmax(logits, dim=-1, onnx_trace=False) | |
| def max_decoder_positions(self): | |
| """Maximum length supported by the decoder.""" | |
| return self.decoder_max_positions | |
| def load_state_dict(self, state_dict, strict=True, model_cfg=None): | |
| """Copies parameters and buffers from *state_dict* into this module and | |
| its descendants. | |
| Overrides the method in :class:`nn.Module`. Compared with that method | |
| this additionally "upgrades" *state_dicts* from old checkpoints. | |
| """ | |
| self.upgrade_state_dict(state_dict) | |
| is_regular_transformer = not any("model.partitions" in k for k in state_dict) | |
| if is_regular_transformer: | |
| state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict) | |
| return super().load_state_dict(state_dict, strict) | |
| def convert_to_pipeline_parallel_state_dict(self, state_dict): | |
| new_state_dict = self.state_dict() | |
| encoder_layer_idx = 0 | |
| decoder_layer_idx = 0 | |
| encoder_key_suffixes = [ | |
| "self_attn.k_proj.weight", | |
| "self_attn.k_proj.bias", | |
| "self_attn.v_proj.weight", | |
| "self_attn.v_proj.bias", | |
| "self_attn.q_proj.weight", | |
| "self_attn.q_proj.bias", | |
| "self_attn.out_proj.weight", | |
| "self_attn.out_proj.bias", | |
| "self_attn_layer_norm.weight", | |
| "self_attn_layer_norm.bias", | |
| "fc1.weight", | |
| "fc1.bias", | |
| "fc2.weight", | |
| "fc2.bias", | |
| "final_layer_norm.weight", | |
| "final_layer_norm.bias", | |
| ] | |
| decoder_key_suffixes = [ | |
| "self_attn.k_proj.weight", | |
| "self_attn.k_proj.bias", | |
| "self_attn.v_proj.weight", | |
| "self_attn.v_proj.bias", | |
| "self_attn.q_proj.weight", | |
| "self_attn.q_proj.bias", | |
| "self_attn.out_proj.weight", | |
| "self_attn.out_proj.bias", | |
| "self_attn_layer_norm.weight", | |
| "self_attn_layer_norm.bias", | |
| "encoder_attn.k_proj.weight", | |
| "encoder_attn.k_proj.bias", | |
| "encoder_attn.v_proj.weight", | |
| "encoder_attn.v_proj.bias", | |
| "encoder_attn.q_proj.weight", | |
| "encoder_attn.q_proj.bias", | |
| "encoder_attn.out_proj.weight", | |
| "encoder_attn.out_proj.bias", | |
| "encoder_attn_layer_norm.weight", | |
| "encoder_attn_layer_norm.bias", | |
| "fc1.weight", | |
| "fc1.bias", | |
| "fc2.weight", | |
| "fc2.bias", | |
| "final_layer_norm.weight", | |
| "final_layer_norm.bias", | |
| ] | |
| for pid, partition in enumerate(self.model.partitions): | |
| logger.info(f"Begin Partition {pid}") | |
| for mid, module in enumerate(partition): | |
| # fmt: off | |
| if isinstance(module, TransformerEncoderEmbedding): | |
| new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight'] | |
| if isinstance(module, TransformerEncoderLayer): | |
| for suffix in encoder_key_suffixes: | |
| new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}'] | |
| encoder_layer_idx += 1 | |
| if isinstance(module, TransformerDecoderLayer): | |
| for suffix in decoder_key_suffixes: | |
| new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}'] | |
| decoder_layer_idx += 1 | |
| if isinstance(module, TransformerEncoderLayerNorm): | |
| if 'encoder.layer_norm.weight' in state_dict: | |
| new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight'] | |
| new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias'] | |
| if isinstance(module, TransformerDecoderEmbedding): | |
| new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight'] | |
| if isinstance(module, TransformerDecoderOutputLayer): | |
| new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight'] | |
| # fmt: on | |
| return new_state_dict | |
| class TransformerEncoder(FairseqEncoder): | |
| """ | |
| Transformer encoder consisting of *args.encoder_layers* layers. Each layer | |
| is a :class:`TransformerEncoderLayer`. | |
| Args: | |
| args (argparse.Namespace): parsed command-line arguments | |
| dictionary (~fairseq.data.Dictionary): encoding dictionary | |
| embed_tokens (torch.nn.Embedding): input embedding | |
| """ | |
| def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): | |
| super().__init__(dictionary) | |
| self.register_buffer("version", torch.Tensor([3])) | |
| import_pipe() | |
| self.use_pipeline = encoder_module_list is not None | |
| if not self.use_pipeline: | |
| self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) | |
| self.encoder_layers = nn.Sequential( | |
| *[TransformerEncoderLayer(args) for i in range(args.encoder_layers)] | |
| ) | |
| if isinstance(embed_tokens, nn.ModuleList): | |
| emb_dim = sum(e.embedding_dim for e in embed_tokens) | |
| else: | |
| emb_dim = embed_tokens.embedding_dim | |
| self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim) | |
| else: | |
| encoder_balance = utils.eval_str_list( | |
| args.pipeline_encoder_balance, type=int | |
| ) | |
| encoder_devices = utils.eval_str_list( | |
| args.pipeline_encoder_devices, type=int | |
| ) | |
| assert sum(encoder_balance) == len(encoder_module_list), ( | |
| f"Sum of encoder_balance={encoder_balance} is not equal " | |
| + f"to num_encoder_modules={len(encoder_module_list)}" | |
| ) | |
| if TORCH_PIPE: | |
| self.model = Pipe( | |
| module=partition_model( | |
| nn.Sequential(*encoder_module_list), | |
| encoder_balance, | |
| encoder_devices, | |
| ), | |
| chunks=args.pipeline_chunks, | |
| checkpoint=args.pipeline_checkpoint, | |
| ) | |
| else: | |
| self.model = Pipe( | |
| module=nn.Sequential(*encoder_module_list), | |
| balance=encoder_balance, | |
| devices=encoder_devices, | |
| chunks=args.pipeline_chunks, | |
| checkpoint=args.pipeline_checkpoint, | |
| ) | |
| def forward(self, src_tokens, src_lengths): | |
| """ | |
| Args: | |
| input_tuple( | |
| src_tokens (LongTensor): tokens in the source language of shape | |
| `(batch, src_len)` | |
| src_lengths (torch.LongTensor): lengths of each source sentence of | |
| shape `(batch)` | |
| ) | |
| Returns: | |
| output_tuple( | |
| - **encoder_out** (Tensor): the last encoder layer's output of | |
| shape `(src_len, batch, embed_dim)` | |
| - **encoder_padding_mask** (ByteTensor): the positions of | |
| padding elements of shape `(batch, src_len)` | |
| - prev_output_tokens | |
| - **encoder_states** (List[Tensor]): all intermediate | |
| hidden states of shape `(src_len, batch, embed_dim)`. | |
| Only populated if *return_all_hiddens* is True. | |
| ) | |
| """ | |
| dummy_prev_output_tokens = torch.zeros( | |
| 1, dtype=src_tokens.dtype, device=src_tokens.device | |
| ) | |
| input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens) | |
| if self.use_pipeline: | |
| input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) | |
| if TORCH_PIPE: | |
| encoder_out = self.model(input_tuple).local_value() | |
| else: | |
| encoder_out = self.model(input_tuple) | |
| else: | |
| encoder_embed_output_tuple = self.embedding_layer(input_tuple) | |
| encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple) | |
| encoder_out = self.final_layer_norm(encoder_layers_output) | |
| # first element is the encoder output | |
| # second element is the encoder padding mask | |
| # the remaining elements of EncoderOut are not computed by | |
| # the PipelineParallelTransformer | |
| return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None) | |
| def reorder_encoder_out(self, encoder_out, new_order): | |
| """ | |
| Reorder encoder output according to *new_order*. | |
| Args: | |
| encoder_out: output from the ``forward()`` method | |
| new_order (LongTensor): desired order | |
| Returns: | |
| *encoder_out* rearranged according to *new_order* | |
| """ | |
| if encoder_out.encoder_out is not None: | |
| encoder_out = encoder_out._replace( | |
| encoder_out=encoder_out.encoder_out.index_select(1, new_order) | |
| ) | |
| if encoder_out.encoder_padding_mask is not None: | |
| encoder_out = encoder_out._replace( | |
| encoder_padding_mask=encoder_out.encoder_padding_mask.index_select( | |
| 0, new_order | |
| ) | |
| ) | |
| if encoder_out.encoder_embedding is not None: | |
| encoder_out = encoder_out._replace( | |
| encoder_embedding=encoder_out.encoder_embedding.index_select( | |
| 0, new_order | |
| ) | |
| ) | |
| if encoder_out.encoder_states is not None: | |
| for idx, state in enumerate(encoder_out.encoder_states): | |
| encoder_out.encoder_states[idx] = state.index_select(1, new_order) | |
| return encoder_out | |
| def max_positions(self): | |
| """Maximum input length supported by the encoder.""" | |
| if self.embedding_layer.embed_positions is None: | |
| return self.embedding_layer.max_source_positions | |
| return min( | |
| self.embedding_layer.max_source_positions, | |
| self.embedding_layer.embed_positions.max_positions, | |
| ) | |
| class TransformerDecoder(FairseqDecoder): | |
| """ | |
| Transformer decoder consisting of *args.decoder_layers* layers. Each layer | |
| is a :class:`TransformerDecoderLayer`. | |
| Args: | |
| args (argparse.Namespace): parsed command-line arguments | |
| dictionary (~fairseq.data.Dictionary): decoding dictionary | |
| embed_tokens (torch.nn.Embedding): output embedding | |
| no_encoder_attn (bool, optional): whether to attend to encoder outputs | |
| (default: False). | |
| """ | |
| def __init__( | |
| self, | |
| args, | |
| dictionary, | |
| embed_tokens, | |
| no_encoder_attn=False, | |
| decoder_module_list=None, | |
| ): | |
| super().__init__(dictionary) | |
| self.register_buffer("version", torch.Tensor([3])) | |
| import_pipe() | |
| self.use_pipeline = decoder_module_list is not None | |
| if not self.use_pipeline: | |
| self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) | |
| self.decoder_layers = nn.Sequential( | |
| *[ | |
| TransformerDecoderLayer(args, no_encoder_attn) | |
| for _ in range(args.decoder_layers) | |
| ] | |
| ) | |
| self.decoder_output_layer = TransformerDecoderOutputLayer( | |
| args, embed_tokens, dictionary | |
| ) | |
| else: | |
| decoder_balance = utils.eval_str_list( | |
| args.pipeline_decoder_balance, type=int | |
| ) | |
| decoder_devices = utils.eval_str_list( | |
| args.pipeline_decoder_devices, type=int | |
| ) | |
| assert sum(decoder_balance) == len(decoder_module_list), ( | |
| f"Sum of decoder_balance={decoder_balance} is not equal " | |
| + f"to num_decoder_modules={len(decoder_module_list)}" | |
| ) | |
| if TORCH_PIPE: | |
| self.model = Pipe( | |
| module=partition_model( | |
| nn.Sequential(*decoder_module_list), | |
| decoder_balance, | |
| decoder_devices, | |
| ), | |
| chunks=args.pipeline_chunks, | |
| checkpoint=args.pipeline_checkpoint, | |
| ) | |
| else: | |
| self.model = Pipe( | |
| module=nn.Sequential(*decoder_module_list), | |
| balance=decoder_balance, | |
| devices=decoder_devices, | |
| chunks=args.pipeline_chunks, | |
| checkpoint=args.pipeline_checkpoint, | |
| ) | |
| def forward( | |
| self, | |
| prev_output_tokens, | |
| encoder_out=None, | |
| ): | |
| """ | |
| Args: | |
| prev_output_tokens (LongTensor): previous decoder outputs of shape | |
| `(batch, tgt_len)`, for teacher forcing | |
| encoder_out (optional): output from the encoder, used for | |
| encoder-side attention | |
| incremental_state (dict): dictionary used for storing state during | |
| :ref:`Incremental decoding` | |
| features_only (bool, optional): only return features without | |
| applying output layer (default: False). | |
| Returns: | |
| tuple: | |
| - the decoder's output of shape `(batch, tgt_len, vocab)` | |
| - a dictionary with any model-specific outputs | |
| """ | |
| input_tuple = ( | |
| encoder_out.encoder_out, | |
| encoder_out.encoder_padding_mask, | |
| prev_output_tokens, | |
| ) | |
| if self.use_pipeline: | |
| input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) | |
| if TORCH_PIPE: | |
| return (self.model(input_tuple).local_value(),) | |
| else: | |
| return (self.model(input_tuple),) | |
| else: | |
| embed_layer_output = self.embedding_layer(input_tuple) | |
| state = self.decoder_layers(embed_layer_output) | |
| return (self.decoder_output_layer(state),) | |
| def output_layer(self, features, **kwargs): | |
| """Project features to the vocabulary size.""" | |
| if self.adaptive_softmax is None: | |
| # project back to size of vocabulary | |
| if self.share_input_output_embed: | |
| return F.linear(features, self.embed_tokens.weight) | |
| else: | |
| return F.linear(features, self.embed_out) | |
| else: | |
| return features | |
| def max_positions(self): | |
| """Maximum output length supported by the decoder.""" | |
| if self.embedding_layer.embed_positions is None: | |
| return self.embedding_layer.max_target_positions | |
| return min( | |
| self.embedding_layer.max_target_positions, | |
| self.embedding_layer.embed_positions.max_positions, | |
| ) | |
| def buffered_future_mask(self, tensor): | |
| dim = tensor.size(0) | |
| if ( | |
| not hasattr(self, "_future_mask") | |
| or self._future_mask is None | |
| or self._future_mask.device != tensor.device | |
| or self._future_mask.size(0) < dim | |
| ): | |
| self._future_mask = torch.triu( | |
| utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 | |
| ) | |
| return self._future_mask[:dim, :dim] | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| """Upgrade a (possibly old) state dict for new versions of fairseq.""" | |
| for i in range(len(self.layers)): | |
| # update layer norms | |
| layer_norm_map = { | |
| "0": "self_attn_layer_norm", | |
| "1": "encoder_attn_layer_norm", | |
| "2": "final_layer_norm", | |
| } | |
| for old, new in layer_norm_map.items(): | |
| for m in ("weight", "bias"): | |
| k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) | |
| if k in state_dict: | |
| state_dict[ | |
| "{}.layers.{}.{}.{}".format(name, i, new, m) | |
| ] = state_dict[k] | |
| del state_dict[k] | |
| version_key = "{}.version".format(name) | |
| if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: | |
| # earlier checkpoints did not normalize after the stack of layers | |
| self.layer_norm = None | |
| self.normalize = False | |
| state_dict[version_key] = torch.Tensor([1]) | |
| return state_dict | |
| def transformer_iwslt_de_en_dist(args): | |
| transformer_iwslt_de_en(args) | |
| def transformer_wmt_en_de_big_dist(args): | |
| transformer_wmt_en_de_big(args) | |