Spaces:
Sleeping
Sleeping
| # Copyright 2024 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Create LM TF examples for XLNet.""" | |
| import dataclasses | |
| import json | |
| import math | |
| import os | |
| import random | |
| from typing import Iterable, Mapping, List, Optional, Tuple | |
| import unicodedata | |
| # Import libraries | |
| from absl import app | |
| from absl import flags | |
| from absl import logging | |
| import numpy as np | |
| import tensorflow as tf, tf_keras | |
| from official.nlp.tools import tokenization | |
| special_symbols = { | |
| "<unk>": 0, | |
| "<s>": 1, | |
| "</s>": 2, | |
| "<cls>": 3, | |
| "<sep>": 4, | |
| "<pad>": 5, | |
| "<mask>": 6, | |
| "<eod>": 7, | |
| "<eop>": 8, | |
| } | |
| FLAGS = flags.FLAGS | |
| flags.DEFINE_integer("seq_length", 512, | |
| help="Sequence length.") | |
| flags.DEFINE_integer("reuse_length", 256, | |
| help="Number of token that can be reused as memory. " | |
| "Could be half of `seq_len`.") | |
| flags.DEFINE_string("input_file", None, | |
| "Input raw text file (or comma-separated list of files).") | |
| flags.DEFINE_string( | |
| "save_dir", None, | |
| "Directory for saving processed data.") | |
| flags.DEFINE_string("sp_model_file", "", | |
| "The path to the model used by sentence piece tokenizer.") | |
| flags.DEFINE_bool("use_eod_token", True, | |
| "Whether or not to include EOD tokens.") | |
| flags.DEFINE_bool("bi_data", True, "Whether or not to use bi-directional data.") | |
| flags.DEFINE_bool( | |
| "do_lower_case", True, | |
| "Whether to lower case the input text. Should be True for uncased " | |
| "models and False for cased models.") | |
| flags.DEFINE_integer("per_host_batch_size", 32, "Batch size per host.") | |
| flags.DEFINE_integer("num_cores_per_host", 16, | |
| "The number of (TPU) cores per host.") | |
| flags.DEFINE_string("prefix", "", "Filename prefix.") | |
| flags.DEFINE_string("suffix", "", "Filename suffix.") | |
| flags.DEFINE_integer("task_id", None, | |
| "The id of the current task.") | |
| flags.DEFINE_integer("num_tasks", None, | |
| "The total number of tasks.") | |
| flags.DEFINE_integer("num_passes", 1, "The number of times to run the script.") | |
| class TrainingInstance: | |
| """Representation of a single XLNet Pretraining instance.""" | |
| data: Iterable[int] | |
| segment_ids: Iterable[int] | |
| boundary_indices: Iterable[int] | |
| label: int | |
| def to_feature(self) -> Mapping[str, tf.train.Feature]: | |
| feat = lambda x: tf.train.Feature(int64_list=tf.train.Int64List(value=x)) | |
| return dict( | |
| input_word_ids=feat(self.data), | |
| input_type_ids=feat(self.segment_ids), | |
| boundary_indices=feat(self.boundary_indices), | |
| label=feat([self.label])) | |
| def to_example(self) -> tf.train.Example: | |
| return tf.train.Example( | |
| features=tf.train.Features(feature=self.to_feature())) | |
| def __str__(self): | |
| def seq_to_str(seq): | |
| return " ".join([str(x) for x in seq]) | |
| s = "" | |
| s += "tokens: %s\n" % seq_to_str(self.data) | |
| s += "segment_ids: %s\n" % seq_to_str(self.segment_ids) | |
| s += "boundary_indices: %s\n" % seq_to_str(self.boundary_indices) | |
| s += "label: %s\n" % self.label | |
| s += "\n" | |
| return s | |
| def __repr__(self): | |
| return self.__str__() | |
| def _preprocess_line(line: str, do_lower_case: bool = False) -> str: | |
| """Preprocesses an individual raw text line. | |
| This function will: | |
| - Remove extraneous spaces. | |
| - Replace `` with ", and '' with ". | |
| - Replaces accents. | |
| - Applies lower casing. | |
| Args: | |
| line: The input line to preprocess. | |
| do_lower_case: Whether or not to lower case the text. | |
| Returns: | |
| The preprocessed line. | |
| """ | |
| line = " ".join(line.split()) | |
| line = line.replace("``", "\"").replace("''", "\"") | |
| # Replace accents. | |
| line = unicodedata.normalize("NFKD", line) | |
| line = "".join([c for c in line if not unicodedata.combining(c)]) | |
| if do_lower_case: | |
| line = line.lower() | |
| return line | |
| def preprocess_and_tokenize_input_files( | |
| input_files: Iterable[str], | |
| tokenizer: tokenization.FullSentencePieceTokenizer, | |
| use_eod: bool = True, | |
| do_lower_case: bool = False, | |
| log_example_freq: int = 100000) -> List[Tuple[np.array, np.array]]: | |
| """Preprocesses and encodes raw text from input files. | |
| This function preprocesses raw text and encodes them into tokens using a | |
| `SentencePieceModel` tokenization method. This also provides the sentence | |
| indicator for each token. | |
| Args: | |
| input_files: The list of input file names. | |
| tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`. | |
| use_eod: Whether or not to use an EOD indicator. If `False`, then EOD is | |
| not included. | |
| do_lower_case: Whether or not to apply lower casing during raw text | |
| preprocessing. | |
| log_example_freq: The optional field for how many lines to process before | |
| emitting an info log. | |
| Returns: | |
| The preprocessed list. Each entry in the list is a tuple consisting of | |
| the token IDs and the sentence IDs. | |
| """ | |
| all_data = [] | |
| eod_symbol = special_symbols["<eod>"] | |
| total_number_of_lines = 0 | |
| # Input file format: | |
| # (1) One sentence per line. These should ideally be actual sentences, not | |
| # entire paragraphs or arbitrary spans of text. (Because we use the | |
| # sentence boundaries for the "next sentence prediction" task). | |
| # (2) Blank lines between documents. Document boundaries are needed so | |
| # that the "next sentence prediction" task doesn't span between documents. | |
| for input_file in input_files: | |
| line_count = 0 | |
| logging.info("Preprocessing %s", input_file) | |
| all_tokens = [] | |
| all_sentence_ids = [] | |
| sentence_id = True | |
| with tf.io.gfile.GFile(input_file, "rb") as reader: | |
| while True: | |
| line = tokenization.convert_to_unicode(reader.readline()) | |
| if not line: | |
| break | |
| line_count += 1 | |
| if line_count % log_example_freq == 0: | |
| logging.info("Loading line %d", line_count) | |
| line = line.strip() | |
| if not line: | |
| if use_eod: | |
| token_ids = [eod_symbol] | |
| sentence_id = not sentence_id | |
| else: | |
| continue | |
| else: | |
| preprocessed_line = _preprocess_line( | |
| line=line, do_lower_case=do_lower_case) | |
| token_ids = tokenization.encode_ids( | |
| sp_model=tokenizer.sp_model, text=preprocessed_line) | |
| all_tokens.extend(token_ids) | |
| all_sentence_ids.extend([sentence_id] * len(token_ids)) | |
| sentence_id = not sentence_id | |
| logging.info("Finished processing %s. Number of lines: %d", | |
| input_file, line_count) | |
| if line_count == 0: | |
| continue | |
| total_number_of_lines += line_count | |
| all_tokens = np.array(all_tokens, dtype=np.int64) | |
| all_sentence_ids = np.array(all_sentence_ids, dtype=bool) | |
| all_data.append((all_tokens, all_sentence_ids)) | |
| logging.info("Completed text preprocessing. Total number of lines: %d", | |
| total_number_of_lines) | |
| return all_data | |
| def _reshape_to_batch_dimensions( | |
| tokens: np.array, | |
| sentence_ids: np.array, | |
| per_host_batch_size: int) -> Tuple[np.array, np.array]: | |
| """Truncates and reshapes input data with a batch major dimension. | |
| Args: | |
| tokens: The input token ids. This should have the same shape as | |
| `sentence_ids`. | |
| sentence_ids: The input sentence ids. This should have the same shape as | |
| `token_ids`. | |
| per_host_batch_size: The target per-host batch size. | |
| Returns: | |
| The tuple of reshaped tokens and sentence_ids. | |
| """ | |
| num_steps = len(tokens) // per_host_batch_size | |
| truncated_data_length = num_steps * per_host_batch_size | |
| logging.info("per_host_batch_size: %d", per_host_batch_size) | |
| logging.info("num_steps: %d", num_steps) | |
| def truncate_and_reshape(a): | |
| return a[:truncated_data_length].reshape((per_host_batch_size, num_steps)) | |
| return (truncate_and_reshape(tokens), truncate_and_reshape(sentence_ids)) | |
| def _create_a_and_b_segments( | |
| tokens: np.array, | |
| sentence_ids: np.array, | |
| begin_index: int, | |
| total_length: int, | |
| no_cut_probability: float = 0.5): | |
| """Splits segments A and B from a single instance of tokens and sentence ids. | |
| Args: | |
| tokens: The 1D input token ids. This represents an individual entry within a | |
| batch. | |
| sentence_ids: The 1D input sentence ids. This represents an individual entry | |
| within a batch. This should be the same length as `tokens`. | |
| begin_index: The reference beginning index to split data. | |
| total_length: The target combined length of segments A and B. | |
| no_cut_probability: The probability of not cutting a segment despite | |
| a cut possibly existing. | |
| Returns: | |
| A tuple consisting of A data, B data, and label. | |
| """ | |
| data_length = tokens.shape[0] | |
| if begin_index + total_length >= data_length: | |
| logging.info("[_create_segments]: begin_index %d + total_length %d >= " | |
| "data_length %d", begin_index, total_length, data_length) | |
| return None | |
| end_index = begin_index + 1 | |
| cut_indices = [] | |
| # Identify all indices where sentence IDs change from one to the next. | |
| while end_index < data_length: | |
| if sentence_ids[end_index] != sentence_ids[end_index - 1]: | |
| if end_index - begin_index >= total_length: | |
| break | |
| cut_indices.append(end_index) | |
| end_index += 1 | |
| a_begin = begin_index | |
| if not cut_indices or random.random() < no_cut_probability: | |
| # Segments A and B are contained within the same sentence. | |
| label = 0 | |
| if not cut_indices: | |
| a_end = end_index | |
| else: | |
| a_end = random.choice(cut_indices) | |
| b_length = max(1, total_length - (a_end - a_begin)) | |
| b_begin = random.randint(0, data_length - 1 - b_length) | |
| b_end = b_begin + b_length | |
| while b_begin > 0 and sentence_ids[b_begin - 1] == sentence_ids[b_begin]: | |
| b_begin -= 1 | |
| while (b_end < data_length - 1 and | |
| sentence_ids[b_end - 1] == sentence_ids[b_end]): | |
| b_end += 1 | |
| else: | |
| # Segments A and B are different sentences. | |
| label = 1 | |
| a_end = random.choice(cut_indices) | |
| b_begin = a_end | |
| b_end = end_index | |
| while a_end - a_begin + b_end - b_begin > total_length: | |
| if a_end - a_begin > b_end - b_begin: | |
| # Delete only the right side for the LM objective. | |
| a_end -= 1 | |
| else: | |
| b_end -= 1 | |
| if a_end >= data_length or b_end >= data_length: | |
| logging.info("[_create_segments]: a_end %d or b_end %d >= data_length %d", | |
| a_end, b_end, data_length) | |
| return None | |
| a_data = tokens[a_begin: a_end] | |
| b_data = tokens[b_begin: b_end] | |
| return a_data, b_data, label | |
| def _is_functional_piece(piece: str) -> bool: | |
| return piece != "<unk>" and piece.startswith("<") and piece.endswith(">") | |
| def _is_start_piece(piece: str) -> bool: | |
| special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~')) | |
| if (piece.startswith("▁") or piece in special_pieces): | |
| return True | |
| else: | |
| return False | |
| def _get_boundary_indices( | |
| data: np.array, | |
| tokenizer: tokenization.FullSentencePieceTokenizer) -> np.array: | |
| """Gets the boundary indices of whole words.""" | |
| seq_length = len(data) | |
| boundary_indices = [] | |
| for index, piece in enumerate(tokenizer.convert_ids_to_tokens(data.tolist())): | |
| if _is_start_piece(piece) and not _is_functional_piece(piece): | |
| boundary_indices.append(index) | |
| boundary_indices.append(seq_length) | |
| return boundary_indices | |
| def _convert_tokens_to_instances( | |
| tokens: np.array, | |
| sentence_ids: np.array, | |
| per_host_batch_size: int, | |
| seq_length: int, | |
| reuse_length: int, | |
| bi_data: bool, | |
| tokenizer: tokenization.FullSentencePieceTokenizer, | |
| num_cores_per_host: int = 0, | |
| logging_frequency: int = 500) -> List[TrainingInstance]: | |
| """Converts tokens and sentence IDs into individual training instances. | |
| The format of data in the XLNet pretraining task is very similar to the | |
| BERT pretraining task. Two segments A and B are randomly sampled, and the | |
| contatenation of A and B into a single sequence is used to perform | |
| language modeling. | |
| To create an XLNet Pretraining instance from a single long sequence, S: | |
| - Create a segment of length `reuse_length`. This first segment represents | |
| past tokens. During modeling, this segment is used to cache obtained | |
| content representations for the segment recurrence mechanism. | |
| - Similar to BERT, create a segment of length `seq_length` - `reuse_length` | |
| composed of A and B segments. | |
| For XLNet, the order is "A", "SEP", "B", "SEP", "CLS". | |
| Args: | |
| tokens: All tokens concatenated into a single list. | |
| sentence_ids: All sentence IDs concatenated into a single list. | |
| per_host_batch_size: The target batch size per host. | |
| seq_length: The max sequence length. | |
| reuse_length: The number of tokens to use from the previous segment. | |
| bi_data: Whether or not to use bidirectional data. | |
| tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`. | |
| num_cores_per_host: The number of cores per host. This is required if | |
| `bi_data` = `True`. | |
| logging_frequency: The frequency at which to log status updates. | |
| Returns: | |
| A list of `TrainingInstance` objects. | |
| """ | |
| instances = [] | |
| per_core_batch_size = (per_host_batch_size // num_cores_per_host | |
| if bi_data else None) | |
| if bi_data: | |
| logging.info("Bi-directional data enabled.") | |
| assert per_host_batch_size % (2 * num_cores_per_host) == 0 | |
| forward_tokens, forward_sentence_ids = _reshape_to_batch_dimensions( | |
| tokens=tokens, | |
| sentence_ids=sentence_ids, | |
| per_host_batch_size=per_host_batch_size // 2) | |
| forward_data_shape = (num_cores_per_host, 1, per_core_batch_size // 2, -1) | |
| forward_tokens = forward_tokens.reshape(forward_data_shape) | |
| forward_sentence_ids = forward_sentence_ids.reshape(forward_data_shape) | |
| backwards_tokens = forward_tokens[:, :, :, ::-1] | |
| backwards_sentence_ids = forward_sentence_ids[:, :, :, ::-1] | |
| tokens = np.concatenate([forward_tokens, backwards_tokens], 1).reshape( | |
| per_host_batch_size, -1) | |
| sentence_ids = np.concatenate( | |
| [forward_sentence_ids, backwards_sentence_ids]).reshape( | |
| per_host_batch_size, -1) | |
| else: | |
| logging.info("Bi-directional data disabled.") | |
| tokens, sentence_ids = _reshape_to_batch_dimensions( | |
| tokens=tokens, | |
| sentence_ids=sentence_ids, | |
| per_host_batch_size=per_host_batch_size) | |
| logging.info("Tokens shape: %s", tokens.shape) | |
| data_length = tokens.shape[1] | |
| sep = np.array([special_symbols["<sep>"]], dtype=np.int64) | |
| cls = np.array([special_symbols["<cls>"]], dtype=np.int64) | |
| # 2 sep, 1 cls | |
| num_special_tokens = 3 | |
| data_index = 0 | |
| batch_number = 0 | |
| step_size = reuse_length if reuse_length else seq_length | |
| num_batches = math.ceil(data_length / step_size) | |
| while data_index + seq_length <= data_length: | |
| if batch_number % logging_frequency == 0: | |
| logging.info("Processing batch %d of %d", batch_number, num_batches) | |
| for batch_index in range(per_host_batch_size): | |
| previous_segment_tokens = tokens[ | |
| batch_index, data_index: data_index + reuse_length] | |
| results = _create_a_and_b_segments( | |
| tokens=tokens[batch_index], | |
| sentence_ids=sentence_ids[batch_index], | |
| begin_index=data_index + reuse_length, | |
| total_length=seq_length - reuse_length - num_special_tokens) | |
| if results is None: | |
| logging.info("Stopping at data index: %d", data_index) | |
| break | |
| a_data, b_data, label = results | |
| data = np.concatenate( | |
| [previous_segment_tokens, a_data, sep, b_data, sep, cls]) | |
| a_length = a_data.shape[0] | |
| b_length = b_data.shape[0] | |
| segment_ids = ([0] * (reuse_length + a_length) + [0] | |
| + [1] * b_length + [1] + [2]) | |
| boundary_indices = _get_boundary_indices(tokenizer=tokenizer, | |
| data=data) | |
| assert len(data) == seq_length | |
| assert len(segment_ids) == seq_length | |
| assert len(boundary_indices) > 0 # pylint: disable=g-explicit-length-test | |
| instances.append(TrainingInstance( | |
| data=data, | |
| segment_ids=segment_ids, | |
| boundary_indices=boundary_indices, | |
| label=label)) | |
| batch_number += 1 | |
| data_index += step_size | |
| return instances | |
| def write_instances_to_tfrecord( | |
| instances: Iterable[TrainingInstance], | |
| save_path: str): | |
| """Writes instances to TFRecord.""" | |
| record_writer = tf.io.TFRecordWriter(save_path) | |
| logging.info("Start writing to %s.", save_path) | |
| for i, instance in enumerate(instances): | |
| if i < 5: | |
| logging.info("Instance %d: %s", i, str(instance)) | |
| record_writer.write(instance.to_example().SerializeToString()) | |
| record_writer.close() | |
| logging.info("Done writing %s.", save_path) | |
| def shuffle_and_combine_preprocessed_data( | |
| all_data: List[Tuple[np.array, np.array]]) -> Tuple[np.array, np.array]: | |
| """Shuffles and combines preprocessed token/sentence IDs from documents.""" | |
| document_permutation = np.random.permutation(len(all_data)) | |
| previous_sentence_id = None | |
| all_tokens, all_sentence_ids = [], [] | |
| for document_index in document_permutation: | |
| tokens, sentence_ids = all_data[document_index] | |
| # pylint: disable=g-explicit-length-test | |
| if len(tokens) == 0: | |
| continue | |
| if (previous_sentence_id is not None and | |
| sentence_ids[0] == previous_sentence_id): | |
| sentence_ids = np.logical_not(sentence_ids) | |
| all_tokens.append(tokens) | |
| all_sentence_ids.append(sentence_ids) | |
| previous_sentence_id = sentence_ids[-1] | |
| return np.concatenate(all_tokens), np.concatenate(all_sentence_ids) | |
| def get_tfrecord_name( | |
| per_host_batch_size: int, | |
| num_cores_per_host: int, | |
| seq_length: int, | |
| bi_data: bool, | |
| reuse_length: int, | |
| do_lower_case: bool, | |
| use_eod_token: bool, | |
| prefix: str = "", | |
| suffix: str = "", | |
| pass_id: int = 0, | |
| num_passes: int = 1, | |
| task_id: int = None, | |
| num_tasks: int = None) -> str: | |
| """Formats the resulting TFRecord name based on provided inputs.""" | |
| components = [] | |
| if prefix: | |
| components.append(prefix) | |
| components.append("seqlen-{}".format(seq_length)) | |
| if reuse_length == 0: | |
| components.append("memless") | |
| else: | |
| components.append("reuse-{}".format(reuse_length)) | |
| components.append("bs-{}".format(per_host_batch_size)) | |
| components.append("cores-{}".format(num_cores_per_host)) | |
| if do_lower_case: | |
| components.append("uncased") | |
| else: | |
| components.append("cased") | |
| if use_eod_token: | |
| components.append("eod") | |
| if bi_data: | |
| components.append("bi") | |
| else: | |
| components.append("uni") | |
| if suffix: | |
| components.append(suffix) | |
| s = "_".join(components) + ".tfrecord" | |
| if num_passes == 1 and task_id is None: | |
| return s | |
| if task_id is None: | |
| num_tasks = 1 | |
| task_id = 0 | |
| current_shard = task_id * num_passes + pass_id | |
| total_shards = num_tasks * num_passes | |
| return s + "-{}-of-{}".format(current_shard, total_shards) | |
| def create_tfrecords( | |
| tokenizer: tokenization.FullSentencePieceTokenizer, | |
| input_file_or_files: str, | |
| use_eod_token: bool, | |
| do_lower_case: bool, | |
| per_host_batch_size: int, | |
| seq_length: int, | |
| reuse_length: int, | |
| bi_data: bool, | |
| num_cores_per_host: int, | |
| save_dir: str, | |
| prefix: str = "", | |
| suffix: str = "", | |
| num_tasks: Optional[int] = None, | |
| task_id: Optional[int] = None, | |
| num_passes: int = 1): | |
| """Runs the end-to-end preprocessing pipeline.""" | |
| logging.info("Input configuration:") | |
| logging.info("input file(s): %s", input_file_or_files) | |
| logging.info("use_eod_token: %s", use_eod_token) | |
| logging.info("do_lower_case: %s", do_lower_case) | |
| logging.info("per_host_batch_size: %d", per_host_batch_size) | |
| logging.info("seq_length: %d", seq_length) | |
| logging.info("reuse_length: %d", reuse_length) | |
| logging.info("bi_data: %s", bi_data) | |
| logging.info("num_cores_per_host: %d", num_cores_per_host) | |
| logging.info("save_dir: %s", save_dir) | |
| if task_id is not None and num_tasks is not None: | |
| logging.info("task_id: %d", task_id) | |
| logging.info("num_tasks: %d", num_tasks) | |
| input_files = [] | |
| for input_pattern in input_file_or_files.split(","): | |
| input_files.extend(tf.io.gfile.glob(input_pattern)) | |
| logging.info("*** Reading from input files ***") | |
| for input_file in input_files: | |
| logging.info(" %s", input_file) | |
| logging.info("Shuffling the files with a fixed random seed.") | |
| np.random.shuffle(input_files) | |
| if num_tasks is not None: | |
| assert task_id is not None | |
| logging.info("Total number of input files: %d", len(input_files)) | |
| logging.info("Splitting into %d shards of %d files each.", | |
| num_tasks, len(input_files) // num_tasks) | |
| input_files = input_files[task_id::num_tasks] | |
| all_data = preprocess_and_tokenize_input_files( | |
| input_files=input_files, | |
| tokenizer=tokenizer, | |
| use_eod=use_eod_token, | |
| do_lower_case=do_lower_case) | |
| for pass_id in range(num_passes): | |
| logging.info("Beginning pass %d of %d", pass_id, num_passes) | |
| tokens, sentence_ids = shuffle_and_combine_preprocessed_data(all_data) | |
| assert len(tokens) == len(sentence_ids) | |
| filename = get_tfrecord_name( | |
| per_host_batch_size=per_host_batch_size, | |
| num_cores_per_host=num_cores_per_host, | |
| seq_length=seq_length, | |
| bi_data=bi_data, | |
| use_eod_token=use_eod_token, | |
| reuse_length=reuse_length, | |
| do_lower_case=do_lower_case, | |
| prefix=prefix, | |
| suffix=suffix, | |
| pass_id=pass_id, | |
| num_passes=num_passes, | |
| num_tasks=num_tasks, | |
| task_id=task_id) | |
| save_path = os.path.join(save_dir, filename) | |
| if os.path.exists(save_path): | |
| # If the path already exists, then we were probably preempted but | |
| # previously wrote this file. | |
| logging.info("%s already exists, skipping this batch.", save_path) | |
| else: | |
| instances = _convert_tokens_to_instances( | |
| tokenizer=tokenizer, | |
| tokens=tokens, | |
| sentence_ids=sentence_ids, | |
| per_host_batch_size=per_host_batch_size, | |
| seq_length=seq_length, | |
| reuse_length=reuse_length, | |
| bi_data=bi_data, | |
| num_cores_per_host=num_cores_per_host) | |
| write_instances_to_tfrecord(instances=instances, save_path=save_path) | |
| if task_id is None or task_id == 0: | |
| corpus_info = { | |
| "vocab_size": 32000, | |
| "per_host_batch_size": per_host_batch_size, | |
| "num_cores_per_host": num_cores_per_host, | |
| "seq_length": seq_length, | |
| "reuse_length": reuse_length, | |
| "do_lower_case": do_lower_case, | |
| "bi_data": bi_data, | |
| "use_eod_token": use_eod_token, | |
| } | |
| corpus_fname = os.path.basename(filename) + ".json" | |
| corpus_destination = os.path.join(save_dir, corpus_fname) | |
| logging.info("Saving corpus info to %s", corpus_destination) | |
| with tf.io.gfile.GFile(corpus_destination, "w") as fp: | |
| json.dump(corpus_info, fp) | |
| def main(_): | |
| tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file) | |
| create_tfrecords( | |
| tokenizer=tokenizer, | |
| input_file_or_files=FLAGS.input_file, | |
| use_eod_token=FLAGS.use_eod_token, | |
| do_lower_case=FLAGS.do_lower_case, | |
| per_host_batch_size=FLAGS.per_host_batch_size, | |
| seq_length=FLAGS.seq_length, | |
| reuse_length=FLAGS.reuse_length, | |
| bi_data=FLAGS.bi_data, | |
| num_cores_per_host=FLAGS.num_cores_per_host, | |
| save_dir=FLAGS.save_dir, | |
| prefix=FLAGS.prefix, | |
| suffix=FLAGS.suffix, | |
| num_tasks=FLAGS.num_tasks, | |
| task_id=FLAGS.task_id, | |
| num_passes=FLAGS.num_passes) | |
| if __name__ == "__main__": | |
| np.random.seed(0) | |
| logging.set_verbosity(logging.INFO) | |
| app.run(main) | |