Spaces:
Sleeping
Sleeping
| # Copyright 2024 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """BERT library to process data for cross lingual sentence retrieval task.""" | |
| import os | |
| from absl import logging | |
| from official.nlp.data import classifier_data_lib | |
| from official.nlp.tools import tokenization | |
| class BuccProcessor(classifier_data_lib.DataProcessor): | |
| """Procssor for Xtreme BUCC data set.""" | |
| supported_languages = ["de", "fr", "ru", "zh"] | |
| def __init__(self, process_text_fn=tokenization.convert_to_unicode): | |
| super(BuccProcessor, self).__init__(process_text_fn) | |
| self.languages = BuccProcessor.supported_languages | |
| def get_dev_examples(self, data_dir, file_pattern): | |
| return self._create_examples( | |
| self._read_tsv(os.path.join(data_dir, file_pattern.format("dev"))), | |
| "sample") | |
| def get_test_examples(self, data_dir, file_pattern): | |
| return self._create_examples( | |
| self._read_tsv(os.path.join(data_dir, file_pattern.format("test"))), | |
| "test") | |
| def get_processor_name(): | |
| """See base class.""" | |
| return "BUCC" | |
| def _create_examples(self, lines, set_type): | |
| """Creates examples for the training and dev sets.""" | |
| examples = [] | |
| for (i, line) in enumerate(lines): | |
| guid = "%s-%s" % (set_type, i) | |
| example_id = int(line[0].split("-")[1]) | |
| text_a = self.process_text_fn(line[1]) | |
| examples.append( | |
| classifier_data_lib.InputExample( | |
| guid=guid, text_a=text_a, example_id=example_id)) | |
| return examples | |
| class TatoebaProcessor(classifier_data_lib.DataProcessor): | |
| """Procssor for Xtreme Tatoeba data set.""" | |
| supported_languages = [ | |
| "af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr", | |
| "he", "hi", "hu", "id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr", | |
| "nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh" | |
| ] | |
| def __init__(self, process_text_fn=tokenization.convert_to_unicode): | |
| super(TatoebaProcessor, self).__init__(process_text_fn) | |
| self.languages = TatoebaProcessor.supported_languages | |
| def get_test_examples(self, data_dir, file_path): | |
| return self._create_examples( | |
| self._read_tsv(os.path.join(data_dir, file_path)), "test") | |
| def get_processor_name(): | |
| """See base class.""" | |
| return "TATOEBA" | |
| def _create_examples(self, lines, set_type): | |
| """Creates examples for the training and dev sets.""" | |
| examples = [] | |
| for (i, line) in enumerate(lines): | |
| guid = "%s-%s" % (set_type, i) | |
| text_a = self.process_text_fn(line[0]) | |
| examples.append( | |
| classifier_data_lib.InputExample( | |
| guid=guid, text_a=text_a, example_id=i)) | |
| return examples | |
| def generate_sentence_retrevial_tf_record(processor, | |
| data_dir, | |
| tokenizer, | |
| eval_data_output_path=None, | |
| test_data_output_path=None, | |
| max_seq_length=128): | |
| """Generates the tf records for retrieval tasks. | |
| Args: | |
| processor: Input processor object to be used for generating data. Subclass | |
| of `DataProcessor`. | |
| data_dir: Directory that contains train/eval data to process. Data files | |
| should be in from. | |
| tokenizer: The tokenizer to be applied on the data. | |
| eval_data_output_path: Output to which processed tf record for evaluation | |
| will be saved. | |
| test_data_output_path: Output to which processed tf record for testing | |
| will be saved. Must be a pattern template with {} if processor has | |
| language specific test data. | |
| max_seq_length: Maximum sequence length of the to be generated | |
| training/eval data. | |
| Returns: | |
| A dictionary containing input meta data. | |
| """ | |
| assert eval_data_output_path or test_data_output_path | |
| if processor.get_processor_name() == "BUCC": | |
| path_pattern = "{}-en.{{}}.{}" | |
| if processor.get_processor_name() == "TATOEBA": | |
| path_pattern = "{}-en.{}" | |
| meta_data = { | |
| "processor_type": processor.get_processor_name(), | |
| "max_seq_length": max_seq_length, | |
| "number_eval_data": {}, | |
| "number_test_data": {}, | |
| } | |
| logging.info("Start to process %s task data", processor.get_processor_name()) | |
| for lang_a in processor.languages: | |
| for lang_b in [lang_a, "en"]: | |
| if eval_data_output_path: | |
| eval_input_data_examples = processor.get_dev_examples( | |
| data_dir, os.path.join(path_pattern.format(lang_a, lang_b))) | |
| num_eval_data = len(eval_input_data_examples) | |
| logging.info("Processing %d dev examples of %s-en.%s", num_eval_data, | |
| lang_a, lang_b) | |
| output_file = os.path.join( | |
| eval_data_output_path, | |
| "{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "dev")) | |
| classifier_data_lib.file_based_convert_examples_to_features( | |
| eval_input_data_examples, None, max_seq_length, tokenizer, | |
| output_file, None) | |
| meta_data["number_eval_data"][f"{lang_a}-en.{lang_b}"] = num_eval_data | |
| if test_data_output_path: | |
| test_input_data_examples = processor.get_test_examples( | |
| data_dir, os.path.join(path_pattern.format(lang_a, lang_b))) | |
| num_test_data = len(test_input_data_examples) | |
| logging.info("Processing %d test examples of %s-en.%s", num_test_data, | |
| lang_a, lang_b) | |
| output_file = os.path.join( | |
| test_data_output_path, | |
| "{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "test")) | |
| classifier_data_lib.file_based_convert_examples_to_features( | |
| test_input_data_examples, None, max_seq_length, tokenizer, | |
| output_file, None) | |
| meta_data["number_test_data"][f"{lang_a}-en.{lang_b}"] = num_test_data | |
| return meta_data | |