|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
BART: Denoising Sequence-to-Sequence Pre-training for |
|
|
Natural Language Generation, Translation, and Comprehension |
|
|
""" |
|
|
|
|
|
import logging |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from fairseq import utils |
|
|
from fairseq.models import ( |
|
|
register_model, |
|
|
register_model_architecture, |
|
|
) |
|
|
from fairseq.models.transformer import TransformerModel |
|
|
from fairseq.modules.transformer_sentence_encoder import init_bert_params |
|
|
|
|
|
from .hub_interface import BARTHubInterface |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@register_model('bart') |
|
|
class BARTModel(TransformerModel): |
|
|
|
|
|
@classmethod |
|
|
def hub_models(cls): |
|
|
return { |
|
|
'bart.base': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz', |
|
|
'bart.large': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz', |
|
|
'bart.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz', |
|
|
'bart.large.cnn': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz', |
|
|
'bart.large.xsum': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz', |
|
|
} |
|
|
|
|
|
def __init__(self, args, encoder, decoder): |
|
|
super().__init__(args, encoder, decoder) |
|
|
|
|
|
|
|
|
self.apply(init_bert_params) |
|
|
|
|
|
self.classification_heads = nn.ModuleDict() |
|
|
|
|
|
@staticmethod |
|
|
def add_args(parser): |
|
|
super(BARTModel, BARTModel).add_args(parser) |
|
|
parser.add_argument( |
|
|
'--pooler-dropout', type=float, metavar='D', |
|
|
help='dropout probability in the masked_lm pooler layers' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--pooler-activation-fn', |
|
|
choices=utils.get_available_activation_fns(), |
|
|
help='activation function to use for pooler layer' |
|
|
) |
|
|
|
|
|
@property |
|
|
def supported_targets(self): |
|
|
return {'self'} |
|
|
|
|
|
def forward( |
|
|
self, src_tokens, src_lengths, prev_output_tokens, |
|
|
features_only=False, classification_head_name=None, **kwargs |
|
|
): |
|
|
if classification_head_name is not None: |
|
|
features_only = True |
|
|
|
|
|
encoder_out = self.encoder( |
|
|
src_tokens, |
|
|
src_lengths=src_lengths, |
|
|
**kwargs, |
|
|
) |
|
|
x, extra = self.decoder( |
|
|
prev_output_tokens, |
|
|
encoder_out=encoder_out, |
|
|
features_only=features_only, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if classification_head_name is not None: |
|
|
sentence_representation = x[ |
|
|
src_tokens.eq(self.encoder.dictionary.eos()), : |
|
|
].view(x.size(0), -1, x.size(-1))[:, -1, :] |
|
|
x = self.classification_heads[classification_head_name]( |
|
|
sentence_representation |
|
|
) |
|
|
return x, extra |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
model_name_or_path, |
|
|
checkpoint_file='model.pt', |
|
|
data_name_or_path='.', |
|
|
bpe='gpt2', |
|
|
**kwargs, |
|
|
): |
|
|
from fairseq import hub_utils |
|
|
x = hub_utils.from_pretrained( |
|
|
model_name_or_path, |
|
|
checkpoint_file, |
|
|
data_name_or_path, |
|
|
archive_map=cls.hub_models(), |
|
|
bpe=bpe, |
|
|
load_checkpoint_heads=True, |
|
|
**kwargs, |
|
|
) |
|
|
return BARTHubInterface(x['args'], x['task'], x['models'][0]) |
|
|
|
|
|
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): |
|
|
"""Register a classification head.""" |
|
|
logger.info("Registering classification head: {0}".format(name)) |
|
|
if name in self.classification_heads: |
|
|
prev_num_classes = self.classification_heads[name].out_proj.out_features |
|
|
prev_inner_dim = self.classification_heads[name].dense.out_features |
|
|
if num_classes != prev_num_classes or inner_dim != prev_inner_dim: |
|
|
logger.warning( |
|
|
're-registering head "{}" with num_classes {} (prev: {}) ' |
|
|
'and inner_dim {} (prev: {})'.format( |
|
|
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim |
|
|
) |
|
|
) |
|
|
self.classification_heads[name] = BARTClassificationHead( |
|
|
self.args.encoder_embed_dim, |
|
|
inner_dim or self.args.encoder_embed_dim, |
|
|
num_classes, |
|
|
self.args.pooler_activation_fn, |
|
|
self.args.pooler_dropout, |
|
|
) |
|
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
|
super().upgrade_state_dict_named(state_dict, name) |
|
|
|
|
|
prefix = name + '.' if name != '' else '' |
|
|
current_head_names = [] if not hasattr(self, 'classification_heads') else \ |
|
|
self.classification_heads.keys() |
|
|
|
|
|
|
|
|
keys_to_delete = [] |
|
|
for k in state_dict.keys(): |
|
|
if not k.startswith(prefix + 'classification_heads.'): |
|
|
continue |
|
|
|
|
|
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0] |
|
|
num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0) |
|
|
inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0) |
|
|
|
|
|
if getattr(self.args, 'load_checkpoint_heads', False): |
|
|
if head_name not in current_head_names: |
|
|
self.register_classification_head(head_name, num_classes, inner_dim) |
|
|
else: |
|
|
if head_name not in current_head_names: |
|
|
logger.warning( |
|
|
'deleting classification head ({}) from checkpoint ' |
|
|
'not present in current model: {}'.format(head_name, k) |
|
|
) |
|
|
keys_to_delete.append(k) |
|
|
elif ( |
|
|
num_classes != self.classification_heads[head_name].out_proj.out_features |
|
|
or inner_dim != self.classification_heads[head_name].dense.out_features |
|
|
): |
|
|
logger.warning( |
|
|
'deleting classification head ({}) from checkpoint ' |
|
|
'with different dimensions than current model: {}'.format(head_name, k) |
|
|
) |
|
|
keys_to_delete.append(k) |
|
|
for k in keys_to_delete: |
|
|
del state_dict[k] |
|
|
|
|
|
def truncate_emb(key): |
|
|
if key in state_dict: |
|
|
state_dict[key] = state_dict[key][:-1, :] |
|
|
|
|
|
|
|
|
|
|
|
loaded_dict_size = state_dict['encoder.embed_tokens.weight'].size(0) |
|
|
if loaded_dict_size == len(self.encoder.dictionary) + 1 and '<mask>' not in self.encoder.dictionary: |
|
|
truncate_emb('encoder.embed_tokens.weight') |
|
|
truncate_emb('decoder.embed_tokens.weight') |
|
|
truncate_emb('encoder.output_projection.weight') |
|
|
truncate_emb('decoder.output_projection.weight') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.args.task == 'multilingual_denoising' and loaded_dict_size < len(self.encoder.dictionary): |
|
|
logger.info( |
|
|
"Adding extra language embeddings not found in pretrained model for "\ |
|
|
"continued pretraining of MBART on new set of languages." |
|
|
) |
|
|
loaded_mask_token_embedding = state_dict['encoder.embed_tokens.weight'][-1, :] |
|
|
|
|
|
num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size |
|
|
embed_dim = state_dict['encoder.embed_tokens.weight'].size(1) |
|
|
|
|
|
new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim) |
|
|
nn.init.normal_( |
|
|
new_lang_embed_to_add, |
|
|
mean=0, |
|
|
std=embed_dim ** -0.5 |
|
|
) |
|
|
new_lang_embed_to_add = new_lang_embed_to_add.to( |
|
|
dtype=state_dict['encoder.embed_tokens.weight'].dtype, |
|
|
) |
|
|
|
|
|
state_dict['encoder.embed_tokens.weight'] = torch.cat([ |
|
|
state_dict['encoder.embed_tokens.weight'][:loaded_dict_size-1, :], |
|
|
new_lang_embed_to_add, |
|
|
loaded_mask_token_embedding.unsqueeze(0)] |
|
|
) |
|
|
state_dict['decoder.embed_tokens.weight'] = torch.cat([ |
|
|
state_dict['decoder.embed_tokens.weight'][:loaded_dict_size-1, :], |
|
|
new_lang_embed_to_add, |
|
|
loaded_mask_token_embedding.unsqueeze(0)] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(self, 'classification_heads'): |
|
|
cur_state = self.classification_heads.state_dict() |
|
|
for k, v in cur_state.items(): |
|
|
if prefix + 'classification_heads.' + k not in state_dict: |
|
|
logger.info('Overwriting', prefix + 'classification_heads.' + k) |
|
|
state_dict[prefix + 'classification_heads.' + k] = v |
|
|
|
|
|
|
|
|
class BARTClassificationHead(nn.Module): |
|
|
"""Head for sentence-level classification tasks.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim, |
|
|
inner_dim, |
|
|
num_classes, |
|
|
activation_fn, |
|
|
pooler_dropout, |
|
|
): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(input_dim, inner_dim) |
|
|
self.activation_fn = utils.get_activation_fn(activation_fn) |
|
|
self.dropout = nn.Dropout(p=pooler_dropout) |
|
|
self.out_proj = nn.Linear(inner_dim, num_classes) |
|
|
|
|
|
def forward(self, features, **kwargs): |
|
|
x = features |
|
|
x = self.dropout(x) |
|
|
x = self.dense(x) |
|
|
x = self.activation_fn(x) |
|
|
x = self.dropout(x) |
|
|
x = self.out_proj(x) |
|
|
return x |
|
|
|
|
|
|
|
|
@register_model_architecture('bart', 'bart_large') |
|
|
def bart_large_architecture(args): |
|
|
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) |
|
|
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) |
|
|
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4*1024) |
|
|
args.encoder_layers = getattr(args, 'encoder_layers', 12) |
|
|
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) |
|
|
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) |
|
|
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True) |
|
|
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) |
|
|
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim) |
|
|
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim) |
|
|
args.decoder_layers = getattr(args, 'decoder_layers', 12) |
|
|
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) |
|
|
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False) |
|
|
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True) |
|
|
args.attention_dropout = getattr(args, 'attention_dropout', 0.) |
|
|
args.relu_dropout = getattr(args, 'relu_dropout', 0.) |
|
|
args.dropout = getattr(args, 'dropout', 0.1) |
|
|
args.max_target_positions = getattr(args, 'max_target_positions', 1024) |
|
|
args.max_source_positions = getattr(args, 'max_source_positions', 1024) |
|
|
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) |
|
|
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) |
|
|
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', True) |
|
|
args.share_all_embeddings = getattr(args, 'share_all_embeddings', True) |
|
|
|
|
|
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) |
|
|
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) |
|
|
|
|
|
args.no_scale_embedding = getattr(args, 'no_scale_embedding', True) |
|
|
args.layernorm_embedding = getattr(args, 'layernorm_embedding', True) |
|
|
|
|
|
args.activation_fn = getattr(args, 'activation_fn', 'gelu') |
|
|
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') |
|
|
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) |
|
|
|
|
|
|
|
|
@register_model_architecture('bart', 'bart_base') |
|
|
def bart_base_architecture(args): |
|
|
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) |
|
|
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4*768) |
|
|
args.encoder_layers = getattr(args, 'encoder_layers', 6) |
|
|
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12) |
|
|
args.decoder_layers = getattr(args, 'decoder_layers', 6) |
|
|
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 12) |
|
|
bart_large_architecture(args) |
|
|
|
|
|
|
|
|
@register_model_architecture('bart', 'mbart_large') |
|
|
def mbart_large_architecture(args): |
|
|
args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) |
|
|
bart_large_architecture(args) |
|
|
|
|
|
|
|
|
@register_model_architecture('bart', 'mbart_base') |
|
|
def mbart_base_architecture(args): |
|
|
args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) |
|
|
bart_base_architecture(args) |
|
|
|
|
|
|
|
|
@register_model_architecture('bart', 'mbart_base_wmt20') |
|
|
def mbart_base_wmt20_architecture(args): |
|
|
args.layernorm_embedding = getattr(args, 'layernorm_embedding', False) |
|
|
mbart_base_architecture(args) |
|
|
|