Spaces:
Sleeping
Sleeping
| # Copyright 2024 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """BERT finetuning task dataset generator.""" | |
| import functools | |
| import json | |
| import os | |
| # Import libraries | |
| from absl import app | |
| from absl import flags | |
| import tensorflow as tf, tf_keras | |
| from official.nlp.data import classifier_data_lib | |
| from official.nlp.data import sentence_retrieval_lib | |
| # word-piece tokenizer based squad_lib | |
| from official.nlp.data import squad_lib as squad_lib_wp | |
| # sentence-piece tokenizer based squad_lib | |
| from official.nlp.data import squad_lib_sp | |
| from official.nlp.data import tagging_data_lib | |
| from official.nlp.tools import tokenization | |
| FLAGS = flags.FLAGS | |
| flags.DEFINE_enum( | |
| "fine_tuning_task_type", "classification", | |
| ["classification", "regression", "squad", "retrieval", "tagging"], | |
| "The name of the BERT fine tuning task for which data " | |
| "will be generated.") | |
| # BERT classification specific flags. | |
| flags.DEFINE_string( | |
| "input_data_dir", None, | |
| "The input data dir. Should contain the .tsv files (or other data files) " | |
| "for the task.") | |
| flags.DEFINE_enum( | |
| "classification_task_name", "MNLI", [ | |
| "AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE", | |
| "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X", | |
| "AX-g", "SUPERGLUE-RTE", "CB", "BoolQ", "WIC" | |
| ], "The name of the task to train BERT classifier. The " | |
| "difference between XTREME-XNLI and XNLI is: 1. the format " | |
| "of input tsv files; 2. the dev set for XTREME is english " | |
| "only and for XNLI is all languages combined. Same for " | |
| "PAWS-X.") | |
| # MNLI task-specific flag. | |
| flags.DEFINE_enum("mnli_type", "matched", ["matched", "mismatched"], | |
| "The type of MNLI dataset.") | |
| # XNLI task-specific flag. | |
| flags.DEFINE_string( | |
| "xnli_language", "en", | |
| "Language of training data for XNLI task. If the value is 'all', the data " | |
| "of all languages will be used for training.") | |
| # PAWS-X task-specific flag. | |
| flags.DEFINE_string( | |
| "pawsx_language", "en", | |
| "Language of training data for PAWS-X task. If the value is 'all', the data " | |
| "of all languages will be used for training.") | |
| # XTREME classification specific flags. Only used in XtremePawsx and XtremeXnli. | |
| flags.DEFINE_string( | |
| "translated_input_data_dir", None, | |
| "The translated input data dir. Should contain the .tsv files (or other " | |
| "data files) for the task.") | |
| # Retrieval task-specific flags. | |
| flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"], | |
| "The name of sentence retrieval task for scoring") | |
| # Tagging task-specific flags. | |
| flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"], | |
| "The name of BERT tagging (token classification) task.") | |
| flags.DEFINE_bool("tagging_only_use_en_train", True, | |
| "Whether only use english training data in tagging.") | |
| # BERT Squad task-specific flags. | |
| flags.DEFINE_string( | |
| "squad_data_file", None, | |
| "The input data file in for generating training data for BERT squad task.") | |
| flags.DEFINE_string( | |
| "translated_squad_data_folder", None, | |
| "The translated data folder for generating training data for BERT squad " | |
| "task.") | |
| flags.DEFINE_integer( | |
| "doc_stride", 128, | |
| "When splitting up a long document into chunks, how much stride to " | |
| "take between chunks.") | |
| flags.DEFINE_integer( | |
| "max_query_length", 64, | |
| "The maximum number of tokens for the question. Questions longer than " | |
| "this will be truncated to this length.") | |
| flags.DEFINE_bool( | |
| "version_2_with_negative", False, | |
| "If true, the SQuAD examples contain some that do not have an answer.") | |
| flags.DEFINE_bool( | |
| "xlnet_format", False, | |
| "If true, then data will be preprocessed in a paragraph, query, class order" | |
| " instead of the BERT-style class, paragraph, query order.") | |
| # XTREME specific flags. | |
| flags.DEFINE_bool("only_use_en_dev", True, "Whether only use english dev data.") | |
| # Shared flags across BERT fine-tuning tasks. | |
| flags.DEFINE_string("vocab_file", None, | |
| "The vocabulary file that the BERT model was trained on.") | |
| flags.DEFINE_string( | |
| "train_data_output_path", None, | |
| "The path in which generated training input data will be written as tf" | |
| " records.") | |
| flags.DEFINE_string( | |
| "eval_data_output_path", None, | |
| "The path in which generated evaluation input data will be written as tf" | |
| " records.") | |
| flags.DEFINE_string( | |
| "test_data_output_path", None, | |
| "The path in which generated test input data will be written as tf" | |
| " records. If None, do not generate test data. Must be a pattern template" | |
| " as test_{}.tfrecords if processor has language specific test data.") | |
| flags.DEFINE_string("meta_data_file_path", None, | |
| "The path in which input meta data will be written.") | |
| flags.DEFINE_bool( | |
| "do_lower_case", True, | |
| "Whether to lower case the input text. Should be True for uncased " | |
| "models and False for cased models.") | |
| flags.DEFINE_integer( | |
| "max_seq_length", 128, | |
| "The maximum total input sequence length after WordPiece tokenization. " | |
| "Sequences longer than this will be truncated, and sequences shorter " | |
| "than this will be padded.") | |
| flags.DEFINE_string("sp_model_file", "", | |
| "The path to the model used by sentence piece tokenizer.") | |
| flags.DEFINE_enum( | |
| "tokenization", "WordPiece", ["WordPiece", "SentencePiece"], | |
| "Specifies the tokenizer implementation, i.e., whether to use WordPiece " | |
| "or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, " | |
| "while ALBERT uses SentencePiece tokenizer.") | |
| flags.DEFINE_string( | |
| "tfds_params", "", "Comma-separated list of TFDS parameter assignments for " | |
| "generic classfication data import (for more details " | |
| "see the TfdsProcessor class documentation).") | |
| def generate_classifier_dataset(): | |
| """Generates classifier dataset and returns input meta data.""" | |
| if FLAGS.classification_task_name in [ | |
| "COLA", | |
| "WNLI", | |
| "SST-2", | |
| "MRPC", | |
| "QQP", | |
| "STS-B", | |
| "MNLI", | |
| "QNLI", | |
| "RTE", | |
| "AX", | |
| "SUPERGLUE-RTE", | |
| "CB", | |
| "BoolQ", | |
| "WIC", | |
| ]: | |
| assert not FLAGS.input_data_dir or FLAGS.tfds_params | |
| else: | |
| assert (FLAGS.input_data_dir and FLAGS.classification_task_name or | |
| FLAGS.tfds_params) | |
| if FLAGS.tokenization == "WordPiece": | |
| tokenizer = tokenization.FullTokenizer( | |
| vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) | |
| processor_text_fn = tokenization.convert_to_unicode | |
| else: | |
| assert FLAGS.tokenization == "SentencePiece" | |
| tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file) | |
| processor_text_fn = functools.partial( | |
| tokenization.preprocess_text, lower=FLAGS.do_lower_case) | |
| if FLAGS.tfds_params: | |
| processor = classifier_data_lib.TfdsProcessor( | |
| tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn) | |
| return classifier_data_lib.generate_tf_record_from_data_file( | |
| processor, | |
| None, | |
| tokenizer, | |
| train_data_output_path=FLAGS.train_data_output_path, | |
| eval_data_output_path=FLAGS.eval_data_output_path, | |
| test_data_output_path=FLAGS.test_data_output_path, | |
| max_seq_length=FLAGS.max_seq_length) | |
| else: | |
| processors = { | |
| "ax": | |
| classifier_data_lib.AxProcessor, | |
| "cola": | |
| classifier_data_lib.ColaProcessor, | |
| "imdb": | |
| classifier_data_lib.ImdbProcessor, | |
| "mnli": | |
| functools.partial( | |
| classifier_data_lib.MnliProcessor, mnli_type=FLAGS.mnli_type), | |
| "mrpc": | |
| classifier_data_lib.MrpcProcessor, | |
| "qnli": | |
| classifier_data_lib.QnliProcessor, | |
| "qqp": | |
| classifier_data_lib.QqpProcessor, | |
| "rte": | |
| classifier_data_lib.RteProcessor, | |
| "sst-2": | |
| classifier_data_lib.SstProcessor, | |
| "sts-b": | |
| classifier_data_lib.StsBProcessor, | |
| "xnli": | |
| functools.partial( | |
| classifier_data_lib.XnliProcessor, | |
| language=FLAGS.xnli_language), | |
| "paws-x": | |
| functools.partial( | |
| classifier_data_lib.PawsxProcessor, | |
| language=FLAGS.pawsx_language), | |
| "wnli": | |
| classifier_data_lib.WnliProcessor, | |
| "xtreme-xnli": | |
| functools.partial( | |
| classifier_data_lib.XtremeXnliProcessor, | |
| translated_data_dir=FLAGS.translated_input_data_dir, | |
| only_use_en_dev=FLAGS.only_use_en_dev), | |
| "xtreme-paws-x": | |
| functools.partial( | |
| classifier_data_lib.XtremePawsxProcessor, | |
| translated_data_dir=FLAGS.translated_input_data_dir, | |
| only_use_en_dev=FLAGS.only_use_en_dev), | |
| "ax-g": | |
| classifier_data_lib.AXgProcessor, | |
| "superglue-rte": | |
| classifier_data_lib.SuperGLUERTEProcessor, | |
| "cb": | |
| classifier_data_lib.CBProcessor, | |
| "boolq": | |
| classifier_data_lib.BoolQProcessor, | |
| "wic": | |
| classifier_data_lib.WnliProcessor, | |
| } | |
| task_name = FLAGS.classification_task_name.lower() | |
| if task_name not in processors: | |
| raise ValueError("Task not found: %s" % (task_name,)) | |
| processor = processors[task_name](process_text_fn=processor_text_fn) | |
| return classifier_data_lib.generate_tf_record_from_data_file( | |
| processor, | |
| FLAGS.input_data_dir, | |
| tokenizer, | |
| train_data_output_path=FLAGS.train_data_output_path, | |
| eval_data_output_path=FLAGS.eval_data_output_path, | |
| test_data_output_path=FLAGS.test_data_output_path, | |
| max_seq_length=FLAGS.max_seq_length) | |
| def generate_regression_dataset(): | |
| """Generates regression dataset and returns input meta data.""" | |
| if FLAGS.tokenization == "WordPiece": | |
| tokenizer = tokenization.FullTokenizer( | |
| vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) | |
| processor_text_fn = tokenization.convert_to_unicode | |
| else: | |
| assert FLAGS.tokenization == "SentencePiece" | |
| tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file) | |
| processor_text_fn = functools.partial( | |
| tokenization.preprocess_text, lower=FLAGS.do_lower_case) | |
| if FLAGS.tfds_params: | |
| processor = classifier_data_lib.TfdsProcessor( | |
| tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn) | |
| return classifier_data_lib.generate_tf_record_from_data_file( | |
| processor, | |
| None, | |
| tokenizer, | |
| train_data_output_path=FLAGS.train_data_output_path, | |
| eval_data_output_path=FLAGS.eval_data_output_path, | |
| test_data_output_path=FLAGS.test_data_output_path, | |
| max_seq_length=FLAGS.max_seq_length) | |
| else: | |
| raise ValueError("No data processor found for the given regression task.") | |
| def generate_squad_dataset(): | |
| """Generates squad training dataset and returns input meta data.""" | |
| assert FLAGS.squad_data_file | |
| if FLAGS.tokenization == "WordPiece": | |
| return squad_lib_wp.generate_tf_record_from_json_file( | |
| input_file_path=FLAGS.squad_data_file, | |
| vocab_file_path=FLAGS.vocab_file, | |
| output_path=FLAGS.train_data_output_path, | |
| translated_input_folder=FLAGS.translated_squad_data_folder, | |
| max_seq_length=FLAGS.max_seq_length, | |
| do_lower_case=FLAGS.do_lower_case, | |
| max_query_length=FLAGS.max_query_length, | |
| doc_stride=FLAGS.doc_stride, | |
| version_2_with_negative=FLAGS.version_2_with_negative, | |
| xlnet_format=FLAGS.xlnet_format) | |
| else: | |
| assert FLAGS.tokenization == "SentencePiece" | |
| return squad_lib_sp.generate_tf_record_from_json_file( | |
| input_file_path=FLAGS.squad_data_file, | |
| sp_model_file=FLAGS.sp_model_file, | |
| output_path=FLAGS.train_data_output_path, | |
| translated_input_folder=FLAGS.translated_squad_data_folder, | |
| max_seq_length=FLAGS.max_seq_length, | |
| do_lower_case=FLAGS.do_lower_case, | |
| max_query_length=FLAGS.max_query_length, | |
| doc_stride=FLAGS.doc_stride, | |
| xlnet_format=FLAGS.xlnet_format, | |
| version_2_with_negative=FLAGS.version_2_with_negative) | |
| def generate_retrieval_dataset(): | |
| """Generate retrieval test and dev dataset and returns input meta data.""" | |
| assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name) | |
| if FLAGS.tokenization == "WordPiece": | |
| tokenizer = tokenization.FullTokenizer( | |
| vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) | |
| processor_text_fn = tokenization.convert_to_unicode | |
| else: | |
| assert FLAGS.tokenization == "SentencePiece" | |
| tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file) | |
| processor_text_fn = functools.partial( | |
| tokenization.preprocess_text, lower=FLAGS.do_lower_case) | |
| processors = { | |
| "bucc": sentence_retrieval_lib.BuccProcessor, | |
| "tatoeba": sentence_retrieval_lib.TatoebaProcessor, | |
| } | |
| task_name = FLAGS.retrieval_task_name.lower() | |
| if task_name not in processors: | |
| raise ValueError("Task not found: %s" % task_name) | |
| processor = processors[task_name](process_text_fn=processor_text_fn) | |
| return sentence_retrieval_lib.generate_sentence_retrevial_tf_record( | |
| processor, FLAGS.input_data_dir, tokenizer, FLAGS.eval_data_output_path, | |
| FLAGS.test_data_output_path, FLAGS.max_seq_length) | |
| def generate_tagging_dataset(): | |
| """Generates tagging dataset.""" | |
| processors = { | |
| "panx": | |
| functools.partial( | |
| tagging_data_lib.PanxProcessor, | |
| only_use_en_train=FLAGS.tagging_only_use_en_train, | |
| only_use_en_dev=FLAGS.only_use_en_dev), | |
| "udpos": | |
| functools.partial( | |
| tagging_data_lib.UdposProcessor, | |
| only_use_en_train=FLAGS.tagging_only_use_en_train, | |
| only_use_en_dev=FLAGS.only_use_en_dev), | |
| } | |
| task_name = FLAGS.tagging_task_name.lower() | |
| if task_name not in processors: | |
| raise ValueError("Task not found: %s" % task_name) | |
| if FLAGS.tokenization == "WordPiece": | |
| tokenizer = tokenization.FullTokenizer( | |
| vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) | |
| processor_text_fn = tokenization.convert_to_unicode | |
| elif FLAGS.tokenization == "SentencePiece": | |
| tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file) | |
| processor_text_fn = functools.partial( | |
| tokenization.preprocess_text, lower=FLAGS.do_lower_case) | |
| else: | |
| raise ValueError("Unsupported tokenization: %s" % FLAGS.tokenization) | |
| processor = processors[task_name]() | |
| return tagging_data_lib.generate_tf_record_from_data_file( | |
| processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length, | |
| FLAGS.train_data_output_path, FLAGS.eval_data_output_path, | |
| FLAGS.test_data_output_path, processor_text_fn) | |
| def main(_): | |
| if FLAGS.tokenization == "WordPiece": | |
| if not FLAGS.vocab_file: | |
| raise ValueError( | |
| "FLAG vocab_file for word-piece tokenizer is not specified.") | |
| else: | |
| assert FLAGS.tokenization == "SentencePiece" | |
| if not FLAGS.sp_model_file: | |
| raise ValueError( | |
| "FLAG sp_model_file for sentence-piece tokenizer is not specified.") | |
| if FLAGS.fine_tuning_task_type != "retrieval": | |
| flags.mark_flag_as_required("train_data_output_path") | |
| if FLAGS.fine_tuning_task_type == "classification": | |
| input_meta_data = generate_classifier_dataset() | |
| elif FLAGS.fine_tuning_task_type == "regression": | |
| input_meta_data = generate_regression_dataset() | |
| elif FLAGS.fine_tuning_task_type == "retrieval": | |
| input_meta_data = generate_retrieval_dataset() | |
| elif FLAGS.fine_tuning_task_type == "squad": | |
| input_meta_data = generate_squad_dataset() | |
| else: | |
| assert FLAGS.fine_tuning_task_type == "tagging" | |
| input_meta_data = generate_tagging_dataset() | |
| tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path)) | |
| with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer: | |
| writer.write(json.dumps(input_meta_data, indent=4) + "\n") | |
| if __name__ == "__main__": | |
| flags.mark_flag_as_required("meta_data_file_path") | |
| app.run(main) | |