| |
| |
| |
| |
|
|
| import logging |
| import math |
| import os |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from fairseq import checkpoint_utils |
| from fairseq.incremental_decoding_utils import with_incremental_state |
| from fairseq.models import ( |
| CompositeEncoder, |
| FairseqDecoder, |
| FairseqEncoder, |
| FairseqEncoderDecoderModel, |
| register_model, |
| register_model_architecture, |
| ) |
| from fairseq.modules import ( |
| DownsampledMultiHeadAttention, |
| FairseqDropout, |
| GradMultiply, |
| LayerNorm, |
| LearnedPositionalEmbedding, |
| LinearizedConvolution, |
| ) |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @register_model("fconv_self_att") |
| class FConvModelSelfAtt(FairseqEncoderDecoderModel): |
| @classmethod |
| def hub_models(cls): |
| return { |
| "conv.stories.pretrained": { |
| "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz", |
| "checkpoint_file": "pretrained_checkpoint.pt", |
| "tokenizer": "nltk", |
| }, |
| "conv.stories": { |
| "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz", |
| "checkpoint_file": "fusion_checkpoint.pt", |
| "tokenizer": "nltk", |
| "pretrained": "True", |
| "pretrained_checkpoint": "./pretrained_checkpoint.pt", |
| }, |
| |
| "data.stories": "https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2", |
| } |
|
|
| def __init__(self, encoder, decoder, pretrained_encoder=None): |
| super().__init__(encoder, decoder) |
| self.encoder.num_attention_layers = sum( |
| layer is not None for layer in decoder.attention |
| ) |
| self.pretrained_encoder = pretrained_encoder |
| if self.pretrained_encoder is None: |
| encoders = {"encoder": encoder} |
| else: |
| encoders = {"encoder": encoder, "pretrained": self.pretrained_encoder} |
| |
| |
| self.encoder = CompositeEncoder(encoders) |
|
|
| @staticmethod |
| def add_args(parser): |
| """Add model-specific arguments to the parser.""" |
| |
| parser.add_argument('--dropout', type=float, metavar='D', |
| help='dropout probability') |
| parser.add_argument('--encoder-embed-dim', type=int, metavar='N', |
| help='encoder embedding dimension') |
| parser.add_argument('--encoder-layers', type=str, metavar='EXPR', |
| help='encoder layers [(dim, kernel_size), ...]') |
| parser.add_argument('--decoder-embed-dim', type=int, metavar='N', |
| help='decoder embedding dimension') |
| parser.add_argument('--decoder-layers', type=str, metavar='EXPR', |
| help='decoder layers [(dim, kernel_size), ...]') |
| parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', |
| help='decoder output embedding dimension') |
| parser.add_argument('--decoder-attention', type=str, metavar='EXPR', |
| help='decoder attention [True, ...]') |
| parser.add_argument('--self-attention', type=str, metavar='EXPR', |
| help='decoder self-attention layers, ex: [True] + [False]*5') |
| parser.add_argument('--multihead-attention-nheads', type=int, |
| help='Number of heads to use in attention') |
| parser.add_argument('--multihead-self-attention-nheads', type=int, |
| help='Number of heads to use in self-attention') |
| parser.add_argument('--encoder-attention', type=str, metavar='EXPR', |
| help='encoder attention [True, ...]') |
| parser.add_argument('--encoder-attention-nheads', type=int, |
| help='Number of heads to use in encoder attention') |
| parser.add_argument('--project-input', type=str, metavar='EXPR', |
| help='Use projections in self-attention [True, ...]') |
| parser.add_argument('--gated-attention', type=str, metavar='EXPR', |
| help='Use GLU layers in self-attention projections [True, ...]') |
| parser.add_argument('--downsample', type=str, metavar='EXPR', |
| help='Use downsampling in self-attention [True, ...]') |
| parser.add_argument('--pretrained-checkpoint', metavar='DIR', |
| help='path to load checkpoint from pretrained model') |
| parser.add_argument('--pretrained', type=str, metavar='EXPR', |
| help='use pretrained model when training [True, ...]') |
| |
|
|
| @classmethod |
| def build_model(cls, args, task): |
| """Build a new model instance.""" |
| trained_encoder, trained_decoder = None, None |
| pretrained = eval(args.pretrained) |
| if pretrained: |
| logger.info("loading pretrained model") |
| if not os.path.exists(args.pretrained_checkpoint): |
| new_pretrained_checkpoint = os.path.join( |
| args.data, args.pretrained_checkpoint |
| ) |
| if os.path.exists(new_pretrained_checkpoint): |
| args.pretrained_checkpoint = new_pretrained_checkpoint |
| trained_model = checkpoint_utils.load_model_ensemble( |
| filenames=[args.pretrained_checkpoint], |
| task=task, |
| )[0][0] |
| trained_decoder = list(trained_model.children())[1] |
| trained_encoder = list(trained_model.children())[0] |
|
|
| |
| for param in trained_decoder.parameters(): |
| param.requires_grad = False |
| for param in trained_encoder.parameters(): |
| param.requires_grad = False |
|
|
| encoder = FConvEncoder( |
| task.source_dictionary, |
| embed_dim=args.encoder_embed_dim, |
| convolutions=eval(args.encoder_layers), |
| dropout=args.dropout, |
| max_positions=args.max_source_positions, |
| attention=eval(args.encoder_attention), |
| attention_nheads=args.encoder_attention_nheads, |
| ) |
|
|
| decoder = FConvDecoder( |
| task.target_dictionary, |
| embed_dim=args.decoder_embed_dim, |
| convolutions=eval(args.decoder_layers), |
| out_embed_dim=args.decoder_out_embed_dim, |
| attention=eval(args.decoder_attention), |
| dropout=args.dropout, |
| max_positions=args.max_target_positions, |
| selfattention=eval(args.self_attention), |
| attention_nheads=args.multihead_attention_nheads, |
| selfattention_nheads=args.multihead_self_attention_nheads, |
| project_input=eval(args.project_input), |
| gated_attention=eval(args.gated_attention), |
| downsample=eval(args.downsample), |
| pretrained=pretrained, |
| trained_decoder=trained_decoder, |
| ) |
| model = FConvModelSelfAtt(encoder, decoder, trained_encoder) |
|
|
| return model |
|
|
| @property |
| def pretrained(self): |
| return self.pretrained_encoder is not None |
|
|
|
|
| class FConvEncoder(FairseqEncoder): |
| """Convolutional encoder""" |
|
|
| def __init__( |
| self, |
| dictionary, |
| embed_dim=512, |
| max_positions=1024, |
| convolutions=((512, 3),) * 20, |
| dropout=0.1, |
| attention=False, |
| attention_nheads=1, |
| ): |
| super().__init__(dictionary) |
| self.dropout_module = FairseqDropout( |
| dropout, module_name=self.__class__.__name__ |
| ) |
| self.num_attention_layers = None |
|
|
| num_embeddings = len(dictionary) |
| self.padding_idx = dictionary.pad() |
| self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) |
| self.embed_positions = PositionalEmbedding( |
| max_positions, |
| embed_dim, |
| self.padding_idx, |
| ) |
|
|
| def expand_bool_array(val): |
| if isinstance(val, bool): |
| |
| return [val] * len(convolutions) |
| return val |
|
|
| attention = expand_bool_array(attention) |
|
|
| in_channels = convolutions[0][0] |
| self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) |
| self.projections = nn.ModuleList() |
| self.convolutions = nn.ModuleList() |
| self.attention = nn.ModuleList() |
| self.attproj = nn.ModuleList() |
| for i, (out_channels, kernel_size) in enumerate(convolutions): |
| self.projections.append( |
| Linear(in_channels, out_channels) |
| if in_channels != out_channels |
| else None |
| ) |
| self.convolutions.append( |
| ConvTBC(in_channels, out_channels * 2, kernel_size, dropout=dropout) |
| ) |
|
|
| self.attention.append( |
| SelfAttention(out_channels, embed_dim, attention_nheads) |
| if attention[i] |
| else None |
| ) |
| in_channels = out_channels |
|
|
| self.fc2 = Linear(in_channels, embed_dim) |
|
|
| def forward(self, src_tokens, src_lengths): |
| |
| x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens) |
| x = self.dropout_module(x) |
| input_embedding = x.transpose(0, 1) |
|
|
| |
| x = self.fc1(x) |
|
|
| encoder_padding_mask = src_tokens.eq(self.padding_idx).t() |
| if not encoder_padding_mask.any(): |
| encoder_padding_mask = None |
|
|
| |
| x = x.transpose(0, 1) |
|
|
| |
| for proj, conv, attention in zip( |
| self.projections, self.convolutions, self.attention |
| ): |
| residual = x if proj is None else proj(x) |
|
|
| if encoder_padding_mask is not None: |
| x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) |
|
|
| x = self.dropout_module(x) |
| padding_l = (conv.kernel_size[0] - 1) // 2 |
| padding_r = conv.kernel_size[0] // 2 |
| x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r)) |
| x = conv(x) |
| x = F.glu(x, dim=2) |
| if attention is not None: |
| x = attention(x) |
| x = (x + residual) * math.sqrt(0.5) |
|
|
| |
| x = x.transpose(1, 0) |
|
|
| |
| x = self.fc2(x) |
|
|
| if encoder_padding_mask is not None: |
| encoder_padding_mask = encoder_padding_mask.t() |
| x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) |
|
|
| |
| x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers)) |
|
|
| |
| y = (x + input_embedding.transpose(0, 1)) * math.sqrt(0.5) |
|
|
| return { |
| "encoder_out": (x, y), |
| "encoder_padding_mask": encoder_padding_mask, |
| } |
|
|
| def reorder_encoder_out(self, encoder_out, new_order): |
| encoder_out["encoder_out"] = tuple( |
| eo.index_select(0, new_order) for eo in encoder_out["encoder_out"] |
| ) |
|
|
| if encoder_out["encoder_padding_mask"] is not None: |
| encoder_out["encoder_padding_mask"] = encoder_out[ |
| "encoder_padding_mask" |
| ].index_select(0, new_order) |
|
|
| if "pretrained" in encoder_out: |
| encoder_out["pretrained"]["encoder_out"] = tuple( |
| eo.index_select(0, new_order) |
| for eo in encoder_out["pretrained"]["encoder_out"] |
| ) |
|
|
| return encoder_out |
|
|
| def max_positions(self): |
| """Maximum input length supported by the encoder.""" |
| return self.embed_positions.max_positions |
|
|
|
|
| @with_incremental_state |
| class FConvDecoder(FairseqDecoder): |
| """Convolutional decoder""" |
|
|
| def __init__( |
| self, |
| dictionary, |
| embed_dim=512, |
| out_embed_dim=256, |
| max_positions=1024, |
| convolutions=((512, 3),) * 8, |
| attention=True, |
| dropout=0.1, |
| selfattention=False, |
| attention_nheads=1, |
| selfattention_nheads=1, |
| project_input=False, |
| gated_attention=False, |
| downsample=False, |
| pretrained=False, |
| trained_decoder=None, |
| ): |
| super().__init__(dictionary) |
| self.register_buffer("version", torch.Tensor([2])) |
| self.pretrained = pretrained |
| self.pretrained_decoder = trained_decoder |
| self.dropout_module = FairseqDropout( |
| dropout, module_name=self.__class__.__name__ |
| ) |
| self.need_attn = True |
| in_channels = convolutions[0][0] |
|
|
| def expand_bool_array(val): |
| if isinstance(val, bool): |
| |
| return [val] * len(convolutions) |
| return val |
|
|
| attention = expand_bool_array(attention) |
| selfattention = expand_bool_array(selfattention) |
|
|
| if not isinstance(attention, list) or len(attention) != len(convolutions): |
| raise ValueError( |
| "Attention is expected to be a list of booleans of " |
| "length equal to the number of layers." |
| ) |
|
|
| num_embeddings = len(dictionary) |
| padding_idx = dictionary.pad() |
| self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) |
|
|
| self.embed_positions = PositionalEmbedding( |
| max_positions, |
| embed_dim, |
| padding_idx, |
| ) |
|
|
| self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) |
| self.projections = nn.ModuleList() |
| self.convolutions = nn.ModuleList() |
| self.attention = nn.ModuleList() |
| self.selfattention = nn.ModuleList() |
| self.attproj = nn.ModuleList() |
| for i, (out_channels, kernel_size) in enumerate(convolutions): |
| self.projections.append( |
| Linear(in_channels, out_channels) |
| if in_channels != out_channels |
| else None |
| ) |
| self.convolutions.append( |
| LinearizedConv1d( |
| in_channels, |
| out_channels * 2, |
| kernel_size, |
| padding=(kernel_size - 1), |
| dropout=dropout, |
| ) |
| ) |
|
|
| self.attention.append( |
| DownsampledMultiHeadAttention( |
| out_channels, |
| embed_dim, |
| attention_nheads, |
| project_input=project_input, |
| gated=False, |
| downsample=False, |
| ) |
| if attention[i] |
| else None |
| ) |
|
|
| self.attproj.append( |
| Linear(out_channels, embed_dim, dropout=dropout) |
| if attention[i] |
| else None |
| ) |
| self.selfattention.append( |
| SelfAttention( |
| out_channels, |
| embed_dim, |
| selfattention_nheads, |
| project_input=project_input, |
| gated=gated_attention, |
| downsample=downsample, |
| ) |
| if selfattention[i] |
| else None |
| ) |
| in_channels = out_channels |
|
|
| self.fc2 = Linear(in_channels, out_embed_dim) |
| self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) |
|
|
| |
| if self.pretrained: |
| |
| self.gate1 = nn.Sequential( |
| Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid() |
| ) |
| self.gate2 = nn.Sequential( |
| Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid() |
| ) |
| |
| self.joining = nn.Sequential( |
| Linear(out_embed_dim * 2, out_embed_dim * 2), |
| LayerNorm(out_embed_dim * 2), |
| nn.GLU(), |
| Linear(out_embed_dim, out_embed_dim * 2), |
| LayerNorm(out_embed_dim * 2), |
| nn.GLU(), |
| Linear(out_embed_dim, out_embed_dim), |
| LayerNorm(out_embed_dim), |
| ) |
| |
| |
| |
| self.pretrained_outputs = {} |
|
|
| def save_output(): |
| def hook(a, b, output): |
| self.pretrained_outputs["out"] = output |
|
|
| return hook |
|
|
| self.pretrained_decoder.fc2.register_forward_hook(save_output()) |
|
|
| def forward(self, prev_output_tokens, encoder_out): |
| trained_encoder_out = encoder_out["pretrained"] if self.pretrained else None |
| encoder_out = encoder_out["encoder"]["encoder_out"] |
|
|
| encoder_a, encoder_b = self._split_encoder_out(encoder_out) |
|
|
| |
| positions = self.embed_positions(prev_output_tokens) |
|
|
| |
| x = self.embed_tokens(prev_output_tokens) + positions |
| x = self.dropout_module(x) |
| target_embedding = x.transpose(0, 1) |
|
|
| |
| x = self.fc1(x) |
|
|
| |
| x = x.transpose(0, 1) |
|
|
| |
| avg_attn_scores = None |
| for proj, conv, attention, selfattention, attproj in zip( |
| self.projections, |
| self.convolutions, |
| self.attention, |
| self.selfattention, |
| self.attproj, |
| ): |
| residual = x if proj is None else proj(x) |
|
|
| x = self.dropout_module(x) |
| x = conv(x) |
| x = F.glu(x, dim=2) |
|
|
| |
| if attention is not None: |
| r = x |
| x, attn_scores = attention( |
| attproj(x) + target_embedding, encoder_a, encoder_b |
| ) |
| x = x + r |
| if not self.training and self.need_attn: |
| if avg_attn_scores is None: |
| avg_attn_scores = attn_scores |
| else: |
| avg_attn_scores.add_(attn_scores) |
|
|
| if selfattention is not None: |
| x = selfattention(x) |
|
|
| x = (x + residual) * math.sqrt(0.5) |
|
|
| |
| x = x.transpose(0, 1) |
|
|
| |
| x = self.fc2(x) |
| x = self.dropout_module(x) |
| if not self.pretrained: |
| x = self.fc3(x) |
|
|
| |
| if self.pretrained: |
| trained_x, _ = self.pretrained_decoder.forward( |
| prev_output_tokens, trained_encoder_out |
| ) |
| y = torch.cat([x, self.pretrained_outputs["out"]], dim=-1) |
| gate1 = self.gate1(y) |
| gate2 = self.gate2(y) |
| gated_x1 = gate1 * x |
| gated_x2 = gate2 * self.pretrained_outputs["out"] |
| fusion = torch.cat([gated_x1, gated_x2], dim=-1) |
| fusion = self.joining(fusion) |
| fusion_output = self.fc3(fusion) |
| return fusion_output, avg_attn_scores |
| else: |
| return x, avg_attn_scores |
|
|
| def max_positions(self): |
| """Maximum output length supported by the decoder.""" |
| return self.embed_positions.max_positions |
|
|
| def make_generation_fast_(self, need_attn=False, **kwargs): |
| self.need_attn = need_attn |
|
|
| def _split_encoder_out(self, encoder_out): |
| """Split and transpose encoder outputs.""" |
| |
| encoder_a, encoder_b = encoder_out |
| encoder_a = encoder_a.transpose(0, 1).contiguous() |
| encoder_b = encoder_b.transpose(0, 1).contiguous() |
| result = (encoder_a, encoder_b) |
| return result |
|
|
|
|
| class SelfAttention(nn.Module): |
| def __init__( |
| self, |
| out_channels, |
| embed_dim, |
| num_heads, |
| project_input=False, |
| gated=False, |
| downsample=False, |
| ): |
| super().__init__() |
| self.attention = DownsampledMultiHeadAttention( |
| out_channels, |
| embed_dim, |
| num_heads, |
| dropout=0, |
| bias=True, |
| project_input=project_input, |
| gated=gated, |
| downsample=downsample, |
| ) |
| self.in_proj_q = Linear(out_channels, embed_dim) |
| self.in_proj_k = Linear(out_channels, embed_dim) |
| self.in_proj_v = Linear(out_channels, embed_dim) |
| self.ln = LayerNorm(out_channels) |
|
|
| def forward(self, x): |
| residual = x |
| query = self.in_proj_q(x) |
| key = self.in_proj_k(x) |
| value = self.in_proj_v(x) |
| x, _ = self.attention( |
| query, key, value, mask_future_timesteps=True, use_scalar_bias=True |
| ) |
| return self.ln(x + residual) |
|
|
|
|
| def Embedding(num_embeddings, embedding_dim, padding_idx): |
| m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) |
| m.weight.data.normal_(0, 0.1) |
| return m |
|
|
|
|
| def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx): |
| m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) |
| m.weight.data.normal_(0, 0.1) |
| return m |
|
|
|
|
| def Linear(in_features, out_features, dropout=0.0): |
| """Weight-normalized Linear layer (input: N x T x C)""" |
| m = nn.Linear(in_features, out_features) |
| m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features)) |
| m.bias.data.zero_() |
| return m |
|
|
|
|
| def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): |
| """Weight-normalized Conv1d layer optimized for decoding""" |
| m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) |
| std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) |
| m.weight.data.normal_(mean=0, std=std) |
| m.bias.data.zero_() |
| return m |
|
|
|
|
| def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): |
| """Weight-normalized Conv1d layer""" |
| from fairseq.modules import ConvTBC |
|
|
| m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs) |
| std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) |
| m.weight.data.normal_(mean=0, std=std) |
| m.bias.data.zero_() |
| return m |
|
|
|
|
| @register_model_architecture("fconv_self_att", "fconv_self_att") |
| def base_architecture(args): |
| args.dropout = getattr(args, "dropout", 0.1) |
| args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) |
| args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 3") |
| args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) |
| args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 8") |
| args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) |
| args.decoder_attention = getattr(args, "decoder_attention", "True") |
| args.self_attention = getattr(args, "self_attention", "False") |
| args.encoder_attention = getattr(args, "encoder_attention", "False") |
| args.multihead_attention_nheads = getattr(args, "multihead_attention_nheads", 1) |
| args.multihead_self_attention_nheads = getattr( |
| args, "multihead_self_attention_nheads", 1 |
| ) |
| args.encoder_attention_nheads = getattr(args, "encoder_attention_nheads", 1) |
| args.project_input = getattr(args, "project_input", "False") |
| args.gated_attention = getattr(args, "gated_attention", "False") |
| args.downsample = getattr(args, "downsample", "False") |
| args.pretrained_checkpoint = getattr(args, "pretrained_checkpoint", "") |
| args.pretrained = getattr(args, "pretrained", "False") |
|
|
|
|
| @register_model_architecture("fconv_self_att", "fconv_self_att_wp") |
| def fconv_self_att_wp(args): |
| args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) |
| args.encoder_layers = getattr( |
| args, "encoder_layers", "[(128, 3)] * 2 + [(512,3)] * 1" |
| ) |
| args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) |
| args.decoder_layers = getattr( |
| args, "decoder_layers", "[(512, 4)] * 4 + [(768, 4)] * 2 + [(1024, 4)] * 1" |
| ) |
| args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) |
| args.self_attention = getattr(args, "self_attention", "True") |
| args.multihead_self_attention_nheads = getattr( |
| args, "multihead_self_attention_nheads", 4 |
| ) |
| args.project_input = getattr(args, "project_input", "True") |
| args.gated_attention = getattr(args, "gated_attention", "True") |
| args.downsample = getattr(args, "downsample", "True") |
| base_architecture(args) |
|
|