Spaces:
Sleeping
Sleeping
Delete pretrain_dynamic_dataloader.py
Browse files- pretrain_dynamic_dataloader.py +0 -223
pretrain_dynamic_dataloader.py
DELETED
|
@@ -1,223 +0,0 @@
|
|
| 1 |
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
"""Dataset loader for the pre-training with dynamic sequence length."""
|
| 16 |
-
from typing import Optional, Tuple
|
| 17 |
-
|
| 18 |
-
import dataclasses
|
| 19 |
-
import tensorflow as tf, tf_keras
|
| 20 |
-
|
| 21 |
-
from official.core import config_definitions as cfg
|
| 22 |
-
from official.core import input_reader
|
| 23 |
-
from official.nlp.data import data_loader_factory
|
| 24 |
-
from official.nlp.data import pretrain_dataloader
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
@dataclasses.dataclass
|
| 28 |
-
class BertPretrainDataConfig(cfg.DataConfig):
|
| 29 |
-
"""Data config for BERT pretraining task (tasks/masked_lm)."""
|
| 30 |
-
input_path: str = ''
|
| 31 |
-
global_batch_size: int = 512
|
| 32 |
-
is_training: bool = True
|
| 33 |
-
seq_bucket_lengths: Tuple[int, ...] = (128, 256, 384, 512,)
|
| 34 |
-
# TODO(rxsang): `seq_bucket_window_scale` is only useful when round robin
|
| 35 |
-
# tf.data service is disabled. Deprecate this flag once we always enable round
|
| 36 |
-
# robin tf.data service.
|
| 37 |
-
seq_bucket_window_scale: int = 8
|
| 38 |
-
use_next_sentence_label: bool = True
|
| 39 |
-
use_position_id: bool = False
|
| 40 |
-
deterministic: bool = False
|
| 41 |
-
enable_tf_data_service: bool = False
|
| 42 |
-
enable_round_robin_tf_data_service: bool = False
|
| 43 |
-
tf_data_service_job_name: str = 'bert_pretrain'
|
| 44 |
-
use_v2_feature_names: bool = False
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
|
| 48 |
-
class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
|
| 49 |
-
"""Dataset loader for bert-style pretraining with dynamic sequenece length.
|
| 50 |
-
|
| 51 |
-
Bucketizes the input id features by the seq_bucket_lengths and features are
|
| 52 |
-
padded to the bucket boundaries. The mask features are usually short than
|
| 53 |
-
input id features and can also be dynamic. We require the mask feature lengths
|
| 54 |
-
within a bucket must be the same. For example, with [128, 256] buckets,
|
| 55 |
-
the mask features for bucket 128 should always have the length as X and
|
| 56 |
-
features for bucket 256 should always have the length as Y.
|
| 57 |
-
|
| 58 |
-
The dataloader does not filter out empty masks. Make sure to handle this
|
| 59 |
-
in the model.
|
| 60 |
-
"""
|
| 61 |
-
|
| 62 |
-
def __init__(self, params):
|
| 63 |
-
self._params = params
|
| 64 |
-
if len(params.seq_bucket_lengths) < 1:
|
| 65 |
-
raise ValueError('The seq_bucket_lengths cannot be empty.')
|
| 66 |
-
self._seq_bucket_lengths = params.seq_bucket_lengths
|
| 67 |
-
self._seq_bucket_window_scale = params.seq_bucket_window_scale
|
| 68 |
-
self._global_batch_size = params.global_batch_size
|
| 69 |
-
self._use_next_sentence_label = params.use_next_sentence_label
|
| 70 |
-
self._use_position_id = params.use_position_id
|
| 71 |
-
self._drop_remainder = params.drop_remainder
|
| 72 |
-
self._enable_tf_data_service = params.enable_tf_data_service
|
| 73 |
-
self._enable_round_robin_tf_data_service = (
|
| 74 |
-
params.enable_round_robin_tf_data_service)
|
| 75 |
-
self._mask_keys = [
|
| 76 |
-
'masked_lm_positions', 'masked_lm_ids', 'masked_lm_weights'
|
| 77 |
-
]
|
| 78 |
-
|
| 79 |
-
def _decode(self, record: tf.Tensor):
|
| 80 |
-
"""Decodes a serialized tf.Example."""
|
| 81 |
-
name_to_features = {
|
| 82 |
-
'input_mask': tf.io.VarLenFeature(tf.int64),
|
| 83 |
-
'masked_lm_positions': tf.io.VarLenFeature(tf.int64),
|
| 84 |
-
'masked_lm_ids': tf.io.VarLenFeature(tf.int64),
|
| 85 |
-
'masked_lm_weights': tf.io.VarLenFeature(tf.float32),
|
| 86 |
-
}
|
| 87 |
-
if self._params.use_v2_feature_names:
|
| 88 |
-
input_ids_key = 'input_word_ids'
|
| 89 |
-
segment_key = 'input_type_ids'
|
| 90 |
-
name_to_features.update({
|
| 91 |
-
input_ids_key: tf.io.VarLenFeature(tf.int64),
|
| 92 |
-
segment_key: tf.io.VarLenFeature(tf.int64),
|
| 93 |
-
})
|
| 94 |
-
else:
|
| 95 |
-
input_ids_key = 'input_ids'
|
| 96 |
-
segment_key = 'segment_ids'
|
| 97 |
-
name_to_features.update({
|
| 98 |
-
input_ids_key: tf.io.VarLenFeature(tf.int64),
|
| 99 |
-
segment_key: tf.io.VarLenFeature(tf.int64),
|
| 100 |
-
})
|
| 101 |
-
if self._use_next_sentence_label:
|
| 102 |
-
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
|
| 103 |
-
tf.int64)
|
| 104 |
-
dynamic_keys = [input_ids_key, 'input_mask', segment_key]
|
| 105 |
-
if self._use_position_id:
|
| 106 |
-
name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64)
|
| 107 |
-
dynamic_keys.append('position_ids')
|
| 108 |
-
|
| 109 |
-
example = tf.io.parse_single_example(record, name_to_features)
|
| 110 |
-
for key in dynamic_keys + self._mask_keys:
|
| 111 |
-
example[key] = tf.sparse.to_dense(example[key])
|
| 112 |
-
|
| 113 |
-
# Truncate padded data after the first non pad in the
|
| 114 |
-
# sequence length dimension.
|
| 115 |
-
# Pad before the first non pad from the back should not be removed.
|
| 116 |
-
mask = tf.math.greater(
|
| 117 |
-
tf.math.cumsum(example[input_ids_key], reverse=True), 0)
|
| 118 |
-
for key in dynamic_keys:
|
| 119 |
-
example[key] = tf.boolean_mask(example[key], mask)
|
| 120 |
-
|
| 121 |
-
# masked_lm_ids should be 0 padded.
|
| 122 |
-
# Change mask features to -1 padding so that we can differentiate
|
| 123 |
-
# padding from data or from bucketizing.
|
| 124 |
-
mask = tf.math.not_equal(example['masked_lm_ids'], 0)
|
| 125 |
-
example['masked_lm_ids'] = tf.where(
|
| 126 |
-
mask, example['masked_lm_ids'],
|
| 127 |
-
-tf.ones(tf.shape(example['masked_lm_ids']), dtype=example[key].dtype))
|
| 128 |
-
|
| 129 |
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
| 130 |
-
# So cast all int64 to int32.
|
| 131 |
-
# tf.data service uses dataset graph fingerprint to distinguish input
|
| 132 |
-
# pipeline jobs, thus we sort the keys here to make sure they are generated
|
| 133 |
-
# in a deterministic order each time the dataset function is traced.
|
| 134 |
-
for name in sorted(list(example.keys())):
|
| 135 |
-
t = example[name]
|
| 136 |
-
if t.dtype == tf.int64:
|
| 137 |
-
t = tf.cast(t, tf.int32)
|
| 138 |
-
example[name] = t
|
| 139 |
-
|
| 140 |
-
return example
|
| 141 |
-
|
| 142 |
-
def _bucketize_and_batch(
|
| 143 |
-
self,
|
| 144 |
-
dataset,
|
| 145 |
-
input_context: Optional[tf.distribute.InputContext] = None):
|
| 146 |
-
"""Bucketize by sequence length and batch the datasets."""
|
| 147 |
-
per_replica_batch_size = input_context.get_per_replica_batch_size(
|
| 148 |
-
self._global_batch_size) if input_context else self._global_batch_size
|
| 149 |
-
|
| 150 |
-
def element_length_func(example, seq_len_dim):
|
| 151 |
-
return tf.shape(example['input_word_ids'])[seq_len_dim]
|
| 152 |
-
|
| 153 |
-
bucket_boundaries = [length + 1 for length in self._seq_bucket_lengths]
|
| 154 |
-
bucket_batch_sizes = [per_replica_batch_size] * (len(bucket_boundaries) + 1)
|
| 155 |
-
|
| 156 |
-
# Bucketize and batch the dataset with per replica batch size first.
|
| 157 |
-
dataset = dataset.apply(
|
| 158 |
-
tf.data.experimental.bucket_by_sequence_length(
|
| 159 |
-
lambda example: tf.cast(element_length_func(example, 0), tf.int32),
|
| 160 |
-
bucket_boundaries,
|
| 161 |
-
bucket_batch_sizes,
|
| 162 |
-
pad_to_bucket_boundary=True,
|
| 163 |
-
drop_remainder=self._drop_remainder))
|
| 164 |
-
if input_context:
|
| 165 |
-
window_size = input_context.num_replicas_in_sync
|
| 166 |
-
if self._enable_tf_data_service and (
|
| 167 |
-
not self._enable_round_robin_tf_data_service):
|
| 168 |
-
# If tf.data service is enabled but round-robin behavior is not enabled,
|
| 169 |
-
# different TPU workers may fetch data from one tf.data service worker
|
| 170 |
-
# in different speed. We set the window size to be
|
| 171 |
-
# `seq_bucket_window_scale` larger to leave buffer if some workers are
|
| 172 |
-
# fetching data faster than others, so all the data within the same
|
| 173 |
-
# global batch can still have more chances to be in the same bucket.
|
| 174 |
-
window_size *= self._seq_bucket_window_scale
|
| 175 |
-
|
| 176 |
-
# Group `num_replicas_in_sync` batches from same bucket together, so all
|
| 177 |
-
# replicas can get the same sequence length for one global step.
|
| 178 |
-
dataset = dataset.apply(
|
| 179 |
-
tf.data.experimental.group_by_window(
|
| 180 |
-
key_func=lambda example: tf.cast( # pylint: disable=g-long-lambda
|
| 181 |
-
element_length_func(example, 1), tf.int64),
|
| 182 |
-
reduce_func=lambda _, x: tf.data.Dataset.from_tensors(x),
|
| 183 |
-
window_size=window_size))
|
| 184 |
-
dataset = dataset.flat_map(lambda x: x)
|
| 185 |
-
|
| 186 |
-
def _remove_pads_from_bucketize(features):
|
| 187 |
-
# All mask features must have the same effective length.
|
| 188 |
-
# The real masked ids padding token is -1 and 0 comes from
|
| 189 |
-
# bucket_by_sequence_length.
|
| 190 |
-
mask = tf.math.not_equal(features['masked_lm_ids'], 0)
|
| 191 |
-
|
| 192 |
-
mask_per_example = tf.math.reduce_sum(tf.cast(mask, tf.int32), axis=1)
|
| 193 |
-
normalized = tf.cast(
|
| 194 |
-
mask_per_example / tf.math.reduce_max(mask_per_example), tf.int32)
|
| 195 |
-
assert_op = tf.debugging.assert_equal(
|
| 196 |
-
tf.math.reduce_sum(normalized), per_replica_batch_size,
|
| 197 |
-
'Number of non padded mask tokens is not the same for each example '
|
| 198 |
-
'in the same sequence length.')
|
| 199 |
-
with tf.control_dependencies([assert_op]):
|
| 200 |
-
for key in self._mask_keys:
|
| 201 |
-
features[key] = tf.reshape(
|
| 202 |
-
tf.boolean_mask(
|
| 203 |
-
features[key], mask), [per_replica_batch_size, -1])
|
| 204 |
-
# Revert masked_lm_ids to be 0-padded.
|
| 205 |
-
mask = tf.math.not_equal(features['masked_lm_ids'], -1)
|
| 206 |
-
features['masked_lm_ids'] = tf.where(
|
| 207 |
-
mask, features['masked_lm_ids'],
|
| 208 |
-
tf.zeros(
|
| 209 |
-
tf.shape(features['masked_lm_ids']),
|
| 210 |
-
dtype=features['masked_lm_ids'].dtype))
|
| 211 |
-
return features
|
| 212 |
-
|
| 213 |
-
dataset = dataset.map(_remove_pads_from_bucketize)
|
| 214 |
-
return dataset
|
| 215 |
-
|
| 216 |
-
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
| 217 |
-
"""Returns a tf.dataset.Dataset."""
|
| 218 |
-
reader = input_reader.InputReader(
|
| 219 |
-
params=self._params,
|
| 220 |
-
decoder_fn=self._decode,
|
| 221 |
-
parser_fn=self._parse,
|
| 222 |
-
transform_and_batch_fn=self._bucketize_and_batch)
|
| 223 |
-
return reader.read(input_context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|