| |
| |
| |
| |
|
|
| import logging |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from fairseq import utils |
| from fairseq.models import ( |
| FairseqEncoder, |
| FairseqEncoderModel, |
| register_model, |
| register_model_architecture, |
| ) |
| from fairseq.modules import ( |
| LayerNorm, |
| SinusoidalPositionalEmbedding, |
| TransformerSentenceEncoder, |
| ) |
| from fairseq.modules.transformer_sentence_encoder import init_bert_params |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @register_model("masked_lm") |
| class MaskedLMModel(FairseqEncoderModel): |
| """ |
| Class for training a Masked Language Model. It also supports an |
| additional sentence level prediction if the sent-loss argument is set. |
| """ |
|
|
| def __init__(self, args, encoder): |
| super().__init__(encoder) |
| self.args = args |
|
|
| |
| |
| |
| if getattr(args, "apply_bert_init", False): |
| self.apply(init_bert_params) |
|
|
| @staticmethod |
| def add_args(parser): |
| """Add model-specific arguments to the parser.""" |
| |
| parser.add_argument( |
| "--dropout", type=float, metavar="D", help="dropout probability" |
| ) |
| parser.add_argument( |
| "--attention-dropout", |
| type=float, |
| metavar="D", |
| help="dropout probability for" " attention weights", |
| ) |
| parser.add_argument( |
| "--act-dropout", |
| type=float, |
| metavar="D", |
| help="dropout probability after" " activation in FFN", |
| ) |
|
|
| |
| parser.add_argument( |
| "--encoder-ffn-embed-dim", |
| type=int, |
| metavar="N", |
| help="encoder embedding dimension for FFN", |
| ) |
| parser.add_argument( |
| "--encoder-layers", type=int, metavar="N", help="num encoder layers" |
| ) |
| parser.add_argument( |
| "--encoder-attention-heads", |
| type=int, |
| metavar="N", |
| help="num encoder attention heads", |
| ) |
|
|
| |
| parser.add_argument( |
| "--encoder-embed-dim", |
| type=int, |
| metavar="N", |
| help="encoder embedding dimension", |
| ) |
| parser.add_argument( |
| "--share-encoder-input-output-embed", |
| action="store_true", |
| help="share encoder input" " and output embeddings", |
| ) |
| parser.add_argument( |
| "--encoder-learned-pos", |
| action="store_true", |
| help="use learned positional embeddings in the encoder", |
| ) |
| parser.add_argument( |
| "--no-token-positional-embeddings", |
| action="store_true", |
| help="if set, disables positional embeddings" " (outside self attention)", |
| ) |
| parser.add_argument( |
| "--num-segment", type=int, metavar="N", help="num segment in the input" |
| ) |
| parser.add_argument( |
| "--max-positions", type=int, help="number of positional embeddings to learn" |
| ) |
|
|
| |
| parser.add_argument( |
| "--sentence-class-num", |
| type=int, |
| metavar="N", |
| help="number of classes for sentence task", |
| ) |
| parser.add_argument( |
| "--sent-loss", |
| action="store_true", |
| help="if set," " calculate sentence level predictions", |
| ) |
|
|
| |
| parser.add_argument( |
| "--apply-bert-init", |
| action="store_true", |
| help="use custom param initialization for BERT", |
| ) |
|
|
| |
| parser.add_argument( |
| "--activation-fn", |
| choices=utils.get_available_activation_fns(), |
| help="activation function to use", |
| ) |
| parser.add_argument( |
| "--pooler-activation-fn", |
| choices=utils.get_available_activation_fns(), |
| help="Which activation function to use for pooler layer.", |
| ) |
| parser.add_argument( |
| "--encoder-normalize-before", |
| action="store_true", |
| help="apply layernorm before each encoder block", |
| ) |
|
|
| def forward(self, src_tokens, segment_labels=None, **kwargs): |
| return self.encoder(src_tokens, segment_labels=segment_labels, **kwargs) |
|
|
| def max_positions(self): |
| return self.encoder.max_positions |
|
|
| @classmethod |
| def build_model(cls, args, task): |
| """Build a new model instance.""" |
| |
| base_architecture(args) |
|
|
| if not hasattr(args, "max_positions"): |
| args.max_positions = args.tokens_per_sample |
|
|
| logger.info(args) |
|
|
| encoder = MaskedLMEncoder(args, task.dictionary) |
| return cls(args, encoder) |
|
|
|
|
| class MaskedLMEncoder(FairseqEncoder): |
| """ |
| Encoder for Masked Language Modelling. |
| """ |
|
|
| def __init__(self, args, dictionary): |
| super().__init__(dictionary) |
|
|
| self.padding_idx = dictionary.pad() |
| self.vocab_size = dictionary.__len__() |
| self.max_positions = args.max_positions |
|
|
| self.sentence_encoder = TransformerSentenceEncoder( |
| padding_idx=self.padding_idx, |
| vocab_size=self.vocab_size, |
| num_encoder_layers=args.encoder_layers, |
| embedding_dim=args.encoder_embed_dim, |
| ffn_embedding_dim=args.encoder_ffn_embed_dim, |
| num_attention_heads=args.encoder_attention_heads, |
| dropout=args.dropout, |
| attention_dropout=args.attention_dropout, |
| activation_dropout=args.act_dropout, |
| max_seq_len=self.max_positions, |
| num_segments=args.num_segment, |
| use_position_embeddings=not args.no_token_positional_embeddings, |
| encoder_normalize_before=args.encoder_normalize_before, |
| apply_bert_init=args.apply_bert_init, |
| activation_fn=args.activation_fn, |
| learned_pos_embedding=args.encoder_learned_pos, |
| ) |
|
|
| self.share_input_output_embed = args.share_encoder_input_output_embed |
| self.embed_out = None |
| self.sentence_projection_layer = None |
| self.sentence_out_dim = args.sentence_class_num |
| self.lm_output_learned_bias = None |
|
|
| |
| self.load_softmax = not getattr(args, "remove_head", False) |
|
|
| self.masked_lm_pooler = nn.Linear( |
| args.encoder_embed_dim, args.encoder_embed_dim |
| ) |
| self.pooler_activation = utils.get_activation_fn(args.pooler_activation_fn) |
|
|
| self.lm_head_transform_weight = nn.Linear( |
| args.encoder_embed_dim, args.encoder_embed_dim |
| ) |
| self.activation_fn = utils.get_activation_fn(args.activation_fn) |
| self.layer_norm = LayerNorm(args.encoder_embed_dim) |
|
|
| self.lm_output_learned_bias = None |
| if self.load_softmax: |
| self.lm_output_learned_bias = nn.Parameter(torch.zeros(self.vocab_size)) |
|
|
| if not self.share_input_output_embed: |
| self.embed_out = nn.Linear( |
| args.encoder_embed_dim, self.vocab_size, bias=False |
| ) |
|
|
| if args.sent_loss: |
| self.sentence_projection_layer = nn.Linear( |
| args.encoder_embed_dim, self.sentence_out_dim, bias=False |
| ) |
|
|
| def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused): |
| """ |
| Forward pass for Masked LM encoder. This first computes the token |
| embedding using the token embedding matrix, position embeddings (if |
| specified) and segment embeddings (if specified). |
| |
| Here we assume that the sentence representation corresponds to the |
| output of the classification_token (see bert_task or cross_lingual_lm |
| task for more details). |
| Args: |
| - src_tokens: B x T matrix representing sentences |
| - segment_labels: B x T matrix representing segment label for tokens |
| Returns: |
| - a tuple of the following: |
| - logits for predictions in format B x T x C to be used in |
| softmax afterwards |
| - a dictionary of additional data, where 'pooled_output' contains |
| the representation for classification_token and 'inner_states' |
| is a list of internal model states used to compute the |
| predictions (similar in ELMO). 'sentence_logits' |
| is the prediction logit for NSP task and is only computed if |
| this is specified in the input arguments. |
| """ |
|
|
| inner_states, sentence_rep = self.sentence_encoder( |
| src_tokens, |
| segment_labels=segment_labels, |
| ) |
|
|
| x = inner_states[-1].transpose(0, 1) |
| |
| if masked_tokens is not None: |
| x = x[masked_tokens, :] |
| x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x))) |
|
|
| pooled_output = self.pooler_activation(self.masked_lm_pooler(sentence_rep)) |
|
|
| |
| if self.share_input_output_embed and hasattr( |
| self.sentence_encoder.embed_tokens, "weight" |
| ): |
| x = F.linear(x, self.sentence_encoder.embed_tokens.weight) |
| elif self.embed_out is not None: |
| x = self.embed_out(x) |
| if self.lm_output_learned_bias is not None: |
| x = x + self.lm_output_learned_bias |
| sentence_logits = None |
| if self.sentence_projection_layer: |
| sentence_logits = self.sentence_projection_layer(pooled_output) |
|
|
| return x, { |
| "inner_states": inner_states, |
| "pooled_output": pooled_output, |
| "sentence_logits": sentence_logits, |
| } |
|
|
| def max_positions(self): |
| """Maximum output length supported by the encoder.""" |
| return self.max_positions |
|
|
| def upgrade_state_dict_named(self, state_dict, name): |
| if isinstance( |
| self.sentence_encoder.embed_positions, SinusoidalPositionalEmbedding |
| ): |
| state_dict[ |
| name + ".sentence_encoder.embed_positions._float_tensor" |
| ] = torch.FloatTensor(1) |
| if not self.load_softmax: |
| for k in list(state_dict.keys()): |
| if ( |
| "embed_out.weight" in k |
| or "sentence_projection_layer.weight" in k |
| or "lm_output_learned_bias" in k |
| ): |
| del state_dict[k] |
| return state_dict |
|
|
|
|
| @register_model_architecture("masked_lm", "masked_lm") |
| def base_architecture(args): |
| args.dropout = getattr(args, "dropout", 0.1) |
| args.attention_dropout = getattr(args, "attention_dropout", 0.1) |
| args.act_dropout = getattr(args, "act_dropout", 0.0) |
|
|
| args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) |
| args.encoder_layers = getattr(args, "encoder_layers", 6) |
| args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) |
|
|
| args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) |
| args.share_encoder_input_output_embed = getattr( |
| args, "share_encoder_input_output_embed", False |
| ) |
| args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) |
| args.no_token_positional_embeddings = getattr( |
| args, "no_token_positional_embeddings", False |
| ) |
| args.num_segment = getattr(args, "num_segment", 2) |
|
|
| args.sentence_class_num = getattr(args, "sentence_class_num", 2) |
| args.sent_loss = getattr(args, "sent_loss", False) |
|
|
| args.apply_bert_init = getattr(args, "apply_bert_init", False) |
|
|
| args.activation_fn = getattr(args, "activation_fn", "relu") |
| args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") |
| args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) |
|
|
|
|
| @register_model_architecture("masked_lm", "bert_base") |
| def bert_base_architecture(args): |
| args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) |
| args.share_encoder_input_output_embed = getattr( |
| args, "share_encoder_input_output_embed", True |
| ) |
| args.no_token_positional_embeddings = getattr( |
| args, "no_token_positional_embeddings", False |
| ) |
| args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) |
| args.num_segment = getattr(args, "num_segment", 2) |
|
|
| args.encoder_layers = getattr(args, "encoder_layers", 12) |
|
|
| args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) |
| args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) |
|
|
| args.sentence_class_num = getattr(args, "sentence_class_num", 2) |
| args.sent_loss = getattr(args, "sent_loss", True) |
|
|
| args.apply_bert_init = getattr(args, "apply_bert_init", True) |
|
|
| args.activation_fn = getattr(args, "activation_fn", "gelu") |
| args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") |
| args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) |
| base_architecture(args) |
|
|
|
|
| @register_model_architecture("masked_lm", "bert_large") |
| def bert_large_architecture(args): |
| args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) |
| args.encoder_layers = getattr(args, "encoder_layers", 24) |
| args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) |
| args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) |
| bert_base_architecture(args) |
|
|
|
|
| @register_model_architecture("masked_lm", "xlm_base") |
| def xlm_architecture(args): |
| args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) |
| args.share_encoder_input_output_embed = getattr( |
| args, "share_encoder_input_output_embed", True |
| ) |
| args.no_token_positional_embeddings = getattr( |
| args, "no_token_positional_embeddings", False |
| ) |
| args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) |
| args.num_segment = getattr(args, "num_segment", 1) |
|
|
| args.encoder_layers = getattr(args, "encoder_layers", 6) |
|
|
| args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) |
| args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) |
|
|
| args.sent_loss = getattr(args, "sent_loss", False) |
|
|
| args.activation_fn = getattr(args, "activation_fn", "gelu") |
| args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) |
| args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") |
| args.apply_bert_init = getattr(args, "apply_bert_init", True) |
| base_architecture(args) |
|
|