Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Script to pre-process pre-training data into tfrecords.""" | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import json | |
| import os | |
| import random | |
| from absl import app | |
| from absl import flags | |
| import absl.logging as _logging # pylint: disable=unused-import | |
| import numpy as np | |
| import tensorflow.google as tf | |
| from official.nlp.xlnet import preprocess_utils | |
| import sentencepiece as spm | |
| special_symbols = { | |
| "<unk>" : 0, | |
| "<s>" : 1, | |
| "</s>" : 2, | |
| "<cls>" : 3, | |
| "<sep>" : 4, | |
| "<pad>" : 5, | |
| "<mask>" : 6, | |
| "<eod>" : 7, | |
| "<eop>" : 8, | |
| } | |
| VOCAB_SIZE = 32000 | |
| UNK_ID = special_symbols["<unk>"] | |
| CLS_ID = special_symbols["<cls>"] | |
| SEP_ID = special_symbols["<sep>"] | |
| MASK_ID = special_symbols["<mask>"] | |
| EOD_ID = special_symbols["<eod>"] | |
| def _int64_feature(values): | |
| return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) | |
| def _float_feature(values): | |
| return tf.train.Feature(float_list=tf.train.FloatList(value=values)) | |
| def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix, | |
| mask_alpha=5, mask_beta=1, reuse_len=None, uncased=False, | |
| fixed_num_predict=None): | |
| """docs.""" | |
| if reuse_len is None: | |
| reuse_len_str = "" | |
| else: | |
| reuse_len_str = "reuse-{}.".format(reuse_len) | |
| if not uncased: | |
| uncased_str = "" | |
| else: | |
| uncased_str = "uncased." | |
| if bi_data: | |
| bi_data_str = "bi" | |
| else: | |
| bi_data_str = "uni" | |
| if fixed_num_predict is not None: | |
| fnp_str = "fnp-{}.".format(fixed_num_predict) | |
| else: | |
| fnp_str = "" | |
| file_name = "{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}".format( | |
| prefix, bsz_per_host, seq_len, reuse_len_str, uncased_str, bi_data_str, | |
| mask_alpha, mask_beta, fnp_str, suffix) | |
| return file_name | |
| def _create_data(idx, input_paths): | |
| # Load sentence-piece model | |
| sp = spm.SentencePieceProcessor() | |
| sp.Load(FLAGS.sp_path) | |
| input_shards = [] | |
| total_line_cnt = 0 | |
| for input_path in input_paths: | |
| input_data, sent_ids = [], [] | |
| sent_id, line_cnt = True, 0 | |
| tf.logging.info("Processing %s", input_path) | |
| for line in tf.gfile.Open(input_path): | |
| if line_cnt % 100000 == 0: | |
| tf.logging.info("Loading line %d", line_cnt) | |
| line_cnt += 1 | |
| if not line.strip(): | |
| if FLAGS.use_eod: | |
| sent_id = not sent_id | |
| cur_sent = [EOD_ID] | |
| else: | |
| continue | |
| else: | |
| if FLAGS.from_raw_text: | |
| cur_sent = preprocess_utils.preprocess_text( | |
| line.strip(), lower=FLAGS.uncased) | |
| cur_sent = preprocess_utils.encode_ids(sp, cur_sent) | |
| else: | |
| cur_sent = list(map(int, line.strip().split())) | |
| input_data.extend(cur_sent) | |
| sent_ids.extend([sent_id] * len(cur_sent)) | |
| sent_id = not sent_id | |
| tf.logging.info("Finish with line %d", line_cnt) | |
| if line_cnt == 0: | |
| continue | |
| input_data = np.array(input_data, dtype=np.int64) | |
| sent_ids = np.array(sent_ids, dtype=np.bool) | |
| total_line_cnt += line_cnt | |
| input_shards.append((input_data, sent_ids)) | |
| tf.logging.info("[Task %d] Total number line: %d", idx, total_line_cnt) | |
| tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords") | |
| filenames, num_batch = [], 0 | |
| # Randomly shuffle input shards (with a fixed but distinct random seed) | |
| np.random.seed(100 * FLAGS.task + FLAGS.pass_id) | |
| perm_indices = np.random.permutation(len(input_shards)) | |
| tf.logging.info("Using perm indices %s for pass %d", | |
| perm_indices.tolist(), FLAGS.pass_id) | |
| input_data_list, sent_ids_list = [], [] | |
| prev_sent_id = None | |
| for perm_idx in perm_indices: | |
| input_data, sent_ids = input_shards[perm_idx] | |
| # make sure the `send_ids[0] == not prev_sent_id` | |
| if prev_sent_id is not None and sent_ids[0] == prev_sent_id: | |
| sent_ids = np.logical_not(sent_ids) | |
| # append to temporary list | |
| input_data_list.append(input_data) | |
| sent_ids_list.append(sent_ids) | |
| # update `prev_sent_id` | |
| prev_sent_id = sent_ids[-1] | |
| input_data = np.concatenate(input_data_list) | |
| sent_ids = np.concatenate(sent_ids_list) | |
| file_name, cur_num_batch = create_tfrecords( | |
| save_dir=tfrecord_dir, | |
| basename="{}-{}-{}".format(FLAGS.split, idx, FLAGS.pass_id), | |
| data=[input_data, sent_ids], | |
| bsz_per_host=FLAGS.bsz_per_host, | |
| seq_len=FLAGS.seq_len, | |
| bi_data=FLAGS.bi_data, | |
| sp=sp, | |
| ) | |
| filenames.append(file_name) | |
| num_batch += cur_num_batch | |
| record_info = { | |
| "filenames": filenames, | |
| "num_batch": num_batch | |
| } | |
| return record_info | |
| def create_data(_): | |
| # Validate FLAGS | |
| assert FLAGS.bsz_per_host % FLAGS.num_core_per_host == 0 | |
| if not FLAGS.use_tpu: | |
| FLAGS.num_core_per_host = 1 # forced to be one | |
| # Make workdirs | |
| if not tf.gfile.Exists(FLAGS.save_dir): | |
| tf.gfile.MakeDirs(FLAGS.save_dir) | |
| tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords") | |
| if not tf.gfile.Exists(tfrecord_dir): | |
| tf.gfile.MakeDirs(tfrecord_dir) | |
| # Create and dump corpus_info from task 0 | |
| if FLAGS.task == 0 and FLAGS.pass_id == 0: | |
| corpus_info = { | |
| "vocab_size": VOCAB_SIZE, | |
| "bsz_per_host": FLAGS.bsz_per_host, | |
| "num_core_per_host": FLAGS.num_core_per_host, | |
| "seq_len": FLAGS.seq_len, | |
| "reuse_len": FLAGS.reuse_len, | |
| "uncased": FLAGS.uncased, | |
| "bi_data": FLAGS.bi_data, | |
| "mask_alpha": FLAGS.mask_alpha, | |
| "mask_beta": FLAGS.mask_beta, | |
| "num_predict": FLAGS.num_predict, | |
| "use_eod": FLAGS.use_eod, | |
| "sp_path": FLAGS.sp_path, | |
| "input_glob": FLAGS.input_glob, | |
| } | |
| corpus_info_path = os.path.join(FLAGS.save_dir, "corpus_info.json") | |
| with tf.gfile.Open(corpus_info_path, "w") as fp: | |
| json.dump(corpus_info, fp) | |
| # Interleavely split the work into FLAGS.num_task splits | |
| file_paths = sorted(tf.gfile.Glob(FLAGS.input_glob)) | |
| tf.logging.info("Use glob: %s", FLAGS.input_glob) | |
| tf.logging.info("Find %d files: %s", len(file_paths), file_paths) | |
| task_file_paths = file_paths[FLAGS.task::FLAGS.num_task] | |
| if not task_file_paths: | |
| tf.logging.info("Exit: task %d has no file to process.", FLAGS.task) | |
| return | |
| tf.logging.info("Task %d process %d files: %s", | |
| FLAGS.task, len(task_file_paths), task_file_paths) | |
| record_info = _create_data(FLAGS.task, task_file_paths) | |
| record_prefix = "record_info-{}-{}-{}".format( | |
| FLAGS.split, FLAGS.task, FLAGS.pass_id) | |
| record_name = format_filename( | |
| prefix=record_prefix, | |
| bsz_per_host=FLAGS.bsz_per_host, | |
| seq_len=FLAGS.seq_len, | |
| mask_alpha=FLAGS.mask_alpha, | |
| mask_beta=FLAGS.mask_beta, | |
| reuse_len=FLAGS.reuse_len, | |
| bi_data=FLAGS.bi_data, | |
| suffix="json", | |
| uncased=FLAGS.uncased, | |
| fixed_num_predict=FLAGS.num_predict) | |
| record_info_path = os.path.join(tfrecord_dir, record_name) | |
| with tf.gfile.Open(record_info_path, "w") as fp: | |
| json.dump(record_info, fp) | |
| def batchify(data, bsz_per_host, sent_ids=None): | |
| num_step = len(data) // bsz_per_host | |
| data = data[:bsz_per_host * num_step] | |
| data = data.reshape(bsz_per_host, num_step) | |
| if sent_ids is not None: | |
| sent_ids = sent_ids[:bsz_per_host * num_step] | |
| sent_ids = sent_ids.reshape(bsz_per_host, num_step) | |
| if sent_ids is not None: | |
| return data, sent_ids | |
| return data | |
| def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False): | |
| """Split two segments from `data` starting from the index `begin_idx`.""" | |
| data_len = data.shape[0] | |
| if begin_idx + tot_len >= data_len: | |
| tf.logging.info("[_split_a_and_b] returns None: " | |
| "begin_idx %d + tot_len %d >= data_len %d", | |
| begin_idx, tot_len, data_len) | |
| return None | |
| end_idx = begin_idx + 1 | |
| cut_points = [] | |
| while end_idx < data_len: | |
| if sent_ids[end_idx] != sent_ids[end_idx - 1]: | |
| if end_idx - begin_idx >= tot_len: break | |
| cut_points.append(end_idx) | |
| end_idx += 1 | |
| a_begin = begin_idx | |
| if len(cut_points) == 0 or random.random() < 0.5: | |
| label = 0 | |
| if len(cut_points) == 0: | |
| a_end = end_idx | |
| else: | |
| a_end = random.choice(cut_points) | |
| b_len = max(1, tot_len - (a_end - a_begin)) | |
| # (zihangd): `data_len - 1` to account for extend_target | |
| b_begin = random.randint(0, data_len - 1 - b_len) | |
| b_end = b_begin + b_len | |
| while b_begin > 0 and sent_ids[b_begin - 1] == sent_ids[b_begin]: | |
| b_begin -= 1 | |
| # (zihangd): `data_len - 1` to account for extend_target | |
| while b_end < data_len - 1 and sent_ids[b_end - 1] == sent_ids[b_end]: | |
| b_end += 1 | |
| new_begin = a_end | |
| else: | |
| label = 1 | |
| a_end = random.choice(cut_points) | |
| b_begin = a_end | |
| b_end = end_idx | |
| new_begin = b_end | |
| while a_end - a_begin + b_end - b_begin > tot_len: | |
| if a_end - a_begin > b_end - b_begin: | |
| # delete the right side only for the LM objective | |
| a_end -= 1 | |
| else: | |
| b_end -= 1 | |
| ret = [data[a_begin: a_end], data[b_begin: b_end], label, new_begin] | |
| if extend_target: | |
| if a_end >= data_len or b_end >= data_len: | |
| tf.logging.info("[_split_a_and_b] returns None: " | |
| "a_end %d or b_end %d >= data_len %d", | |
| a_end, b_end, data_len) | |
| return None | |
| a_target = data[a_begin + 1: a_end + 1] | |
| b_target = data[b_begin: b_end + 1] | |
| ret.extend([a_target, b_target]) | |
| return ret | |
| def _is_start_piece(piece): | |
| special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~')) | |
| if (piece.startswith("▁") or piece.startswith("<") | |
| or piece in special_pieces): | |
| return True | |
| else: | |
| return False | |
| def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None): | |
| """Sample `goal_num_predict` tokens for partial prediction. | |
| About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens.""" | |
| seg_len = len(seg) | |
| mask = np.array([False] * seg_len, dtype=np.bool) | |
| num_predict = 0 | |
| ngrams = np.arange(1, max_gram + 1, dtype=np.int64) | |
| pvals = 1. / np.arange(1, max_gram + 1) | |
| pvals /= pvals.sum(keepdims=True) | |
| if reverse: | |
| seg = np.flip(seg, 0) | |
| cur_len = 0 | |
| while cur_len < seg_len: | |
| if goal_num_predict is not None and num_predict >= goal_num_predict: break | |
| n = np.random.choice(ngrams, p=pvals) | |
| if goal_num_predict is not None: | |
| n = min(n, goal_num_predict - num_predict) | |
| ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta | |
| l_ctx = np.random.choice(ctx_size) | |
| r_ctx = ctx_size - l_ctx | |
| # Find the start position of a complete token | |
| beg = cur_len + l_ctx | |
| while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())): | |
| beg += 1 | |
| if beg >= seg_len: | |
| break | |
| # Find the end position of the n-gram (start pos of the n+1-th gram) | |
| end = beg + 1 | |
| cnt_ngram = 1 | |
| while end < seg_len: | |
| cnt_ngram += 1 | |
| if cnt_ngram > n: | |
| break | |
| end += 1 | |
| if end >= seg_len: | |
| break | |
| # Update | |
| mask[beg:end] = True | |
| num_predict += end - beg | |
| cur_len = end + r_ctx | |
| while goal_num_predict is not None and num_predict < goal_num_predict: | |
| i = np.random.randint(seg_len) | |
| if not mask[i]: | |
| mask[i] = True | |
| num_predict += 1 | |
| if reverse: | |
| mask = np.flip(mask, 0) | |
| return mask | |
| def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5, | |
| goal_num_predict=None): | |
| """Sample `goal_num_predict` tokens for partial prediction. | |
| About `mask_beta` tokens are chosen in a context of `mask_alpha` tokens.""" | |
| seg_len = len(seg) | |
| mask = np.array([False] * seg_len, dtype=np.bool) | |
| num_predict = 0 | |
| ngrams = np.arange(1, max_gram + 1, dtype=np.int64) | |
| pvals = 1. / np.arange(1, max_gram + 1) | |
| pvals /= pvals.sum(keepdims=True) | |
| if reverse: | |
| seg = np.flip(seg, 0) | |
| cur_len = 0 | |
| while cur_len < seg_len: | |
| if goal_num_predict is not None and num_predict >= goal_num_predict: break | |
| n = np.random.choice(ngrams, p=pvals) | |
| if goal_num_predict is not None: | |
| n = min(n, goal_num_predict - num_predict) | |
| ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta | |
| l_ctx = np.random.choice(ctx_size) | |
| r_ctx = ctx_size - l_ctx | |
| # Find the start position of a complete token | |
| beg = cur_len + l_ctx | |
| while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())): | |
| beg += 1 | |
| if beg >= seg_len: | |
| break | |
| # Find the end position of the n-gram (start pos of the n+1-th gram) | |
| end = beg | |
| cnt_ngram = 0 | |
| while end < seg_len: | |
| if _is_start_piece(sp.IdToPiece(seg[end].item())): | |
| cnt_ngram += 1 | |
| if cnt_ngram > n: | |
| break | |
| # select current piece | |
| mask[end] = True | |
| # update the end pointer and increment num_predict | |
| end += 1 | |
| num_predict += 1 | |
| if goal_num_predict is not None and num_predict >= goal_num_predict: | |
| break | |
| cur_len = end + r_ctx | |
| while goal_num_predict is not None and num_predict < goal_num_predict: | |
| i = np.random.randint(seg_len) | |
| if not mask[i]: | |
| mask[i] = True | |
| num_predict += 1 | |
| if reverse: | |
| mask = np.flip(mask, 0) | |
| return mask | |
| def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len, | |
| bi_data, sp): | |
| data, sent_ids = data[0], data[1] | |
| num_core = FLAGS.num_core_per_host | |
| bsz_per_core = bsz_per_host // num_core | |
| if bi_data: | |
| assert bsz_per_host % (2 * FLAGS.num_core_per_host) == 0 | |
| fwd_data, fwd_sent_ids = batchify(data, bsz_per_host // 2, sent_ids) | |
| fwd_data = fwd_data.reshape(num_core, 1, bsz_per_core // 2, -1) | |
| fwd_sent_ids = fwd_sent_ids.reshape(num_core, 1, bsz_per_core // 2, -1) | |
| bwd_data = fwd_data[:, :, :, ::-1] | |
| bwd_sent_ids = fwd_sent_ids[:, :, :, ::-1] | |
| data = np.concatenate( | |
| [fwd_data, bwd_data], 1).reshape(bsz_per_host, -1) | |
| sent_ids = np.concatenate( | |
| [fwd_sent_ids, bwd_sent_ids], 1).reshape(bsz_per_host, -1) | |
| else: | |
| data, sent_ids = batchify(data, bsz_per_host, sent_ids) | |
| tf.logging.info("Raw data shape %s.", data.shape) | |
| file_name = format_filename( | |
| prefix=basename, | |
| bsz_per_host=bsz_per_host, | |
| seq_len=seq_len, | |
| bi_data=bi_data, | |
| suffix="tfrecords", | |
| mask_alpha=FLAGS.mask_alpha, | |
| mask_beta=FLAGS.mask_beta, | |
| reuse_len=FLAGS.reuse_len, | |
| uncased=FLAGS.uncased, | |
| fixed_num_predict=FLAGS.num_predict | |
| ) | |
| save_path = os.path.join(save_dir, file_name) | |
| record_writer = tf.python_io.TFRecordWriter(save_path) | |
| tf.logging.info("Start writing %s.", save_path) | |
| num_batch = 0 | |
| reuse_len = FLAGS.reuse_len | |
| # [sep] x 2 + [cls] | |
| assert reuse_len < seq_len - 3 | |
| data_len = data.shape[1] | |
| sep_array = np.array([SEP_ID], dtype=np.int64) | |
| cls_array = np.array([CLS_ID], dtype=np.int64) | |
| i = 0 | |
| while i + seq_len <= data_len: | |
| if num_batch % 500 == 0: | |
| tf.logging.info("Processing batch %d", num_batch) | |
| all_ok = True | |
| features = [] | |
| for idx in range(bsz_per_host): | |
| inp = data[idx, i: i + reuse_len] | |
| tgt = data[idx, i + 1: i + reuse_len + 1] | |
| results = _split_a_and_b( | |
| data[idx], | |
| sent_ids[idx], | |
| begin_idx=i + reuse_len, | |
| tot_len=seq_len - reuse_len - 3, | |
| extend_target=True) | |
| if results is None: | |
| tf.logging.info("Break out with seq idx %d", i) | |
| all_ok = False | |
| break | |
| # unpack the results | |
| (a_data, b_data, label, _, a_target, b_target) = tuple(results) | |
| # sample ngram spans to predict | |
| reverse = bi_data and (idx // (bsz_per_core // 2)) % 2 == 1 | |
| if FLAGS.num_predict is None: | |
| num_predict_0 = num_predict_1 = None | |
| else: | |
| num_predict_1 = FLAGS.num_predict // 2 | |
| num_predict_0 = FLAGS.num_predict - num_predict_1 | |
| mask_0 = _sample_mask(sp, inp, reverse=reverse, | |
| goal_num_predict=num_predict_0) | |
| mask_1 = _sample_mask(sp, np.concatenate([a_data, sep_array, b_data, | |
| sep_array, cls_array]), | |
| reverse=reverse, goal_num_predict=num_predict_1) | |
| # concatenate data | |
| cat_data = np.concatenate([inp, a_data, sep_array, b_data, | |
| sep_array, cls_array]) | |
| seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] + | |
| [1] * b_data.shape[0] + [1] + [2]) | |
| assert cat_data.shape[0] == seq_len | |
| assert mask_0.shape[0] == seq_len // 2 | |
| assert mask_1.shape[0] == seq_len // 2 | |
| # the last two CLS's are not used, just for padding purposes | |
| tgt = np.concatenate([tgt, a_target, b_target, cls_array, cls_array]) | |
| assert tgt.shape[0] == seq_len | |
| is_masked = np.concatenate([mask_0, mask_1], 0) | |
| if FLAGS.num_predict is not None: | |
| assert np.sum(is_masked) == FLAGS.num_predict | |
| feature = { | |
| "input": _int64_feature(cat_data), | |
| "is_masked": _int64_feature(is_masked), | |
| "target": _int64_feature(tgt), | |
| "seg_id": _int64_feature(seg_id), | |
| "label": _int64_feature([label]), | |
| } | |
| features.append(feature) | |
| if all_ok: | |
| assert len(features) == bsz_per_host | |
| for feature in features: | |
| example = tf.train.Example(features=tf.train.Features(feature=feature)) | |
| record_writer.write(example.SerializeToString()) | |
| num_batch += 1 | |
| else: | |
| break | |
| i += reuse_len | |
| record_writer.close() | |
| tf.logging.info("Done writing %s. Num of batches: %d", save_path, num_batch) | |
| return save_path, num_batch | |
| ################ | |
| # get_input_fn # | |
| ################ | |
| def _convert_example(example, use_bfloat16): | |
| """Cast int64 into int32 and float32 to bfloat16 if use_bfloat16.""" | |
| for key in list(example.keys()): | |
| val = example[key] | |
| if tf.keras.backend.is_sparse(val): | |
| val = tf.sparse.to_dense(val) | |
| if val.dtype == tf.int64: | |
| val = tf.cast(val, tf.int32) | |
| if use_bfloat16 and val.dtype == tf.float32: | |
| val = tf.cast(val, tf.bfloat16) | |
| example[key] = val | |
| def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts, | |
| host_id, num_core_per_host, bsz_per_core): | |
| # list of file pathes | |
| num_files = len(file_names) | |
| num_files_per_host = num_files // num_hosts | |
| my_start_file_id = host_id * num_files_per_host | |
| my_end_file_id = (host_id + 1) * num_files_per_host | |
| if host_id == num_hosts - 1: | |
| my_end_file_id = num_files | |
| file_paths = file_names[my_start_file_id: my_end_file_id] | |
| tf.logging.info("Host %d handles %d files", host_id, len(file_paths)) | |
| assert split == "train" | |
| dataset = tf.data.Dataset.from_tensor_slices(file_paths) | |
| # file-level shuffle | |
| if len(file_paths) > 1: | |
| dataset = dataset.shuffle(len(file_paths)) | |
| # Note: we cannot perform sample-level shuffle here because this will violate | |
| # the consecutive requirement of data stream. | |
| dataset = tf.data.TFRecordDataset(dataset) | |
| # Note: since we are doing online preprocessing, the parsed result of | |
| # the same input at each time will be different. Thus, cache processed data | |
| # is not helpful. It will use a lot of memory and lead to contrainer OOM. | |
| # So, change to cache non-parsed raw data instead. | |
| dataset = dataset.cache().map(parser).repeat() | |
| dataset = dataset.batch(bsz_per_core, drop_remainder=True) | |
| dataset = dataset.prefetch(num_core_per_host * bsz_per_core) | |
| return dataset | |
| def _local_perm(inputs, targets, is_masked, perm_size, seq_len): | |
| """ | |
| Sample a permutation of the factorization order, and create an | |
| attention mask accordingly. | |
| Args: | |
| inputs: int64 Tensor in shape [seq_len], input ids. | |
| targets: int64 Tensor in shape [seq_len], target ids. | |
| is_masked: bool Tensor in shape [seq_len]. True means being selected | |
| for partial prediction. | |
| perm_size: the length of longest permutation. Could be set to be reuse_len. | |
| Should not be larger than reuse_len or there will be data leaks. | |
| seq_len: int, sequence length. | |
| """ | |
| # Generate permutation indices | |
| index = tf.range(seq_len, dtype=tf.int64) | |
| index = tf.transpose(tf.reshape(index, [-1, perm_size])) | |
| index = tf.random_shuffle(index) | |
| index = tf.reshape(tf.transpose(index), [-1]) | |
| # `perm_mask` and `target_mask` | |
| # non-functional tokens | |
| non_func_tokens = tf.logical_not(tf.logical_or( | |
| tf.equal(inputs, SEP_ID), | |
| tf.equal(inputs, CLS_ID))) | |
| non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) | |
| masked_or_func_tokens = tf.logical_not(non_mask_tokens) | |
| # Set the permutation indices of non-masked (& non-funcional) tokens to the | |
| # smallest index (-1): | |
| # (1) they can be seen by all other positions | |
| # (2) they cannot see masked positions, so there won"t be information leak | |
| smallest_index = -tf.ones([seq_len], dtype=tf.int64) | |
| rev_index = tf.where(non_mask_tokens, smallest_index, index) | |
| # Create `target_mask`: non-funcional and maksed tokens | |
| # 1: use mask as input and have loss | |
| # 0: use token (or [SEP], [CLS]) as input and do not have loss | |
| target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) | |
| target_mask = tf.cast(target_tokens, tf.float32) | |
| # Create `perm_mask` | |
| # `target_tokens` cannot see themselves | |
| self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) | |
| # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) | |
| # 0: can attend if i > j or j is non-masked | |
| perm_mask = tf.logical_and( | |
| self_rev_index[:, None] <= rev_index[None, :], | |
| masked_or_func_tokens) | |
| perm_mask = tf.cast(perm_mask, tf.float32) | |
| # new target: [next token] for LM and [curr token] (self) for PLM | |
| new_targets = tf.concat([inputs[0: 1], targets[: -1]], | |
| axis=0) | |
| # construct inputs_k | |
| inputs_k = inputs | |
| # construct inputs_q | |
| inputs_q = target_mask | |
| return perm_mask, new_targets, target_mask, inputs_k, inputs_q | |
| def get_dataset(params, num_hosts, num_core_per_host, split, file_names, | |
| num_batch, seq_len, reuse_len, perm_size, mask_alpha, | |
| mask_beta, use_bfloat16=False, num_predict=None): | |
| bsz_per_core = params["batch_size"] | |
| if num_hosts > 1: | |
| host_id = params["context"].current_host | |
| else: | |
| host_id = 0 | |
| #### Function used to parse tfrecord | |
| def parser(record): | |
| """function used to parse tfrecord.""" | |
| record_spec = { | |
| "input": tf.FixedLenFeature([seq_len], tf.int64), | |
| "target": tf.FixedLenFeature([seq_len], tf.int64), | |
| "seg_id": tf.FixedLenFeature([seq_len], tf.int64), | |
| "label": tf.FixedLenFeature([1], tf.int64), | |
| "is_masked": tf.FixedLenFeature([seq_len], tf.int64), | |
| } | |
| # retrieve serialized example | |
| example = tf.parse_single_example( | |
| serialized=record, | |
| features=record_spec) | |
| inputs = example.pop("input") | |
| target = example.pop("target") | |
| is_masked = tf.cast(example.pop("is_masked"), tf.bool) | |
| non_reuse_len = seq_len - reuse_len | |
| assert perm_size <= reuse_len and perm_size <= non_reuse_len | |
| perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm( | |
| inputs[:reuse_len], | |
| target[:reuse_len], | |
| is_masked[:reuse_len], | |
| perm_size, | |
| reuse_len) | |
| perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm( | |
| inputs[reuse_len:], | |
| target[reuse_len:], | |
| is_masked[reuse_len:], | |
| perm_size, | |
| non_reuse_len) | |
| perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])], | |
| axis=1) | |
| perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], | |
| axis=1) | |
| perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) | |
| target = tf.concat([target_0, target_1], axis=0) | |
| target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) | |
| input_k = tf.concat([input_k_0, input_k_1], axis=0) | |
| input_q = tf.concat([input_q_0, input_q_1], axis=0) | |
| if num_predict is not None: | |
| indices = tf.range(seq_len, dtype=tf.int64) | |
| bool_target_mask = tf.cast(target_mask, tf.bool) | |
| indices = tf.boolean_mask(indices, bool_target_mask) | |
| ##### extra padding due to CLS/SEP introduced after prepro | |
| actual_num_predict = tf.shape(indices)[0] | |
| pad_len = num_predict - actual_num_predict | |
| ##### target_mapping | |
| target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) | |
| paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype) | |
| target_mapping = tf.concat([target_mapping, paddings], axis=0) | |
| example["target_mapping"] = tf.reshape(target_mapping, | |
| [num_predict, seq_len]) | |
| ##### target | |
| target = tf.boolean_mask(target, bool_target_mask) | |
| paddings = tf.zeros([pad_len], dtype=target.dtype) | |
| target = tf.concat([target, paddings], axis=0) | |
| example["target"] = tf.reshape(target, [num_predict]) | |
| ##### target mask | |
| target_mask = tf.concat( | |
| [tf.ones([actual_num_predict], dtype=tf.float32), | |
| tf.zeros([pad_len], dtype=tf.float32)], | |
| axis=0) | |
| example["target_mask"] = tf.reshape(target_mask, [num_predict]) | |
| else: | |
| example["target"] = tf.reshape(target, [seq_len]) | |
| example["target_mask"] = tf.reshape(target_mask, [seq_len]) | |
| # reshape back to fixed shape | |
| example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len]) | |
| example["input_k"] = tf.reshape(input_k, [seq_len]) | |
| example["input_q"] = tf.reshape(input_q, [seq_len]) | |
| _convert_example(example, use_bfloat16) | |
| for k, v in example.items(): | |
| tf.logging.info("%s: %s", k, v) | |
| return example | |
| # Get dataset | |
| dataset = parse_files_to_dataset( | |
| parser=parser, | |
| file_names=file_names, | |
| split=split, | |
| num_batch=num_batch, | |
| num_hosts=num_hosts, | |
| host_id=host_id, | |
| num_core_per_host=num_core_per_host, | |
| bsz_per_core=bsz_per_core) | |
| return dataset | |
| def get_input_fn( | |
| tfrecord_dir, | |
| split, | |
| bsz_per_host, | |
| seq_len, | |
| reuse_len, | |
| bi_data, | |
| num_hosts=1, | |
| num_core_per_host=1, | |
| perm_size=None, | |
| mask_alpha=None, | |
| mask_beta=None, | |
| uncased=False, | |
| num_passes=None, | |
| use_bfloat16=False, | |
| num_predict=None): | |
| # Merge all record infos into a single one | |
| record_glob_base = format_filename( | |
| prefix="record_info-{}-*".format(split), | |
| bsz_per_host=bsz_per_host, | |
| seq_len=seq_len, | |
| bi_data=bi_data, | |
| suffix="json", | |
| mask_alpha=mask_alpha, | |
| mask_beta=mask_beta, | |
| reuse_len=reuse_len, | |
| uncased=uncased, | |
| fixed_num_predict=num_predict) | |
| record_info = {"num_batch": 0, "filenames": []} | |
| tfrecord_dirs = tfrecord_dir.split(",") | |
| tf.logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs) | |
| for idx, record_dir in enumerate(tfrecord_dirs): | |
| record_glob = os.path.join(record_dir, record_glob_base) | |
| tf.logging.info("[%d] Record glob: %s", idx, record_glob) | |
| record_paths = sorted(tf.gfile.Glob(record_glob)) | |
| tf.logging.info("[%d] Num of record info path: %d", | |
| idx, len(record_paths)) | |
| cur_record_info = {"num_batch": 0, "filenames": []} | |
| for record_info_path in record_paths: | |
| if num_passes is not None: | |
| record_info_name = os.path.basename(record_info_path) | |
| fields = record_info_name.split(".")[0].split("-") | |
| pass_id = int(fields[-1]) | |
| if len(fields) == 5 and pass_id >= num_passes: | |
| tf.logging.info("Skip pass %d: %s", pass_id, record_info_name) | |
| continue | |
| with tf.gfile.Open(record_info_path, "r") as fp: | |
| info = json.load(fp) | |
| if num_passes is not None: | |
| eff_num_passes = min(num_passes, len(info["filenames"])) | |
| ratio = eff_num_passes / len(info["filenames"]) | |
| cur_record_info["num_batch"] += int(info["num_batch"] * ratio) | |
| cur_record_info["filenames"] += info["filenames"][:eff_num_passes] | |
| else: | |
| cur_record_info["num_batch"] += info["num_batch"] | |
| cur_record_info["filenames"] += info["filenames"] | |
| # overwrite directory for `cur_record_info` | |
| new_filenames = [] | |
| for filename in cur_record_info["filenames"]: | |
| basename = os.path.basename(filename) | |
| new_filename = os.path.join(record_dir, basename) | |
| new_filenames.append(new_filename) | |
| cur_record_info["filenames"] = new_filenames | |
| tf.logging.info("[Dir %d] Number of chosen batches: %s", | |
| idx, cur_record_info["num_batch"]) | |
| tf.logging.info("[Dir %d] Number of chosen files: %s", | |
| idx, len(cur_record_info["filenames"])) | |
| tf.logging.info(cur_record_info["filenames"]) | |
| # add `cur_record_info` to global `record_info` | |
| record_info["num_batch"] += cur_record_info["num_batch"] | |
| record_info["filenames"] += cur_record_info["filenames"] | |
| tf.logging.info("Total number of batches: %d", | |
| record_info["num_batch"]) | |
| tf.logging.info("Total number of files: %d", | |
| len(record_info["filenames"])) | |
| tf.logging.info(record_info["filenames"]) | |
| def input_fn(params): | |
| """docs.""" | |
| assert params["batch_size"] * num_core_per_host == bsz_per_host | |
| dataset = get_dataset( | |
| params=params, | |
| num_hosts=num_hosts, | |
| num_core_per_host=num_core_per_host, | |
| split=split, | |
| file_names=record_info["filenames"], | |
| num_batch=record_info["num_batch"], | |
| seq_len=seq_len, | |
| reuse_len=reuse_len, | |
| perm_size=perm_size, | |
| mask_alpha=mask_alpha, | |
| mask_beta=mask_beta, | |
| use_bfloat16=use_bfloat16, | |
| num_predict=num_predict) | |
| return dataset | |
| return input_fn, record_info | |
| if __name__ == "__main__": | |
| FLAGS = flags.FLAGS | |
| flags.DEFINE_bool("use_tpu", True, help="whether to use TPUs") | |
| flags.DEFINE_integer("bsz_per_host", 32, help="batch size per host.") | |
| flags.DEFINE_integer("num_core_per_host", 8, help="num TPU cores per host.") | |
| flags.DEFINE_integer("seq_len", 512, | |
| help="Sequence length.") | |
| flags.DEFINE_integer("reuse_len", 256, | |
| help="Number of token that can be reused as memory. " | |
| "Could be half of `seq_len`.") | |
| flags.DEFINE_bool("uncased", False, help="Use uncased inputs or not.") | |
| flags.DEFINE_bool("bi_data", True, | |
| help="whether to create bidirectional data") | |
| flags.DEFINE_integer("mask_alpha", default=6, | |
| help="How many tokens to form a group.") | |
| flags.DEFINE_integer("mask_beta", default=1, | |
| help="How many tokens to mask within each group.") | |
| flags.DEFINE_bool("use_eod", True, | |
| help="whether to append EOD at the end of a doc.") | |
| flags.DEFINE_bool("from_raw_text", True, | |
| help="Whether the input is raw text or encoded ids.") | |
| flags.DEFINE_integer("num_predict", default=85, | |
| help="Num of tokens to predict.") | |
| flags.DEFINE_string("input_glob", "data/example/*.txt", | |
| help="Input file glob.") | |
| flags.DEFINE_string("sp_path", "", help="Path to the sentence piece model.") | |
| flags.DEFINE_string("save_dir", "proc_data/example", | |
| help="Directory for saving the processed data.") | |
| flags.DEFINE_enum("split", "train", ["train", "dev", "test"], | |
| help="Save the data as which split.") | |
| flags.DEFINE_integer("pass_id", 0, help="ID of the current pass." | |
| "Different passes sample different negative segment.") | |
| flags.DEFINE_integer("num_task", 1, help="Number of total tasks.") | |
| flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when " | |
| "using multiple workers to identify each worker.") | |
| tf.logging.set_verbosity(tf.logging.INFO) | |
| app.run(create_data) | |