# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ RoBERTa: A Robustly Optimized BERT Pretraining Approach. """ import logging import torch import torch.nn as nn import torch.nn.functional as F from fairseq import utils from fairseq.models import ( FairseqEncoder, register_model, register_model_architecture, ) from fairseq.models.roberta import ( RobertaModel, RobertaEncoder, RobertaLMHead, RobertaClassificationHead, ) from fairseq.modules import ( LayerNorm, TransformerSentenceEncoder, ) from fairseq.model_parallel.modules import ( ModelParallelTransformerSentenceEncoder, ) from fairseq.modules.transformer_sentence_encoder import init_bert_params try: from fairseq.model_parallel.megatron.mpu import ( copy_to_model_parallel_region, gather_from_model_parallel_region, ColumnParallelLinear, RowParallelLinear, ) has_megatron_submodule = True except (ImportError, ModuleNotFoundError): has_megatron_submodule = False logger = logging.getLogger(__name__) @register_model('model_parallel_roberta') class ModelParallelRobertaModel(RobertaModel): def __init__(self, args, encoder): super().__init__(args, encoder) self.classification_heads = nn.ModuleDict() @staticmethod def add_args(parser): super(ModelParallelRobertaModel, ModelParallelRobertaModel).add_args(parser) @classmethod def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present base_architecture(args) if not hasattr(args, 'max_positions'): args.max_positions = args.tokens_per_sample encoder = ModelParallelRobertaEncoder(args, task.source_dictionary) return cls(args, encoder) def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs): if classification_head_name is not None: features_only = True x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs) if classification_head_name is not None: x = self.classification_heads[classification_head_name](x) return x, extra def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): """Register a classification head.""" if name in self.classification_heads: prev_num_classes = self.classification_heads[name].out_proj.out_features prev_inner_dim = self.classification_heads[name].dense.out_features if num_classes != prev_num_classes or inner_dim != prev_inner_dim: logger.warning( 're-registering head "{}" with num_classes {} (prev: {}) ' 'and inner_dim {} (prev: {})'.format( name, num_classes, prev_num_classes, inner_dim, prev_inner_dim ) ) self.classification_heads[name] = ModelParallelRobertaClassificationHead( self.args.encoder_embed_dim, inner_dim or self.args.encoder_embed_dim, num_classes, self.args.pooler_activation_fn, self.args.pooler_dropout, ) class ModelParallelRobertaLMHead(nn.Module): """Head for masked language modeling.""" def __init__(self, embed_dim, output_dim, activation_fn, weight=None): super().__init__() self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True) self.activation_fn = utils.get_activation_fn(activation_fn) self.layer_norm = LayerNorm(embed_dim) if weight is None: weight = nn.Linear(embed_dim, output_dim, bias=False).weight self.weight = weight self.bias = nn.Parameter(torch.zeros(output_dim)) def forward(self, features, masked_tokens=None, **kwargs): # Only project the unmasked tokens while training, # saves both memory and computation if masked_tokens is not None: features = features[masked_tokens, :] x = self.dense(features) x = self.activation_fn(x) x = self.layer_norm(x) features = copy_to_model_parallel_region(features) # project back to size of vocabulary with bias x = F.linear(x, self.weight) x = gather_from_model_parallel_region(x).contiguous() x = x + self.bias return x class ModelParallelRobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout): super().__init__() self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True) self.activation_fn = utils.get_activation_fn(activation_fn) self.dropout = nn.Dropout(p=pooler_dropout) self.out_proj = nn.Linear(inner_dim, num_classes) def forward(self, features, **kwargs): x = features[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) x = self.activation_fn(x) x = self.dropout(x) x = self.out_proj(x) return x class ModelParallelRobertaEncoder(FairseqEncoder): """RoBERTa encoder. Implements the :class:`~fairseq.models.FairseqDecoder` interface required by :class:`~fairseq.models.FairseqLanguageModel`. """ def __init__(self, args, dictionary): super().__init__(dictionary) self.args = args # RoBERTa is a sentence encoder model, so users will intuitively trim # encoder layers. However, the implementation uses the fairseq decoder, # so we fix here. if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) args.decoder_layers_to_keep = args.encoder_layers_to_keep args.encoder_layers_to_keep = None self.sentence_encoder = ModelParallelTransformerSentenceEncoder( padding_idx=dictionary.pad(), vocab_size=len(dictionary), num_encoder_layers=args.encoder_layers, embedding_dim=args.encoder_embed_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=args.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, layerdrop=args.encoder_layerdrop, max_seq_len=args.max_positions, num_segments=0, encoder_normalize_before=False, apply_bert_init=False, activation_fn=args.activation_fn, ) self.lm_head = ModelParallelRobertaLMHead( embed_dim=args.encoder_embed_dim, output_dim=len(dictionary), activation_fn=args.activation_fn, weight=self.sentence_encoder.embed_tokens.weight, ) def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused): """ Args: src_tokens (LongTensor): input tokens of shape `(batch, src_len)` features_only (bool, optional): skip LM head and just return features. If True, the output will be of shape `(batch, src_len, embed_dim)`. return_all_hiddens (bool, optional): also return all of the intermediate hidden states (default: False). Returns: tuple: - the LM output of shape `(batch, src_len, vocab)` - a dictionary of additional data, where 'inner_states' is a list of hidden states. Note that the hidden states have shape `(src_len, batch, vocab)`. """ x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens) if not features_only: x = self.output_layer(x, masked_tokens=masked_tokens) return x, extra def extract_features(self, src_tokens, return_all_hiddens=False, **unused): inner_states, _ = self.sentence_encoder( src_tokens, last_state_only=not return_all_hiddens, ) features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C return features, {'inner_states': inner_states if return_all_hiddens else None} def output_layer(self, features, masked_tokens=None, **unused): return self.lm_head(features, masked_tokens) def max_positions(self): """Maximum output length supported by the encoder.""" return self.args.max_positions @register_model_architecture('model_parallel_roberta', 'model_parallel_roberta') def base_architecture(args): args.encoder_layers = getattr(args, 'encoder_layers', 12) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072) args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12) args.activation_fn = getattr(args, 'activation_fn', 'gelu') args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') args.dropout = getattr(args, 'dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.activation_dropout = getattr(args, 'activation_dropout', 0.0) args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None) args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0) @register_model_architecture('model_parallel_roberta', 'model_parallel_roberta_base') def roberta_base_architecture(args): base_architecture(args) @register_model_architecture('model_parallel_roberta', 'model_parallel_roberta_large') def roberta_large_architecture(args): args.encoder_layers = getattr(args, 'encoder_layers', 24) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) base_architecture(args)