Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from fairseq import options, utils | |
| from fairseq.models import ( | |
| FairseqEncoder, | |
| FairseqIncrementalDecoder, | |
| FairseqEncoderDecoderModel, | |
| register_model, | |
| register_model_architecture, | |
| ) | |
| class LSTMModel(FairseqEncoderDecoderModel): | |
| def __init__(self, encoder, decoder): | |
| super().__init__(encoder, decoder) | |
| def forward( | |
| self, | |
| src_tokens, | |
| src_lengths, | |
| prev_output_tokens=None, | |
| tgt_tokens=None, | |
| tgt_lengths=None, | |
| target_language_id=None, | |
| dataset_name="", | |
| ): | |
| assert target_language_id is not None | |
| src_encoder_out = self.encoder(src_tokens, src_lengths, dataset_name) | |
| return self.decoder( | |
| prev_output_tokens, src_encoder_out, lang_id=target_language_id | |
| ) | |
| def add_args(parser): | |
| """Add model-specific arguments to the parser.""" | |
| parser.add_argument( | |
| "--dropout", | |
| default=0.1, | |
| type=float, | |
| metavar="D", | |
| help="dropout probability", | |
| ) | |
| parser.add_argument( | |
| "--encoder-embed-dim", | |
| type=int, | |
| metavar="N", | |
| help="encoder embedding dimension", | |
| ) | |
| parser.add_argument( | |
| "--encoder-embed-path", | |
| default=None, | |
| type=str, | |
| metavar="STR", | |
| help="path to pre-trained encoder embedding", | |
| ) | |
| parser.add_argument( | |
| "--encoder-hidden-size", type=int, metavar="N", help="encoder hidden size" | |
| ) | |
| parser.add_argument( | |
| "--encoder-layers", type=int, metavar="N", help="number of encoder layers" | |
| ) | |
| parser.add_argument( | |
| "--encoder-bidirectional", | |
| action="store_true", | |
| help="make all layers of encoder bidirectional", | |
| ) | |
| parser.add_argument( | |
| "--decoder-embed-dim", | |
| type=int, | |
| metavar="N", | |
| help="decoder embedding dimension", | |
| ) | |
| parser.add_argument( | |
| "--decoder-embed-path", | |
| default=None, | |
| type=str, | |
| metavar="STR", | |
| help="path to pre-trained decoder embedding", | |
| ) | |
| parser.add_argument( | |
| "--decoder-hidden-size", type=int, metavar="N", help="decoder hidden size" | |
| ) | |
| parser.add_argument( | |
| "--decoder-layers", type=int, metavar="N", help="number of decoder layers" | |
| ) | |
| parser.add_argument( | |
| "--decoder-out-embed-dim", | |
| type=int, | |
| metavar="N", | |
| help="decoder output embedding dimension", | |
| ) | |
| parser.add_argument( | |
| "--decoder-zero-init", | |
| type=str, | |
| metavar="BOOL", | |
| help="initialize the decoder hidden/cell state to zero", | |
| ) | |
| parser.add_argument( | |
| "--decoder-lang-embed-dim", | |
| type=int, | |
| metavar="N", | |
| help="decoder language embedding dimension", | |
| ) | |
| parser.add_argument( | |
| "--fixed-embeddings", | |
| action="store_true", | |
| help="keep embeddings fixed (ENCODER ONLY)", | |
| ) # TODO Also apply to decoder embeddings? | |
| # Granular dropout settings (if not specified these default to --dropout) | |
| parser.add_argument( | |
| "--encoder-dropout-in", | |
| type=float, | |
| metavar="D", | |
| help="dropout probability for encoder input embedding", | |
| ) | |
| parser.add_argument( | |
| "--encoder-dropout-out", | |
| type=float, | |
| metavar="D", | |
| help="dropout probability for encoder output", | |
| ) | |
| parser.add_argument( | |
| "--decoder-dropout-in", | |
| type=float, | |
| metavar="D", | |
| help="dropout probability for decoder input embedding", | |
| ) | |
| parser.add_argument( | |
| "--decoder-dropout-out", | |
| type=float, | |
| metavar="D", | |
| help="dropout probability for decoder output", | |
| ) | |
| def build_model(cls, args, task): | |
| """Build a new model instance.""" | |
| # make sure that all args are properly defaulted (in case there are any new ones) | |
| base_architecture(args) | |
| def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): | |
| num_embeddings = len(dictionary) | |
| padding_idx = dictionary.pad() | |
| embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) | |
| embed_dict = utils.parse_embedding(embed_path) | |
| utils.print_embed_overlap(embed_dict, dictionary) | |
| return utils.load_embedding(embed_dict, dictionary, embed_tokens) | |
| pretrained_encoder_embed = None | |
| if args.encoder_embed_path: | |
| pretrained_encoder_embed = load_pretrained_embedding_from_file( | |
| args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim | |
| ) | |
| pretrained_decoder_embed = None | |
| if args.decoder_embed_path: | |
| pretrained_decoder_embed = load_pretrained_embedding_from_file( | |
| args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim | |
| ) | |
| num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0 | |
| encoder = LSTMEncoder( | |
| dictionary=task.source_dictionary, | |
| embed_dim=args.encoder_embed_dim, | |
| hidden_size=args.encoder_hidden_size, | |
| num_layers=args.encoder_layers, | |
| dropout_in=args.encoder_dropout_in, | |
| dropout_out=args.encoder_dropout_out, | |
| bidirectional=args.encoder_bidirectional, | |
| pretrained_embed=pretrained_encoder_embed, | |
| fixed_embeddings=args.fixed_embeddings, | |
| ) | |
| decoder = LSTMDecoder( | |
| dictionary=task.target_dictionary, | |
| embed_dim=args.decoder_embed_dim, | |
| hidden_size=args.decoder_hidden_size, | |
| out_embed_dim=args.decoder_out_embed_dim, | |
| num_layers=args.decoder_layers, | |
| dropout_in=args.decoder_dropout_in, | |
| dropout_out=args.decoder_dropout_out, | |
| zero_init=options.eval_bool(args.decoder_zero_init), | |
| encoder_embed_dim=args.encoder_embed_dim, | |
| encoder_output_units=encoder.output_units, | |
| pretrained_embed=pretrained_decoder_embed, | |
| num_langs=num_langs, | |
| lang_embed_dim=args.decoder_lang_embed_dim, | |
| ) | |
| return cls(encoder, decoder) | |
| class LSTMEncoder(FairseqEncoder): | |
| """LSTM encoder.""" | |
| def __init__( | |
| self, | |
| dictionary, | |
| embed_dim=512, | |
| hidden_size=512, | |
| num_layers=1, | |
| dropout_in=0.1, | |
| dropout_out=0.1, | |
| bidirectional=False, | |
| left_pad=True, | |
| pretrained_embed=None, | |
| padding_value=0.0, | |
| fixed_embeddings=False, | |
| ): | |
| super().__init__(dictionary) | |
| self.num_layers = num_layers | |
| self.dropout_in = dropout_in | |
| self.dropout_out = dropout_out | |
| self.bidirectional = bidirectional | |
| self.hidden_size = hidden_size | |
| num_embeddings = len(dictionary) | |
| self.padding_idx = dictionary.pad() | |
| if pretrained_embed is None: | |
| self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) | |
| else: | |
| self.embed_tokens = pretrained_embed | |
| if fixed_embeddings: | |
| self.embed_tokens.weight.requires_grad = False | |
| self.lstm = LSTM( | |
| input_size=embed_dim, | |
| hidden_size=hidden_size, | |
| num_layers=num_layers, | |
| dropout=self.dropout_out if num_layers > 1 else 0.0, | |
| bidirectional=bidirectional, | |
| ) | |
| self.left_pad = left_pad | |
| self.padding_value = padding_value | |
| self.output_units = hidden_size | |
| if bidirectional: | |
| self.output_units *= 2 | |
| def forward(self, src_tokens, src_lengths, dataset_name): | |
| if self.left_pad: | |
| # convert left-padding to right-padding | |
| src_tokens = utils.convert_padding_direction( | |
| src_tokens, | |
| self.padding_idx, | |
| left_to_right=True, | |
| ) | |
| bsz, seqlen = src_tokens.size() | |
| # embed tokens | |
| x = self.embed_tokens(src_tokens) | |
| x = F.dropout(x, p=self.dropout_in, training=self.training) | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| # pack embedded source tokens into a PackedSequence | |
| try: | |
| packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) | |
| except BaseException: | |
| raise Exception(f"Packing failed in dataset {dataset_name}") | |
| # apply LSTM | |
| if self.bidirectional: | |
| state_size = 2 * self.num_layers, bsz, self.hidden_size | |
| else: | |
| state_size = self.num_layers, bsz, self.hidden_size | |
| h0 = x.data.new(*state_size).zero_() | |
| c0 = x.data.new(*state_size).zero_() | |
| packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) | |
| # unpack outputs and apply dropout | |
| x, _ = nn.utils.rnn.pad_packed_sequence( | |
| packed_outs, padding_value=self.padding_value | |
| ) | |
| x = F.dropout(x, p=self.dropout_out, training=self.training) | |
| assert list(x.size()) == [seqlen, bsz, self.output_units] | |
| if self.bidirectional: | |
| def combine_bidir(outs): | |
| return torch.cat( | |
| [ | |
| torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view( | |
| 1, bsz, self.output_units | |
| ) | |
| for i in range(self.num_layers) | |
| ], | |
| dim=0, | |
| ) | |
| final_hiddens = combine_bidir(final_hiddens) | |
| final_cells = combine_bidir(final_cells) | |
| encoder_padding_mask = src_tokens.eq(self.padding_idx).t() | |
| # Set padded outputs to -inf so they are not selected by max-pooling | |
| padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) | |
| if padding_mask.any(): | |
| x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) | |
| # Build the sentence embedding by max-pooling over the encoder outputs | |
| sentemb = x.max(dim=0)[0] | |
| return { | |
| "sentemb": sentemb, | |
| "encoder_out": (x, final_hiddens, final_cells), | |
| "encoder_padding_mask": encoder_padding_mask | |
| if encoder_padding_mask.any() | |
| else None, | |
| } | |
| def reorder_encoder_out(self, encoder_out_dict, new_order): | |
| encoder_out_dict["sentemb"] = encoder_out_dict["sentemb"].index_select( | |
| 0, new_order | |
| ) | |
| encoder_out_dict["encoder_out"] = tuple( | |
| eo.index_select(1, new_order) for eo in encoder_out_dict["encoder_out"] | |
| ) | |
| if encoder_out_dict["encoder_padding_mask"] is not None: | |
| encoder_out_dict["encoder_padding_mask"] = encoder_out_dict[ | |
| "encoder_padding_mask" | |
| ].index_select(1, new_order) | |
| return encoder_out_dict | |
| def max_positions(self): | |
| """Maximum input length supported by the encoder.""" | |
| return int(1e5) # an arbitrary large number | |
| class LSTMDecoder(FairseqIncrementalDecoder): | |
| """LSTM decoder.""" | |
| def __init__( | |
| self, | |
| dictionary, | |
| embed_dim=512, | |
| hidden_size=512, | |
| out_embed_dim=512, | |
| num_layers=1, | |
| dropout_in=0.1, | |
| dropout_out=0.1, | |
| zero_init=False, | |
| encoder_embed_dim=512, | |
| encoder_output_units=512, | |
| pretrained_embed=None, | |
| num_langs=1, | |
| lang_embed_dim=0, | |
| ): | |
| super().__init__(dictionary) | |
| self.dropout_in = dropout_in | |
| self.dropout_out = dropout_out | |
| self.hidden_size = hidden_size | |
| num_embeddings = len(dictionary) | |
| padding_idx = dictionary.pad() | |
| if pretrained_embed is None: | |
| self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) | |
| else: | |
| self.embed_tokens = pretrained_embed | |
| self.layers = nn.ModuleList( | |
| [ | |
| LSTMCell( | |
| input_size=encoder_output_units + embed_dim + lang_embed_dim | |
| if layer == 0 | |
| else hidden_size, | |
| hidden_size=hidden_size, | |
| ) | |
| for layer in range(num_layers) | |
| ] | |
| ) | |
| if hidden_size != out_embed_dim: | |
| self.additional_fc = Linear(hidden_size, out_embed_dim) | |
| self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) | |
| if zero_init: | |
| self.sentemb2init = None | |
| else: | |
| self.sentemb2init = Linear( | |
| encoder_output_units, 2 * num_layers * hidden_size | |
| ) | |
| if lang_embed_dim == 0: | |
| self.embed_lang = None | |
| else: | |
| self.embed_lang = nn.Embedding(num_langs, lang_embed_dim) | |
| nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1) | |
| def forward( | |
| self, prev_output_tokens, encoder_out_dict, incremental_state=None, lang_id=0 | |
| ): | |
| sentemb = encoder_out_dict["sentemb"] | |
| encoder_out = encoder_out_dict["encoder_out"] | |
| if incremental_state is not None: | |
| prev_output_tokens = prev_output_tokens[:, -1:] | |
| bsz, seqlen = prev_output_tokens.size() | |
| # get outputs from encoder | |
| encoder_outs, _, _ = encoder_out[:3] | |
| srclen = encoder_outs.size(0) | |
| # embed tokens | |
| x = self.embed_tokens(prev_output_tokens) | |
| x = F.dropout(x, p=self.dropout_in, training=self.training) | |
| # embed language identifier | |
| if self.embed_lang is not None: | |
| lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id) | |
| langemb = self.embed_lang(lang_ids) | |
| # TODO Should we dropout here??? | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| # initialize previous states (or get from cache during incremental generation) | |
| cached_state = utils.get_incremental_state( | |
| self, incremental_state, "cached_state" | |
| ) | |
| if cached_state is not None: | |
| prev_hiddens, prev_cells, input_feed = cached_state | |
| else: | |
| num_layers = len(self.layers) | |
| if self.sentemb2init is None: | |
| prev_hiddens = [ | |
| x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers) | |
| ] | |
| prev_cells = [ | |
| x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers) | |
| ] | |
| else: | |
| init = self.sentemb2init(sentemb) | |
| prev_hiddens = [ | |
| init[:, (2 * i) * self.hidden_size : (2 * i + 1) * self.hidden_size] | |
| for i in range(num_layers) | |
| ] | |
| prev_cells = [ | |
| init[ | |
| :, | |
| (2 * i + 1) * self.hidden_size : (2 * i + 2) * self.hidden_size, | |
| ] | |
| for i in range(num_layers) | |
| ] | |
| input_feed = x.data.new(bsz, self.hidden_size).zero_() | |
| attn_scores = x.data.new(srclen, seqlen, bsz).zero_() | |
| outs = [] | |
| for j in range(seqlen): | |
| if self.embed_lang is None: | |
| input = torch.cat((x[j, :, :], sentemb), dim=1) | |
| else: | |
| input = torch.cat((x[j, :, :], sentemb, langemb), dim=1) | |
| for i, rnn in enumerate(self.layers): | |
| # recurrent cell | |
| hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) | |
| # hidden state becomes the input to the next layer | |
| input = F.dropout(hidden, p=self.dropout_out, training=self.training) | |
| # save state for next time step | |
| prev_hiddens[i] = hidden | |
| prev_cells[i] = cell | |
| out = hidden | |
| out = F.dropout(out, p=self.dropout_out, training=self.training) | |
| # input feeding | |
| input_feed = out | |
| # save final output | |
| outs.append(out) | |
| # cache previous states (no-op except during incremental generation) | |
| utils.set_incremental_state( | |
| self, | |
| incremental_state, | |
| "cached_state", | |
| (prev_hiddens, prev_cells, input_feed), | |
| ) | |
| # collect outputs across time steps | |
| x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) | |
| # T x B x C -> B x T x C | |
| x = x.transpose(1, 0) | |
| # srclen x tgtlen x bsz -> bsz x tgtlen x srclen | |
| attn_scores = attn_scores.transpose(0, 2) | |
| # project back to size of vocabulary | |
| if hasattr(self, "additional_fc"): | |
| x = self.additional_fc(x) | |
| x = F.dropout(x, p=self.dropout_out, training=self.training) | |
| x = self.fc_out(x) | |
| return x, attn_scores | |
| def reorder_incremental_state(self, incremental_state, new_order): | |
| super().reorder_incremental_state(incremental_state, new_order) | |
| cached_state = utils.get_incremental_state( | |
| self, incremental_state, "cached_state" | |
| ) | |
| if cached_state is None: | |
| return | |
| def reorder_state(state): | |
| if isinstance(state, list): | |
| return [reorder_state(state_i) for state_i in state] | |
| return state.index_select(0, new_order) | |
| new_state = tuple(map(reorder_state, cached_state)) | |
| utils.set_incremental_state(self, incremental_state, "cached_state", new_state) | |
| def max_positions(self): | |
| """Maximum output length supported by the decoder.""" | |
| return int(1e5) # an arbitrary large number | |
| def Embedding(num_embeddings, embedding_dim, padding_idx): | |
| m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) | |
| nn.init.uniform_(m.weight, -0.1, 0.1) | |
| nn.init.constant_(m.weight[padding_idx], 0) | |
| return m | |
| def LSTM(input_size, hidden_size, **kwargs): | |
| m = nn.LSTM(input_size, hidden_size, **kwargs) | |
| for name, param in m.named_parameters(): | |
| if "weight" in name or "bias" in name: | |
| param.data.uniform_(-0.1, 0.1) | |
| return m | |
| def LSTMCell(input_size, hidden_size, **kwargs): | |
| m = nn.LSTMCell(input_size, hidden_size, **kwargs) | |
| for name, param in m.named_parameters(): | |
| if "weight" in name or "bias" in name: | |
| param.data.uniform_(-0.1, 0.1) | |
| return m | |
| def Linear(in_features, out_features, bias=True, dropout=0): | |
| """Weight-normalized Linear layer (input: N x T x C)""" | |
| m = nn.Linear(in_features, out_features, bias=bias) | |
| m.weight.data.uniform_(-0.1, 0.1) | |
| if bias: | |
| m.bias.data.uniform_(-0.1, 0.1) | |
| return m | |
| def base_architecture(args): | |
| args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) | |
| args.encoder_embed_path = getattr(args, "encoder_embed_path", None) | |
| args.encoder_hidden_size = getattr( | |
| args, "encoder_hidden_size", args.encoder_embed_dim | |
| ) | |
| args.encoder_layers = getattr(args, "encoder_layers", 1) | |
| args.encoder_bidirectional = getattr(args, "encoder_bidirectional", False) | |
| args.encoder_dropout_in = getattr(args, "encoder_dropout_in", args.dropout) | |
| args.encoder_dropout_out = getattr(args, "encoder_dropout_out", args.dropout) | |
| args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) | |
| args.decoder_embed_path = getattr(args, "decoder_embed_path", None) | |
| args.decoder_hidden_size = getattr( | |
| args, "decoder_hidden_size", args.decoder_embed_dim | |
| ) | |
| args.decoder_layers = getattr(args, "decoder_layers", 1) | |
| args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) | |
| args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout) | |
| args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) | |
| args.decoder_zero_init = getattr(args, "decoder_zero_init", "0") | |
| args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0) | |
| args.fixed_embeddings = getattr(args, "fixed_embeddings", False) | |