Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from collections import namedtuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from fairseq import options, utils | |
| from fairseq.modules import ( | |
| AdaptiveSoftmax, | |
| LayerNorm, | |
| MultiheadAttention, | |
| PositionalEmbedding, | |
| ) | |
| EncoderOut = namedtuple( | |
| "TransformerEncoderOut", | |
| [ | |
| "encoder_out", # T x B x C | |
| "encoder_padding_mask", # B x T | |
| "encoder_embedding", # B x T x C | |
| "encoder_states", # List[T x B x C] | |
| ], | |
| ) | |
| class TransformerEncoderEmbedding(nn.Module): | |
| """Encoder Embedding + Positional Embedding""" | |
| def __init__(self, args, embed_tokens): | |
| super().__init__() | |
| self.dropout = args.dropout | |
| self.max_source_positions = args.max_source_positions | |
| self.embed_tokens = embed_tokens | |
| if isinstance(embed_tokens, nn.ModuleList): | |
| self.padding_idx = embed_tokens[0].padding_idx | |
| embed_dim = sum(e.embedding_dim for e in embed_tokens) | |
| else: | |
| self.padding_idx = embed_tokens.padding_idx | |
| embed_dim = embed_tokens.embedding_dim | |
| self.embed_scale = math.sqrt(embed_dim) | |
| self.embed_positions = ( | |
| PositionalEmbedding( | |
| args.max_source_positions, | |
| embed_dim, | |
| self.padding_idx, | |
| learned=args.encoder_learned_pos, | |
| ) | |
| if not args.no_token_positional_embeddings | |
| else None | |
| ) | |
| if getattr(args, "layernorm_embedding", False): | |
| self.layernorm_embedding = LayerNorm(embed_dim) | |
| else: | |
| self.layernorm_embedding = None | |
| def forward(self, input): | |
| # embed tokens and positions | |
| src_tokens = input[0] | |
| prev_output_tokens = input[2] | |
| if isinstance(self.embed_tokens, nn.ModuleList): | |
| x_embed_list = [] | |
| for embed_tokens_part in self.embed_tokens: | |
| x_embed_list.append(embed_tokens_part(src_tokens)) | |
| embedded = torch.cat(x_embed_list, dim=-1) | |
| else: | |
| embedded = self.embed_tokens(src_tokens) | |
| x = embed = self.embed_scale * embedded | |
| if self.embed_positions is not None: | |
| x = embed + self.embed_positions(src_tokens) | |
| if self.layernorm_embedding: | |
| x = self.layernorm_embedding(x) | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| # compute padding mask | |
| encoder_padding_mask = src_tokens.eq(self.padding_idx) | |
| return (x, encoder_padding_mask, prev_output_tokens) | |
| class TransformerEncoderLayerNorm(nn.Module): | |
| """ | |
| Layer norm at the the end of all encoder layers if | |
| args.encoder_enormalize_before = True | |
| """ | |
| def __init__(self, args, embed_dim): | |
| super().__init__() | |
| if args.encoder_normalize_before: | |
| self.layer_norm = LayerNorm(embed_dim) | |
| else: | |
| self.layer_norm = None | |
| def forward(self, input): | |
| x = input[0] | |
| encoder_padding_mask = input[1] | |
| prev_output_tokens = input[2] | |
| if self.layer_norm: | |
| x = self.layer_norm(x) | |
| # keeping track of the incremental_state is not supported yet | |
| return (x, encoder_padding_mask, prev_output_tokens) | |
| class TransformerDecoderEmbedding(nn.Module): | |
| """Decoder Embedding + Positional Embedding""" | |
| def __init__(self, args, embed_tokens): | |
| super().__init__() | |
| self.dropout = args.dropout | |
| self.share_input_output_embed = args.share_decoder_input_output_embed | |
| input_embed_dim = ( | |
| sum(e.embedding_dim for e in embed_tokens) | |
| if isinstance(embed_tokens, nn.ModuleList) | |
| else embed_tokens.embedding_dim | |
| ) | |
| embed_dim = args.decoder_embed_dim | |
| self.output_embed_dim = args.decoder_output_dim | |
| padding_idx = ( | |
| embed_tokens[0].padding_idx | |
| if isinstance(embed_tokens, nn.ModuleList) | |
| else embed_tokens.padding_idx | |
| ) | |
| self.max_target_positions = args.max_target_positions | |
| self.embed_tokens = embed_tokens | |
| self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim | |
| self.project_in_dim = ( | |
| Linear(input_embed_dim, embed_dim, bias=False) | |
| if embed_dim != input_embed_dim | |
| else None | |
| ) | |
| self.embed_positions = ( | |
| PositionalEmbedding( | |
| args.max_target_positions, | |
| embed_dim, | |
| padding_idx, | |
| learned=args.decoder_learned_pos, | |
| ) | |
| if not args.no_token_positional_embeddings | |
| else None | |
| ) | |
| def forward(self, input): | |
| mt_task = False | |
| if isinstance(input, tuple): | |
| if len(input) == 3: | |
| encoder_out = input[0] | |
| encoder_padding_mask = input[1] | |
| prev_output_tokens = input[2] | |
| incremental_state = None # Hardcoding to avoid passing of None objects | |
| mt_task = True | |
| else: | |
| # HACK for now, need to fix (TODO sidgoyal) | |
| prev_output_tokens = input[0] | |
| # discard "src_lengths" | |
| encoder_out = None | |
| encoder_padding_mask = None | |
| incremental_state = None | |
| else: | |
| prev_output_tokens = input | |
| encoder_out = None | |
| encoder_padding_mask = None | |
| incremental_state = None | |
| positions = ( | |
| self.embed_positions( | |
| prev_output_tokens, | |
| incremental_state=incremental_state, | |
| ) | |
| if self.embed_positions is not None | |
| else None | |
| ) | |
| if incremental_state is not None: | |
| prev_output_tokens = prev_output_tokens[:, -1:] | |
| if positions is not None: | |
| positions = positions[:, -1:] | |
| # embed tokens and positions | |
| if isinstance(self.embed_tokens, nn.ModuleList): | |
| x_embed_list = [] | |
| for embed_tokens_part in self.embed_tokens: | |
| x_embed_list.append(embed_tokens_part(prev_output_tokens)) | |
| x = self.embed_scale * torch.cat(x_embed_list, dim=-1) | |
| else: | |
| x = self.embed_scale * self.embed_tokens(prev_output_tokens) | |
| if self.project_in_dim is not None: | |
| x = self.project_in_dim(x) | |
| if positions is not None: | |
| x += positions | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| if mt_task: | |
| return (x, encoder_out, encoder_padding_mask) | |
| return x | |
| class TransformerDecoderOutputLayer(nn.Module): | |
| def __init__(self, args, embed_tokens, dictionary): | |
| super().__init__() | |
| self.share_input_output_embed = args.share_decoder_input_output_embed | |
| self.embed_tokens = embed_tokens | |
| self.output_embed_dim = args.decoder_output_dim | |
| embed_dim = args.decoder_embed_dim | |
| self.project_out_dim = ( | |
| Linear(embed_dim, self.output_embed_dim, bias=False) | |
| if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights | |
| else None | |
| ) | |
| self.adaptive_softmax = None | |
| if args.adaptive_softmax_cutoff is not None: | |
| assert not isinstance(embed_tokens, nn.ModuleList) | |
| self.adaptive_softmax = AdaptiveSoftmax( | |
| len(dictionary), | |
| self.output_embed_dim, | |
| options.eval_str_list(args.adaptive_softmax_cutoff, type=int), | |
| dropout=args.adaptive_softmax_dropout, | |
| adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, | |
| factor=args.adaptive_softmax_factor, | |
| tie_proj=args.tie_adaptive_proj, | |
| ) | |
| elif not self.share_input_output_embed: | |
| self.embed_tokens = nn.Parameter( | |
| torch.Tensor(len(dictionary), self.output_embed_dim) | |
| ) | |
| nn.init.normal_( | |
| self.embed_tokens, mean=0, std=self.output_embed_dim**-0.5 | |
| ) | |
| if args.decoder_normalize_before and not getattr( | |
| args, "no_decoder_final_norm", False | |
| ): | |
| self.layer_norm = LayerNorm(embed_dim) | |
| else: | |
| self.layer_norm = None | |
| def forward(self, input, apply_final_proj=True): | |
| if isinstance(input, tuple): | |
| x = input[0] | |
| else: | |
| x = input | |
| if self.layer_norm: | |
| x = self.layer_norm(x) | |
| # T x B x C -> B x T x C | |
| x = x.transpose(0, 1) | |
| if self.project_out_dim is not None: | |
| x = self.project_out_dim(x) | |
| if apply_final_proj: | |
| x = self.output_layer(x) | |
| return x | |
| def output_layer(self, features, **kwargs): | |
| """Project features to the vocabulary size.""" | |
| if self.adaptive_softmax is None: | |
| # project back to size of vocabulary | |
| if self.share_input_output_embed: | |
| if isinstance(self.embed_tokens, nn.ModuleList): | |
| output = None | |
| for i, emb in enumerate(self.embed_tokens): | |
| sidx = i * emb.embedding_dim | |
| eidx = (i + 1) * emb.embedding_dim | |
| if output is None: | |
| output = F.linear(features[:, :, sidx:eidx], emb.weight) | |
| else: | |
| output += F.linear(features[:, :, sidx:eidx], emb.weight) | |
| return output | |
| else: | |
| return F.linear(features, self.embed_tokens.weight) | |
| else: | |
| return F.linear(features, self.embed_tokens) | |
| else: | |
| return features | |
| class TransformerEncoderLayer(nn.Module): | |
| """Encoder layer block. | |
| In the original paper each operation (multi-head attention or FFN) is | |
| postprocessed with: `dropout -> add residual -> layernorm`. In the | |
| tensor2tensor code they suggest that learning is more robust when | |
| preprocessing each layer with layernorm and postprocessing with: | |
| `dropout -> add residual`. We default to the approach in the paper, but the | |
| tensor2tensor approach can be enabled by setting | |
| *args.encoder_normalize_before* to ``True``. | |
| Args: | |
| args (argparse.Namespace): parsed command-line arguments | |
| """ | |
| def __init__(self, args): | |
| super().__init__() | |
| self.embed_dim = args.encoder_embed_dim | |
| self.self_attn = MultiheadAttention( | |
| self.embed_dim, | |
| args.encoder_attention_heads, | |
| dropout=args.attention_dropout, | |
| self_attention=True, | |
| ) | |
| self.self_attn_layer_norm = LayerNorm(self.embed_dim) | |
| self.dropout = args.dropout | |
| self.activation_fn = utils.get_activation_fn( | |
| activation=getattr(args, "activation_fn", "relu") | |
| ) | |
| self.activation_dropout = getattr(args, "activation_dropout", 0) | |
| if self.activation_dropout == 0: | |
| # for backwards compatibility with models that use args.relu_dropout | |
| self.activation_dropout = getattr(args, "relu_dropout", 0) | |
| self.normalize_before = args.encoder_normalize_before | |
| self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) | |
| self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) | |
| self.final_layer_norm = LayerNorm(self.embed_dim) | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| """ | |
| Rename layer norm states from `...layer_norms.0.weight` to | |
| `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to | |
| `...final_layer_norm.weight` | |
| """ | |
| layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} | |
| for old, new in layer_norm_map.items(): | |
| for m in ("weight", "bias"): | |
| k = "{}.layer_norms.{}.{}".format(name, old, m) | |
| if k in state_dict: | |
| state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] | |
| del state_dict[k] | |
| def forward(self, input): | |
| """ | |
| Args: | |
| input (Tuple): | |
| input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` | |
| input[1] (ByteTensor/FloatTensor): encoder padding mask - | |
| binary ByteTensor of shape `(batch, src_len)` where padding elements | |
| are indicated by ``1``. | |
| input[2] (LongTensor): previous decoder outputs of shape | |
| `(batch, tgt_len)`, for teacher forcing) | |
| Returns: | |
| output (Tuple): | |
| output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` | |
| output[1] (ByteTensor/FloatTensor): encoder padding mask | |
| output[2] (LongTensor): previous decoder outputs | |
| """ | |
| x = input[0] | |
| encoder_padding_mask = input[1] | |
| prev_output_tokens = input[2] | |
| residual = x | |
| x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) | |
| x, _ = self.self_attn( | |
| query=x, key=x, value=x, key_padding_mask=encoder_padding_mask | |
| ) | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| x = residual + x | |
| x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) | |
| residual = x | |
| x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) | |
| x = self.activation_fn(self.fc1(x)) | |
| x = F.dropout(x, p=self.activation_dropout, training=self.training) | |
| x = self.fc2(x) | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| x = residual + x | |
| x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) | |
| return (x, encoder_padding_mask, prev_output_tokens) | |
| def maybe_layer_norm(self, layer_norm, x, before=False, after=False): | |
| assert before ^ after | |
| if after ^ self.normalize_before: | |
| return layer_norm(x) | |
| else: | |
| return x | |
| class TransformerDecoderLayer(nn.Module): | |
| """Decoder layer block. | |
| In the original paper each operation (multi-head attention, encoder | |
| attention or FFN) is postprocessed with: `dropout -> add residual -> | |
| layernorm`. In the tensor2tensor code they suggest that learning is more | |
| robust when preprocessing each layer with layernorm and postprocessing with: | |
| `dropout -> add residual`. We default to the approach in the paper, but the | |
| tensor2tensor approach can be enabled by setting | |
| *args.decoder_normalize_before* to ``True``. | |
| Args: | |
| args (argparse.Namespace): parsed command-line arguments | |
| no_encoder_attn (bool, optional): whether to attend to encoder outputs | |
| (default: False). | |
| """ | |
| def __init__( | |
| self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False | |
| ): | |
| super().__init__() | |
| self.embed_dim = args.decoder_embed_dim | |
| self.self_attn = MultiheadAttention( | |
| embed_dim=self.embed_dim, | |
| num_heads=args.decoder_attention_heads, | |
| dropout=args.attention_dropout, | |
| add_bias_kv=add_bias_kv, | |
| add_zero_attn=add_zero_attn, | |
| self_attention=True, | |
| ) | |
| self.dropout = args.dropout | |
| self.activation_fn = utils.get_activation_fn( | |
| activation=getattr(args, "activation_fn", "relu") | |
| ) | |
| self.activation_dropout = getattr(args, "activation_dropout", 0) | |
| if self.activation_dropout == 0: | |
| # for backwards compatibility with models that use args.relu_dropout | |
| self.activation_dropout = getattr(args, "relu_dropout", 0) | |
| self.normalize_before = args.decoder_normalize_before | |
| # use layerNorm rather than FusedLayerNorm for exporting. | |
| # char_inputs can be used to determint this. | |
| # TODO remove this once we update apex with the fix | |
| export = getattr(args, "char_inputs", False) | |
| self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) | |
| if no_encoder_attn: | |
| self.encoder_attn = None | |
| self.encoder_attn_layer_norm = None | |
| else: | |
| self.encoder_attn = MultiheadAttention( | |
| self.embed_dim, | |
| args.decoder_attention_heads, | |
| kdim=getattr(args, "encoder_embed_dim", None), | |
| vdim=getattr(args, "encoder_embed_dim", None), | |
| dropout=args.attention_dropout, | |
| encoder_decoder_attention=True, | |
| ) | |
| self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) | |
| self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) | |
| self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) | |
| self.final_layer_norm = LayerNorm(self.embed_dim, export=export) | |
| self.need_attn = True | |
| self.onnx_trace = False | |
| def prepare_for_onnx_export_(self): | |
| self.onnx_trace = True | |
| def forward(self, input): | |
| """ | |
| Args: | |
| input (Tuple): | |
| input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` | |
| input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)` | |
| input[2] (ByteTensor/FloatTensor): encoder padding mask - | |
| binary ByteTensor of shape `(batch, src_len)` where padding elements | |
| are indicated by ``1``. | |
| Returns: | |
| output (Tuple): | |
| output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` | |
| output[1] (ByteTensor/FloatTensor): encoder padding mask | |
| output[2] (LongTensor): previous decoder outputs | |
| """ | |
| # Note: incremental state is not yet supported | |
| mt_task = False | |
| if isinstance(input, tuple): | |
| x = input[0] | |
| encoder_out = input[1] | |
| encoder_padding_mask = input[2] | |
| incremental_state = None | |
| mt_task = True | |
| else: | |
| x = input | |
| encoder_out = None | |
| encoder_padding_mask = None | |
| incremental_state = None | |
| if incremental_state is None: | |
| self_attn_mask = self.buffered_future_mask(x) | |
| else: | |
| self_attn_mask = None | |
| # TODO: add back prev_self_attn_state, prev_attn_state, | |
| # self_attn_padding_mask | |
| prev_self_attn_state = None | |
| prev_attn_state = None | |
| self_attn_padding_mask = None | |
| residual = x | |
| x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) | |
| if prev_self_attn_state is not None: | |
| if incremental_state is None: | |
| incremental_state = {} | |
| prev_key, prev_value = prev_self_attn_state | |
| saved_state = {"prev_key": prev_key, "prev_value": prev_value} | |
| self.self_attn._set_input_buffer(incremental_state, saved_state) | |
| x, attn = self.self_attn( | |
| query=x, | |
| key=x, | |
| value=x, | |
| key_padding_mask=self_attn_padding_mask, | |
| incremental_state=incremental_state, | |
| need_weights=False, | |
| attn_mask=self_attn_mask, | |
| ) | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| x = residual + x | |
| x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) | |
| if self.encoder_attn is not None: | |
| residual = x | |
| x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) | |
| if prev_attn_state is not None: | |
| if incremental_state is None: | |
| incremental_state = {} | |
| prev_key, prev_value = prev_attn_state | |
| saved_state = {"prev_key": prev_key, "prev_value": prev_value} | |
| self.encoder_attn._set_input_buffer(incremental_state, saved_state) | |
| x, attn = self.encoder_attn( | |
| query=x, | |
| key=encoder_out, | |
| value=encoder_out, | |
| key_padding_mask=encoder_padding_mask, | |
| incremental_state=incremental_state, | |
| static_kv=True, | |
| need_weights=(not self.training and self.need_attn), | |
| ) | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| x = residual + x | |
| x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) | |
| residual = x | |
| x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) | |
| x = self.activation_fn(self.fc1(x)) | |
| x = F.dropout(x, p=self.activation_dropout, training=self.training) | |
| x = self.fc2(x) | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| x = residual + x | |
| x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) | |
| if mt_task: | |
| return (x, encoder_out, encoder_padding_mask) | |
| return x | |
| def buffered_future_mask(self, tensor): | |
| dim = tensor.size(0) | |
| if ( | |
| not hasattr(self, "_future_mask") | |
| or self._future_mask is None | |
| or self._future_mask.device != tensor.device | |
| ): | |
| self._future_mask = torch.triu( | |
| utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 | |
| ) | |
| if self._future_mask.size(0) < dim: | |
| self._future_mask = torch.triu( | |
| utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 | |
| ) | |
| return self._future_mask[:dim, :dim] | |
| def maybe_layer_norm(self, layer_norm, x, before=False, after=False): | |
| assert before ^ after | |
| if after ^ self.normalize_before: | |
| return layer_norm(x) | |
| else: | |
| return x | |
| def make_generation_fast_(self, need_attn=False, **kwargs): | |
| self.need_attn = need_attn | |
| def Embedding(num_embeddings, embedding_dim, padding_idx): | |
| m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) | |
| nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) | |
| nn.init.constant_(m.weight[padding_idx], 0) | |
| return m | |
| def Linear(in_features, out_features, bias=True): | |
| m = nn.Linear(in_features, out_features, bias) | |
| nn.init.xavier_uniform_(m.weight) | |
| if bias: | |
| nn.init.constant_(m.bias, 0.0) | |
| return m | |