julse's picture
Upload 551 files
be611b4 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
"""
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
register_model,
register_model_architecture,
)
from fairseq.models.roberta import (
RobertaModel,
RobertaEncoder,
RobertaLMHead,
RobertaClassificationHead,
)
from fairseq.modules import (
LayerNorm,
TransformerSentenceEncoder,
)
from fairseq.model_parallel.modules import (
ModelParallelTransformerSentenceEncoder,
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
try:
from fairseq.model_parallel.megatron.mpu import (
copy_to_model_parallel_region,
gather_from_model_parallel_region,
ColumnParallelLinear,
RowParallelLinear,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
logger = logging.getLogger(__name__)
@register_model('model_parallel_roberta')
class ModelParallelRobertaModel(RobertaModel):
def __init__(self, args, encoder):
super().__init__(args, encoder)
self.classification_heads = nn.ModuleDict()
@staticmethod
def add_args(parser):
super(ModelParallelRobertaModel, ModelParallelRobertaModel).add_args(parser)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present
base_architecture(args)
if not hasattr(args, 'max_positions'):
args.max_positions = args.tokens_per_sample
encoder = ModelParallelRobertaEncoder(args, task.source_dictionary)
return cls(args, encoder)
def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs):
if classification_head_name is not None:
features_only = True
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
return x, extra
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
'and inner_dim {} (prev: {})'.format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = ModelParallelRobertaClassificationHead(
self.args.encoder_embed_dim,
inner_dim or self.args.encoder_embed_dim,
num_classes,
self.args.pooler_activation_fn,
self.args.pooler_dropout,
)
class ModelParallelRobertaLMHead(nn.Module):
"""Head for masked language modeling."""
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.layer_norm = LayerNorm(embed_dim)
if weight is None:
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features, masked_tokens=None, **kwargs):
# Only project the unmasked tokens while training,
# saves both memory and computation
if masked_tokens is not None:
features = features[masked_tokens, :]
x = self.dense(features)
x = self.activation_fn(x)
x = self.layer_norm(x)
features = copy_to_model_parallel_region(features)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight)
x = gather_from_model_parallel_region(x).contiguous()
x = x + self.bias
return x
class ModelParallelRobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout):
super().__init__()
self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class ModelParallelRobertaEncoder(FairseqEncoder):
"""RoBERTa encoder.
Implements the :class:`~fairseq.models.FairseqDecoder` interface required
by :class:`~fairseq.models.FairseqLanguageModel`.
"""
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.args = args
# RoBERTa is a sentence encoder model, so users will intuitively trim
# encoder layers. However, the implementation uses the fairseq decoder,
# so we fix here.
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
args.decoder_layers_to_keep = args.encoder_layers_to_keep
args.encoder_layers_to_keep = None
self.sentence_encoder = ModelParallelTransformerSentenceEncoder(
padding_idx=dictionary.pad(),
vocab_size=len(dictionary),
num_encoder_layers=args.encoder_layers,
embedding_dim=args.encoder_embed_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
layerdrop=args.encoder_layerdrop,
max_seq_len=args.max_positions,
num_segments=0,
encoder_normalize_before=False,
apply_bert_init=False,
activation_fn=args.activation_fn,
)
self.lm_head = ModelParallelRobertaLMHead(
embed_dim=args.encoder_embed_dim,
output_dim=len(dictionary),
activation_fn=args.activation_fn,
weight=self.sentence_encoder.embed_tokens.weight,
)
def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused):
"""
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
features_only (bool, optional): skip LM head and just return
features. If True, the output will be of shape
`(batch, src_len, embed_dim)`.
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
tuple:
- the LM output of shape `(batch, src_len, vocab)`
- a dictionary of additional data, where 'inner_states'
is a list of hidden states. Note that the hidden
states have shape `(src_len, batch, vocab)`.
"""
x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens)
if not features_only:
x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra
def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
inner_states, _ = self.sentence_encoder(
src_tokens,
last_state_only=not return_all_hiddens,
)
features = inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
return features, {'inner_states': inner_states if return_all_hiddens else None}
def output_layer(self, features, masked_tokens=None, **unused):
return self.lm_head(features, masked_tokens)
def max_positions(self):
"""Maximum output length supported by the encoder."""
return self.args.max_positions
@register_model_architecture('model_parallel_roberta', 'model_parallel_roberta')
def base_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 12)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
args.encoder_layers_to_keep = getattr(args, 'encoder_layers_to_keep', None)
args.encoder_layerdrop = getattr(args, 'encoder_layerdrop', 0.0)
@register_model_architecture('model_parallel_roberta', 'model_parallel_roberta_base')
def roberta_base_architecture(args):
base_architecture(args)
@register_model_architecture('model_parallel_roberta', 'model_parallel_roberta_large')
def roberta_large_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 24)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
base_architecture(args)