| | |
| |
|
| |
|
| | import functools |
| | import seqio |
| | import my_metrics |
| | import tensorflow_datasets as tfds |
| | from t5.evaluation import metrics |
| | from t5.data import preprocessors |
| | |
| | import t5 |
| | import tensorflow.compat.v1 as tf |
| |
|
| |
|
| |
|
| | tsv_parliament_path = { |
| | "train": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/train.tsv", |
| | "validation": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/dev.tsv", |
| | "test": "gs://notram-public/finetune_datasets/parliament_speeches_1998_2016_frp_or_sv/test.tsv" |
| | } |
| |
|
| | tsv_summary_path = { |
| | "train": "gs://north-t5x/corpus/summary_test/norwegian_train.tsv", |
| | "validation": "gs://north-t5x/corpus/summary_test/test.tsv", |
| | "test": "gs://north-t5x/corpus/summary_test/test.tsv" |
| | } |
| |
|
| | tsv_summary_all_path = { |
| | "train": "gs://north-t5x/corpus/summary_test/cnn_and_norwegian_train.tsv", |
| | "validation": "gs://north-t5x/corpus/summary_test/test.tsv", |
| | "test": "gs://north-t5x/corpus/summary_test/test.tsv" |
| | } |
| |
|
| | tsv_translate_path = { |
| | "train": "gs://nb-t5x-us-central2/corpus_en_no/train.tsv", |
| | "validation": "gs://nb-t5x-us-central2/corpus_en_no/dev.tsv", |
| | "test": "gs://nb-t5x-us-central2/corpus_en_no/test.tsv" |
| | } |
| |
|
| |
|
| | tsv_sentiment_path = { |
| | "train": "gs://notram-public/finetune_datasets/norec_sentiment/train.tsv", |
| | "validation": "gs://notram-public/finetune_datasets/norec_sentiment/dev.tsv", |
| | "test": "gs://notram-public/finetune_datasets/norec_sentiment/test.tsv" |
| | } |
| |
|
| | json_angry_tweets_path = { |
| | "train": "gs://notram-public/finetune_datasets/angry_tweets/train.jsonl", |
| | "validation": "gs://notram-public/finetune_datasets/angry_tweets/test.jsonl", |
| | "test": "gs://notram-public/finetune_datasets/angry_tweets/test.jsonl" |
| | } |
| |
|
| | tsv_angry_tweets_path = { |
| | "train": "gs://notram-public/finetune_datasets/angry_tweets/train.tsv", |
| | "validation": "gs://notram-public/finetune_datasets/angry_tweets/test.tsv", |
| | "test": "gs://notram-public/finetune_datasets/angry_tweets/test.tsv" |
| | } |
| |
|
| |
|
| | tsv_dane_path = { |
| | "train": "gs://notram-public/finetune_datasets/dane/train.tsv", |
| | "validation": "gs://notram-public/finetune_datasets/dane/test.tsv", |
| | "test": "gs://notram-public/finetune_datasets/dane/test.tsv" |
| | } |
| |
|
| | tsv_dane_tokens_path = { |
| | "train": "gs://notram-public/finetune_datasets/dane/train_tokens.tsv", |
| | "validation": "gs://notram-public/finetune_datasets/dane/test_tokens.tsv", |
| | "test": "gs://notram-public/finetune_datasets/dane/test_tokens.tsv" |
| | } |
| |
|
| |
|
| | tsv_dane_long_tokens_path = { |
| | "train": "gs://notram-public/finetune_datasets/dane/train_long_tokens.tsv", |
| | "validation": "gs://notram-public/finetune_datasets/dane/test_long_tokens.tsv", |
| | "test": "gs://notram-public/finetune_datasets/dane/test_long_tokens.tsv" |
| | } |
| |
|
| | scand_vocabulary=seqio.SentencePieceVocabulary('gs://nb-t5/t5/vocabs/wikipedia/no-da-en-sv-nn-is_32000_unigram.sp.model', extra_ids=100) |
| | eng_vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model', extra_ids=0) |
| | mt5_vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0) |
| |
|
| | DEFAULT_OUTPUT_FEATURES = { |
| | "inputs": seqio.Feature( |
| | vocabulary=eng_vocabulary, add_eos=True, |
| | required=False), |
| | "targets": seqio.Feature( |
| | vocabulary=eng_vocabulary, add_eos=True) |
| | } |
| |
|
| |
|
| |
|
| | SCAND_OUTPUT_FEATURES = { |
| | "inputs": seqio.Feature( |
| | vocabulary=scand_vocabulary, add_eos=True, |
| | required=False), |
| | "targets": seqio.Feature( |
| | vocabulary=scand_vocabulary, add_eos=True) |
| | } |
| |
|
| | MT5_OUTPUT_FEATURES = { |
| | "inputs": seqio.Feature( |
| | vocabulary=mt5_vocabulary, add_eos=True, |
| | required=False), |
| | "targets": seqio.Feature( |
| | vocabulary=mt5_vocabulary, add_eos=True) |
| | } |
| |
|
| |
|
| |
|
| | def categorise_preprocessor(ds): |
| | def normalize_text(text): |
| | """Lowercase and remove quotes from a TensorFlow string.""" |
| | |
| | ... |
| | return text |
| |
|
| | def to_inputs_and_targets(ex): |
| | """Map {"source": ..., "source": ...}->{"target": ..., "target": ...}.""" |
| | return { |
| | "inputs": |
| | tf.strings.join( |
| | [normalize_text(ex["source"])]), |
| | "targets": |
| | tf.strings.join( |
| | [normalize_text(ex["target"])]), |
| | } |
| | return ds.map(to_inputs_and_targets, |
| | num_parallel_calls=tf.data.experimental.AUTOTUNE) |
| |
|
| |
|
| | seqio.TaskRegistry.add( |
| | "parliament", |
| | source=seqio.TextLineDataSource( |
| | split_to_filepattern=tsv_parliament_path, |
| | |
| | ), |
| | preprocessors=[ |
| | functools.partial( |
| | t5.data.preprocessors.parse_tsv, |
| | field_names=["target","source"]), |
| | categorise_preprocessor, |
| | seqio.preprocessors.tokenize_and_append_eos, |
| | ], |
| | metric_fns=[metrics.accuracy,my_metrics.f1_macro], |
| | output_features=DEFAULT_OUTPUT_FEATURES, |
| | ) |
| |
|
| | seqio.TaskRegistry.add( |
| | "sentiment", |
| | source=seqio.TextLineDataSource( |
| | split_to_filepattern=tsv_sentiment_path, |
| | |
| | ), |
| | preprocessors=[ |
| | functools.partial( |
| | t5.data.preprocessors.parse_tsv, |
| | field_names=["target","source"]), |
| | categorise_preprocessor, |
| | seqio.preprocessors.tokenize_and_append_eos, |
| | ], |
| | metric_fns=[metrics.accuracy,my_metrics.f1_macro], |
| | output_features=DEFAULT_OUTPUT_FEATURES, |
| | ) |
| |
|
| | seqio.TaskRegistry.add( |
| | "angry_tweets", |
| | source=seqio.TextLineDataSource( |
| | split_to_filepattern=tsv_angry_tweets_path, |
| | |
| | ), |
| | preprocessors=[ |
| | functools.partial( |
| | t5.data.preprocessors.parse_tsv, |
| | field_names=["target","source"]), |
| | categorise_preprocessor, |
| | seqio.preprocessors.tokenize_and_append_eos, |
| | ], |
| | metric_fns=[metrics.accuracy,my_metrics.f1_macro], |
| | output_features=DEFAULT_OUTPUT_FEATURES, |
| | ) |
| |
|
| | seqio.TaskRegistry.add( |
| | "dane", |
| | source=seqio.TextLineDataSource( |
| | split_to_filepattern=tsv_dane_long_tokens_path, |
| | |
| | ), |
| | preprocessors=[ |
| | functools.partial( |
| | t5.data.preprocessors.parse_tsv, |
| | field_names=["placeholder1","placeholder2","placeholder3","target","source"]), |
| | categorise_preprocessor, |
| | seqio.preprocessors.tokenize_and_append_eos, |
| | ], |
| | metric_fns=[metrics.accuracy,my_metrics.f1_macro], |
| | output_features=DEFAULT_OUTPUT_FEATURES, |
| | ) |
| |
|
| | seqio.TaskRegistry.add( |
| | "summary_scand", |
| | source=seqio.TextLineDataSource( |
| | split_to_filepattern=tsv_summary_path, |
| | |
| | ), |
| | preprocessors=[ |
| | functools.partial( |
| | t5.data.preprocessors.parse_tsv, |
| | field_names=["source","target"]), |
| | categorise_preprocessor, |
| | seqio.preprocessors.tokenize_and_append_eos, |
| | ], |
| | metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu,metrics.rouge], |
| | output_features=SCAND_OUTPUT_FEATURES, |
| | ) |
| |
|
| | seqio.TaskRegistry.add( |
| | "summary", |
| | source=seqio.TextLineDataSource( |
| | split_to_filepattern=tsv_summary_path, |
| | |
| | ), |
| | preprocessors=[ |
| | functools.partial( |
| | t5.data.preprocessors.parse_tsv, |
| | field_names=["source","target"]), |
| | categorise_preprocessor, |
| | seqio.preprocessors.tokenize_and_append_eos, |
| | ], |
| | metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu,metrics.rouge], |
| | output_features=MT5_OUTPUT_FEATURES, |
| | ) |
| |
|
| | seqio.TaskRegistry.add( |
| | "summary_all", |
| | source=seqio.TextLineDataSource( |
| | split_to_filepattern=tsv_summary_all_path, |
| | |
| | ), |
| | preprocessors=[ |
| | functools.partial( |
| | t5.data.preprocessors.parse_tsv, |
| | field_names=["source","target"]), |
| | categorise_preprocessor, |
| | seqio.preprocessors.tokenize_and_append_eos, |
| | ], |
| | metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu,metrics.rouge], |
| | output_features=MT5_OUTPUT_FEATURES, |
| | ) |
| | seqio.TaskRegistry.add( |
| | "summary_all_scand", |
| | source=seqio.TextLineDataSource( |
| | split_to_filepattern=tsv_summary_all_path, |
| | |
| | ), |
| | preprocessors=[ |
| | functools.partial( |
| | t5.data.preprocessors.parse_tsv, |
| | field_names=["source","target"]), |
| | categorise_preprocessor, |
| | seqio.preprocessors.tokenize_and_append_eos, |
| | ], |
| | metric_fns=[metrics.accuracy,my_metrics.f1_macro,metrics.bleu,metrics.rouge], |
| | output_features=SCAND_OUTPUT_FEATURES, |
| | ) |
| |
|
| |
|