|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Transformer-based BERT encoder network."""
|
|
|
|
|
| from typing import Any, Callable, Optional, Union, Tuple
|
| from absl import logging
|
| import tensorflow as tf, tf_keras
|
|
|
| from official.modeling import tf_utils
|
| from official.nlp.modeling import layers
|
|
|
|
|
| _Initializer = Union[str, tf_keras.initializers.Initializer]
|
| _Activation = Union[str, Callable[..., Any]]
|
|
|
| _approx_gelu = lambda x: tf_keras.activations.gelu(x, approximate=True)
|
|
|
|
|
| class TokenDropBertEncoder(tf_keras.layers.Layer):
|
| """Bi-directional Transformer-based encoder network with token dropping.
|
|
|
| During pretraining, we drop unimportant tokens starting from an intermediate
|
| layer in the model, to make the model focus on important tokens more
|
| efficiently with its limited computational resources. The dropped tokens are
|
| later picked up by the last layer of the model, so that the model still
|
| produces full-length sequences. This approach reduces the pretraining cost of
|
| BERT by 25% while achieving better overall fine-tuning performance on standard
|
| downstream tasks.
|
|
|
| Args:
|
| vocab_size: The size of the token vocabulary.
|
| hidden_size: The size of the transformer hidden layers.
|
| num_layers: The number of transformer layers.
|
| num_attention_heads: The number of attention heads for each transformer. The
|
| hidden size must be divisible by the number of attention heads.
|
| max_sequence_length: The maximum sequence length that this encoder can
|
| consume. If None, max_sequence_length uses the value from sequence length.
|
| This determines the variable shape for positional embeddings.
|
| type_vocab_size: The number of types that the 'type_ids' input can take.
|
| inner_dim: The output dimension of the first Dense layer in a two-layer
|
| feedforward network for each transformer.
|
| inner_activation: The activation for the first Dense layer in a two-layer
|
| feedforward network for each transformer.
|
| output_dropout: Dropout probability for the post-attention and output
|
| dropout.
|
| attention_dropout: The dropout rate to use for the attention layers within
|
| the transformer layers.
|
| token_loss_init_value: The default loss value of a token, when the token is
|
| never masked and predicted.
|
| token_loss_beta: How running average factor for computing the average loss
|
| value of a token.
|
| token_keep_k: The number of tokens you want to keep in the intermediate
|
| layers. The rest will be dropped in those layers.
|
| token_allow_list: The list of token-ids that should not be droped. In the
|
| BERT English vocab, token-id from 1 to 998 contains special tokens such as
|
| [CLS], [SEP]. By default, token_allow_list contains all of these special
|
| tokens.
|
| token_deny_list: The list of token-ids that should always be droped. In the
|
| BERT English vocab, token-id=0 means [PAD]. By default, token_deny_list
|
| contains and only contains [PAD].
|
| initializer: The initialzer to use for all weights in this encoder.
|
| output_range: The sequence output range, [0, output_range), by slicing the
|
| target sequence of the last transformer layer. `None` means the entire
|
| target sequence will attend to the source sequence, which yields the full
|
| output.
|
| embedding_width: The width of the word embeddings. If the embedding width is
|
| not equal to hidden size, embedding parameters will be factorized into two
|
| matrices in the shape of ['vocab_size', 'embedding_width'] and
|
| ['embedding_width', 'hidden_size'] ('embedding_width' is usually much
|
| smaller than 'hidden_size').
|
| embedding_layer: An optional Layer instance which will be called to generate
|
| embeddings for the input word IDs.
|
| norm_first: Whether to normalize inputs to attention and intermediate dense
|
| layers. If set False, output of attention and intermediate dense layers is
|
| normalized.
|
| with_dense_inputs: Whether to accept dense embeddings as the input.
|
| """
|
|
|
| def __init__(
|
| self,
|
| vocab_size: int,
|
| hidden_size: int = 768,
|
| num_layers: int = 12,
|
| num_attention_heads: int = 12,
|
| max_sequence_length: int = 512,
|
| type_vocab_size: int = 16,
|
| inner_dim: int = 3072,
|
| inner_activation: _Activation = _approx_gelu,
|
| output_dropout: float = 0.1,
|
| attention_dropout: float = 0.1,
|
| token_loss_init_value: float = 10.0,
|
| token_loss_beta: float = 0.995,
|
| token_keep_k: int = 256,
|
| token_allow_list: Tuple[int, ...] = (100, 101, 102, 103),
|
| token_deny_list: Tuple[int, ...] = (0,),
|
| initializer: _Initializer = tf_keras.initializers.TruncatedNormal(
|
| stddev=0.02),
|
| output_range: Optional[int] = None,
|
| embedding_width: Optional[int] = None,
|
| embedding_layer: Optional[tf_keras.layers.Layer] = None,
|
| norm_first: bool = False,
|
| with_dense_inputs: bool = False,
|
| **kwargs):
|
|
|
| if 'dict_outputs' in kwargs:
|
| kwargs.pop('dict_outputs')
|
| if 'return_all_encoder_outputs' in kwargs:
|
| kwargs.pop('return_all_encoder_outputs')
|
| if 'intermediate_size' in kwargs:
|
| inner_dim = kwargs.pop('intermediate_size')
|
| if 'activation' in kwargs:
|
| inner_activation = kwargs.pop('activation')
|
| if 'dropout_rate' in kwargs:
|
| output_dropout = kwargs.pop('dropout_rate')
|
| if 'attention_dropout_rate' in kwargs:
|
| attention_dropout = kwargs.pop('attention_dropout_rate')
|
| super().__init__(**kwargs)
|
|
|
| if output_range is not None:
|
| logging.warning('`output_range` is available as an argument for `call()`.'
|
| 'The `output_range` as __init__ argument is deprecated.')
|
|
|
| activation = tf_keras.activations.get(inner_activation)
|
| initializer = tf_keras.initializers.get(initializer)
|
|
|
| if embedding_width is None:
|
| embedding_width = hidden_size
|
|
|
| if embedding_layer is None:
|
| self._embedding_layer = layers.OnDeviceEmbedding(
|
| vocab_size=vocab_size,
|
| embedding_width=embedding_width,
|
| initializer=tf_utils.clone_initializer(initializer),
|
| name='word_embeddings')
|
| else:
|
| self._embedding_layer = embedding_layer
|
|
|
| self._position_embedding_layer = layers.PositionEmbedding(
|
| initializer=tf_utils.clone_initializer(initializer),
|
| max_length=max_sequence_length,
|
| name='position_embedding')
|
|
|
| self._type_embedding_layer = layers.OnDeviceEmbedding(
|
| vocab_size=type_vocab_size,
|
| embedding_width=embedding_width,
|
| initializer=tf_utils.clone_initializer(initializer),
|
| use_one_hot=True,
|
| name='type_embeddings')
|
|
|
| self._embedding_norm_layer = tf_keras.layers.LayerNormalization(
|
| name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)
|
|
|
| self._embedding_dropout = tf_keras.layers.Dropout(
|
| rate=output_dropout, name='embedding_dropout')
|
|
|
|
|
|
|
| self._embedding_projection = None
|
| if embedding_width != hidden_size:
|
| self._embedding_projection = tf_keras.layers.EinsumDense(
|
| '...x,xy->...y',
|
| output_shape=hidden_size,
|
| bias_axes='y',
|
| kernel_initializer=tf_utils.clone_initializer(initializer),
|
| name='embedding_projection')
|
|
|
|
|
|
|
| init_importance = tf.constant(token_loss_init_value, shape=(vocab_size))
|
| if token_allow_list:
|
| init_importance = tf.tensor_scatter_nd_update(
|
| tensor=init_importance,
|
| indices=[[x] for x in token_allow_list],
|
| updates=[1.0e4 for x in token_allow_list])
|
| if token_deny_list:
|
| init_importance = tf.tensor_scatter_nd_update(
|
| tensor=init_importance,
|
| indices=[[x] for x in token_deny_list],
|
| updates=[-1.0e4 for x in token_deny_list])
|
| self._token_importance_embed = layers.TokenImportanceWithMovingAvg(
|
| vocab_size=vocab_size,
|
| init_importance=init_importance,
|
| moving_average_beta=token_loss_beta)
|
|
|
| self._token_separator = layers.SelectTopK(top_k=token_keep_k)
|
| self._transformer_layers = []
|
| self._num_layers = num_layers
|
| self._attention_mask_layer = layers.SelfAttentionMask(
|
| name='self_attention_mask')
|
| for i in range(num_layers):
|
| layer = layers.TransformerEncoderBlock(
|
| num_attention_heads=num_attention_heads,
|
| inner_dim=inner_dim,
|
| inner_activation=inner_activation,
|
| output_dropout=output_dropout,
|
| attention_dropout=attention_dropout,
|
| norm_first=norm_first,
|
| kernel_initializer=tf_utils.clone_initializer(initializer),
|
| name='transformer/layer_%d' % i)
|
| self._transformer_layers.append(layer)
|
|
|
| self._pooler_layer = tf_keras.layers.Dense(
|
| units=hidden_size,
|
| activation='tanh',
|
| kernel_initializer=tf_utils.clone_initializer(initializer),
|
| name='pooler_transform')
|
|
|
| self._config = {
|
| 'vocab_size': vocab_size,
|
| 'hidden_size': hidden_size,
|
| 'num_layers': num_layers,
|
| 'num_attention_heads': num_attention_heads,
|
| 'max_sequence_length': max_sequence_length,
|
| 'type_vocab_size': type_vocab_size,
|
| 'inner_dim': inner_dim,
|
| 'inner_activation': tf_keras.activations.serialize(activation),
|
| 'output_dropout': output_dropout,
|
| 'attention_dropout': attention_dropout,
|
| 'token_loss_init_value': token_loss_init_value,
|
| 'token_loss_beta': token_loss_beta,
|
| 'token_keep_k': token_keep_k,
|
| 'token_allow_list': token_allow_list,
|
| 'token_deny_list': token_deny_list,
|
| 'initializer': tf_keras.initializers.serialize(initializer),
|
| 'output_range': output_range,
|
| 'embedding_width': embedding_width,
|
| 'embedding_layer': embedding_layer,
|
| 'norm_first': norm_first,
|
| 'with_dense_inputs': with_dense_inputs,
|
| }
|
| if with_dense_inputs:
|
| self.inputs = dict(
|
| input_word_ids=tf_keras.Input(shape=(None,), dtype=tf.int32),
|
| input_mask=tf_keras.Input(shape=(None,), dtype=tf.int32),
|
| input_type_ids=tf_keras.Input(shape=(None,), dtype=tf.int32),
|
| dense_inputs=tf_keras.Input(
|
| shape=(None, embedding_width), dtype=tf.float32),
|
| dense_mask=tf_keras.Input(shape=(None,), dtype=tf.int32),
|
| dense_type_ids=tf_keras.Input(shape=(None,), dtype=tf.int32),
|
| )
|
| else:
|
| self.inputs = dict(
|
| input_word_ids=tf_keras.Input(shape=(None,), dtype=tf.int32),
|
| input_mask=tf_keras.Input(shape=(None,), dtype=tf.int32),
|
| input_type_ids=tf_keras.Input(shape=(None,), dtype=tf.int32))
|
|
|
| def call(self, inputs, output_range: Optional[tf.Tensor] = None):
|
| if isinstance(inputs, dict):
|
| word_ids = inputs.get('input_word_ids')
|
| mask = inputs.get('input_mask')
|
| type_ids = inputs.get('input_type_ids')
|
|
|
| dense_inputs = inputs.get('dense_inputs', None)
|
| dense_mask = inputs.get('dense_mask', None)
|
| dense_type_ids = inputs.get('dense_type_ids', None)
|
| else:
|
| raise ValueError('Unexpected inputs type to %s.' % self.__class__)
|
|
|
| word_embeddings = self._embedding_layer(word_ids)
|
|
|
| if dense_inputs is not None:
|
|
|
| word_embeddings = tf.concat([word_embeddings, dense_inputs], axis=1)
|
| type_ids = tf.concat([type_ids, dense_type_ids], axis=1)
|
| mask = tf.concat([mask, dense_mask], axis=1)
|
|
|
|
|
| position_embeddings = self._position_embedding_layer(word_embeddings)
|
| type_embeddings = self._type_embedding_layer(type_ids)
|
|
|
| embeddings = word_embeddings + position_embeddings + type_embeddings
|
| embeddings = self._embedding_norm_layer(embeddings)
|
| embeddings = self._embedding_dropout(embeddings)
|
|
|
| if self._embedding_projection is not None:
|
| embeddings = self._embedding_projection(embeddings)
|
|
|
| attention_mask = self._attention_mask_layer(embeddings, mask)
|
|
|
| encoder_outputs = []
|
| x = embeddings
|
|
|
|
|
| token_importance = self._token_importance_embed(word_ids)
|
| selected, not_selected = self._token_separator(token_importance)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| for i, layer in enumerate(self._transformer_layers[:self._num_layers // 2 -
|
| 1]):
|
| x = layer([x, attention_mask],
|
| output_range=output_range if i == self._num_layers -
|
| 1 else None)
|
| encoder_outputs.append(x)
|
|
|
|
|
|
|
| x_selected = tf.gather(x, selected, batch_dims=1, axis=1)
|
| mask_selected = tf.gather(mask, selected, batch_dims=1, axis=1)
|
| attention_mask_token_drop = self._attention_mask_layer(
|
| x_selected, mask_selected)
|
|
|
| x_not_selected = tf.gather(x, not_selected, batch_dims=1, axis=1)
|
| mask_not_selected = tf.gather(mask, not_selected, batch_dims=1, axis=1)
|
| attention_mask_token_pass = self._attention_mask_layer(
|
| x_selected, tf.concat([mask_selected, mask_not_selected], axis=1))
|
| x_all = tf.concat([x_selected, x_not_selected], axis=1)
|
|
|
|
|
| x_selected = self._transformer_layers[self._num_layers // 2 - 1](
|
| [x_selected, x_all, attention_mask_token_pass],
|
| output_range=output_range if self._num_layers // 2 -
|
| 1 == self._num_layers - 1 else None)
|
| encoder_outputs.append(x_selected)
|
|
|
|
|
| for i, layer in enumerate(self._transformer_layers[self._num_layers //
|
| 2:-1]):
|
| x_selected = layer([x_selected, attention_mask_token_drop],
|
| output_range=output_range if i == self._num_layers - 1
|
| else None)
|
| encoder_outputs.append(x_selected)
|
|
|
|
|
|
|
| x_not_selected = tf.cast(x_not_selected, dtype=x_selected.dtype)
|
| x = tf.concat([x_selected, x_not_selected], axis=1)
|
| indices = tf.concat([selected, not_selected], axis=1)
|
| reverse_indices = tf.argsort(indices)
|
| x = tf.gather(x, reverse_indices, batch_dims=1, axis=1)
|
|
|
|
|
| x = self._transformer_layers[-1]([x, attention_mask],
|
| output_range=output_range)
|
| encoder_outputs.append(x)
|
|
|
| last_encoder_output = encoder_outputs[-1]
|
| first_token_tensor = last_encoder_output[:, 0, :]
|
| pooled_output = self._pooler_layer(first_token_tensor)
|
|
|
| return dict(
|
| sequence_output=encoder_outputs[-1],
|
| pooled_output=pooled_output,
|
| encoder_outputs=encoder_outputs)
|
|
|
| def record_mlm_loss(self, mlm_ids: tf.Tensor, mlm_losses: tf.Tensor):
|
| self._token_importance_embed.update_token_importance(
|
| token_ids=mlm_ids, importance=mlm_losses)
|
|
|
| def get_embedding_table(self):
|
| return self._embedding_layer.embeddings
|
|
|
| def get_embedding_layer(self):
|
| return self._embedding_layer
|
|
|
| def get_config(self):
|
| return dict(self._config)
|
|
|
| @property
|
| def transformer_layers(self):
|
| """List of Transformer layers in the encoder."""
|
| return self._transformer_layers
|
|
|
| @property
|
| def pooler_layer(self):
|
| """The pooler dense layer after the transformer layers."""
|
| return self._pooler_layer
|
|
|
| @classmethod
|
| def from_config(cls, config, custom_objects=None):
|
| if 'embedding_layer' in config and config['embedding_layer'] is not None:
|
| warn_string = (
|
| 'You are reloading a model that was saved with a '
|
| 'potentially-shared embedding layer object. If you contine to '
|
| 'train this model, the embedding layer will no longer be shared. '
|
| 'To work around this, load the model outside of the Keras API.')
|
| print('WARNING: ' + warn_string)
|
| logging.warn(warn_string)
|
|
|
| return cls(**config)
|
|
|