|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Methods related to input datasets and readers."""
|
|
|
| import functools
|
| import sys
|
|
|
| from absl import logging
|
|
|
| import tensorflow as tf
|
| from tensorflow import estimator as tf_estimator
|
| import tensorflow_datasets as tfds
|
| import tensorflow_text as tftext
|
|
|
| from layers import projection_layers
|
| from utils import misc_utils
|
|
|
|
|
| def imdb_reviews(features, _):
|
| return features["text"], features["label"]
|
|
|
|
|
| def civil_comments(features, runner_config):
|
| labels = runner_config["model_config"]["labels"]
|
| label_tensor = tf.stack([features[label] for label in labels], axis=1)
|
| label_tensor = tf.floor(label_tensor + 0.5)
|
| return features["text"], label_tensor
|
|
|
|
|
| def goemotions(features, runner_config):
|
| labels = runner_config["model_config"]["labels"]
|
| label_tensor = tf.stack([features[label] for label in labels], axis=1)
|
| return features["comment_text"], tf.cast(label_tensor, tf.float32)
|
|
|
|
|
| def create_input_fn(runner_config, mode, drop_remainder):
|
| """Returns an input function to use in the instantiation of tf.estimator.*."""
|
|
|
| def _post_processor(features, batch_size):
|
| """Post process the data to a form expected by model_fn."""
|
| data_processor = getattr(sys.modules[__name__], runner_config["dataset"])
|
| text, label = data_processor(features, runner_config)
|
| model_config = runner_config["model_config"]
|
| if "max_seq_len" in model_config:
|
| max_seq_len = model_config["max_seq_len"]
|
| logging.info("Truncating text to have at most %d tokens", max_seq_len)
|
| text = misc_utils.random_substr(text, max_seq_len)
|
| text = tf.reshape(text, [batch_size])
|
| num_classes = len(model_config["labels"])
|
| label = tf.reshape(label, [batch_size, num_classes])
|
| prxlayer = projection_layers.ProjectionLayer(model_config, mode)
|
| projection, seq_length = prxlayer(text)
|
| gbst_max_token_len = max_seq_len
|
| if "gbst_max_token_len" in model_config:
|
| gbst_max_token_len = model_config["gbst_max_token_len"]
|
| byte_int = tftext.ByteSplitter().split(text).to_tensor(
|
| default_value=0, shape=[batch_size, gbst_max_token_len])
|
| token_ids = tf.cast(byte_int, tf.int32)
|
| token_len = tf.strings.length(text)
|
| mask = tf.cast(
|
| tf.sequence_mask(token_len, maxlen=gbst_max_token_len), tf.int32)
|
| mask *= 3
|
| token_ids += mask
|
| return {
|
| "projection": projection,
|
| "seq_length": seq_length,
|
| "token_ids": token_ids,
|
| "token_len": token_len,
|
| "label": label
|
| }
|
|
|
| def _input_fn(params):
|
| """Method to be used for reading the data."""
|
| assert mode != tf_estimator.ModeKeys.PREDICT
|
| split = "train" if mode == tf_estimator.ModeKeys.TRAIN else "test"
|
| ds = tfds.load(runner_config["dataset"], split=split)
|
| ds = ds.batch(params["batch_size"], drop_remainder=drop_remainder)
|
| ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
|
| ds = ds.shuffle(buffer_size=100)
|
| ds = ds.repeat(count=1 if mode == tf_estimator.ModeKeys.EVAL else None)
|
| ds = ds.map(
|
| functools.partial(_post_processor, batch_size=params["batch_size"]),
|
| num_parallel_calls=tf.data.experimental.AUTOTUNE,
|
| deterministic=False)
|
| return ds
|
|
|
| return _input_fn
|
|
|