| | |
| |
|
| | import functools |
| | import seqio |
| | import tensorflow_datasets as tfds |
| | from t5.evaluation import metrics |
| | import my_metrics |
| | from t5.data import preprocessors |
| | import t5 |
| | import tensorflow.compat.v1 as tf |
| |
|
| | tsv_path = { |
| | "train": "gs://north-t5x/corpus/danish_hate/train.tsv", |
| | "validation": "gs://north-t5x/corpus/danish_hate/eval.tsv", |
| | "test": "gs://north-t5x/corpus/danish_hate/eval.tsv" |
| | } |
| |
|
| | vocabulary = seqio.SentencePieceVocabulary( |
| | 'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0) |
| |
|
| | DEFAULT_OUTPUT_FEATURES = { |
| | "inputs": |
| | seqio.Feature( |
| | vocabulary=vocabulary, add_eos=True), |
| | "targets": |
| | seqio.Feature( |
| | vocabulary=vocabulary, add_eos=True) |
| | } |
| |
|
| | def categorise_preprocessor(ds): |
| | def normalize_text(text): |
| | """Lowercase and remove quotes from a TensorFlow string.""" |
| | text = tf.strings.regex_replace(text,"'(.*)'", r"\1") |
| | return text |
| |
|
| | def to_inputs_and_targets(ex): |
| | """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}.""" |
| | return { |
| | "inputs": |
| | tf.strings.join( |
| | [normalize_text(ex["source"])]), |
| | "targets": |
| | tf.strings.join( |
| | [normalize_text(ex["target"])]), |
| | } |
| |
|
| | return ds.map(to_inputs_and_targets, |
| | num_parallel_calls=tf.data.experimental.AUTOTUNE) |
| |
|
| |
|
| | seqio.TaskRegistry.add( |
| | "classify_hate", |
| | source=seqio.TextLineDataSource( |
| | split_to_filepattern=tsv_path, |
| | |
| | ), |
| | preprocessors=[ |
| | functools.partial( |
| | t5.data.preprocessors.parse_tsv, |
| | field_names=["pid","id","source","target"]), |
| | categorise_preprocessor, |
| | seqio.preprocessors.tokenize_and_append_eos, |
| | ], |
| | metric_fns=[metrics.accuracy,my_metrics.f1_macro], |
| | output_features=DEFAULT_OUTPUT_FEATURES, |
| | ) |
| |
|
| |
|