| |
|
|
| import functools |
|
|
| import seqio |
| import tensorflow as tf |
| import t5.data |
| from datasets import load_dataset, load_from_disk |
| from t5.data import postprocessors |
| from t5.data import preprocessors |
| from t5.evaluation import metrics |
| from seqio import FunctionDataSource, utils |
|
|
| TaskRegistry = seqio.TaskRegistry |
|
|
| vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0) |
|
|
| DEFAULT_OUTPUT_FEATURES = { |
| "inputs": seqio.Feature( |
| vocabulary=vocabulary, add_eos=True, |
| required=False), |
| "targets": seqio.Feature( |
| vocabulary=vocabulary, add_eos=True) |
| } |
|
|
|
|
| def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None): |
| if shuffle: |
| if seed: |
| dataset = dataset.shuffle(seed=seed) |
| else: |
| dataset = dataset.shuffle() |
| while True: |
| for item in dataset[str(split)]: |
| yield item[column] |
|
|
|
|
| def dataset_fn(split, shuffle_files, seed=None, dataset=None): |
| return tf.data.Dataset.from_generator( |
| functools.partial(gen_dataset, split, shuffle_files, seed, dataset=dataset), |
| output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name) |
| ) |
|
|
|
|
| @utils.map_over_dataset |
| def target_to_key(x, key_map, target_key): |
| """Assign the value from the dataset to target_key in key_map""" |
| return {**key_map, target_key: x} |
|
|
|
|
| |
| dataset_name = "/researchdisk/lm_training_dataset_full" |
| dataset_params = {"from_disk_path": dataset_name} |
|
|
| if "from_disk_path" in dataset_params: |
| dataset = load_from_disk(dataset_params.get("from_disk_path")) |
| else: |
| dataset = load_dataset(**dataset_params) |
|
|
| dataset_shapes = {"train": dataset["train"].num_rows, "validation": dataset["validation"].num_rows} |
| TaskRegistry.add( |
| "pretrain_finnish", |
| source=seqio.FunctionDataSource( |
| dataset_fn=functools.partial(dataset_fn, dataset=dataset), |
| splits=("train", "validation"), |
| caching_permitted=False, |
| num_input_examples=dataset_shapes, |
| ), |
| preprocessors=[ |
| functools.partial( |
| target_to_key, key_map={ |
| "inputs": None, |
| "targets": None, |
| }, target_key="targets"), |
| seqio.preprocessors.tokenize, |
| |
| preprocessors.span_corruption, |
| seqio.preprocessors.append_eos_after_trim, |
| ], |
| output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]}, |
| metric_fns=[metrics.accuracy] |
| ) |