| |
|
|
| import functools |
| import seqio |
| import tensorflow_datasets as tfds |
| from t5.evaluation import metrics |
| from t5.data import preprocessors |
| import t5 |
| import tensorflow.compat.v1 as tf |
|
|
| tsv_path = { |
| "train": "gs://nb-t5x-us-central2/corpus_big/train.tsv", |
| "validation": "gs://nb-t5x-us-central2/corpus_big/eval.tsv", |
| "test": "gs://nb-t5x-us-central2/corpus_big/test.tsv" |
| } |
|
|
| vocabulary = seqio.SentencePieceVocabulary( |
| 'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0) |
|
|
| DEFAULT_OUTPUT_FEATURES = { |
| "inputs": |
| seqio.Feature( |
| vocabulary=vocabulary, add_eos=True), |
| "targets": |
| seqio.Feature( |
| vocabulary=vocabulary, add_eos=True) |
| } |
|
|
| def sentencefix_preprocessor(ds): |
| def normalize_text(text): |
| """Lowercase and remove quotes from a TensorFlow string.""" |
| text = tf.strings.regex_replace(text,"'(.*)'", r"\1") |
| return text |
|
|
| def to_inputs_and_targets(ex): |
| """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}.""" |
| return { |
| "inputs": |
| tf.strings.join( |
| [normalize_text(ex["source"])]), |
| "targets": |
| tf.strings.join( |
| [normalize_text(ex["target"])]), |
| } |
| return ds.map(to_inputs_and_targets, |
| num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
|
|
|
| seqio.TaskRegistry.add( |
| "sentencefix", |
| source=seqio.TextLineDataSource( |
| split_to_filepattern=tsv_path, |
| |
| ), |
| preprocessors=[ |
| functools.partial( |
| t5.data.preprocessors.parse_tsv, |
| field_names=["source", "target"]), |
| sentencefix_preprocessor, |
| seqio.preprocessors.tokenize_and_append_eos, |
| ], |
| |
| output_features=DEFAULT_OUTPUT_FEATURES, |
| ) |
|
|
|
|