| |
| |
| |
| |
|
|
| from typing import Dict, List, Optional |
|
|
| import torch |
| import torch.nn as nn |
| from fairseq import utils |
| from fairseq.modules import LayerNorm |
| from fairseq.modules.fairseq_dropout import FairseqDropout |
| from fairseq.modules.quant_noise import quant_noise |
| from torch import Tensor |
|
|
| from .unify_multihead_attention import MultiheadAttention |
|
|
|
|
| def drop_path(x, drop_prob: float = 0.0, training: bool = False): |
| """ |
| Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, |
| however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the |
| layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the |
| argument. |
| """ |
| if drop_prob == 0.0 or not training: |
| return x |
| keep_prob = 1 - drop_prob |
| shape = (1, x.shape[1], 1) |
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
| random_tensor.floor_() |
| output = x.div(keep_prob) * random_tensor |
| return output |
|
|
|
|
| class DropPath(nn.Module): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" |
|
|
| def __init__(self, drop_prob=None): |
| super().__init__() |
| self.drop_prob = drop_prob |
|
|
| def forward(self, x): |
| return drop_path(x, self.drop_prob, self.training) |
|
|
| def extra_repr(self) -> str: |
| return "p={}".format(self.drop_prob) |
|
|
|
|
| class TransformerEncoderLayer(nn.Module): |
| """Encoder layer block. |
| |
| In the original paper each operation (multi-head attention or FFN) is |
| postprocessed with: `dropout -> add residual -> layernorm`. In the |
| tensor2tensor code they suggest that learning is more robust when |
| preprocessing each layer with layernorm and postprocessing with: |
| `dropout -> add residual`. We default to the approach in the paper, but the |
| tensor2tensor approach can be enabled by setting |
| *args.encoder_normalize_before* to ``True``. |
| |
| Args: |
| args (argparse.Namespace): parsed command-line arguments |
| """ |
|
|
| def __init__(self, args, drop_path_rate=0.0): |
| super().__init__() |
| self.args = args |
| self.embed_dim = args.encoder_embed_dim |
| self.quant_noise = getattr(args, 'quant_noise_pq', 0) |
| self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 |
| self.self_attn = self.build_self_attention(self.embed_dim, args) |
| self.self_attn_layer_norm = LayerNorm(self.embed_dim) |
| self.dropout_module = FairseqDropout( |
| args.dropout, module_name=self.__class__.__name__ |
| ) |
| self.activation_fn = utils.get_activation_fn( |
| activation=getattr(args, 'activation_fn', 'relu') or "relu" |
| ) |
| activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 |
| if activation_dropout_p == 0: |
| |
| activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 |
| self.activation_dropout_module = FairseqDropout( |
| float(activation_dropout_p), module_name=self.__class__.__name__ |
| ) |
| self.normalize_before = args.encoder_normalize_before |
| self.fc1 = self.build_fc1( |
| self.embed_dim, |
| args.encoder_ffn_embed_dim, |
| self.quant_noise, |
| self.quant_noise_block_size, |
| ) |
| self.fc2 = self.build_fc2( |
| args.encoder_ffn_embed_dim, |
| self.embed_dim, |
| self.quant_noise, |
| self.quant_noise_block_size, |
| ) |
|
|
| self.attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None |
| self.nh = self.self_attn.num_heads |
| self.head_dim = self.self_attn.head_dim |
|
|
| self.ffn_layernorm = LayerNorm(args.encoder_ffn_embed_dim) if getattr(args, 'scale_fc', False) else None |
| self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if getattr(args, 'scale_resids', False) else None |
|
|
| self.final_layer_norm = LayerNorm(self.embed_dim) |
|
|
| self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() |
|
|
| def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): |
| return quant_noise( |
| nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size |
| ) |
|
|
| def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): |
| return quant_noise( |
| nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size |
| ) |
|
|
| def build_self_attention(self, embed_dim, args): |
| return MultiheadAttention( |
| embed_dim, |
| args.encoder_attention_heads, |
| dropout=args.attention_dropout, |
| self_attention=True, |
| q_noise=self.quant_noise, |
| qn_block_size=self.quant_noise_block_size, |
| scale_factor=args.attn_scale_factor, |
| scale_heads=getattr(args, 'scale_heads', False) |
| ) |
|
|
| def residual_connection(self, x, residual): |
| return residual + self.drop_path(x) |
|
|
| def upgrade_state_dict_named(self, state_dict, name): |
| """ |
| Rename layer norm states from `...layer_norms.0.weight` to |
| `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to |
| `...final_layer_norm.weight` |
| """ |
| layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} |
| for old, new in layer_norm_map.items(): |
| for m in ("weight", "bias"): |
| k = "{}.layer_norms.{}.{}".format(name, old, m) |
| if k in state_dict: |
| state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] |
| del state_dict[k] |
| if "{}.{}.{}".format(name, new, m) not in state_dict and "{}.{}".format(new, m) in self.state_dict(): |
| state_dict[ |
| "{}.{}.{}".format(name, new, m) |
| ] = self.state_dict()["{}.{}".format(new, m)] |
|
|
| prefix = name + "." if name != "" else "" |
| for param_name, param_tensor in self.state_dict().items(): |
| if (prefix + param_name) not in state_dict: |
| state_dict[prefix + param_name] = self.state_dict()[param_name] |
|
|
| def forward( |
| self, |
| x, |
| encoder_padding_mask: Optional[Tensor], |
| attn_mask: Optional[Tensor] = None, |
| self_attn_bias: Optional[Tensor] = None |
| ): |
| """ |
| Args: |
| x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` |
| encoder_padding_mask (ByteTensor): binary ByteTensor of shape |
| `(batch, seq_len)` where padding elements are indicated by ``1``. |
| attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, |
| where `tgt_len` is the length of output and `src_len` is the |
| length of input, though here both are equal to `seq_len`. |
| `attn_mask[tgt_i, src_j] = 1` means that when calculating the |
| embedding for `tgt_i`, we exclude (mask out) `src_j`. This is |
| useful for strided self-attention. |
| |
| Returns: |
| encoded output of shape `(seq_len, batch, embed_dim)` |
| """ |
| |
| |
| |
| |
| |
| if attn_mask is not None: |
| attn_mask = attn_mask.masked_fill( |
| attn_mask.to(torch.bool), |
| -1e8 if x.dtype == torch.float32 else -1e4 |
| ) |
|
|
| residual = x |
| if self.normalize_before: |
| x = self.self_attn_layer_norm(x) |
| x, _ = self.self_attn( |
| query=x, |
| key=x, |
| value=x, |
| key_padding_mask=encoder_padding_mask, |
| need_weights=False, |
| attn_mask=attn_mask, |
| attn_bias=self_attn_bias |
| ) |
| if self.attn_ln is not None: |
| x = self.attn_ln(x) |
| x = self.dropout_module(x) |
| x = self.residual_connection(x, residual) |
| if not self.normalize_before: |
| x = self.self_attn_layer_norm(x) |
|
|
| residual = x |
| if self.normalize_before: |
| x = self.final_layer_norm(x) |
| x = self.activation_fn(self.fc1(x)) |
| x = self.activation_dropout_module(x) |
| if self.ffn_layernorm is not None: |
| x = self.ffn_layernorm(x) |
| x = self.fc2(x) |
| x = self.dropout_module(x) |
| if self.w_resid is not None: |
| residual = torch.mul(self.w_resid, residual) |
| x = self.residual_connection(x, residual) |
| if not self.normalize_before: |
| x = self.final_layer_norm(x) |
| return x |
|
|
|
|
| class TransformerDecoderLayer(nn.Module): |
| """Decoder layer block. |
| |
| In the original paper each operation (multi-head attention, encoder |
| attention or FFN) is postprocessed with: `dropout -> add residual -> |
| layernorm`. In the tensor2tensor code they suggest that learning is more |
| robust when preprocessing each layer with layernorm and postprocessing with: |
| `dropout -> add residual`. We default to the approach in the paper, but the |
| tensor2tensor approach can be enabled by setting |
| *args.decoder_normalize_before* to ``True``. |
| |
| Args: |
| args (argparse.Namespace): parsed command-line arguments |
| no_encoder_attn (bool, optional): whether to attend to encoder outputs |
| (default: False). |
| """ |
|
|
| def __init__( |
| self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, drop_path_rate=0.0 |
| ): |
| super().__init__() |
| self.embed_dim = args.decoder_embed_dim |
| self.dropout_module = FairseqDropout( |
| args.dropout, module_name=self.__class__.__name__ |
| ) |
| self.quant_noise = getattr(args, "quant_noise_pq", 0) |
| self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) |
|
|
| self.cross_self_attention = getattr(args, "cross_self_attention", False) |
|
|
| self.self_attn = self.build_self_attention( |
| self.embed_dim, |
| args, |
| add_bias_kv=add_bias_kv, |
| add_zero_attn=add_zero_attn, |
| ) |
| self.self_attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None |
| self.cross_attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None |
| self.nh = self.self_attn.num_heads |
| self.head_dim = self.self_attn.head_dim |
|
|
| self.activation_fn = utils.get_activation_fn( |
| activation=str(args.activation_fn) |
| if getattr(args, "activation_fn", None) is not None |
| else "relu" |
| ) |
| activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 |
| if activation_dropout_p == 0: |
| |
| activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 |
| self.activation_dropout_module = FairseqDropout( |
| float(activation_dropout_p), module_name=self.__class__.__name__ |
| ) |
| self.normalize_before = args.decoder_normalize_before |
|
|
| |
| |
| |
| export = getattr(args, "char_inputs", False) |
| self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) |
|
|
| if no_encoder_attn: |
| self.encoder_attn = None |
| self.encoder_attn_layer_norm = None |
| else: |
| self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) |
| self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) |
|
|
| self.ffn_layernorm = LayerNorm(args.decoder_ffn_embed_dim) if getattr(args, 'scale_fc', False) else None |
| self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if getattr(args, 'scale_resids', False) else None |
|
|
| self.fc1 = self.build_fc1( |
| self.embed_dim, |
| args.decoder_ffn_embed_dim, |
| self.quant_noise, |
| self.quant_noise_block_size, |
| ) |
| self.fc2 = self.build_fc2( |
| args.decoder_ffn_embed_dim, |
| self.embed_dim, |
| self.quant_noise, |
| self.quant_noise_block_size, |
| ) |
|
|
| self.final_layer_norm = LayerNorm(self.embed_dim, export=export) |
| self.need_attn = True |
|
|
| self.onnx_trace = False |
|
|
| self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() |
|
|
| def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): |
| return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) |
|
|
| def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): |
| return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) |
|
|
| def build_self_attention( |
| self, embed_dim, args, add_bias_kv=False, add_zero_attn=False |
| ): |
| return MultiheadAttention( |
| embed_dim, |
| args.decoder_attention_heads, |
| dropout=args.attention_dropout, |
| add_bias_kv=add_bias_kv, |
| add_zero_attn=add_zero_attn, |
| self_attention=not getattr(args, "cross_self_attention", False), |
| q_noise=self.quant_noise, |
| qn_block_size=self.quant_noise_block_size, |
| scale_factor=args.attn_scale_factor, |
| scale_heads=getattr(args, 'scale_heads', False) |
| ) |
|
|
| def build_encoder_attention(self, embed_dim, args): |
| return MultiheadAttention( |
| embed_dim, |
| args.decoder_attention_heads, |
| kdim=getattr(args, "encoder_embed_dim", None), |
| vdim=getattr(args, "encoder_embed_dim", None), |
| dropout=args.attention_dropout, |
| encoder_decoder_attention=True, |
| q_noise=self.quant_noise, |
| qn_block_size=self.quant_noise_block_size, |
| scale_factor=args.attn_scale_factor, |
| scale_heads=getattr(args, 'scale_heads', False) |
| ) |
|
|
| def prepare_for_onnx_export_(self): |
| self.onnx_trace = True |
|
|
| def residual_connection(self, x, residual): |
| return residual + self.drop_path(x) |
|
|
| def forward( |
| self, |
| x, |
| encoder_out: Optional[torch.Tensor] = None, |
| encoder_padding_mask: Optional[torch.Tensor] = None, |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
| prev_self_attn_state: Optional[List[torch.Tensor]] = None, |
| prev_attn_state: Optional[List[torch.Tensor]] = None, |
| self_attn_mask: Optional[torch.Tensor] = None, |
| self_attn_padding_mask: Optional[torch.Tensor] = None, |
| need_attn: bool = False, |
| need_head_weights: bool = False, |
| self_attn_bias: Optional[Tensor] = None, |
| cross_attn_bias: Optional[Tensor] = None |
| ): |
| """ |
| Args: |
| x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` |
| encoder_padding_mask (ByteTensor, optional): binary |
| ByteTensor of shape `(batch, src_len)` where padding |
| elements are indicated by ``1``. |
| need_attn (bool, optional): return attention weights |
| need_head_weights (bool, optional): return attention weights |
| for each head (default: return average over heads). |
| |
| Returns: |
| encoded output of shape `(seq_len, batch, embed_dim)` |
| """ |
| if need_head_weights: |
| need_attn = True |
|
|
| residual = x |
| if self.normalize_before: |
| x = self.self_attn_layer_norm(x) |
| if prev_self_attn_state is not None: |
| prev_key, prev_value = prev_self_attn_state[:2] |
| saved_state: Dict[str, Optional[Tensor]] = { |
| "prev_key": prev_key, |
| "prev_value": prev_value, |
| } |
| if len(prev_self_attn_state) >= 3: |
| saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] |
| assert incremental_state is not None |
| self.self_attn._set_input_buffer(incremental_state, saved_state) |
| _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) |
| if self.cross_self_attention and not ( |
| incremental_state is not None |
| and _self_attn_input_buffer is not None |
| and "prev_key" in _self_attn_input_buffer |
| ): |
| if self_attn_mask is not None: |
| assert encoder_out is not None |
| self_attn_mask = torch.cat( |
| (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 |
| ) |
| if self_attn_padding_mask is not None: |
| if encoder_padding_mask is None: |
| assert encoder_out is not None |
| encoder_padding_mask = self_attn_padding_mask.new_zeros( |
| encoder_out.size(1), encoder_out.size(0) |
| ) |
| self_attn_padding_mask = torch.cat( |
| (encoder_padding_mask, self_attn_padding_mask), dim=1 |
| ) |
| assert encoder_out is not None |
| y = torch.cat((encoder_out, x), dim=0) |
| else: |
| y = x |
|
|
| x, attn = self.self_attn( |
| query=x, |
| key=y, |
| value=y, |
| key_padding_mask=self_attn_padding_mask, |
| incremental_state=incremental_state, |
| need_weights=False, |
| attn_mask=self_attn_mask, |
| attn_bias=self_attn_bias |
| ) |
| if self.self_attn_ln is not None: |
| x = self.self_attn_ln(x) |
| x = self.dropout_module(x) |
| x = self.residual_connection(x, residual) |
| if not self.normalize_before: |
| x = self.self_attn_layer_norm(x) |
|
|
| if self.encoder_attn is not None and encoder_out is not None: |
| residual = x |
| if self.normalize_before: |
| x = self.encoder_attn_layer_norm(x) |
| if prev_attn_state is not None: |
| prev_key, prev_value = prev_attn_state[:2] |
| saved_state: Dict[str, Optional[Tensor]] = { |
| "prev_key": prev_key, |
| "prev_value": prev_value, |
| } |
| if len(prev_attn_state) >= 3: |
| saved_state["prev_key_padding_mask"] = prev_attn_state[2] |
| assert incremental_state is not None |
| self.encoder_attn._set_input_buffer(incremental_state, saved_state) |
|
|
| x, attn = self.encoder_attn( |
| query=x, |
| key=encoder_out, |
| value=encoder_out, |
| key_padding_mask=encoder_padding_mask, |
| incremental_state=incremental_state, |
| static_kv=True, |
| need_weights=need_attn or (not self.training and self.need_attn), |
| need_head_weights=need_head_weights, |
| attn_bias=cross_attn_bias |
| ) |
| if self.cross_attn_ln is not None: |
| x = self.cross_attn_ln(x) |
| x = self.dropout_module(x) |
| x = self.residual_connection(x, residual) |
| if not self.normalize_before: |
| x = self.encoder_attn_layer_norm(x) |
|
|
| residual = x |
| if self.normalize_before: |
| x = self.final_layer_norm(x) |
|
|
| x = self.activation_fn(self.fc1(x)) |
| x = self.activation_dropout_module(x) |
| if self.ffn_layernorm is not None: |
| x = self.ffn_layernorm(x) |
| x = self.fc2(x) |
| x = self.dropout_module(x) |
| if self.w_resid is not None: |
| residual = torch.mul(self.w_resid, residual) |
| x = self.residual_connection(x, residual) |
| if not self.normalize_before: |
| x = self.final_layer_norm(x) |
| if self.onnx_trace and incremental_state is not None: |
| saved_state = self.self_attn._get_input_buffer(incremental_state) |
| assert saved_state is not None |
| if self_attn_padding_mask is not None: |
| self_attn_state = [ |
| saved_state["prev_key"], |
| saved_state["prev_value"], |
| saved_state["prev_key_padding_mask"], |
| ] |
| else: |
| self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] |
| return x, attn, self_attn_state |
| return x, attn, None |
|
|
| def make_generation_fast_(self, need_attn: bool = False, **kwargs): |
| self.need_attn = need_attn |
|
|
| def upgrade_state_dict_named(self, state_dict, name): |
| """ |
| Rename layer norm states from `...layer_norms.0.weight` to |
| `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to |
| `...final_layer_norm.weight` |
| """ |
| |
| layer_norm_map = { |
| "0": "self_attn_layer_norm", |
| "1": "encoder_attn_layer_norm", |
| "2": "final_layer_norm", |
| } |
| for old, new in layer_norm_map.items(): |
| for m in ("weight", "bias"): |
| k = "{}.layer_norms.{}.{}".format(name, old, m) |
| if k in state_dict: |
| state_dict[ |
| "{}.{}.{}".format(name, new, m) |
| ] = state_dict[k] |
| del state_dict[k] |
| if "{}.{}.{}".format(name, new, m) not in state_dict and "{}.{}".format(new, m) in self.state_dict(): |
| state_dict[ |
| "{}.{}.{}".format(name, new, m) |
| ] = self.state_dict()["{}.{}".format(new, m)] |
|
|
| prefix = name + "." if name != "" else "" |
| for param_name, param_tensor in self.state_dict().items(): |
| if (prefix + param_name) not in state_dict: |
| state_dict[prefix + param_name] = self.state_dict()[param_name] |
|
|