Spaces:
Sleeping
Sleeping
Delete sentence_prediction_dataloader.py
Browse files
sentence_prediction_dataloader.py
DELETED
|
@@ -1,267 +0,0 @@
|
|
| 1 |
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
"""Loads dataset for the sentence prediction (classification) task."""
|
| 16 |
-
import dataclasses
|
| 17 |
-
import functools
|
| 18 |
-
from typing import List, Mapping, Optional, Tuple
|
| 19 |
-
|
| 20 |
-
import tensorflow as tf, tf_keras
|
| 21 |
-
import tensorflow_hub as hub
|
| 22 |
-
|
| 23 |
-
from official.common import dataset_fn
|
| 24 |
-
from official.core import config_definitions as cfg
|
| 25 |
-
from official.core import input_reader
|
| 26 |
-
from official.nlp import modeling
|
| 27 |
-
from official.nlp.data import data_loader
|
| 28 |
-
from official.nlp.data import data_loader_factory
|
| 29 |
-
|
| 30 |
-
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
@dataclasses.dataclass
|
| 34 |
-
class SentencePredictionDataConfig(cfg.DataConfig):
|
| 35 |
-
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
|
| 36 |
-
input_path: str = ''
|
| 37 |
-
global_batch_size: int = 32
|
| 38 |
-
is_training: bool = True
|
| 39 |
-
seq_length: int = 128
|
| 40 |
-
label_type: str = 'int'
|
| 41 |
-
# Whether to include the example id number.
|
| 42 |
-
include_example_id: bool = False
|
| 43 |
-
label_field: str = 'label_ids'
|
| 44 |
-
# Maps the key in TfExample to feature name.
|
| 45 |
-
# E.g 'label_ids' to 'next_sentence_labels'
|
| 46 |
-
label_name: Optional[Tuple[str, str]] = None
|
| 47 |
-
# Either tfrecord, sstable, or recordio.
|
| 48 |
-
file_type: str = 'tfrecord'
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
|
| 52 |
-
class SentencePredictionDataLoader(data_loader.DataLoader):
|
| 53 |
-
"""A class to load dataset for sentence prediction (classification) task."""
|
| 54 |
-
|
| 55 |
-
def __init__(self, params):
|
| 56 |
-
self._params = params
|
| 57 |
-
self._seq_length = params.seq_length
|
| 58 |
-
self._include_example_id = params.include_example_id
|
| 59 |
-
self._label_field = params.label_field
|
| 60 |
-
if params.label_name:
|
| 61 |
-
self._label_name_mapping = dict([params.label_name])
|
| 62 |
-
else:
|
| 63 |
-
self._label_name_mapping = dict()
|
| 64 |
-
|
| 65 |
-
def name_to_features_spec(self):
|
| 66 |
-
"""Defines features to decode. Subclass may override to append features."""
|
| 67 |
-
label_type = LABEL_TYPES_MAP[self._params.label_type]
|
| 68 |
-
name_to_features = {
|
| 69 |
-
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
| 70 |
-
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
| 71 |
-
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
| 72 |
-
self._label_field: tf.io.FixedLenFeature([], label_type),
|
| 73 |
-
}
|
| 74 |
-
if self._include_example_id:
|
| 75 |
-
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
|
| 76 |
-
|
| 77 |
-
return name_to_features
|
| 78 |
-
|
| 79 |
-
def _decode(self, record: tf.Tensor):
|
| 80 |
-
"""Decodes a serialized tf.Example."""
|
| 81 |
-
example = tf.io.parse_single_example(record, self.name_to_features_spec())
|
| 82 |
-
|
| 83 |
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
| 84 |
-
# So cast all int64 to int32.
|
| 85 |
-
for name in example:
|
| 86 |
-
t = example[name]
|
| 87 |
-
if t.dtype == tf.int64:
|
| 88 |
-
t = tf.cast(t, tf.int32)
|
| 89 |
-
example[name] = t
|
| 90 |
-
|
| 91 |
-
return example
|
| 92 |
-
|
| 93 |
-
def _parse(self, record: Mapping[str, tf.Tensor]):
|
| 94 |
-
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
|
| 95 |
-
key_mapping = {
|
| 96 |
-
'input_ids': 'input_word_ids',
|
| 97 |
-
'input_mask': 'input_mask',
|
| 98 |
-
'segment_ids': 'input_type_ids'
|
| 99 |
-
}
|
| 100 |
-
ret = {}
|
| 101 |
-
for record_key in record:
|
| 102 |
-
if record_key in key_mapping:
|
| 103 |
-
ret[key_mapping[record_key]] = record[record_key]
|
| 104 |
-
else:
|
| 105 |
-
ret[record_key] = record[record_key]
|
| 106 |
-
|
| 107 |
-
if self._label_field in self._label_name_mapping:
|
| 108 |
-
ret[self._label_name_mapping[self._label_field]] = record[
|
| 109 |
-
self._label_field]
|
| 110 |
-
|
| 111 |
-
return ret
|
| 112 |
-
|
| 113 |
-
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
| 114 |
-
"""Returns a tf.dataset.Dataset."""
|
| 115 |
-
reader = input_reader.InputReader(
|
| 116 |
-
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
| 117 |
-
params=self._params,
|
| 118 |
-
decoder_fn=self._decode,
|
| 119 |
-
parser_fn=self._parse)
|
| 120 |
-
return reader.read(input_context)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
@dataclasses.dataclass
|
| 124 |
-
class SentencePredictionTextDataConfig(cfg.DataConfig):
|
| 125 |
-
"""Data config for sentence prediction task with raw text."""
|
| 126 |
-
# Either set `input_path`...
|
| 127 |
-
input_path: str = ''
|
| 128 |
-
# Either `int` or `float`.
|
| 129 |
-
label_type: str = 'int'
|
| 130 |
-
# ...or `tfds_name` and `tfds_split` to specify input.
|
| 131 |
-
tfds_name: str = ''
|
| 132 |
-
tfds_split: str = ''
|
| 133 |
-
# The name of the text feature fields. The text features will be
|
| 134 |
-
# concatenated in order.
|
| 135 |
-
text_fields: Optional[List[str]] = None
|
| 136 |
-
label_field: str = 'label'
|
| 137 |
-
global_batch_size: int = 32
|
| 138 |
-
seq_length: int = 128
|
| 139 |
-
is_training: bool = True
|
| 140 |
-
# Either build preprocessing with Python code by specifying these values
|
| 141 |
-
# for modeling.layers.BertTokenizer()/SentencepieceTokenizer()....
|
| 142 |
-
tokenization: str = 'WordPiece' # WordPiece or SentencePiece
|
| 143 |
-
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
|
| 144 |
-
# file if tokenization is SentencePiece.
|
| 145 |
-
vocab_file: str = ''
|
| 146 |
-
lower_case: bool = True
|
| 147 |
-
# ...or load preprocessing from a SavedModel at this location.
|
| 148 |
-
preprocessing_hub_module_url: str = ''
|
| 149 |
-
# Either tfrecord or sstsable or recordio.
|
| 150 |
-
file_type: str = 'tfrecord'
|
| 151 |
-
include_example_id: bool = False
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
class TextProcessor(tf.Module):
|
| 155 |
-
"""Text features processing for sentence prediction task."""
|
| 156 |
-
|
| 157 |
-
def __init__(self,
|
| 158 |
-
seq_length: int,
|
| 159 |
-
vocab_file: Optional[str] = None,
|
| 160 |
-
tokenization: Optional[str] = None,
|
| 161 |
-
lower_case: Optional[bool] = True,
|
| 162 |
-
preprocessing_hub_module_url: Optional[str] = None):
|
| 163 |
-
if preprocessing_hub_module_url:
|
| 164 |
-
self._preprocessing_hub_module = hub.load(preprocessing_hub_module_url)
|
| 165 |
-
self._tokenizer = self._preprocessing_hub_module.tokenize
|
| 166 |
-
self._pack_inputs = functools.partial(
|
| 167 |
-
self._preprocessing_hub_module.bert_pack_inputs,
|
| 168 |
-
seq_length=seq_length)
|
| 169 |
-
return
|
| 170 |
-
|
| 171 |
-
if tokenization == 'WordPiece':
|
| 172 |
-
self._tokenizer = modeling.layers.BertTokenizer(
|
| 173 |
-
vocab_file=vocab_file, lower_case=lower_case)
|
| 174 |
-
elif tokenization == 'SentencePiece':
|
| 175 |
-
self._tokenizer = modeling.layers.SentencepieceTokenizer(
|
| 176 |
-
model_file_path=vocab_file,
|
| 177 |
-
lower_case=lower_case,
|
| 178 |
-
strip_diacritics=True) # Strip diacritics to follow ALBERT model
|
| 179 |
-
else:
|
| 180 |
-
raise ValueError('Unsupported tokenization: %s' % tokenization)
|
| 181 |
-
|
| 182 |
-
self._pack_inputs = modeling.layers.BertPackInputs(
|
| 183 |
-
seq_length=seq_length,
|
| 184 |
-
special_tokens_dict=self._tokenizer.get_special_tokens_dict())
|
| 185 |
-
|
| 186 |
-
def __call__(self, segments):
|
| 187 |
-
segments = [self._tokenizer(s) for s in segments]
|
| 188 |
-
# BertTokenizer returns a RaggedTensor with shape [batch, word, subword],
|
| 189 |
-
# and SentencepieceTokenizer returns a RaggedTensor with shape
|
| 190 |
-
# [batch, sentencepiece],
|
| 191 |
-
segments = [
|
| 192 |
-
tf.cast(x.merge_dims(1, -1) if x.shape.rank > 2 else x, tf.int32)
|
| 193 |
-
for x in segments
|
| 194 |
-
]
|
| 195 |
-
return self._pack_inputs(segments)
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
@data_loader_factory.register_data_loader_cls(SentencePredictionTextDataConfig)
|
| 199 |
-
class SentencePredictionTextDataLoader(data_loader.DataLoader):
|
| 200 |
-
"""Loads dataset with raw text for sentence prediction task."""
|
| 201 |
-
|
| 202 |
-
def __init__(self, params):
|
| 203 |
-
if bool(params.tfds_name) != bool(params.tfds_split):
|
| 204 |
-
raise ValueError('`tfds_name` and `tfds_split` should be specified or '
|
| 205 |
-
'unspecified at the same time.')
|
| 206 |
-
if bool(params.tfds_name) == bool(params.input_path):
|
| 207 |
-
raise ValueError('Must specify either `tfds_name` and `tfds_split` '
|
| 208 |
-
'or `input_path`.')
|
| 209 |
-
if not params.text_fields:
|
| 210 |
-
raise ValueError('Unexpected empty text fields.')
|
| 211 |
-
if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
|
| 212 |
-
raise ValueError('Must specify exactly one of vocab_file (with matching '
|
| 213 |
-
'lower_case flag) or preprocessing_hub_module_url.')
|
| 214 |
-
|
| 215 |
-
self._params = params
|
| 216 |
-
self._text_fields = params.text_fields
|
| 217 |
-
self._label_field = params.label_field
|
| 218 |
-
self._label_type = params.label_type
|
| 219 |
-
self._include_example_id = params.include_example_id
|
| 220 |
-
self._text_processor = TextProcessor(
|
| 221 |
-
seq_length=params.seq_length,
|
| 222 |
-
vocab_file=params.vocab_file,
|
| 223 |
-
tokenization=params.tokenization,
|
| 224 |
-
lower_case=params.lower_case,
|
| 225 |
-
preprocessing_hub_module_url=params.preprocessing_hub_module_url)
|
| 226 |
-
|
| 227 |
-
def _bert_preprocess(self, record: Mapping[str, tf.Tensor]):
|
| 228 |
-
"""Berts preprocess."""
|
| 229 |
-
segments = [record[x] for x in self._text_fields]
|
| 230 |
-
model_inputs = self._text_processor(segments)
|
| 231 |
-
for key in record:
|
| 232 |
-
if key not in self._text_fields:
|
| 233 |
-
model_inputs[key] = record[key]
|
| 234 |
-
return model_inputs
|
| 235 |
-
|
| 236 |
-
def name_to_features_spec(self):
|
| 237 |
-
name_to_features = {}
|
| 238 |
-
for text_field in self._text_fields:
|
| 239 |
-
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
|
| 240 |
-
|
| 241 |
-
label_type = LABEL_TYPES_MAP[self._label_type]
|
| 242 |
-
name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
|
| 243 |
-
if self._include_example_id:
|
| 244 |
-
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
|
| 245 |
-
return name_to_features
|
| 246 |
-
|
| 247 |
-
def _decode(self, record: tf.Tensor):
|
| 248 |
-
"""Decodes a serialized tf.Example."""
|
| 249 |
-
example = tf.io.parse_single_example(record, self.name_to_features_spec())
|
| 250 |
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
| 251 |
-
# So cast all int64 to int32.
|
| 252 |
-
for name in example:
|
| 253 |
-
t = example[name]
|
| 254 |
-
if t.dtype == tf.int64:
|
| 255 |
-
t = tf.cast(t, tf.int32)
|
| 256 |
-
example[name] = t
|
| 257 |
-
|
| 258 |
-
return example
|
| 259 |
-
|
| 260 |
-
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
| 261 |
-
"""Returns a tf.dataset.Dataset."""
|
| 262 |
-
reader = input_reader.InputReader(
|
| 263 |
-
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
| 264 |
-
decoder_fn=self._decode if self._params.input_path else None,
|
| 265 |
-
params=self._params,
|
| 266 |
-
postprocess_fn=self._bert_preprocess)
|
| 267 |
-
return reader.read(input_context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|