Spaces:
Sleeping
Sleeping
| # Copyright 2024 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Loads dataset for the BERT pretraining task.""" | |
| import dataclasses | |
| from typing import Mapping, Optional | |
| from absl import logging | |
| import numpy as np | |
| import tensorflow as tf, tf_keras | |
| from official.common import dataset_fn | |
| from official.core import config_definitions as cfg | |
| from official.core import input_reader | |
| from official.nlp.data import data_loader | |
| from official.nlp.data import data_loader_factory | |
| class BertPretrainDataConfig(cfg.DataConfig): | |
| """Data config for BERT pretraining task (tasks/masked_lm).""" | |
| input_path: str = '' | |
| global_batch_size: int = 512 | |
| is_training: bool = True | |
| seq_length: int = 512 | |
| max_predictions_per_seq: int = 76 | |
| use_next_sentence_label: bool = True | |
| use_position_id: bool = False | |
| # Historically, BERT implementations take `input_ids` and `segment_ids` as | |
| # feature names. Inside the TF Model Garden implementation, the Keras model | |
| # inputs are set as `input_word_ids` and `input_type_ids`. When | |
| # v2_feature_names is True, the data loader assumes the tf.Examples use | |
| # `input_word_ids` and `input_type_ids` as keys. | |
| use_v2_feature_names: bool = False | |
| file_type: str = 'tfrecord' | |
| class BertPretrainDataLoader(data_loader.DataLoader): | |
| """A class to load dataset for bert pretraining task.""" | |
| def __init__(self, params): | |
| """Inits `BertPretrainDataLoader` class. | |
| Args: | |
| params: A `BertPretrainDataConfig` object. | |
| """ | |
| self._params = params | |
| self._seq_length = params.seq_length | |
| self._max_predictions_per_seq = params.max_predictions_per_seq | |
| self._use_next_sentence_label = params.use_next_sentence_label | |
| self._use_position_id = params.use_position_id | |
| def _name_to_features(self): | |
| name_to_features = { | |
| 'input_mask': | |
| tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
| 'masked_lm_positions': | |
| tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64), | |
| 'masked_lm_ids': | |
| tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64), | |
| 'masked_lm_weights': | |
| tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32), | |
| } | |
| if self._params.use_v2_feature_names: | |
| name_to_features.update({ | |
| 'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
| 'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
| }) | |
| else: | |
| name_to_features.update({ | |
| 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
| 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
| }) | |
| if self._use_next_sentence_label: | |
| name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1], | |
| tf.int64) | |
| if self._use_position_id: | |
| name_to_features['position_ids'] = tf.io.FixedLenFeature( | |
| [self._seq_length], tf.int64) | |
| return name_to_features | |
| def _decode(self, record: tf.Tensor): | |
| """Decodes a serialized tf.Example.""" | |
| name_to_features = self._name_to_features() | |
| example = tf.io.parse_single_example(record, name_to_features) | |
| # tf.Example only supports tf.int64, but the TPU only supports tf.int32. | |
| # So cast all int64 to int32. | |
| for name in list(example.keys()): | |
| t = example[name] | |
| if t.dtype == tf.int64: | |
| t = tf.cast(t, tf.int32) | |
| example[name] = t | |
| return example | |
| def _parse(self, record: Mapping[str, tf.Tensor]): | |
| """Parses raw tensors into a dict of tensors to be consumed by the model.""" | |
| x = { | |
| 'input_mask': record['input_mask'], | |
| 'masked_lm_positions': record['masked_lm_positions'], | |
| 'masked_lm_ids': record['masked_lm_ids'], | |
| 'masked_lm_weights': record['masked_lm_weights'], | |
| } | |
| if self._params.use_v2_feature_names: | |
| x['input_word_ids'] = record['input_word_ids'] | |
| x['input_type_ids'] = record['input_type_ids'] | |
| else: | |
| x['input_word_ids'] = record['input_ids'] | |
| x['input_type_ids'] = record['segment_ids'] | |
| if self._use_next_sentence_label: | |
| x['next_sentence_labels'] = record['next_sentence_labels'] | |
| if self._use_position_id: | |
| x['position_ids'] = record['position_ids'] | |
| return x | |
| def load(self, input_context: Optional[tf.distribute.InputContext] = None): | |
| """Returns a tf.dataset.Dataset.""" | |
| reader = input_reader.InputReader( | |
| params=self._params, | |
| dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type), | |
| decoder_fn=self._decode, | |
| parser_fn=self._parse) | |
| return reader.read(input_context) | |
| class XLNetPretrainDataConfig(cfg.DataConfig): | |
| """Data config for XLNet pretraining task. | |
| Attributes: | |
| input_path: See base class. | |
| global_batch_size: See base class. | |
| is_training: See base class. | |
| seq_length: The length of each sequence. | |
| max_predictions_per_seq: The number of predictions per sequence. | |
| reuse_length: The number of tokens in a previous segment to reuse. This | |
| should be the same value used during pretrain data creation. | |
| sample_strategy: The strategy used to sample factorization permutations. | |
| Possible values: 'single_token', 'whole_word', 'token_span', 'word_span'. | |
| min_num_tokens: The minimum number of tokens to sample in a span. This is | |
| used when `sample_strategy` is 'token_span'. | |
| max_num_tokens: The maximum number of tokens to sample in a span. This is | |
| used when `sample_strategy` is 'token_span'. | |
| min_num_words: The minimum number of words to sample in a span. This is used | |
| when `sample_strategy` is 'word_span'. | |
| max_num_words: The maximum number of words to sample in a span. This is used | |
| when `sample_strategy` is 'word_span'. | |
| permutation_size: The length of the longest permutation. This can be set to | |
| `reuse_length`. This should NOT be greater than `reuse_length`, otherwise | |
| this may introduce data leaks. | |
| leak_ratio: The percentage of masked tokens that are leaked. | |
| segment_sep_id: The ID of the SEP token used when preprocessing the dataset. | |
| segment_cls_id: The ID of the CLS token used when preprocessing the dataset. | |
| """ | |
| input_path: str = '' | |
| global_batch_size: int = 512 | |
| is_training: bool = True | |
| seq_length: int = 512 | |
| max_predictions_per_seq: int = 76 | |
| reuse_length: int = 256 | |
| sample_strategy: str = 'word_span' | |
| min_num_tokens: int = 1 | |
| max_num_tokens: int = 5 | |
| min_num_words: int = 1 | |
| max_num_words: int = 5 | |
| permutation_size: int = 256 | |
| leak_ratio: float = 0.1 | |
| segment_sep_id: int = 4 | |
| segment_cls_id: int = 3 | |
| class XLNetPretrainDataLoader(data_loader.DataLoader): | |
| """A class to load dataset for xlnet pretraining task.""" | |
| def __init__(self, params: XLNetPretrainDataConfig): | |
| """Inits `XLNetPretrainDataLoader` class. | |
| Args: | |
| params: A `XLNetPretrainDataConfig` object. | |
| """ | |
| self._params = params | |
| self._seq_length = params.seq_length | |
| self._max_predictions_per_seq = params.max_predictions_per_seq | |
| self._reuse_length = params.reuse_length | |
| self._num_replicas_in_sync = None | |
| self._permutation_size = params.permutation_size | |
| self._sep_id = params.segment_sep_id | |
| self._cls_id = params.segment_cls_id | |
| self._sample_strategy = params.sample_strategy | |
| self._leak_ratio = params.leak_ratio | |
| def _decode(self, record: tf.Tensor): | |
| """Decodes a serialized tf.Example.""" | |
| name_to_features = { | |
| 'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
| 'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64), | |
| 'boundary_indices': tf.io.VarLenFeature(tf.int64), | |
| } | |
| example = tf.io.parse_single_example(record, name_to_features) | |
| # tf.Example only supports tf.int64, but the TPU only supports tf.int32. | |
| # So cast all int64 to int32. | |
| for name in list(example.keys()): | |
| t = example[name] | |
| if t.dtype == tf.int64: | |
| t = tf.cast(t, tf.int32) | |
| example[name] = t | |
| return example | |
| def _parse(self, record: Mapping[str, tf.Tensor]): | |
| """Parses raw tensors into a dict of tensors to be consumed by the model.""" | |
| x = {} | |
| inputs = record['input_word_ids'] | |
| x['input_type_ids'] = record['input_type_ids'] | |
| if self._sample_strategy in ['whole_word', 'word_span']: | |
| boundary = tf.sparse.to_dense(record['boundary_indices']) | |
| else: | |
| boundary = None | |
| input_mask = self._online_sample_mask(inputs=inputs, boundary=boundary) | |
| if self._reuse_length > 0: | |
| if self._permutation_size > self._reuse_length: | |
| logging.warning( | |
| '`permutation_size` is greater than `reuse_length` (%d > %d).' | |
| 'This may introduce data leakage.', self._permutation_size, | |
| self._reuse_length) | |
| # Enable the memory mechanism. | |
| # Permute the reuse and non-reuse segments separately. | |
| non_reuse_len = self._seq_length - self._reuse_length | |
| if not (self._reuse_length % self._permutation_size == 0 and | |
| non_reuse_len % self._permutation_size == 0): | |
| raise ValueError('`reuse_length` and `seq_length` should both be ' | |
| 'a multiple of `permutation_size`.') | |
| # Creates permutation mask and target mask for the first reuse_len tokens. | |
| # The tokens in this part are reused from the last sequence. | |
| perm_mask_0, target_mask_0, tokens_0, masked_0 = self._get_factorization( | |
| inputs=inputs[:self._reuse_length], | |
| input_mask=input_mask[:self._reuse_length]) | |
| # Creates permutation mask and target mask for the rest of tokens in | |
| # current example, which are concatenation of two new segments. | |
| perm_mask_1, target_mask_1, tokens_1, masked_1 = self._get_factorization( | |
| inputs[self._reuse_length:], input_mask[self._reuse_length:]) | |
| perm_mask_0 = tf.concat([ | |
| perm_mask_0, | |
| tf.zeros([self._reuse_length, non_reuse_len], dtype=tf.int32) | |
| ], | |
| axis=1) | |
| perm_mask_1 = tf.concat([ | |
| tf.ones([non_reuse_len, self._reuse_length], dtype=tf.int32), | |
| perm_mask_1 | |
| ], | |
| axis=1) | |
| perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) | |
| target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) | |
| tokens = tf.concat([tokens_0, tokens_1], axis=0) | |
| masked_tokens = tf.concat([masked_0, masked_1], axis=0) | |
| else: | |
| # Disable the memory mechanism. | |
| if self._seq_length % self._permutation_size != 0: | |
| raise ValueError('`seq_length` should be a multiple of ' | |
| '`permutation_size`.') | |
| # Permute the entire sequence together | |
| perm_mask, target_mask, tokens, masked_tokens = self._get_factorization( | |
| inputs=inputs, input_mask=input_mask) | |
| x['permutation_mask'] = tf.reshape(perm_mask, | |
| [self._seq_length, self._seq_length]) | |
| x['input_word_ids'] = tokens | |
| x['masked_tokens'] = masked_tokens | |
| target = tokens | |
| if self._max_predictions_per_seq is not None: | |
| indices = tf.range(self._seq_length, dtype=tf.int32) | |
| bool_target_mask = tf.cast(target_mask, tf.bool) | |
| indices = tf.boolean_mask(indices, bool_target_mask) | |
| # account for extra padding due to CLS/SEP. | |
| actual_num_predict = tf.shape(indices)[0] | |
| pad_len = self._max_predictions_per_seq - actual_num_predict | |
| target_mapping = tf.one_hot(indices, self._seq_length, dtype=tf.int32) | |
| paddings = tf.zeros([pad_len, self._seq_length], | |
| dtype=target_mapping.dtype) | |
| target_mapping = tf.concat([target_mapping, paddings], axis=0) | |
| x['target_mapping'] = tf.reshape( | |
| target_mapping, [self._max_predictions_per_seq, self._seq_length]) | |
| target = tf.boolean_mask(target, bool_target_mask) | |
| paddings = tf.zeros([pad_len], dtype=target.dtype) | |
| target = tf.concat([target, paddings], axis=0) | |
| x['target'] = tf.reshape(target, [self._max_predictions_per_seq]) | |
| target_mask = tf.concat([ | |
| tf.ones([actual_num_predict], dtype=tf.int32), | |
| tf.zeros([pad_len], dtype=tf.int32) | |
| ], | |
| axis=0) | |
| x['target_mask'] = tf.reshape(target_mask, | |
| [self._max_predictions_per_seq]) | |
| else: | |
| x['target'] = tf.reshape(target, [self._seq_length]) | |
| x['target_mask'] = tf.reshape(target_mask, [self._seq_length]) | |
| return x | |
| def _index_pair_to_mask(self, begin_indices: tf.Tensor, | |
| end_indices: tf.Tensor, | |
| inputs: tf.Tensor) -> tf.Tensor: | |
| """Converts beginning and end indices into an actual mask.""" | |
| non_func_mask = tf.logical_and( | |
| tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id)) | |
| all_indices = tf.where( | |
| non_func_mask, tf.range(self._seq_length, dtype=tf.int32), | |
| tf.constant(-1, shape=[self._seq_length], dtype=tf.int32)) | |
| candidate_matrix = tf.cast( | |
| tf.logical_and(all_indices[None, :] >= begin_indices[:, None], | |
| all_indices[None, :] < end_indices[:, None]), tf.float32) | |
| cumsum_matrix = tf.reshape( | |
| tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, self._seq_length]) | |
| masked_matrix = tf.cast(cumsum_matrix <= self._max_predictions_per_seq, | |
| tf.float32) | |
| target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0) | |
| return tf.cast(target_mask, tf.bool) | |
| def _single_token_mask(self, inputs: tf.Tensor) -> tf.Tensor: | |
| """Samples individual tokens as prediction targets.""" | |
| all_indices = tf.range(self._seq_length, dtype=tf.int32) | |
| non_func_mask = tf.logical_and( | |
| tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id)) | |
| non_func_indices = tf.boolean_mask(all_indices, non_func_mask) | |
| masked_pos = tf.random.shuffle(non_func_indices) | |
| masked_pos = tf.sort(masked_pos[:self._max_predictions_per_seq]) | |
| sparse_indices = tf.stack([tf.zeros_like(masked_pos), masked_pos], axis=-1) | |
| sparse_indices = tf.cast(sparse_indices, tf.int64) | |
| sparse_indices = tf.sparse.SparseTensor( | |
| sparse_indices, | |
| values=tf.ones_like(masked_pos), | |
| dense_shape=(1, self._seq_length)) | |
| target_mask = tf.sparse.to_dense(sp_input=sparse_indices, default_value=0) | |
| return tf.squeeze(tf.cast(target_mask, tf.bool)) | |
| def _whole_word_mask(self, inputs: tf.Tensor, | |
| boundary: tf.Tensor) -> tf.Tensor: | |
| """Samples whole words as prediction targets.""" | |
| pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1) | |
| cand_pair_indices = tf.random.shuffle( | |
| pair_indices)[:self._max_predictions_per_seq] | |
| begin_indices = cand_pair_indices[:, 0] | |
| end_indices = cand_pair_indices[:, 1] | |
| return self._index_pair_to_mask( | |
| begin_indices=begin_indices, end_indices=end_indices, inputs=inputs) | |
| def _token_span_mask(self, inputs: tf.Tensor) -> tf.Tensor: | |
| """Samples token spans as prediction targets.""" | |
| min_num_tokens = self._params.min_num_tokens | |
| max_num_tokens = self._params.max_num_tokens | |
| mask_alpha = self._seq_length / self._max_predictions_per_seq | |
| round_to_int = lambda x: tf.cast(tf.round(x), tf.int32) | |
| # Sample span lengths from a zipf distribution | |
| span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1) | |
| probs = np.array([1.0 / (i + 1) for i in span_len_seq]) | |
| probs /= np.sum(probs) | |
| logits = tf.constant(np.log(probs), dtype=tf.float32) | |
| span_lens = tf.random.categorical( | |
| logits=logits[None], | |
| num_samples=self._max_predictions_per_seq, | |
| dtype=tf.int32, | |
| )[0] + min_num_tokens | |
| # Sample the ratio [0.0, 1.0) of left context lengths | |
| span_lens_float = tf.cast(span_lens, tf.float32) | |
| left_ratio = tf.random.uniform( | |
| shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0) | |
| left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1) | |
| left_ctx_len = round_to_int(left_ctx_len) | |
| # Compute the offset from left start to the right end | |
| right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len | |
| # Get the actual begin and end indices | |
| begin_indices = ( | |
| tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True)) | |
| end_indices = begin_indices + span_lens | |
| # Remove out of range indices | |
| valid_idx_mask = end_indices < self._seq_length | |
| begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask) | |
| end_indices = tf.boolean_mask(end_indices, valid_idx_mask) | |
| # Shuffle valid indices | |
| num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32) | |
| order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32)) | |
| begin_indices = tf.gather(begin_indices, order) | |
| end_indices = tf.gather(end_indices, order) | |
| return self._index_pair_to_mask( | |
| begin_indices=begin_indices, end_indices=end_indices, inputs=inputs) | |
| def _word_span_mask(self, inputs: tf.Tensor, boundary: tf.Tensor): | |
| """Sample whole word spans as prediction targets.""" | |
| min_num_words = self._params.min_num_words | |
| max_num_words = self._params.max_num_words | |
| # Note: 1.2 is the token-to-word ratio | |
| mask_alpha = self._seq_length / self._max_predictions_per_seq / 1.2 | |
| round_to_int = lambda x: tf.cast(tf.round(x), tf.int32) | |
| # Sample span lengths from a zipf distribution | |
| span_len_seq = np.arange(min_num_words, max_num_words + 1) | |
| probs = np.array([1.0 / (i + 1) for i in span_len_seq]) | |
| probs /= np.sum(probs) | |
| logits = tf.constant(np.log(probs), dtype=tf.float32) | |
| # Sample `num_predict` words here: note that this is over sampling | |
| span_lens = tf.random.categorical( | |
| logits=logits[None], | |
| num_samples=self._max_predictions_per_seq, | |
| dtype=tf.int32, | |
| )[0] + min_num_words | |
| # Sample the ratio [0.0, 1.0) of left context lengths | |
| span_lens_float = tf.cast(span_lens, tf.float32) | |
| left_ratio = tf.random.uniform( | |
| shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0) | |
| left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1) | |
| left_ctx_len = round_to_int(left_ctx_len) | |
| right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len | |
| begin_indices = ( | |
| tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True)) | |
| end_indices = begin_indices + span_lens | |
| # Remove out of range indices | |
| max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int32) | |
| valid_idx_mask = end_indices < max_boundary_index | |
| begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask) | |
| end_indices = tf.boolean_mask(end_indices, valid_idx_mask) | |
| begin_indices = tf.gather(boundary, begin_indices) | |
| end_indices = tf.gather(boundary, end_indices) | |
| # Shuffle valid indices | |
| num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32) | |
| order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32)) | |
| begin_indices = tf.gather(begin_indices, order) | |
| end_indices = tf.gather(end_indices, order) | |
| return self._index_pair_to_mask( | |
| begin_indices=begin_indices, end_indices=end_indices, inputs=inputs) | |
| def _online_sample_mask(self, inputs: tf.Tensor, | |
| boundary: tf.Tensor) -> tf.Tensor: | |
| """Samples target positions for predictions. | |
| Descriptions of each strategy: | |
| - 'single_token': Samples individual tokens as prediction targets. | |
| - 'token_span': Samples spans of tokens as prediction targets. | |
| - 'whole_word': Samples individual words as prediction targets. | |
| - 'word_span': Samples spans of words as prediction targets. | |
| Args: | |
| inputs: The input tokens. | |
| boundary: The `int` Tensor of indices indicating whole word boundaries. | |
| This is used in 'whole_word' and 'word_span' | |
| Returns: | |
| The sampled `bool` input mask. | |
| Raises: | |
| `ValueError`: if `max_predictions_per_seq` is not set or if boundary is | |
| not provided for 'whole_word' and 'word_span' sample strategies. | |
| """ | |
| if self._max_predictions_per_seq is None: | |
| raise ValueError('`max_predictions_per_seq` must be set.') | |
| if boundary is None and 'word' in self._sample_strategy: | |
| raise ValueError('`boundary` must be provided for {} strategy'.format( | |
| self._sample_strategy)) | |
| if self._sample_strategy == 'single_token': | |
| return self._single_token_mask(inputs) | |
| elif self._sample_strategy == 'token_span': | |
| return self._token_span_mask(inputs) | |
| elif self._sample_strategy == 'whole_word': | |
| return self._whole_word_mask(inputs, boundary) | |
| elif self._sample_strategy == 'word_span': | |
| return self._word_span_mask(inputs, boundary) | |
| else: | |
| raise NotImplementedError('Invalid sample strategy.') | |
| def _get_factorization(self, inputs: tf.Tensor, input_mask: tf.Tensor): | |
| """Samples a permutation of the factorization order. | |
| Args: | |
| inputs: the input tokens. | |
| input_mask: the `bool` Tensor of the same shape as `inputs`. If `True`, | |
| then this means select for partial prediction. | |
| Returns: | |
| perm_mask: An `int32` Tensor of shape [seq_length, seq_length] consisting | |
| of 0s and 1s. If perm_mask[i][j] == 0, then this means that the i-th | |
| token (in original order) cannot attend to the jth attention token. | |
| target_mask: An `int32` Tensor of shape [seq_len] consisting of 0s and 1s. | |
| If target_mask[i] == 1, then the i-th token needs to be predicted and | |
| the mask will be used as input. This token will be included in the loss. | |
| If target_mask[i] == 0, then the token (or [SEP], [CLS]) will be used as | |
| input. This token will not be included in the loss. | |
| tokens: int32 Tensor of shape [seq_length]. | |
| masked_tokens: int32 Tensor of shape [seq_length]. | |
| """ | |
| factorization_length = tf.shape(inputs)[0] | |
| # Generate permutation indices | |
| index = tf.range(factorization_length, dtype=tf.int32) | |
| index = tf.transpose(tf.reshape(index, [-1, self._permutation_size])) | |
| index = tf.random.shuffle(index) | |
| index = tf.reshape(tf.transpose(index), [-1]) | |
| input_mask = tf.cast(input_mask, tf.bool) | |
| # non-functional tokens | |
| non_func_tokens = tf.logical_not( | |
| tf.logical_or( | |
| tf.equal(inputs, self._sep_id), tf.equal(inputs, self._cls_id))) | |
| masked_tokens = tf.logical_and(input_mask, non_func_tokens) | |
| non_masked_or_func_tokens = tf.logical_not(masked_tokens) | |
| smallest_index = -2 * tf.ones([factorization_length], dtype=tf.int32) | |
| # Similar to BERT, randomly leak some masked tokens | |
| if self._leak_ratio > 0: | |
| leak_tokens = tf.logical_and( | |
| masked_tokens, | |
| tf.random.uniform([factorization_length], maxval=1.0) < | |
| self._leak_ratio) | |
| can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens) | |
| else: | |
| can_attend_self = non_masked_or_func_tokens | |
| to_index = tf.where(can_attend_self, smallest_index, index) | |
| from_index = tf.where(can_attend_self, to_index + 1, to_index) | |
| # For masked tokens, can attend if i > j | |
| # For context tokens, always can attend each other | |
| can_attend = from_index[:, None] > to_index[None, :] | |
| perm_mask = tf.cast(can_attend, tf.int32) | |
| # Only masked tokens are included in the loss | |
| target_mask = tf.cast(masked_tokens, tf.int32) | |
| return perm_mask, target_mask, inputs, masked_tokens | |
| def load(self, input_context: Optional[tf.distribute.InputContext] = None): | |
| """Returns a tf.dataset.Dataset.""" | |
| if input_context: | |
| self._num_replicas_in_sync = input_context.num_replicas_in_sync | |
| reader = input_reader.InputReader( | |
| params=self._params, decoder_fn=self._decode, parser_fn=self._parse) | |
| return reader.read(input_context) | |