Spaces:
Sleeping
Sleeping
| import argparse | |
| import logging | |
| import torch.nn as nn | |
| import fairseq.checkpoint_utils | |
| from fairseq.models import ( | |
| FairseqEncoderDecoderModel, | |
| register_model, | |
| register_model_architecture, | |
| ) | |
| from fairseq.models.transformer import TransformerDecoder | |
| from fairseq.models.roberta import model as roberta | |
| logger = logging.getLogger(__name__) | |
| class RobertaEncDecModel(FairseqEncoderDecoderModel): | |
| def add_args(parser): | |
| parser.add_argument( | |
| "--pretrained-mlm-checkpoint", | |
| default=None, | |
| type=str, | |
| metavar="PRETRAINED", | |
| help="path to pretrained mlm checkpoint", | |
| ) | |
| parser.add_argument( | |
| "--pretrained-decoder", action="store_true", help="reload decoder" | |
| ) | |
| parser.add_argument( | |
| "--hack-layernorm-embedding", | |
| action="store_true", | |
| help="hack to reload old models trained with encoder-normalize-before=False (no equivalent to encoder-normalize-before=False and layernorm_embedding=False", | |
| ) | |
| parser.add_argument( | |
| "--share-decoder-input-output-embed", | |
| action="store_true", | |
| help="share decoder input and output embeddings", | |
| ) | |
| parser.add_argument( | |
| "--share-all-embeddings", | |
| action="store_true", | |
| help="share encoder, decoder and output embeddings" | |
| " (requires shared dictionary and embed dim)", | |
| ) | |
| def build_model(cls, args, task): | |
| """Build a new model instance.""" | |
| # make sure all arguments are present | |
| base_enc_dec_architecture(args) | |
| if args.pretrained_mlm_checkpoint: | |
| arg_overrides = None | |
| if args.hack_layernorm_embedding: | |
| arg_overrides = {"layernorm_embedding": False} | |
| loaded = fairseq.checkpoint_utils.load_model_ensemble_and_task( | |
| [args.pretrained_mlm_checkpoint], arg_overrides=arg_overrides | |
| ) | |
| ([roberta_enc], _cfg, _task) = loaded | |
| else: | |
| # Do we need to edit untie_weights here ? | |
| share_in_out = ( | |
| args.share_decoder_input_output_embed or args.share_all_embeddings | |
| ) | |
| args.untie_weights_roberta = not share_in_out | |
| if args.hack_layernorm_embedding: | |
| args.layernorm_embedding = False | |
| args.encoder_normalize_before = False | |
| roberta_enc = roberta.RobertaModel.build_model(args, task) | |
| return cls.from_roberta(roberta_enc, args, task.source_dictionary) | |
| def from_roberta(roberta_enc: roberta.RobertaModel, args, dictionary): | |
| encoder = roberta_enc.encoder.sentence_encoder | |
| vocab_size, embed_dim = encoder.embed_tokens.weight.shape | |
| if args.share_all_embeddings: | |
| lm_head = roberta_enc.encoder.lm_head | |
| assert encoder.embed_tokens.weight is lm_head.weight, ( | |
| "Can't use --share-all-embeddings with a model " | |
| "that was pretraiend with --untie-weights-roberta_enc" | |
| ) | |
| else: | |
| lm_head = roberta.RobertaLMHead( | |
| embed_dim, vocab_size, roberta_enc.args.activation_fn | |
| ) | |
| dec_embs = nn.Embedding(vocab_size, embed_dim, dictionary.pad()) | |
| if args.share_all_embeddings or args.share_decoder_input_output_embed: | |
| # Note: I wasn't able to use Embedding _weight parameter to achive this sharing. | |
| dec_embs.weight = lm_head.weight | |
| decoder = TransformerDecoder( | |
| RobertaEncDecModel.read_args_from_roberta(roberta_enc.args), | |
| dictionary, | |
| dec_embs, | |
| no_encoder_attn=False, | |
| output_projection=lm_head, | |
| ) | |
| if getattr(args, "pretrained_decoder", False): | |
| decoder_dict = encoder.state_dict() | |
| # TODO: hide setting "encoder_attn" layers behind a flag. | |
| for k, w in list(decoder_dict.items()): | |
| if ".self_attn" in k: | |
| k_enc_attn = k.replace(".self_attn", ".encoder_attn") | |
| decoder_dict[k_enc_attn] = w.detach().clone() | |
| for k, w in lm_head.state_dict().items(): | |
| decoder_dict["output_projection." + k] = w | |
| missing_keys, unexpected_keys = decoder.load_state_dict( | |
| decoder_dict, strict=False | |
| ) | |
| # missing_keys = [m for m in missing_keys if ".encoder_attn" not in m] | |
| assert not missing_keys and not unexpected_keys, ( | |
| "Failed to load state dict. " | |
| f"Missing keys: {missing_keys}. " | |
| f"Unexpected keys: {unexpected_keys}." | |
| ) | |
| if args.share_all_embeddings: | |
| assert decoder.output_projection.weight is decoder.embed_tokens.weight | |
| assert encoder.embed_tokens.weight is decoder.embed_tokens.weight | |
| elif args.share_decoder_input_output_embed: | |
| assert decoder.output_projection.weight is decoder.embed_tokens.weight | |
| assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight | |
| else: | |
| assert decoder.output_projection.weight is not decoder.embed_tokens.weight | |
| assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight | |
| return RobertaEncDecModel(encoder, decoder) | |
| def read_args_from_roberta(roberta_args: argparse.Namespace): | |
| # TODO: this would become easier if encoder/decoder where using a similar | |
| # TransformerConfig object | |
| args = argparse.Namespace(**vars(roberta_args)) | |
| attr_map = [ | |
| ("encoder_attention_heads", "decoder_attention_heads"), | |
| ("encoder_embed_dim", "decoder_embed_dim"), | |
| ("encoder_embed_dim", "decoder_output_dim"), | |
| ("encoder_normalize_before", "decoder_normalize_before"), | |
| ("encoder_layers_to_keep", "decoder_layers_to_keep"), | |
| ("encoder_ffn_embed_dim", "decoder_ffn_embed_dim"), | |
| ("encoder_layerdrop", "decoder_layerdrop"), | |
| ("encoder_layers", "decoder_layers"), | |
| ("encoder_learned_pos", "decoder_learned_pos"), | |
| # should this be set from here ? | |
| ("max_positions", "max_target_positions"), | |
| ] | |
| for k1, k2 in attr_map: | |
| setattr(args, k2, getattr(roberta_args, k1)) | |
| args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) | |
| args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) | |
| args.share_decoder_input_output_embed = not roberta_args.untie_weights_roberta | |
| return args | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| prefix = name + "." if name != "" else "" | |
| super().upgrade_state_dict_named(state_dict, name) | |
| old_keys = list(state_dict.keys()) | |
| # rename decoder -> encoder before upgrading children modules | |
| for k in old_keys: | |
| if k.startswith(prefix + "encoder.lm_head"): | |
| state_dict.pop(k) | |
| continue | |
| new_k = k | |
| new_k = new_k.replace(".sentence_encoder.", ".") | |
| new_k = new_k.replace("decoder.lm_head.", "decoder.output_projection.") | |
| if k == new_k: | |
| continue | |
| # print(k, "->", new_k) | |
| state_dict[new_k] = state_dict.pop(k) | |
| def base_enc_dec_architecture(args): | |
| args.hack_layernorm_embedding = getattr(args, "hack_layernorm_embedding", False) | |
| args.pretrained_mlm_checkpoint = getattr(args, "pretrained_mlm_checkpoint", None) | |
| args.pretrained_decoder = getattr(args, "pretrained_decoder", None) | |
| args.share_all_embeddings = getattr(args, "share_all_embeddings", False) | |
| args.share_decoder_input_output_embed = getattr( | |
| args, "share_decoder_input_output_embed", False | |
| ) | |
| roberta.base_architecture(args) | |