Spaces:
Sleeping
Sleeping
Delete pretrain_dataloader.py
Browse files- pretrain_dataloader.py +0 -589
pretrain_dataloader.py
DELETED
|
@@ -1,589 +0,0 @@
|
|
| 1 |
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
"""Loads dataset for the BERT pretraining task."""
|
| 16 |
-
import dataclasses
|
| 17 |
-
from typing import Mapping, Optional
|
| 18 |
-
|
| 19 |
-
from absl import logging
|
| 20 |
-
|
| 21 |
-
import numpy as np
|
| 22 |
-
import tensorflow as tf, tf_keras
|
| 23 |
-
from official.common import dataset_fn
|
| 24 |
-
from official.core import config_definitions as cfg
|
| 25 |
-
from official.core import input_reader
|
| 26 |
-
from official.nlp.data import data_loader
|
| 27 |
-
from official.nlp.data import data_loader_factory
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
@dataclasses.dataclass
|
| 31 |
-
class BertPretrainDataConfig(cfg.DataConfig):
|
| 32 |
-
"""Data config for BERT pretraining task (tasks/masked_lm)."""
|
| 33 |
-
input_path: str = ''
|
| 34 |
-
global_batch_size: int = 512
|
| 35 |
-
is_training: bool = True
|
| 36 |
-
seq_length: int = 512
|
| 37 |
-
max_predictions_per_seq: int = 76
|
| 38 |
-
use_next_sentence_label: bool = True
|
| 39 |
-
use_position_id: bool = False
|
| 40 |
-
# Historically, BERT implementations take `input_ids` and `segment_ids` as
|
| 41 |
-
# feature names. Inside the TF Model Garden implementation, the Keras model
|
| 42 |
-
# inputs are set as `input_word_ids` and `input_type_ids`. When
|
| 43 |
-
# v2_feature_names is True, the data loader assumes the tf.Examples use
|
| 44 |
-
# `input_word_ids` and `input_type_ids` as keys.
|
| 45 |
-
use_v2_feature_names: bool = False
|
| 46 |
-
file_type: str = 'tfrecord'
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
|
| 50 |
-
class BertPretrainDataLoader(data_loader.DataLoader):
|
| 51 |
-
"""A class to load dataset for bert pretraining task."""
|
| 52 |
-
|
| 53 |
-
def __init__(self, params):
|
| 54 |
-
"""Inits `BertPretrainDataLoader` class.
|
| 55 |
-
|
| 56 |
-
Args:
|
| 57 |
-
params: A `BertPretrainDataConfig` object.
|
| 58 |
-
"""
|
| 59 |
-
self._params = params
|
| 60 |
-
self._seq_length = params.seq_length
|
| 61 |
-
self._max_predictions_per_seq = params.max_predictions_per_seq
|
| 62 |
-
self._use_next_sentence_label = params.use_next_sentence_label
|
| 63 |
-
self._use_position_id = params.use_position_id
|
| 64 |
-
|
| 65 |
-
def _name_to_features(self):
|
| 66 |
-
name_to_features = {
|
| 67 |
-
'input_mask':
|
| 68 |
-
tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
| 69 |
-
'masked_lm_positions':
|
| 70 |
-
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
|
| 71 |
-
'masked_lm_ids':
|
| 72 |
-
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
|
| 73 |
-
'masked_lm_weights':
|
| 74 |
-
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32),
|
| 75 |
-
}
|
| 76 |
-
if self._params.use_v2_feature_names:
|
| 77 |
-
name_to_features.update({
|
| 78 |
-
'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
| 79 |
-
'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
| 80 |
-
})
|
| 81 |
-
else:
|
| 82 |
-
name_to_features.update({
|
| 83 |
-
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
| 84 |
-
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
| 85 |
-
})
|
| 86 |
-
if self._use_next_sentence_label:
|
| 87 |
-
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
|
| 88 |
-
tf.int64)
|
| 89 |
-
if self._use_position_id:
|
| 90 |
-
name_to_features['position_ids'] = tf.io.FixedLenFeature(
|
| 91 |
-
[self._seq_length], tf.int64)
|
| 92 |
-
return name_to_features
|
| 93 |
-
|
| 94 |
-
def _decode(self, record: tf.Tensor):
|
| 95 |
-
"""Decodes a serialized tf.Example."""
|
| 96 |
-
name_to_features = self._name_to_features()
|
| 97 |
-
example = tf.io.parse_single_example(record, name_to_features)
|
| 98 |
-
|
| 99 |
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
| 100 |
-
# So cast all int64 to int32.
|
| 101 |
-
for name in list(example.keys()):
|
| 102 |
-
t = example[name]
|
| 103 |
-
if t.dtype == tf.int64:
|
| 104 |
-
t = tf.cast(t, tf.int32)
|
| 105 |
-
example[name] = t
|
| 106 |
-
|
| 107 |
-
return example
|
| 108 |
-
|
| 109 |
-
def _parse(self, record: Mapping[str, tf.Tensor]):
|
| 110 |
-
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
|
| 111 |
-
x = {
|
| 112 |
-
'input_mask': record['input_mask'],
|
| 113 |
-
'masked_lm_positions': record['masked_lm_positions'],
|
| 114 |
-
'masked_lm_ids': record['masked_lm_ids'],
|
| 115 |
-
'masked_lm_weights': record['masked_lm_weights'],
|
| 116 |
-
}
|
| 117 |
-
if self._params.use_v2_feature_names:
|
| 118 |
-
x['input_word_ids'] = record['input_word_ids']
|
| 119 |
-
x['input_type_ids'] = record['input_type_ids']
|
| 120 |
-
else:
|
| 121 |
-
x['input_word_ids'] = record['input_ids']
|
| 122 |
-
x['input_type_ids'] = record['segment_ids']
|
| 123 |
-
if self._use_next_sentence_label:
|
| 124 |
-
x['next_sentence_labels'] = record['next_sentence_labels']
|
| 125 |
-
if self._use_position_id:
|
| 126 |
-
x['position_ids'] = record['position_ids']
|
| 127 |
-
|
| 128 |
-
return x
|
| 129 |
-
|
| 130 |
-
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
| 131 |
-
"""Returns a tf.dataset.Dataset."""
|
| 132 |
-
reader = input_reader.InputReader(
|
| 133 |
-
params=self._params,
|
| 134 |
-
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
| 135 |
-
decoder_fn=self._decode,
|
| 136 |
-
parser_fn=self._parse)
|
| 137 |
-
return reader.read(input_context)
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
@dataclasses.dataclass
|
| 141 |
-
class XLNetPretrainDataConfig(cfg.DataConfig):
|
| 142 |
-
"""Data config for XLNet pretraining task.
|
| 143 |
-
|
| 144 |
-
Attributes:
|
| 145 |
-
input_path: See base class.
|
| 146 |
-
global_batch_size: See base class.
|
| 147 |
-
is_training: See base class.
|
| 148 |
-
seq_length: The length of each sequence.
|
| 149 |
-
max_predictions_per_seq: The number of predictions per sequence.
|
| 150 |
-
reuse_length: The number of tokens in a previous segment to reuse. This
|
| 151 |
-
should be the same value used during pretrain data creation.
|
| 152 |
-
sample_strategy: The strategy used to sample factorization permutations.
|
| 153 |
-
Possible values: 'single_token', 'whole_word', 'token_span', 'word_span'.
|
| 154 |
-
min_num_tokens: The minimum number of tokens to sample in a span. This is
|
| 155 |
-
used when `sample_strategy` is 'token_span'.
|
| 156 |
-
max_num_tokens: The maximum number of tokens to sample in a span. This is
|
| 157 |
-
used when `sample_strategy` is 'token_span'.
|
| 158 |
-
min_num_words: The minimum number of words to sample in a span. This is used
|
| 159 |
-
when `sample_strategy` is 'word_span'.
|
| 160 |
-
max_num_words: The maximum number of words to sample in a span. This is used
|
| 161 |
-
when `sample_strategy` is 'word_span'.
|
| 162 |
-
permutation_size: The length of the longest permutation. This can be set to
|
| 163 |
-
`reuse_length`. This should NOT be greater than `reuse_length`, otherwise
|
| 164 |
-
this may introduce data leaks.
|
| 165 |
-
leak_ratio: The percentage of masked tokens that are leaked.
|
| 166 |
-
segment_sep_id: The ID of the SEP token used when preprocessing the dataset.
|
| 167 |
-
segment_cls_id: The ID of the CLS token used when preprocessing the dataset.
|
| 168 |
-
"""
|
| 169 |
-
input_path: str = ''
|
| 170 |
-
global_batch_size: int = 512
|
| 171 |
-
is_training: bool = True
|
| 172 |
-
seq_length: int = 512
|
| 173 |
-
max_predictions_per_seq: int = 76
|
| 174 |
-
reuse_length: int = 256
|
| 175 |
-
sample_strategy: str = 'word_span'
|
| 176 |
-
min_num_tokens: int = 1
|
| 177 |
-
max_num_tokens: int = 5
|
| 178 |
-
min_num_words: int = 1
|
| 179 |
-
max_num_words: int = 5
|
| 180 |
-
permutation_size: int = 256
|
| 181 |
-
leak_ratio: float = 0.1
|
| 182 |
-
segment_sep_id: int = 4
|
| 183 |
-
segment_cls_id: int = 3
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
@data_loader_factory.register_data_loader_cls(XLNetPretrainDataConfig)
|
| 187 |
-
class XLNetPretrainDataLoader(data_loader.DataLoader):
|
| 188 |
-
"""A class to load dataset for xlnet pretraining task."""
|
| 189 |
-
|
| 190 |
-
def __init__(self, params: XLNetPretrainDataConfig):
|
| 191 |
-
"""Inits `XLNetPretrainDataLoader` class.
|
| 192 |
-
|
| 193 |
-
Args:
|
| 194 |
-
params: A `XLNetPretrainDataConfig` object.
|
| 195 |
-
"""
|
| 196 |
-
self._params = params
|
| 197 |
-
self._seq_length = params.seq_length
|
| 198 |
-
self._max_predictions_per_seq = params.max_predictions_per_seq
|
| 199 |
-
self._reuse_length = params.reuse_length
|
| 200 |
-
self._num_replicas_in_sync = None
|
| 201 |
-
self._permutation_size = params.permutation_size
|
| 202 |
-
self._sep_id = params.segment_sep_id
|
| 203 |
-
self._cls_id = params.segment_cls_id
|
| 204 |
-
self._sample_strategy = params.sample_strategy
|
| 205 |
-
self._leak_ratio = params.leak_ratio
|
| 206 |
-
|
| 207 |
-
def _decode(self, record: tf.Tensor):
|
| 208 |
-
"""Decodes a serialized tf.Example."""
|
| 209 |
-
name_to_features = {
|
| 210 |
-
'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
| 211 |
-
'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
| 212 |
-
'boundary_indices': tf.io.VarLenFeature(tf.int64),
|
| 213 |
-
}
|
| 214 |
-
example = tf.io.parse_single_example(record, name_to_features)
|
| 215 |
-
|
| 216 |
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
| 217 |
-
# So cast all int64 to int32.
|
| 218 |
-
for name in list(example.keys()):
|
| 219 |
-
t = example[name]
|
| 220 |
-
if t.dtype == tf.int64:
|
| 221 |
-
t = tf.cast(t, tf.int32)
|
| 222 |
-
example[name] = t
|
| 223 |
-
|
| 224 |
-
return example
|
| 225 |
-
|
| 226 |
-
def _parse(self, record: Mapping[str, tf.Tensor]):
|
| 227 |
-
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
|
| 228 |
-
x = {}
|
| 229 |
-
|
| 230 |
-
inputs = record['input_word_ids']
|
| 231 |
-
x['input_type_ids'] = record['input_type_ids']
|
| 232 |
-
|
| 233 |
-
if self._sample_strategy in ['whole_word', 'word_span']:
|
| 234 |
-
boundary = tf.sparse.to_dense(record['boundary_indices'])
|
| 235 |
-
else:
|
| 236 |
-
boundary = None
|
| 237 |
-
|
| 238 |
-
input_mask = self._online_sample_mask(inputs=inputs, boundary=boundary)
|
| 239 |
-
|
| 240 |
-
if self._reuse_length > 0:
|
| 241 |
-
if self._permutation_size > self._reuse_length:
|
| 242 |
-
logging.warning(
|
| 243 |
-
'`permutation_size` is greater than `reuse_length` (%d > %d).'
|
| 244 |
-
'This may introduce data leakage.', self._permutation_size,
|
| 245 |
-
self._reuse_length)
|
| 246 |
-
|
| 247 |
-
# Enable the memory mechanism.
|
| 248 |
-
# Permute the reuse and non-reuse segments separately.
|
| 249 |
-
non_reuse_len = self._seq_length - self._reuse_length
|
| 250 |
-
if not (self._reuse_length % self._permutation_size == 0 and
|
| 251 |
-
non_reuse_len % self._permutation_size == 0):
|
| 252 |
-
raise ValueError('`reuse_length` and `seq_length` should both be '
|
| 253 |
-
'a multiple of `permutation_size`.')
|
| 254 |
-
|
| 255 |
-
# Creates permutation mask and target mask for the first reuse_len tokens.
|
| 256 |
-
# The tokens in this part are reused from the last sequence.
|
| 257 |
-
perm_mask_0, target_mask_0, tokens_0, masked_0 = self._get_factorization(
|
| 258 |
-
inputs=inputs[:self._reuse_length],
|
| 259 |
-
input_mask=input_mask[:self._reuse_length])
|
| 260 |
-
|
| 261 |
-
# Creates permutation mask and target mask for the rest of tokens in
|
| 262 |
-
# current example, which are concatenation of two new segments.
|
| 263 |
-
perm_mask_1, target_mask_1, tokens_1, masked_1 = self._get_factorization(
|
| 264 |
-
inputs[self._reuse_length:], input_mask[self._reuse_length:])
|
| 265 |
-
|
| 266 |
-
perm_mask_0 = tf.concat([
|
| 267 |
-
perm_mask_0,
|
| 268 |
-
tf.zeros([self._reuse_length, non_reuse_len], dtype=tf.int32)
|
| 269 |
-
],
|
| 270 |
-
axis=1)
|
| 271 |
-
perm_mask_1 = tf.concat([
|
| 272 |
-
tf.ones([non_reuse_len, self._reuse_length], dtype=tf.int32),
|
| 273 |
-
perm_mask_1
|
| 274 |
-
],
|
| 275 |
-
axis=1)
|
| 276 |
-
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
|
| 277 |
-
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
|
| 278 |
-
tokens = tf.concat([tokens_0, tokens_1], axis=0)
|
| 279 |
-
masked_tokens = tf.concat([masked_0, masked_1], axis=0)
|
| 280 |
-
else:
|
| 281 |
-
# Disable the memory mechanism.
|
| 282 |
-
if self._seq_length % self._permutation_size != 0:
|
| 283 |
-
raise ValueError('`seq_length` should be a multiple of '
|
| 284 |
-
'`permutation_size`.')
|
| 285 |
-
# Permute the entire sequence together
|
| 286 |
-
perm_mask, target_mask, tokens, masked_tokens = self._get_factorization(
|
| 287 |
-
inputs=inputs, input_mask=input_mask)
|
| 288 |
-
x['permutation_mask'] = tf.reshape(perm_mask,
|
| 289 |
-
[self._seq_length, self._seq_length])
|
| 290 |
-
x['input_word_ids'] = tokens
|
| 291 |
-
x['masked_tokens'] = masked_tokens
|
| 292 |
-
|
| 293 |
-
target = tokens
|
| 294 |
-
if self._max_predictions_per_seq is not None:
|
| 295 |
-
indices = tf.range(self._seq_length, dtype=tf.int32)
|
| 296 |
-
bool_target_mask = tf.cast(target_mask, tf.bool)
|
| 297 |
-
indices = tf.boolean_mask(indices, bool_target_mask)
|
| 298 |
-
|
| 299 |
-
# account for extra padding due to CLS/SEP.
|
| 300 |
-
actual_num_predict = tf.shape(indices)[0]
|
| 301 |
-
pad_len = self._max_predictions_per_seq - actual_num_predict
|
| 302 |
-
|
| 303 |
-
target_mapping = tf.one_hot(indices, self._seq_length, dtype=tf.int32)
|
| 304 |
-
paddings = tf.zeros([pad_len, self._seq_length],
|
| 305 |
-
dtype=target_mapping.dtype)
|
| 306 |
-
target_mapping = tf.concat([target_mapping, paddings], axis=0)
|
| 307 |
-
x['target_mapping'] = tf.reshape(
|
| 308 |
-
target_mapping, [self._max_predictions_per_seq, self._seq_length])
|
| 309 |
-
|
| 310 |
-
target = tf.boolean_mask(target, bool_target_mask)
|
| 311 |
-
paddings = tf.zeros([pad_len], dtype=target.dtype)
|
| 312 |
-
target = tf.concat([target, paddings], axis=0)
|
| 313 |
-
x['target'] = tf.reshape(target, [self._max_predictions_per_seq])
|
| 314 |
-
|
| 315 |
-
target_mask = tf.concat([
|
| 316 |
-
tf.ones([actual_num_predict], dtype=tf.int32),
|
| 317 |
-
tf.zeros([pad_len], dtype=tf.int32)
|
| 318 |
-
],
|
| 319 |
-
axis=0)
|
| 320 |
-
x['target_mask'] = tf.reshape(target_mask,
|
| 321 |
-
[self._max_predictions_per_seq])
|
| 322 |
-
else:
|
| 323 |
-
x['target'] = tf.reshape(target, [self._seq_length])
|
| 324 |
-
x['target_mask'] = tf.reshape(target_mask, [self._seq_length])
|
| 325 |
-
return x
|
| 326 |
-
|
| 327 |
-
def _index_pair_to_mask(self, begin_indices: tf.Tensor,
|
| 328 |
-
end_indices: tf.Tensor,
|
| 329 |
-
inputs: tf.Tensor) -> tf.Tensor:
|
| 330 |
-
"""Converts beginning and end indices into an actual mask."""
|
| 331 |
-
non_func_mask = tf.logical_and(
|
| 332 |
-
tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
|
| 333 |
-
all_indices = tf.where(
|
| 334 |
-
non_func_mask, tf.range(self._seq_length, dtype=tf.int32),
|
| 335 |
-
tf.constant(-1, shape=[self._seq_length], dtype=tf.int32))
|
| 336 |
-
candidate_matrix = tf.cast(
|
| 337 |
-
tf.logical_and(all_indices[None, :] >= begin_indices[:, None],
|
| 338 |
-
all_indices[None, :] < end_indices[:, None]), tf.float32)
|
| 339 |
-
cumsum_matrix = tf.reshape(
|
| 340 |
-
tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, self._seq_length])
|
| 341 |
-
masked_matrix = tf.cast(cumsum_matrix <= self._max_predictions_per_seq,
|
| 342 |
-
tf.float32)
|
| 343 |
-
target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
|
| 344 |
-
return tf.cast(target_mask, tf.bool)
|
| 345 |
-
|
| 346 |
-
def _single_token_mask(self, inputs: tf.Tensor) -> tf.Tensor:
|
| 347 |
-
"""Samples individual tokens as prediction targets."""
|
| 348 |
-
all_indices = tf.range(self._seq_length, dtype=tf.int32)
|
| 349 |
-
non_func_mask = tf.logical_and(
|
| 350 |
-
tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
|
| 351 |
-
non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
|
| 352 |
-
|
| 353 |
-
masked_pos = tf.random.shuffle(non_func_indices)
|
| 354 |
-
masked_pos = tf.sort(masked_pos[:self._max_predictions_per_seq])
|
| 355 |
-
|
| 356 |
-
sparse_indices = tf.stack([tf.zeros_like(masked_pos), masked_pos], axis=-1)
|
| 357 |
-
sparse_indices = tf.cast(sparse_indices, tf.int64)
|
| 358 |
-
|
| 359 |
-
sparse_indices = tf.sparse.SparseTensor(
|
| 360 |
-
sparse_indices,
|
| 361 |
-
values=tf.ones_like(masked_pos),
|
| 362 |
-
dense_shape=(1, self._seq_length))
|
| 363 |
-
|
| 364 |
-
target_mask = tf.sparse.to_dense(sp_input=sparse_indices, default_value=0)
|
| 365 |
-
|
| 366 |
-
return tf.squeeze(tf.cast(target_mask, tf.bool))
|
| 367 |
-
|
| 368 |
-
def _whole_word_mask(self, inputs: tf.Tensor,
|
| 369 |
-
boundary: tf.Tensor) -> tf.Tensor:
|
| 370 |
-
"""Samples whole words as prediction targets."""
|
| 371 |
-
pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
|
| 372 |
-
cand_pair_indices = tf.random.shuffle(
|
| 373 |
-
pair_indices)[:self._max_predictions_per_seq]
|
| 374 |
-
begin_indices = cand_pair_indices[:, 0]
|
| 375 |
-
end_indices = cand_pair_indices[:, 1]
|
| 376 |
-
|
| 377 |
-
return self._index_pair_to_mask(
|
| 378 |
-
begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
|
| 379 |
-
|
| 380 |
-
def _token_span_mask(self, inputs: tf.Tensor) -> tf.Tensor:
|
| 381 |
-
"""Samples token spans as prediction targets."""
|
| 382 |
-
min_num_tokens = self._params.min_num_tokens
|
| 383 |
-
max_num_tokens = self._params.max_num_tokens
|
| 384 |
-
|
| 385 |
-
mask_alpha = self._seq_length / self._max_predictions_per_seq
|
| 386 |
-
round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
|
| 387 |
-
|
| 388 |
-
# Sample span lengths from a zipf distribution
|
| 389 |
-
span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
|
| 390 |
-
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
|
| 391 |
-
|
| 392 |
-
probs /= np.sum(probs)
|
| 393 |
-
logits = tf.constant(np.log(probs), dtype=tf.float32)
|
| 394 |
-
span_lens = tf.random.categorical(
|
| 395 |
-
logits=logits[None],
|
| 396 |
-
num_samples=self._max_predictions_per_seq,
|
| 397 |
-
dtype=tf.int32,
|
| 398 |
-
)[0] + min_num_tokens
|
| 399 |
-
|
| 400 |
-
# Sample the ratio [0.0, 1.0) of left context lengths
|
| 401 |
-
span_lens_float = tf.cast(span_lens, tf.float32)
|
| 402 |
-
left_ratio = tf.random.uniform(
|
| 403 |
-
shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
|
| 404 |
-
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
|
| 405 |
-
left_ctx_len = round_to_int(left_ctx_len)
|
| 406 |
-
|
| 407 |
-
# Compute the offset from left start to the right end
|
| 408 |
-
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
|
| 409 |
-
|
| 410 |
-
# Get the actual begin and end indices
|
| 411 |
-
begin_indices = (
|
| 412 |
-
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
|
| 413 |
-
end_indices = begin_indices + span_lens
|
| 414 |
-
|
| 415 |
-
# Remove out of range indices
|
| 416 |
-
valid_idx_mask = end_indices < self._seq_length
|
| 417 |
-
begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
|
| 418 |
-
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
|
| 419 |
-
|
| 420 |
-
# Shuffle valid indices
|
| 421 |
-
num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
|
| 422 |
-
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
|
| 423 |
-
begin_indices = tf.gather(begin_indices, order)
|
| 424 |
-
end_indices = tf.gather(end_indices, order)
|
| 425 |
-
|
| 426 |
-
return self._index_pair_to_mask(
|
| 427 |
-
begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
|
| 428 |
-
|
| 429 |
-
def _word_span_mask(self, inputs: tf.Tensor, boundary: tf.Tensor):
|
| 430 |
-
"""Sample whole word spans as prediction targets."""
|
| 431 |
-
min_num_words = self._params.min_num_words
|
| 432 |
-
max_num_words = self._params.max_num_words
|
| 433 |
-
|
| 434 |
-
# Note: 1.2 is the token-to-word ratio
|
| 435 |
-
mask_alpha = self._seq_length / self._max_predictions_per_seq / 1.2
|
| 436 |
-
round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
|
| 437 |
-
|
| 438 |
-
# Sample span lengths from a zipf distribution
|
| 439 |
-
span_len_seq = np.arange(min_num_words, max_num_words + 1)
|
| 440 |
-
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
|
| 441 |
-
probs /= np.sum(probs)
|
| 442 |
-
logits = tf.constant(np.log(probs), dtype=tf.float32)
|
| 443 |
-
|
| 444 |
-
# Sample `num_predict` words here: note that this is over sampling
|
| 445 |
-
span_lens = tf.random.categorical(
|
| 446 |
-
logits=logits[None],
|
| 447 |
-
num_samples=self._max_predictions_per_seq,
|
| 448 |
-
dtype=tf.int32,
|
| 449 |
-
)[0] + min_num_words
|
| 450 |
-
|
| 451 |
-
# Sample the ratio [0.0, 1.0) of left context lengths
|
| 452 |
-
span_lens_float = tf.cast(span_lens, tf.float32)
|
| 453 |
-
left_ratio = tf.random.uniform(
|
| 454 |
-
shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
|
| 455 |
-
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
|
| 456 |
-
|
| 457 |
-
left_ctx_len = round_to_int(left_ctx_len)
|
| 458 |
-
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
|
| 459 |
-
|
| 460 |
-
begin_indices = (
|
| 461 |
-
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
|
| 462 |
-
end_indices = begin_indices + span_lens
|
| 463 |
-
|
| 464 |
-
# Remove out of range indices
|
| 465 |
-
max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int32)
|
| 466 |
-
valid_idx_mask = end_indices < max_boundary_index
|
| 467 |
-
begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
|
| 468 |
-
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
|
| 469 |
-
|
| 470 |
-
begin_indices = tf.gather(boundary, begin_indices)
|
| 471 |
-
end_indices = tf.gather(boundary, end_indices)
|
| 472 |
-
|
| 473 |
-
# Shuffle valid indices
|
| 474 |
-
num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
|
| 475 |
-
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
|
| 476 |
-
begin_indices = tf.gather(begin_indices, order)
|
| 477 |
-
end_indices = tf.gather(end_indices, order)
|
| 478 |
-
|
| 479 |
-
return self._index_pair_to_mask(
|
| 480 |
-
begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
|
| 481 |
-
|
| 482 |
-
def _online_sample_mask(self, inputs: tf.Tensor,
|
| 483 |
-
boundary: tf.Tensor) -> tf.Tensor:
|
| 484 |
-
"""Samples target positions for predictions.
|
| 485 |
-
|
| 486 |
-
Descriptions of each strategy:
|
| 487 |
-
- 'single_token': Samples individual tokens as prediction targets.
|
| 488 |
-
- 'token_span': Samples spans of tokens as prediction targets.
|
| 489 |
-
- 'whole_word': Samples individual words as prediction targets.
|
| 490 |
-
- 'word_span': Samples spans of words as prediction targets.
|
| 491 |
-
|
| 492 |
-
Args:
|
| 493 |
-
inputs: The input tokens.
|
| 494 |
-
boundary: The `int` Tensor of indices indicating whole word boundaries.
|
| 495 |
-
This is used in 'whole_word' and 'word_span'
|
| 496 |
-
|
| 497 |
-
Returns:
|
| 498 |
-
The sampled `bool` input mask.
|
| 499 |
-
|
| 500 |
-
Raises:
|
| 501 |
-
`ValueError`: if `max_predictions_per_seq` is not set or if boundary is
|
| 502 |
-
not provided for 'whole_word' and 'word_span' sample strategies.
|
| 503 |
-
"""
|
| 504 |
-
if self._max_predictions_per_seq is None:
|
| 505 |
-
raise ValueError('`max_predictions_per_seq` must be set.')
|
| 506 |
-
|
| 507 |
-
if boundary is None and 'word' in self._sample_strategy:
|
| 508 |
-
raise ValueError('`boundary` must be provided for {} strategy'.format(
|
| 509 |
-
self._sample_strategy))
|
| 510 |
-
|
| 511 |
-
if self._sample_strategy == 'single_token':
|
| 512 |
-
return self._single_token_mask(inputs)
|
| 513 |
-
elif self._sample_strategy == 'token_span':
|
| 514 |
-
return self._token_span_mask(inputs)
|
| 515 |
-
elif self._sample_strategy == 'whole_word':
|
| 516 |
-
return self._whole_word_mask(inputs, boundary)
|
| 517 |
-
elif self._sample_strategy == 'word_span':
|
| 518 |
-
return self._word_span_mask(inputs, boundary)
|
| 519 |
-
else:
|
| 520 |
-
raise NotImplementedError('Invalid sample strategy.')
|
| 521 |
-
|
| 522 |
-
def _get_factorization(self, inputs: tf.Tensor, input_mask: tf.Tensor):
|
| 523 |
-
"""Samples a permutation of the factorization order.
|
| 524 |
-
|
| 525 |
-
Args:
|
| 526 |
-
inputs: the input tokens.
|
| 527 |
-
input_mask: the `bool` Tensor of the same shape as `inputs`. If `True`,
|
| 528 |
-
then this means select for partial prediction.
|
| 529 |
-
|
| 530 |
-
Returns:
|
| 531 |
-
perm_mask: An `int32` Tensor of shape [seq_length, seq_length] consisting
|
| 532 |
-
of 0s and 1s. If perm_mask[i][j] == 0, then this means that the i-th
|
| 533 |
-
token (in original order) cannot attend to the jth attention token.
|
| 534 |
-
target_mask: An `int32` Tensor of shape [seq_len] consisting of 0s and 1s.
|
| 535 |
-
If target_mask[i] == 1, then the i-th token needs to be predicted and
|
| 536 |
-
the mask will be used as input. This token will be included in the loss.
|
| 537 |
-
If target_mask[i] == 0, then the token (or [SEP], [CLS]) will be used as
|
| 538 |
-
input. This token will not be included in the loss.
|
| 539 |
-
tokens: int32 Tensor of shape [seq_length].
|
| 540 |
-
masked_tokens: int32 Tensor of shape [seq_length].
|
| 541 |
-
"""
|
| 542 |
-
factorization_length = tf.shape(inputs)[0]
|
| 543 |
-
# Generate permutation indices
|
| 544 |
-
index = tf.range(factorization_length, dtype=tf.int32)
|
| 545 |
-
index = tf.transpose(tf.reshape(index, [-1, self._permutation_size]))
|
| 546 |
-
index = tf.random.shuffle(index)
|
| 547 |
-
index = tf.reshape(tf.transpose(index), [-1])
|
| 548 |
-
|
| 549 |
-
input_mask = tf.cast(input_mask, tf.bool)
|
| 550 |
-
|
| 551 |
-
# non-functional tokens
|
| 552 |
-
non_func_tokens = tf.logical_not(
|
| 553 |
-
tf.logical_or(
|
| 554 |
-
tf.equal(inputs, self._sep_id), tf.equal(inputs, self._cls_id)))
|
| 555 |
-
masked_tokens = tf.logical_and(input_mask, non_func_tokens)
|
| 556 |
-
non_masked_or_func_tokens = tf.logical_not(masked_tokens)
|
| 557 |
-
|
| 558 |
-
smallest_index = -2 * tf.ones([factorization_length], dtype=tf.int32)
|
| 559 |
-
|
| 560 |
-
# Similar to BERT, randomly leak some masked tokens
|
| 561 |
-
if self._leak_ratio > 0:
|
| 562 |
-
leak_tokens = tf.logical_and(
|
| 563 |
-
masked_tokens,
|
| 564 |
-
tf.random.uniform([factorization_length], maxval=1.0) <
|
| 565 |
-
self._leak_ratio)
|
| 566 |
-
can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
|
| 567 |
-
else:
|
| 568 |
-
can_attend_self = non_masked_or_func_tokens
|
| 569 |
-
to_index = tf.where(can_attend_self, smallest_index, index)
|
| 570 |
-
from_index = tf.where(can_attend_self, to_index + 1, to_index)
|
| 571 |
-
|
| 572 |
-
# For masked tokens, can attend if i > j
|
| 573 |
-
# For context tokens, always can attend each other
|
| 574 |
-
can_attend = from_index[:, None] > to_index[None, :]
|
| 575 |
-
|
| 576 |
-
perm_mask = tf.cast(can_attend, tf.int32)
|
| 577 |
-
|
| 578 |
-
# Only masked tokens are included in the loss
|
| 579 |
-
target_mask = tf.cast(masked_tokens, tf.int32)
|
| 580 |
-
|
| 581 |
-
return perm_mask, target_mask, inputs, masked_tokens
|
| 582 |
-
|
| 583 |
-
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
| 584 |
-
"""Returns a tf.dataset.Dataset."""
|
| 585 |
-
if input_context:
|
| 586 |
-
self._num_replicas_in_sync = input_context.num_replicas_in_sync
|
| 587 |
-
reader = input_reader.InputReader(
|
| 588 |
-
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
|
| 589 |
-
return reader.read(input_context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|