| import functools |
| import seqio |
| import tensorflow as tf |
| import t5.data |
| from datasets import load_dataset, load_from_disk |
| from t5.data import postprocessors |
| from t5.data import preprocessors |
| from t5.evaluation import metrics |
| from seqio import FunctionDataSource, utils |
|
|
| from ul2_objective import ul2_objective |
|
|
| |
| R_DENOISER_SPAN_LENGTHS = [3.0, 8.0] |
| X_DENOISER_SPAN_LENGTHS = [3.0, 8.0, 64.0, 64.0] |
| R_DENOISER_CORRUPT_RATES = [0.15, 0.15] |
| X_DENOISER_CORRUPT_RATES = [0.5, 0.5, 0.15, 0.5] |
|
|
| R_DENOISER_TOKEN_PREFIX = '[NLU]' |
| X_DENOISER_TOKEN_PREFIX = '[NLG]' |
| S_DENOISER_TOKEN_PREFIX = '[S2S]' |
|
|
| TaskRegistry = seqio.TaskRegistry |
|
|
| vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0) |
|
|
| DEFAULT_OUTPUT_FEATURES = { |
| "inputs": seqio.Feature( |
| vocabulary=vocabulary, add_eos=True, |
| required=False), |
| "targets": seqio.Feature( |
| vocabulary=vocabulary, add_eos=True) |
| } |
|
|
|
|
| def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None): |
| if shuffle: |
| if seed: |
| dataset = dataset.shuffle(seed=seed) |
| else: |
| dataset = dataset.shuffle() |
| while True: |
| for item in dataset[str(split)]: |
| if item[column] is not None: |
| yield item[column] |
|
|
|
|
| def dataset_fn(split, shuffle_files, seed=None, dataset=None): |
| return tf.data.Dataset.from_generator( |
| functools.partial(gen_dataset, split, shuffle_files, |
| seed, dataset=dataset), |
| output_signature=tf.TensorSpec( |
| shape=(), dtype=tf.string, name=dataset_name) |
| ) |
|
|
|
|
| @utils.map_over_dataset |
| def target_to_key(x, key_map, target_key): |
| """Assign the value from the dataset to target_key in key_map""" |
| return {**key_map, target_key: x} |
|
|
|
|
| dataset_name = "/researchdisk/lm_training_dataset_full" |
| dataset_params = {"from_disk_path": dataset_name} |
|
|
| if "from_disk_path" in dataset_params: |
| dataset = load_from_disk(dataset_params.get("from_disk_path")) |
| else: |
| dataset = load_dataset(**dataset_params) |
|
|
| dataset_shapes = {"train": dataset["train"].num_rows, |
| "validation": dataset["validation"].num_rows} |
|
|
| TaskRegistry.add( |
| "pretrain_finnish_ul2", |
| source=seqio.FunctionDataSource( |
| dataset_fn=functools.partial(dataset_fn, dataset=dataset), |
| splits=("train", "validation"), |
| caching_permitted=False, |
| num_input_examples=dataset_shapes, |
| ), |
| preprocessors=[ |
| functools.partial( |
| target_to_key, key_map={ |
| "inputs": None, |
| "targets": None, |
| }, target_key="targets"), |
| seqio.preprocessors.tokenize, |
| functools.partial( |
| ul2_objective, |
| shard_ds=False, |
| use_prefix_lm_task=True, |
| rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)]*len(R_DENOISER_SPAN_LENGTHS) + [ |
| 0.4 / len(X_DENOISER_SPAN_LENGTHS)]*len(X_DENOISER_SPAN_LENGTHS) + [0.2], |
| mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS, |
| noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES, |
| optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]*len(R_DENOISER_SPAN_LENGTHS) + [ |
| X_DENOISER_TOKEN_PREFIX]*len(X_DENOISER_SPAN_LENGTHS) + [S_DENOISER_TOKEN_PREFIX], |
| reserved_for_packing=1, |
| ), |
| seqio.preprocessors.append_eos_after_trim, |
| ], |
| output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]}, |
| metric_fns=[metrics.accuracy] |
| ) |
|
|