# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension """ import logging import torch import torch.nn as nn from fairseq import utils from fairseq.models import ( register_model, register_model_architecture, ) from fairseq.models.transformer import TransformerModel from fairseq.modules.transformer_sentence_encoder import init_bert_params from .hub_interface import BARTHubInterface logger = logging.getLogger(__name__) @register_model('bart') class BARTModel(TransformerModel): @classmethod def hub_models(cls): return { 'bart.base': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz', 'bart.large': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz', 'bart.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz', 'bart.large.cnn': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz', 'bart.large.xsum': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz', } def __init__(self, args, encoder, decoder): super().__init__(args, encoder, decoder) # We follow BERT's random weight initialization self.apply(init_bert_params) self.classification_heads = nn.ModuleDict() @staticmethod def add_args(parser): super(BARTModel, BARTModel).add_args(parser) parser.add_argument( '--pooler-dropout', type=float, metavar='D', help='dropout probability in the masked_lm pooler layers' ) parser.add_argument( '--pooler-activation-fn', choices=utils.get_available_activation_fns(), help='activation function to use for pooler layer' ) @property def supported_targets(self): return {'self'} def forward( self, src_tokens, src_lengths, prev_output_tokens, features_only=False, classification_head_name=None, **kwargs ): if classification_head_name is not None: features_only = True encoder_out = self.encoder( src_tokens, src_lengths=src_lengths, **kwargs, ) x, extra = self.decoder( prev_output_tokens, encoder_out=encoder_out, features_only=features_only, **kwargs, ) if classification_head_name is not None: sentence_representation = x[ src_tokens.eq(self.encoder.dictionary.eos()), : ].view(x.size(0), -1, x.size(-1))[:, -1, :] x = self.classification_heads[classification_head_name]( sentence_representation ) return x, extra @classmethod def from_pretrained( cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='gpt2', **kwargs, ): from fairseq import hub_utils x = hub_utils.from_pretrained( model_name_or_path, checkpoint_file, data_name_or_path, archive_map=cls.hub_models(), bpe=bpe, load_checkpoint_heads=True, **kwargs, ) return BARTHubInterface(x['args'], x['task'], x['models'][0]) def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): """Register a classification head.""" logger.info("Registering classification head: {0}".format(name)) if name in self.classification_heads: prev_num_classes = self.classification_heads[name].out_proj.out_features prev_inner_dim = self.classification_heads[name].dense.out_features if num_classes != prev_num_classes or inner_dim != prev_inner_dim: logger.warning( 're-registering head "{}" with num_classes {} (prev: {}) ' 'and inner_dim {} (prev: {})'.format( name, num_classes, prev_num_classes, inner_dim, prev_inner_dim ) ) self.classification_heads[name] = BARTClassificationHead( self.args.encoder_embed_dim, inner_dim or self.args.encoder_embed_dim, num_classes, self.args.pooler_activation_fn, self.args.pooler_dropout, ) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) prefix = name + '.' if name != '' else '' current_head_names = [] if not hasattr(self, 'classification_heads') else \ self.classification_heads.keys() # Handle new classification heads present in the state dict. keys_to_delete = [] for k in state_dict.keys(): if not k.startswith(prefix + 'classification_heads.'): continue head_name = k[len(prefix + 'classification_heads.'):].split('.')[0] num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0) inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0) if getattr(self.args, 'load_checkpoint_heads', False): if head_name not in current_head_names: self.register_classification_head(head_name, num_classes, inner_dim) else: if head_name not in current_head_names: logger.warning( 'deleting classification head ({}) from checkpoint ' 'not present in current model: {}'.format(head_name, k) ) keys_to_delete.append(k) elif ( num_classes != self.classification_heads[head_name].out_proj.out_features or inner_dim != self.classification_heads[head_name].dense.out_features ): logger.warning( 'deleting classification head ({}) from checkpoint ' 'with different dimensions than current model: {}'.format(head_name, k) ) keys_to_delete.append(k) for k in keys_to_delete: del state_dict[k] def truncate_emb(key): if key in state_dict: state_dict[key] = state_dict[key][:-1, :] # When finetuning on translation task, remove last row of # embedding matrix that corresponds to mask_idx token. loaded_dict_size = state_dict['encoder.embed_tokens.weight'].size(0) if loaded_dict_size == len(self.encoder.dictionary) + 1 and '' not in self.encoder.dictionary: truncate_emb('encoder.embed_tokens.weight') truncate_emb('decoder.embed_tokens.weight') truncate_emb('encoder.output_projection.weight') truncate_emb('decoder.output_projection.weight') # When continued pretraining on new set of languages for mbart, # add extra lang embeddings at the end of embed_tokens. # Note: newly added languages are assumed to have been added at the end. if self.args.task == 'multilingual_denoising' and loaded_dict_size < len(self.encoder.dictionary): logger.info( "Adding extra language embeddings not found in pretrained model for "\ "continued pretraining of MBART on new set of languages." ) loaded_mask_token_embedding = state_dict['encoder.embed_tokens.weight'][-1, :] num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size embed_dim = state_dict['encoder.embed_tokens.weight'].size(1) new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim) nn.init.normal_( new_lang_embed_to_add, mean=0, std=embed_dim ** -0.5 ) new_lang_embed_to_add = new_lang_embed_to_add.to( dtype=state_dict['encoder.embed_tokens.weight'].dtype, ) state_dict['encoder.embed_tokens.weight'] = torch.cat([ state_dict['encoder.embed_tokens.weight'][:loaded_dict_size-1, :], new_lang_embed_to_add, loaded_mask_token_embedding.unsqueeze(0)] ) state_dict['decoder.embed_tokens.weight'] = torch.cat([ state_dict['decoder.embed_tokens.weight'][:loaded_dict_size-1, :], new_lang_embed_to_add, loaded_mask_token_embedding.unsqueeze(0)] ) # Copy any newly-added classification heads into the state dict # with their current weights. if hasattr(self, 'classification_heads'): cur_state = self.classification_heads.state_dict() for k, v in cur_state.items(): if prefix + 'classification_heads.' + k not in state_dict: logger.info('Overwriting', prefix + 'classification_heads.' + k) state_dict[prefix + 'classification_heads.' + k] = v class BARTClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__( self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout, ): super().__init__() self.dense = nn.Linear(input_dim, inner_dim) self.activation_fn = utils.get_activation_fn(activation_fn) self.dropout = nn.Dropout(p=pooler_dropout) self.out_proj = nn.Linear(inner_dim, num_classes) def forward(self, features, **kwargs): x = features x = self.dropout(x) x = self.dense(x) x = self.activation_fn(x) x = self.dropout(x) x = self.out_proj(x) return x @register_model_architecture('bart', 'bart_large') def bart_large_architecture(args): args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4*1024) args.encoder_layers = getattr(args, 'encoder_layers', 12) args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True) args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim) args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim) args.decoder_layers = getattr(args, 'decoder_layers', 12) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False) args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True) args.attention_dropout = getattr(args, 'attention_dropout', 0.) args.relu_dropout = getattr(args, 'relu_dropout', 0.) args.dropout = getattr(args, 'dropout', 0.1) args.max_target_positions = getattr(args, 'max_target_positions', 1024) args.max_source_positions = getattr(args, 'max_source_positions', 1024) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', True) args.share_all_embeddings = getattr(args, 'share_all_embeddings', True) args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) args.no_scale_embedding = getattr(args, 'no_scale_embedding', True) args.layernorm_embedding = getattr(args, 'layernorm_embedding', True) args.activation_fn = getattr(args, 'activation_fn', 'gelu') args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) @register_model_architecture('bart', 'bart_base') def bart_base_architecture(args): args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4*768) args.encoder_layers = getattr(args, 'encoder_layers', 6) args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12) args.decoder_layers = getattr(args, 'decoder_layers', 6) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 12) bart_large_architecture(args) @register_model_architecture('bart', 'mbart_large') def mbart_large_architecture(args): args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) bart_large_architecture(args) @register_model_architecture('bart', 'mbart_base') def mbart_base_architecture(args): args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) bart_base_architecture(args) @register_model_architecture('bart', 'mbart_base_wmt20') def mbart_base_wmt20_architecture(args): args.layernorm_embedding = getattr(args, 'layernorm_embedding', False) mbart_base_architecture(args)