Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Any, Dict, List, Optional | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from fairseq import utils | |
| from fairseq.distributed import fsdp_wrap | |
| from fairseq.models.transformer import TransformerConfig | |
| from fairseq.models.transformer.transformer_decoder import TransformerDecoderBase | |
| from fairseq.modules import ( | |
| LayerDropModuleList, | |
| SinusoidalPositionalEmbedding, | |
| transformer_layer_aug, | |
| ) | |
| from fairseq.modules.checkpoint_activations import checkpoint_wrapper | |
| class AugTransformerDecoderBase(TransformerDecoderBase): | |
| """ | |
| Transformer decoder augmented with an additional cross-attention. Each layer | |
| is a :class:`AugTransformerDecoderLayerBase`. | |
| Args: | |
| cfg (argparse.Namespace): parsed command-line arguments | |
| dictionary (~fairseq.data.Dictionary): decoding dictionary | |
| embed_tokens (torch.nn.Embedding): output embedding | |
| encoder_attn_merge_type (str, optional): the way to combine outputs from | |
| two cross-attention modules. If "sequential" is set, two cross-attention | |
| modules are stacked sequentially. If "parallel" is set, they are processed | |
| in parallel and combined before feeding it to FFN (default: sequential). | |
| dropnet_ratio (float, optional): a probability to drop each cross-attention | |
| module during training (default: 0.0). | |
| """ | |
| def __init__( | |
| self, | |
| cfg, | |
| dictionary, | |
| embed_tokens, | |
| output_projection=None, | |
| encoder_attn_merge_type="sequential", | |
| dropnet_ratio=0.0, | |
| ): | |
| super().__init__( | |
| cfg, | |
| dictionary, | |
| embed_tokens, | |
| no_encoder_attn=False, | |
| output_projection=output_projection, | |
| ) | |
| # assert cfg.cross_self_attention | |
| self.cross_self_attention = cfg.cross_self_attention | |
| if self.decoder_layerdrop > 0.0: | |
| self.layers = LayerDropModuleList(p=self.decoder_layerdrop) | |
| else: | |
| self.layers = nn.ModuleList([]) | |
| self.layers.extend( | |
| [ | |
| self.build_decoder_layer(cfg, encoder_attn_merge_type, dropnet_ratio) | |
| for _ in range(cfg.decoder.layers) | |
| ] | |
| ) | |
| def build_decoder_layer( | |
| self, | |
| cfg, | |
| encoder_attn_merge_type="sequential", | |
| dropnet_ratio=0, | |
| ): | |
| layer = transformer_layer_aug.AugTransformerDecoderLayerBase( | |
| cfg, | |
| no_encoder_attn=False, | |
| encoder_attn_merge_type=encoder_attn_merge_type, | |
| dropnet_ratio=dropnet_ratio, | |
| ) | |
| checkpoint = cfg.checkpoint_activations | |
| if checkpoint: | |
| offload_to_cpu = cfg.offload_activations | |
| layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) | |
| # if we are checkpointing, enforce that FSDP always wraps the | |
| # checkpointed layer, regardless of layer size | |
| min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 | |
| layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) | |
| return layer | |
| def forward( | |
| self, | |
| prev_output_tokens, | |
| encoder_out: Optional[Dict[str, List[Tensor]]] = None, | |
| encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
| features_only: bool = False, | |
| full_context_alignment: bool = False, | |
| alignment_layer: Optional[int] = None, | |
| alignment_heads: Optional[int] = None, | |
| src_lengths: Optional[Any] = None, | |
| return_all_hiddens: bool = False, | |
| ): | |
| """ | |
| Args: | |
| prev_output_tokens (LongTensor): previous decoder outputs of shape | |
| `(batch, tgt_len)`, for teacher forcing | |
| encoder_out (optional): output from the encoder, used for | |
| encoder-side attention, should be of size T x B x C | |
| incremental_state (dict): dictionary used for storing state during | |
| :ref:`Incremental decoding` | |
| features_only (bool, optional): only return features without | |
| applying output layer (default: False). | |
| full_context_alignment (bool, optional): don't apply | |
| auto-regressive mask to self-attention (default: False). | |
| Returns: | |
| tuple: | |
| - the decoder's output of shape `(batch, tgt_len, vocab)` | |
| - a dictionary with any model-specific outputs | |
| """ | |
| x, extra = self.extract_features( | |
| prev_output_tokens, | |
| encoder_out=encoder_out, | |
| encoder_out_aug=encoder_out_aug, | |
| incremental_state=incremental_state, | |
| full_context_alignment=full_context_alignment, | |
| alignment_layer=alignment_layer, | |
| alignment_heads=alignment_heads, | |
| ) | |
| if not features_only: | |
| x = self.output_layer(x) | |
| return x, extra | |
| def extract_features( | |
| self, | |
| prev_output_tokens, | |
| encoder_out: Optional[Dict[str, List[Tensor]]], | |
| encoder_out_aug: Optional[Dict[str, List[Tensor]]], | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
| full_context_alignment: bool = False, | |
| alignment_layer: Optional[int] = None, | |
| alignment_heads: Optional[int] = None, | |
| ): | |
| return self.extract_features_scriptable( | |
| prev_output_tokens, | |
| encoder_out, | |
| encoder_out_aug, | |
| incremental_state, | |
| full_context_alignment, | |
| alignment_layer, | |
| alignment_heads, | |
| ) | |
| """ | |
| A scriptable subclass of this class has an extract_features method and calls | |
| super().extract_features, but super() is not supported in torchscript. A copy of | |
| this function is made to be used in the subclass instead. | |
| """ | |
| def extract_features_scriptable( | |
| self, | |
| prev_output_tokens, | |
| encoder_out: Optional[Dict[str, List[Tensor]]], | |
| encoder_out_aug: Optional[Dict[str, List[Tensor]]], | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
| full_context_alignment: bool = False, | |
| alignment_layer: Optional[int] = None, | |
| alignment_heads: Optional[int] = None, | |
| ): | |
| """ | |
| Similar to *forward* but only return features. | |
| Includes several features from "Jointly Learning to Align and | |
| Translate with Transformer Models" (Garg et al., EMNLP 2019). | |
| Args: | |
| full_context_alignment (bool, optional): don't apply | |
| auto-regressive mask to self-attention (default: False). | |
| alignment_layer (int, optional): return mean alignment over | |
| heads at this layer (default: last layer). | |
| alignment_heads (int, optional): only average alignment over | |
| this many heads (default: all heads). | |
| Returns: | |
| tuple: | |
| - the decoder's features of shape `(batch, tgt_len, embed_dim)` | |
| - a dictionary with any model-specific outputs | |
| """ | |
| bs, slen = prev_output_tokens.size() | |
| if alignment_layer is None: | |
| alignment_layer = self.num_layers - 1 | |
| enc: Optional[Tensor] = None | |
| padding_mask: Optional[Tensor] = None | |
| if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: | |
| enc = encoder_out["encoder_out"][0] | |
| if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: | |
| padding_mask = encoder_out["encoder_padding_mask"][0] | |
| enc_aug: Optional[Tensor] = None | |
| padding_mask_aug: Optional[Tensor] = None | |
| if encoder_out_aug is not None and len(encoder_out_aug["encoder_out"]) > 0: | |
| enc_aug = encoder_out_aug["encoder_out"][0] | |
| if ( | |
| encoder_out_aug is not None | |
| and len(encoder_out_aug["encoder_padding_mask"]) > 0 | |
| ): | |
| padding_mask_aug = encoder_out_aug["encoder_padding_mask"][0] | |
| # embed positions | |
| positions = None | |
| if self.embed_positions is not None: | |
| positions = self.embed_positions( | |
| prev_output_tokens, incremental_state=incremental_state | |
| ) | |
| if incremental_state is not None: | |
| prev_output_tokens = prev_output_tokens[:, -1:] | |
| if positions is not None: | |
| positions = positions[:, -1:] | |
| # Prevent torchscript exporting issue for dynamic quant embedding | |
| prev_output_tokens = prev_output_tokens.contiguous() | |
| # embed tokens and positions | |
| x = self.embed_scale * self.embed_tokens(prev_output_tokens) | |
| if self.quant_noise is not None: | |
| x = self.quant_noise(x) | |
| if self.project_in_dim is not None: | |
| x = self.project_in_dim(x) | |
| if positions is not None: | |
| x += positions | |
| if self.layernorm_embedding is not None: | |
| x = self.layernorm_embedding(x) | |
| x = self.dropout_module(x) | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| self_attn_padding_mask: Optional[Tensor] = None | |
| if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): | |
| self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) | |
| # decoder layers | |
| attn: Optional[Tensor] = None | |
| attn_aug: Optional[Tensor] = None | |
| inner_states: List[Optional[Tensor]] = [x] | |
| for idx, layer in enumerate(self.layers): | |
| if incremental_state is None and not full_context_alignment: | |
| self_attn_mask = self.buffered_future_mask(x) | |
| else: | |
| self_attn_mask = None | |
| x, layer_attn, layer_attn_aug, _ = layer( | |
| x, | |
| enc, | |
| padding_mask, | |
| enc_aug, | |
| padding_mask_aug, | |
| incremental_state, | |
| self_attn_mask=self_attn_mask, | |
| self_attn_padding_mask=self_attn_padding_mask, | |
| need_attn=bool((idx == alignment_layer)), | |
| need_head_weights=bool((idx == alignment_layer)), | |
| ) | |
| inner_states.append(x) | |
| if layer_attn is not None and idx == alignment_layer: | |
| attn = layer_attn.float().to(x) | |
| if layer_attn_aug is not None and idx == alignment_layer: | |
| attn_aug = layer_attn_aug.float().to(x) | |
| if attn is not None: | |
| if alignment_heads is not None: | |
| attn = attn[:alignment_heads] | |
| # average probabilities over heads | |
| attn = attn.mean(dim=0) | |
| if attn_aug is not None: | |
| if alignment_heads is not None: | |
| attn_aug = attn_aug[:alignment_heads] | |
| # average probabilities over heads | |
| attn_aug = attn_aug.mean(dim=0) | |
| if self.layer_norm is not None: | |
| x = self.layer_norm(x) | |
| # T x B x C -> B x T x C | |
| x = x.transpose(0, 1) | |
| if self.project_out_dim is not None: | |
| x = self.project_out_dim(x) | |
| return x, {"attn": [attn], "attn_aug": [attn_aug], "inner_states": inner_states} | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| """Upgrade a (possibly old) state dict for new versions of fairseq.""" | |
| if f"{name}.output_projection.weight" not in state_dict: | |
| if self.share_input_output_embed: | |
| embed_out_key = f"{name}.embed_tokens.weight" | |
| else: | |
| embed_out_key = f"{name}.embed_out" | |
| if embed_out_key in state_dict: | |
| state_dict[f"{name}.output_projection.weight"] = state_dict[ | |
| embed_out_key | |
| ] | |
| if not self.share_input_output_embed: | |
| del state_dict[embed_out_key] | |
| for i in range(self.num_layers): | |
| # update layer norms | |
| layer_norm_map = { | |
| "0": "self_attn_layer_norm", | |
| "1": "encoder_attn_layer_norm", | |
| "2": "encoder_attn_layer_norm2", | |
| "3": "final_layer_norm", | |
| } | |
| for old, new in layer_norm_map.items(): | |
| for m in ("weight", "bias"): | |
| k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) | |
| if k in state_dict: | |
| state_dict[ | |
| "{}.layers.{}.{}.{}".format(name, i, new, m) | |
| ] = state_dict[k] | |
| del state_dict[k] | |
| version_key = "{}.version".format(name) | |
| if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: | |
| # earlier checkpoints did not normalize after the stack of layers | |
| self.layer_norm = None | |
| self.normalize = False | |
| state_dict[version_key] = torch.Tensor([1]) | |
| return state_dict | |
| class AugTransformerDecoder(AugTransformerDecoderBase): | |
| def __init__( | |
| self, | |
| args, | |
| dictionary, | |
| embed_tokens, | |
| output_projection=None, | |
| ): | |
| self.args = args | |
| super().__init__( | |
| TransformerConfig.from_namespace(args), | |
| dictionary, | |
| embed_tokens, | |
| no_encoder_attn=False, | |
| output_projection=output_projection, | |
| encoder_attn_merge_type=getattr( | |
| args, "synthesizer_augmented_cross_attention_merge_type", "sequential" | |
| ), | |
| dropnet_ratio=getattr(args, "dropnet_ratio", 0), | |
| ) | |
| def build_output_projection(self, args, dictionary, embed_tokens): | |
| super().build_output_projection( | |
| TransformerConfig.from_namespace(args), dictionary, embed_tokens | |
| ) | |
| def build_decoder_layer( | |
| self, | |
| args, | |
| encoder_attn_merge_type="sequential", | |
| dropnet_ratio=0, | |
| ): | |
| return super().build_decoder_layer( | |
| TransformerConfig.from_namespace(args), | |
| no_encoder_attn=False, | |
| encoder_attn_merge_type=encoder_attn_merge_type, | |
| dropnet_ratio=dropnet_ratio, | |
| ) | |