Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import torch | |
| from fairseq.models.transformer import ( | |
| TransformerDecoder, | |
| TransformerEncoder, | |
| TransformerModel, | |
| ) | |
| from fairseq.modules.transformer_sentence_encoder import init_bert_params | |
| def ensemble_encoder(func): | |
| def wrapper(self, *args, **kwargs): | |
| if self.ensemble_models is None or len(self.ensemble_models) == 1: | |
| return func(self, *args, **kwargs) | |
| encoder_outs = [ | |
| func(model, *args, **kwargs, return_all_hiddens=True) | |
| for model in self.ensemble_models | |
| ] | |
| _encoder_out = encoder_outs[0].copy() | |
| def stack(key): | |
| outs = [e[key][0] for e in encoder_outs] | |
| return [torch.stack(outs, -1) if outs[0] is not None else None] | |
| _encoder_out["encoder_out"] = stack("encoder_out") | |
| _encoder_out["encoder_embedding"] = stack("encoder_embedding") | |
| num_layers = len(_encoder_out["encoder_states"]) | |
| if num_layers > 0: | |
| _encoder_out["encoder_states"] = [ | |
| torch.stack([e["encoder_states"][i] for e in encoder_outs], -1) | |
| for i in range(num_layers) | |
| ] | |
| return _encoder_out | |
| return wrapper | |
| def ensemble_decoder(func): | |
| def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs): | |
| if self.ensemble_models is None or len(self.ensemble_models) == 1: | |
| return func( | |
| self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs | |
| ) | |
| def _replace(encoder_out, new_val): | |
| new_encoder_out = encoder_out.copy() | |
| new_encoder_out["encoder_out"] = [new_val] | |
| return new_encoder_out | |
| action_outs = [ | |
| func( | |
| model, | |
| normalize=normalize, | |
| encoder_out=_replace( | |
| encoder_out, encoder_out["encoder_out"][0][:, :, :, i] | |
| ), | |
| *args, | |
| **kwargs | |
| ) | |
| for i, model in enumerate(self.ensemble_models) | |
| ] | |
| if not isinstance(action_outs[0], tuple): # return multiple values | |
| action_outs = [[a] for a in action_outs] | |
| else: | |
| action_outs = [list(a) for a in action_outs] | |
| ensembled_outs = [] | |
| for i in range(len(action_outs[0])): | |
| if i == 0 and normalize: | |
| ensembled_outs += [ | |
| torch.logsumexp( | |
| torch.stack([a[i] for a in action_outs], -1), dim=-1 | |
| ) | |
| - math.log(len(self.ensemble_models)) | |
| ] | |
| elif action_outs[0][i] is not None: | |
| ensembled_outs += [torch.stack([a[i] for a in action_outs], -1)] | |
| else: | |
| ensembled_outs += [None] | |
| if len(ensembled_outs) == 1: | |
| return ensembled_outs[0] | |
| return tuple(ensembled_outs) | |
| return wrapper | |
| class FairseqNATModel(TransformerModel): | |
| """ | |
| Abstract class for all nonautoregressive-based models | |
| """ | |
| def __init__(self, args, encoder, decoder): | |
| super().__init__(args, encoder, decoder) | |
| self.tgt_dict = decoder.dictionary | |
| self.bos = decoder.dictionary.bos() | |
| self.eos = decoder.dictionary.eos() | |
| self.pad = decoder.dictionary.pad() | |
| self.unk = decoder.dictionary.unk() | |
| self.ensemble_models = None | |
| def allow_length_beam(self): | |
| return False | |
| def allow_ensemble(self): | |
| return True | |
| def enable_ensemble(self, models): | |
| self.encoder.ensemble_models = [m.encoder for m in models] | |
| self.decoder.ensemble_models = [m.decoder for m in models] | |
| def add_args(parser): | |
| TransformerModel.add_args(parser) | |
| parser.add_argument( | |
| "--apply-bert-init", | |
| action="store_true", | |
| help="use custom param initialization for BERT", | |
| ) | |
| def build_decoder(cls, args, tgt_dict, embed_tokens): | |
| decoder = FairseqNATDecoder(args, tgt_dict, embed_tokens) | |
| if getattr(args, "apply_bert_init", False): | |
| decoder.apply(init_bert_params) | |
| return decoder | |
| def build_encoder(cls, args, src_dict, embed_tokens): | |
| encoder = FairseqNATEncoder(args, src_dict, embed_tokens) | |
| if getattr(args, "apply_bert_init", False): | |
| encoder.apply(init_bert_params) | |
| return encoder | |
| def forward_encoder(self, encoder_inputs): | |
| return self.encoder(*encoder_inputs) | |
| def forward_decoder(self, *args, **kwargs): | |
| return NotImplementedError | |
| def initialize_output_tokens(self, *args, **kwargs): | |
| return NotImplementedError | |
| def forward(self, *args, **kwargs): | |
| return NotImplementedError | |
| class FairseqNATEncoder(TransformerEncoder): | |
| def __init__(self, args, dictionary, embed_tokens): | |
| super().__init__(args, dictionary, embed_tokens) | |
| self.ensemble_models = None | |
| def forward(self, *args, **kwargs): | |
| return super().forward(*args, **kwargs) | |
| class FairseqNATDecoder(TransformerDecoder): | |
| def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): | |
| super().__init__(args, dictionary, embed_tokens, no_encoder_attn) | |
| self.ensemble_models = None | |