|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from fairseq import options, utils |
|
|
from fairseq.models import ( |
|
|
FairseqEncoder, |
|
|
FairseqIncrementalDecoder, |
|
|
FairseqEncoderDecoderModel, |
|
|
register_model, |
|
|
register_model_architecture, |
|
|
) |
|
|
from fairseq.modules import ( |
|
|
AdaptiveSoftmax, |
|
|
DynamicConv, |
|
|
FairseqDropout, |
|
|
LayerNorm, |
|
|
PositionalEmbedding, |
|
|
LightweightConv, |
|
|
MultiheadAttention, |
|
|
) |
|
|
|
|
|
|
|
|
@register_model('lightconv') |
|
|
class LightConvModel(FairseqEncoderDecoderModel): |
|
|
""" |
|
|
LightConv and DynamicConv model from `"Pay Less Attention with Lightweight and Dynamic Convolutions" (Wu, et al, 2019) |
|
|
<https://openreview.net/pdf?id=SkVhlh09tX>`_. |
|
|
To use LightConv please set ``--encoder-conv-type lightweight --decoder-conv-type lightweight`` |
|
|
To use DynamicConv please set ``--encoder-conv-type dynamic --decoder-conv-type dynamic`` |
|
|
|
|
|
Args: |
|
|
encoder (LightConvEncoder): the encoder |
|
|
decoder (LightConvDecoder): the decoder |
|
|
|
|
|
The LightConv model provides the following named architectures and |
|
|
command-line arguments: |
|
|
|
|
|
.. argparse:: |
|
|
:ref: fairseq.models.lightconv_parser |
|
|
:prog: |
|
|
""" |
|
|
|
|
|
@classmethod |
|
|
def hub_models(cls): |
|
|
|
|
|
|
|
|
def moses_subword(path): |
|
|
return { |
|
|
'path': path, |
|
|
'tokenizer': 'moses', |
|
|
'bpe': 'subword_nmt', |
|
|
} |
|
|
|
|
|
return { |
|
|
'lightconv.no_glu.iwslt14.de-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.lightconv.tar.gz'), |
|
|
'dynamicconv.no_glu.iwslt14.de-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.dynamicconv.tar.gz'), |
|
|
'lightconv.no_glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv.tar.gz'), |
|
|
'dynamicconv.no_glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv.tar.gz'), |
|
|
'lightconv.glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.gz'), |
|
|
'dynamicconv.glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.gz'), |
|
|
'lightconv.glu.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.gz'), |
|
|
'dynamicconv.glu.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.gz'), |
|
|
'lightconv.glu.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.lightconv-glu.tar.gz'), |
|
|
'dynamicconv.glu.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.dynamicconv-glu.tar.gz'), |
|
|
'lightconv.glu.wmt17.zh-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.lightconv-glu.tar.gz'), |
|
|
'dynamicconv.glu.wmt17.zh-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.dynamicconv-glu.tar.gz'), |
|
|
} |
|
|
|
|
|
|
|
|
def __init__(self, encoder, decoder): |
|
|
super().__init__(encoder, decoder) |
|
|
|
|
|
@staticmethod |
|
|
def add_args(parser): |
|
|
"""Add model-specific arguments to the parser.""" |
|
|
parser.add_argument('--dropout', type=float, metavar='D', |
|
|
help='dropout probability') |
|
|
parser.add_argument('--attention-dropout', type=float, metavar='D', |
|
|
help='dropout probability for attention weights') |
|
|
parser.add_argument('--relu-dropout', type=float, metavar='D', |
|
|
help='dropout probability after ReLU in FFN') |
|
|
parser.add_argument('--input-dropout', type=float, metavar='D', |
|
|
help='dropout probability of the inputs') |
|
|
parser.add_argument('--encoder-embed-path', type=str, metavar='STR', |
|
|
help='path to pre-trained encoder embedding') |
|
|
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', |
|
|
help='encoder embedding dimension') |
|
|
parser.add_argument('--encoder-conv-dim', type=int, metavar='N', |
|
|
help='encoder embedding dimension') |
|
|
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', |
|
|
help='encoder embedding dimension for FFN') |
|
|
parser.add_argument('--encoder-layers', type=int, metavar='N', |
|
|
help='num encoder layers') |
|
|
parser.add_argument('--encoder-attention-heads', type=int, metavar='N', |
|
|
help='num encoder attention heads or LightConv/DynamicConv heads') |
|
|
parser.add_argument('--encoder-normalize-before', action='store_true', |
|
|
help='apply layernorm before each encoder block') |
|
|
parser.add_argument('--encoder-learned-pos', action='store_true', |
|
|
help='use learned positional embeddings in the encoder') |
|
|
parser.add_argument('--decoder-embed-path', type=str, metavar='STR', |
|
|
help='path to pre-trained decoder embedding') |
|
|
parser.add_argument('--decoder-embed-dim', type=int, metavar='N', |
|
|
help='decoder embedding dimension') |
|
|
parser.add_argument('--decoder-conv-dim', type=int, metavar='N', |
|
|
help='decoder embedding dimension') |
|
|
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', |
|
|
help='decoder embedding dimension for FFN') |
|
|
parser.add_argument('--decoder-layers', type=int, metavar='N', |
|
|
help='num decoder layers') |
|
|
parser.add_argument('--decoder-attention-heads', type=int, metavar='N', |
|
|
help='num decoder attention heads or LightConv/DynamicConv heads') |
|
|
parser.add_argument('--decoder-learned-pos', action='store_true', |
|
|
help='use learned positional embeddings in the decoder') |
|
|
parser.add_argument('--decoder-normalize-before', action='store_true', |
|
|
help='apply layernorm before each decoder block') |
|
|
parser.add_argument('--share-decoder-input-output-embed', action='store_true', |
|
|
help='share decoder input and output embeddings') |
|
|
parser.add_argument('--share-all-embeddings', action='store_true', |
|
|
help='share encoder, decoder and output embeddings' |
|
|
' (requires shared dictionary and embed dim)') |
|
|
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', |
|
|
help='comma separated list of adaptive softmax cutoff points. ' |
|
|
'Must be used with adaptive_loss criterion'), |
|
|
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', |
|
|
help='sets adaptive softmax dropout for the tail projections') |
|
|
|
|
|
"""LightConv and DynamicConv arguments""" |
|
|
parser.add_argument('--encoder-kernel-size-list', type=lambda x: options.eval_str_list(x, int), |
|
|
help='list of kernel size (default: "[3,7,15,31,31,31,31]")') |
|
|
parser.add_argument('--decoder-kernel-size-list', type=lambda x: options.eval_str_list(x, int), |
|
|
help='list of kernel size (default: "[3,7,15,31,31,31]")') |
|
|
parser.add_argument('--encoder-glu', type=options.eval_bool, |
|
|
help='glu after in proj') |
|
|
parser.add_argument('--decoder-glu', type=options.eval_bool, |
|
|
help='glu after in proj') |
|
|
parser.add_argument('--encoder-conv-type', default='dynamic', type=str, |
|
|
choices=['dynamic', 'lightweight'], |
|
|
help='type of convolution') |
|
|
parser.add_argument('--decoder-conv-type', default='dynamic', type=str, |
|
|
choices=['dynamic', 'lightweight'], |
|
|
help='type of convolution') |
|
|
parser.add_argument('--weight-softmax', default=True, type=options.eval_bool) |
|
|
parser.add_argument('--weight-dropout', type=float, metavar='D', |
|
|
help='dropout probability for conv weights') |
|
|
|
|
|
@classmethod |
|
|
def build_model(cls, args, task): |
|
|
"""Build a new model instance.""" |
|
|
|
|
|
|
|
|
base_architecture(args) |
|
|
|
|
|
if not hasattr(args, 'max_source_positions'): |
|
|
args.max_source_positions = 1024 |
|
|
if not hasattr(args, 'max_target_positions'): |
|
|
args.max_target_positions = 1024 |
|
|
|
|
|
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary |
|
|
|
|
|
def build_embedding(dictionary, embed_dim, path=None): |
|
|
num_embeddings = len(dictionary) |
|
|
padding_idx = dictionary.pad() |
|
|
emb = Embedding(num_embeddings, embed_dim, padding_idx) |
|
|
|
|
|
if path: |
|
|
embed_dict = utils.parse_embedding(path) |
|
|
utils.load_embedding(embed_dict, dictionary, emb) |
|
|
return emb |
|
|
|
|
|
if args.share_all_embeddings: |
|
|
if src_dict != tgt_dict: |
|
|
raise RuntimeError('--share-all-embeddings requires a joined dictionary') |
|
|
if args.encoder_embed_dim != args.decoder_embed_dim: |
|
|
raise RuntimeError( |
|
|
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim') |
|
|
if args.decoder_embed_path and ( |
|
|
args.decoder_embed_path != args.encoder_embed_path): |
|
|
raise RuntimeError('--share-all-embeddings not compatible with --decoder-embed-path') |
|
|
encoder_embed_tokens = build_embedding( |
|
|
src_dict, args.encoder_embed_dim, args.encoder_embed_path |
|
|
) |
|
|
decoder_embed_tokens = encoder_embed_tokens |
|
|
args.share_decoder_input_output_embed = True |
|
|
else: |
|
|
encoder_embed_tokens = build_embedding( |
|
|
src_dict, args.encoder_embed_dim, args.encoder_embed_path |
|
|
) |
|
|
decoder_embed_tokens = build_embedding( |
|
|
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path |
|
|
) |
|
|
|
|
|
encoder = LightConvEncoder(args, src_dict, encoder_embed_tokens) |
|
|
decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens) |
|
|
return LightConvModel(encoder, decoder) |
|
|
|
|
|
|
|
|
class LightConvEncoder(FairseqEncoder): |
|
|
""" |
|
|
LightConv encoder consisting of *args.encoder_layers* layers. Each layer |
|
|
is a :class:`LightConvEncoderLayer`. |
|
|
|
|
|
Args: |
|
|
args (argparse.Namespace): parsed command-line arguments |
|
|
dictionary (~fairseq.data.Dictionary): encoding dictionary |
|
|
embed_tokens (torch.nn.Embedding): input embedding |
|
|
""" |
|
|
|
|
|
def __init__(self, args, dictionary, embed_tokens): |
|
|
super().__init__(dictionary) |
|
|
self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) |
|
|
|
|
|
embed_dim = embed_tokens.embedding_dim |
|
|
self.padding_idx = embed_tokens.padding_idx |
|
|
self.max_source_positions = args.max_source_positions |
|
|
|
|
|
self.embed_tokens = embed_tokens |
|
|
self.embed_scale = math.sqrt(embed_dim) |
|
|
self.embed_positions = PositionalEmbedding( |
|
|
args.max_source_positions, embed_dim, self.padding_idx, |
|
|
learned=args.encoder_learned_pos, |
|
|
) if not args.no_token_positional_embeddings else None |
|
|
|
|
|
self.layers = nn.ModuleList([]) |
|
|
self.layers.extend([ |
|
|
LightConvEncoderLayer(args, kernel_size=args.encoder_kernel_size_list[i]) |
|
|
for i in range(args.encoder_layers) |
|
|
]) |
|
|
self.register_buffer('version', torch.Tensor([2])) |
|
|
self.normalize = args.encoder_normalize_before |
|
|
if self.normalize: |
|
|
self.layer_norm = LayerNorm(embed_dim) |
|
|
|
|
|
def forward(self, src_tokens, **unused): |
|
|
""" |
|
|
Args: |
|
|
src_tokens (LongTensor): tokens in the source language of shape |
|
|
`(batch, src_len)` |
|
|
|
|
|
Returns: |
|
|
dict: |
|
|
- **encoder_out** (Tensor): the last encoder layer's output of |
|
|
shape `(src_len, batch, embed_dim)` |
|
|
- **encoder_padding_mask** (ByteTensor): the positions of |
|
|
padding elements of shape `(batch, src_len)` |
|
|
""" |
|
|
|
|
|
x = self.embed_scale * self.embed_tokens(src_tokens) |
|
|
if self.embed_positions is not None: |
|
|
x += self.embed_positions(src_tokens) |
|
|
x = self.dropout_module(x) |
|
|
|
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
|
|
|
|
|
encoder_padding_mask = src_tokens.eq(self.padding_idx) |
|
|
if not encoder_padding_mask.any(): |
|
|
encoder_padding_mask = None |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x, encoder_padding_mask) |
|
|
|
|
|
if self.normalize: |
|
|
x = self.layer_norm(x) |
|
|
|
|
|
return { |
|
|
'encoder_out': x, |
|
|
'encoder_padding_mask': encoder_padding_mask, |
|
|
} |
|
|
|
|
|
def reorder_encoder_out(self, encoder_out, new_order): |
|
|
""" |
|
|
Reorder encoder output according to *new_order*. |
|
|
|
|
|
Args: |
|
|
encoder_out: output from the ``forward()`` method |
|
|
new_order (LongTensor): desired order |
|
|
|
|
|
Returns: |
|
|
*encoder_out* rearranged according to *new_order* |
|
|
""" |
|
|
if encoder_out['encoder_out'] is not None: |
|
|
encoder_out['encoder_out'] = \ |
|
|
encoder_out['encoder_out'].index_select(1, new_order) |
|
|
if encoder_out['encoder_padding_mask'] is not None: |
|
|
encoder_out['encoder_padding_mask'] = \ |
|
|
encoder_out['encoder_padding_mask'].index_select(0, new_order) |
|
|
return encoder_out |
|
|
|
|
|
def max_positions(self): |
|
|
"""Maximum input length supported by the encoder.""" |
|
|
if self.embed_positions is None: |
|
|
return self.max_source_positions |
|
|
return min(self.max_source_positions, self.embed_positions.max_positions) |
|
|
|
|
|
|
|
|
class LightConvDecoder(FairseqIncrementalDecoder): |
|
|
""" |
|
|
LightConv decoder consisting of *args.decoder_layers* layers. Each layer |
|
|
is a :class:`LightConvDecoderLayer`. |
|
|
|
|
|
Args: |
|
|
args (argparse.Namespace): parsed command-line arguments |
|
|
dictionary (~fairseq.data.Dictionary): decoding dictionary |
|
|
embed_tokens (torch.nn.Embedding): output embedding |
|
|
no_encoder_attn (bool, optional): whether to attend to encoder outputs. |
|
|
Default: ``False`` |
|
|
""" |
|
|
|
|
|
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True): |
|
|
super().__init__(dictionary) |
|
|
self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) |
|
|
self.share_input_output_embed = args.share_decoder_input_output_embed |
|
|
|
|
|
input_embed_dim = embed_tokens.embedding_dim |
|
|
embed_dim = args.decoder_embed_dim |
|
|
output_embed_dim = args.decoder_output_dim |
|
|
|
|
|
padding_idx = embed_tokens.padding_idx |
|
|
self.max_target_positions = args.max_target_positions |
|
|
|
|
|
self.embed_tokens = embed_tokens |
|
|
self.embed_scale = math.sqrt(embed_dim) |
|
|
|
|
|
self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None |
|
|
|
|
|
self.embed_positions = PositionalEmbedding( |
|
|
args.max_target_positions, embed_dim, padding_idx, |
|
|
learned=args.decoder_learned_pos, |
|
|
) if not args.no_token_positional_embeddings else None |
|
|
|
|
|
self.layers = nn.ModuleList([]) |
|
|
self.layers.extend([ |
|
|
LightConvDecoderLayer(args, no_encoder_attn, kernel_size=args.decoder_kernel_size_list[i]) |
|
|
for i in range(args.decoder_layers) |
|
|
]) |
|
|
|
|
|
self.adaptive_softmax = None |
|
|
|
|
|
self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \ |
|
|
if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None |
|
|
|
|
|
if args.adaptive_softmax_cutoff is not None: |
|
|
self.adaptive_softmax = AdaptiveSoftmax( |
|
|
len(dictionary), |
|
|
output_embed_dim, |
|
|
options.eval_str_list(args.adaptive_softmax_cutoff, type=int), |
|
|
dropout=args.adaptive_softmax_dropout, |
|
|
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, |
|
|
factor=args.adaptive_softmax_factor, |
|
|
tie_proj=args.tie_adaptive_proj, |
|
|
) |
|
|
elif not self.share_input_output_embed: |
|
|
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim)) |
|
|
nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5) |
|
|
self.register_buffer('version', torch.Tensor([2])) |
|
|
self.normalize = args.decoder_normalize_before and final_norm |
|
|
if self.normalize: |
|
|
self.layer_norm = LayerNorm(embed_dim) |
|
|
|
|
|
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): |
|
|
""" |
|
|
Args: |
|
|
prev_output_tokens (LongTensor): previous decoder outputs of shape |
|
|
`(batch, tgt_len)`, for teacher forcing |
|
|
encoder_out (Tensor, optional): output from the encoder, used for |
|
|
encoder-side attention |
|
|
incremental_state (dict): dictionary used for storing state during |
|
|
:ref:`Incremental decoding` |
|
|
|
|
|
Returns: |
|
|
tuple: |
|
|
- the last decoder layer's output of shape `(batch, tgt_len, |
|
|
vocab)` |
|
|
- the last decoder layer's attention weights of shape `(batch, |
|
|
tgt_len, src_len)` |
|
|
""" |
|
|
|
|
|
positions = self.embed_positions( |
|
|
prev_output_tokens, |
|
|
incremental_state=incremental_state, |
|
|
) if self.embed_positions is not None else None |
|
|
|
|
|
if incremental_state is not None: |
|
|
prev_output_tokens = prev_output_tokens[:, -1:] |
|
|
if positions is not None: |
|
|
positions = positions[:, -1:] |
|
|
|
|
|
|
|
|
x = self.embed_scale * self.embed_tokens(prev_output_tokens) |
|
|
|
|
|
if self.project_in_dim is not None: |
|
|
x = self.project_in_dim(x) |
|
|
|
|
|
if positions is not None: |
|
|
x += positions |
|
|
x = self.dropout_module(x) |
|
|
|
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
attn = None |
|
|
|
|
|
inner_states = [x] |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
x, attn = layer( |
|
|
x, |
|
|
encoder_out['encoder_out'] if encoder_out is not None else None, |
|
|
encoder_out['encoder_padding_mask'] if encoder_out is not None else None, |
|
|
incremental_state, |
|
|
) |
|
|
inner_states.append(x) |
|
|
|
|
|
if self.normalize: |
|
|
x = self.layer_norm(x) |
|
|
|
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
|
|
if self.project_out_dim is not None: |
|
|
x = self.project_out_dim(x) |
|
|
|
|
|
if self.adaptive_softmax is None: |
|
|
|
|
|
if self.share_input_output_embed: |
|
|
x = F.linear(x, self.embed_tokens.weight) |
|
|
else: |
|
|
x = F.linear(x, self.embed_out) |
|
|
|
|
|
return x, {'attn': attn, 'inner_states': inner_states} |
|
|
|
|
|
def max_positions(self): |
|
|
"""Maximum output length supported by the decoder.""" |
|
|
if self.embed_positions is None: |
|
|
return self.max_target_positions |
|
|
return min(self.max_target_positions, self.embed_positions.max_positions) |
|
|
|
|
|
def buffered_future_mask(self, tensor): |
|
|
dim = tensor.size(0) |
|
|
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: |
|
|
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) |
|
|
if self._future_mask.size(0) < dim: |
|
|
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) |
|
|
return self._future_mask[:dim, :dim] |
|
|
|
|
|
|
|
|
class LightConvEncoderLayer(nn.Module): |
|
|
"""Encoder layer block. |
|
|
|
|
|
Args: |
|
|
args (argparse.Namespace): parsed command-line arguments |
|
|
kernel_size: kernel size of the convolution |
|
|
""" |
|
|
|
|
|
def __init__(self, args, kernel_size=0): |
|
|
super().__init__() |
|
|
self.embed_dim = args.encoder_embed_dim |
|
|
self.conv_dim = args.encoder_conv_dim |
|
|
padding_l = kernel_size // 2 if kernel_size % 2 == 1 else ((kernel_size - 1) // 2, kernel_size // 2) |
|
|
|
|
|
if args.encoder_glu: |
|
|
self.linear1 = Linear(self.embed_dim, 2*self.conv_dim) |
|
|
self.act = nn.GLU() |
|
|
else: |
|
|
self.linear1 = Linear(self.embed_dim, self.conv_dim) |
|
|
self.act = None |
|
|
if args.encoder_conv_type == 'lightweight': |
|
|
self.conv = LightweightConv(self.conv_dim, kernel_size, padding_l=padding_l, |
|
|
weight_softmax=args.weight_softmax, |
|
|
num_heads=args.encoder_attention_heads, |
|
|
weight_dropout=args.weight_dropout) |
|
|
elif args.encoder_conv_type == 'dynamic': |
|
|
self.conv = DynamicConv(self.conv_dim, kernel_size, padding_l=padding_l, |
|
|
weight_softmax=args.weight_softmax, |
|
|
num_heads=args.encoder_attention_heads, |
|
|
weight_dropout=args.weight_dropout) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
self.linear2 = Linear(self.conv_dim, self.embed_dim) |
|
|
|
|
|
self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) |
|
|
self.relu_dropout_module = FairseqDropout(args.relu_dropout, module_name=self.__class__.__name__) |
|
|
self.input_dropout_module = FairseqDropout(args.input_dropout, module_name=self.__class__.__name__) |
|
|
self.normalize_before = args.encoder_normalize_before |
|
|
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) |
|
|
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) |
|
|
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for _ in range(2)]) |
|
|
|
|
|
def forward(self, x, encoder_padding_mask): |
|
|
""" |
|
|
Args: |
|
|
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` |
|
|
encoder_padding_mask (ByteTensor): binary ByteTensor of shape |
|
|
`(batch, src_len)` where padding elements are indicated by ``1``. |
|
|
|
|
|
Returns: |
|
|
encoded output of shape `(batch, src_len, embed_dim)` |
|
|
""" |
|
|
residual = x |
|
|
x = self.maybe_layer_norm(0, x, before=True) |
|
|
x = self.input_dropout_module(x) |
|
|
x = self.linear1(x) |
|
|
if self.act is not None: |
|
|
x = self.act(x) |
|
|
if encoder_padding_mask is not None: |
|
|
x = x.masked_fill(encoder_padding_mask.transpose(0, 1).unsqueeze(2), 0) |
|
|
x = self.conv(x) |
|
|
x = self.linear2(x) |
|
|
x = self.dropout_module(x) |
|
|
x = residual + x |
|
|
x = self.maybe_layer_norm(0, x, after=True) |
|
|
|
|
|
residual = x |
|
|
x = self.maybe_layer_norm(1, x, before=True) |
|
|
x = F.relu(self.fc1(x)) |
|
|
x = self.relu_dropout_module(x) |
|
|
x = self.fc2(x) |
|
|
x = self.dropout_module(x) |
|
|
x = residual + x |
|
|
x = self.maybe_layer_norm(1, x, after=True) |
|
|
return x |
|
|
|
|
|
def maybe_layer_norm(self, i, x, before=False, after=False): |
|
|
assert before ^ after |
|
|
if after ^ self.normalize_before: |
|
|
return self.layer_norms[i](x) |
|
|
else: |
|
|
return x |
|
|
|
|
|
def extra_repr(self): |
|
|
return 'dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}'.format( |
|
|
self.dropout_module.p, self.relu_dropout_module.p, self.input_dropout_module.p, self.normalize_before) |
|
|
|
|
|
|
|
|
class LightConvDecoderLayer(nn.Module): |
|
|
"""Decoder layer block. |
|
|
|
|
|
Args: |
|
|
args (argparse.Namespace): parsed command-line arguments |
|
|
no_encoder_attn (bool, optional): whether to attend to encoder outputs. |
|
|
Default: ``False`` |
|
|
kernel_size: kernel size of the convolution |
|
|
""" |
|
|
|
|
|
def __init__(self, args, no_encoder_attn=False, kernel_size=0): |
|
|
super().__init__() |
|
|
self.embed_dim = args.decoder_embed_dim |
|
|
self.conv_dim = args.decoder_conv_dim |
|
|
if args.decoder_glu: |
|
|
self.linear1 = Linear(self.embed_dim, 2*self.conv_dim) |
|
|
self.act = nn.GLU() |
|
|
else: |
|
|
self.linear1 = Linear(self.embed_dim, self.conv_dim) |
|
|
self.act = None |
|
|
if args.decoder_conv_type == 'lightweight': |
|
|
self.conv = LightweightConv(self.conv_dim, kernel_size, padding_l=kernel_size-1, |
|
|
weight_softmax=args.weight_softmax, |
|
|
num_heads=args.decoder_attention_heads, |
|
|
weight_dropout=args.weight_dropout) |
|
|
elif args.decoder_conv_type == 'dynamic': |
|
|
self.conv = DynamicConv(self.conv_dim, kernel_size, padding_l=kernel_size-1, |
|
|
weight_softmax=args.weight_softmax, |
|
|
num_heads=args.decoder_attention_heads, |
|
|
weight_dropout=args.weight_dropout) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
self.linear2 = Linear(self.conv_dim, self.embed_dim) |
|
|
|
|
|
self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) |
|
|
self.relu_dropout_module = FairseqDropout(args.relu_dropout, module_name=self.__class__.__name__) |
|
|
self.input_dropout_module = FairseqDropout(args.input_dropout, module_name=self.__class__.__name__) |
|
|
self.normalize_before = args.decoder_normalize_before |
|
|
|
|
|
self.conv_layer_norm = LayerNorm(self.embed_dim) |
|
|
|
|
|
if no_encoder_attn: |
|
|
self.encoder_attn = None |
|
|
self.encoder_attn_layer_norm = None |
|
|
else: |
|
|
self.encoder_attn = MultiheadAttention( |
|
|
self.embed_dim, args.decoder_attention_heads, |
|
|
dropout=args.attention_dropout, encoder_decoder_attention=True, |
|
|
) |
|
|
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) |
|
|
|
|
|
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) |
|
|
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) |
|
|
|
|
|
self.final_layer_norm = LayerNorm(self.embed_dim) |
|
|
self.need_attn = True |
|
|
|
|
|
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, |
|
|
prev_conv_state=None, prev_attn_state=None, conv_mask=None, |
|
|
conv_padding_mask=None): |
|
|
""" |
|
|
Args: |
|
|
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` |
|
|
encoder_padding_mask (ByteTensor): binary ByteTensor of shape |
|
|
`(batch, src_len)` where padding elements are indicated by ``1``. |
|
|
|
|
|
Returns: |
|
|
encoded output of shape `(batch, src_len, embed_dim)` |
|
|
""" |
|
|
residual = x |
|
|
x = self.maybe_layer_norm(self.conv_layer_norm, x, before=True) |
|
|
if prev_conv_state is not None: |
|
|
if incremental_state is None: |
|
|
incremental_state = {} |
|
|
self.conv._set_input_buffer(incremental_state, prev_conv_state) |
|
|
x = self.input_dropout_module(x) |
|
|
x = self.linear1(x) |
|
|
if self.act is not None: |
|
|
x = self.act(x) |
|
|
x = self.conv(x, incremental_state=incremental_state) |
|
|
x = self.linear2(x) |
|
|
x = self.dropout_module(x) |
|
|
x = residual + x |
|
|
x = self.maybe_layer_norm(self.conv_layer_norm, x, after=True) |
|
|
|
|
|
attn = None |
|
|
if self.encoder_attn is not None: |
|
|
residual = x |
|
|
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) |
|
|
if prev_attn_state is not None: |
|
|
if incremental_state is None: |
|
|
incremental_state = {} |
|
|
prev_key, prev_value = prev_attn_state |
|
|
saved_state = {"prev_key": prev_key, "prev_value": prev_value} |
|
|
self.encoder_attn._set_input_buffer(incremental_state, saved_state) |
|
|
x, attn = self.encoder_attn( |
|
|
query=x, |
|
|
key=encoder_out, |
|
|
value=encoder_out, |
|
|
key_padding_mask=encoder_padding_mask, |
|
|
incremental_state=incremental_state, |
|
|
static_kv=True, |
|
|
need_weights=(not self.training and self.need_attn), |
|
|
) |
|
|
x = self.dropout_module(x) |
|
|
x = residual + x |
|
|
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) |
|
|
|
|
|
residual = x |
|
|
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) |
|
|
x = F.relu(self.fc1(x)) |
|
|
x = self.relu_dropout_module(x) |
|
|
x = self.fc2(x) |
|
|
x = self.dropout_module(x) |
|
|
x = residual + x |
|
|
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) |
|
|
return x, attn |
|
|
|
|
|
def maybe_layer_norm(self, layer_norm, x, before=False, after=False): |
|
|
assert before ^ after |
|
|
if after ^ self.normalize_before: |
|
|
return layer_norm(x) |
|
|
else: |
|
|
return x |
|
|
|
|
|
def make_generation_fast_(self, need_attn=False, **kwargs): |
|
|
self.need_attn = need_attn |
|
|
|
|
|
def extra_repr(self): |
|
|
return 'dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}'.format( |
|
|
self.dropout_module.p, self.relu_dropout_module.p, self.input_dropout_module.p, self.normalize_before) |
|
|
|
|
|
|
|
|
def Embedding(num_embeddings, embedding_dim, padding_idx): |
|
|
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) |
|
|
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) |
|
|
nn.init.constant_(m.weight[padding_idx], 0) |
|
|
return m |
|
|
|
|
|
|
|
|
def Linear(in_features, out_features, bias=True): |
|
|
m = nn.Linear(in_features, out_features, bias) |
|
|
nn.init.xavier_uniform_(m.weight) |
|
|
if bias: |
|
|
nn.init.constant_(m.bias, 0.) |
|
|
return m |
|
|
|
|
|
|
|
|
@register_model_architecture('lightconv', 'lightconv') |
|
|
def base_architecture(args): |
|
|
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) |
|
|
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) |
|
|
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048) |
|
|
args.encoder_layers = getattr(args, 'encoder_layers', 7) |
|
|
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8) |
|
|
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) |
|
|
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False) |
|
|
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) |
|
|
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim) |
|
|
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim) |
|
|
args.decoder_layers = getattr(args, 'decoder_layers', 6) |
|
|
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) |
|
|
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False) |
|
|
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) |
|
|
args.attention_dropout = getattr(args, 'attention_dropout', 0.) |
|
|
args.relu_dropout = getattr(args, 'relu_dropout', 0.) |
|
|
args.dropout = getattr(args, 'dropout', 0.1) |
|
|
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) |
|
|
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) |
|
|
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) |
|
|
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) |
|
|
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) |
|
|
|
|
|
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) |
|
|
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) |
|
|
|
|
|
args.encoder_conv_dim = getattr(args, 'encoder_conv_dim', args.encoder_embed_dim) |
|
|
args.decoder_conv_dim = getattr(args, 'decoder_conv_dim', args.decoder_embed_dim) |
|
|
|
|
|
args.encoder_kernel_size_list = getattr(args, 'encoder_kernel_size_list', [3, 7, 15, 31, 31, 31, 31]) |
|
|
args.decoder_kernel_size_list = getattr(args, 'decoder_kernel_size_list', [3, 7, 15, 31, 31, 31]) |
|
|
if len(args.encoder_kernel_size_list) == 1: |
|
|
args.encoder_kernel_size_list = args.encoder_kernel_size_list * args.encoder_layers |
|
|
if len(args.decoder_kernel_size_list) == 1: |
|
|
args.decoder_kernel_size_list = args.decoder_kernel_size_list * args.decoder_layers |
|
|
assert len(args.encoder_kernel_size_list) == args.encoder_layers, "encoder_kernel_size_list doesn't match encoder_layers" |
|
|
assert len(args.decoder_kernel_size_list) == args.decoder_layers, "decoder_kernel_size_list doesn't match decoder_layers" |
|
|
args.encoder_glu = getattr(args, 'encoder_glu', True) |
|
|
args.decoder_glu = getattr(args, 'decoder_glu', True) |
|
|
args.input_dropout = getattr(args, 'input_dropout', 0.1) |
|
|
args.weight_dropout = getattr(args, 'weight_dropout', args.attention_dropout) |
|
|
|
|
|
|
|
|
@register_model_architecture('lightconv', 'lightconv_iwslt_de_en') |
|
|
def lightconv_iwslt_de_en(args): |
|
|
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) |
|
|
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024) |
|
|
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4) |
|
|
args.encoder_layers = getattr(args, 'encoder_layers', 7) |
|
|
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) |
|
|
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024) |
|
|
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4) |
|
|
args.decoder_layers = getattr(args, 'decoder_layers', 6) |
|
|
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) |
|
|
args.weight_dropout = getattr(args, 'weight_dropout', 0.1) |
|
|
args.encoder_glu = getattr(args, 'encoder_glu', False) |
|
|
args.decoder_glu = getattr(args, 'decoder_glu', False) |
|
|
args.input_dropout = getattr(args, 'input_dropout', 0.0) |
|
|
base_architecture(args) |
|
|
|
|
|
|
|
|
@register_model_architecture('lightconv', 'lightconv_wmt_en_de') |
|
|
def lightconv_wmt_en_de(args): |
|
|
base_architecture(args) |
|
|
|
|
|
|
|
|
@register_model_architecture('lightconv', 'lightconv_wmt_en_de_big') |
|
|
def lightconv_wmt_en_de_big(args): |
|
|
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) |
|
|
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) |
|
|
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) |
|
|
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) |
|
|
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) |
|
|
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) |
|
|
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096) |
|
|
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) |
|
|
args.dropout = getattr(args, 'dropout', 0.3) |
|
|
base_architecture(args) |
|
|
|
|
|
|
|
|
@register_model_architecture('lightconv', 'lightconv_wmt_en_fr_big') |
|
|
def lightconv_wmt_en_fr_big(args): |
|
|
args.dropout = getattr(args, 'dropout', 0.1) |
|
|
lightconv_wmt_en_de_big(args) |
|
|
|
|
|
|
|
|
@register_model_architecture('lightconv', 'lightconv_wmt_zh_en_big') |
|
|
def lightconv_wmt_zh_en_big(args): |
|
|
args.dropout = getattr(args, 'dropout', 0.2) |
|
|
args.attention_dropout = getattr(args, 'attention_dropout', 0.2) |
|
|
args.weight_dropout = getattr(args, 'weight_dropout', 0.2) |
|
|
lightconv_wmt_en_de_big(args) |
|
|
|