| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
|
|
| import math |
| import time |
| import collections |
|
|
| import numpy as np |
| import re |
|
|
| from fengshen.data.megatron_dataloader.utils import ( |
| print_rank_0 |
| ) |
| from fengshen.data.megatron_dataloader.blendable_dataset import BlendableDataset |
| from fengshen.data.megatron_dataloader.indexed_dataset import make_dataset as make_indexed_dataset |
|
|
| DSET_TYPE_BERT = 'standard_bert' |
| DSET_TYPE_ICT = 'ict' |
| DSET_TYPE_T5 = 't5' |
| DSET_TYPE_BERT_CN_WWM = 'bert_cn_wwm' |
| DSET_TYPE_BART = 'bart' |
| DSET_TYPE_COCOLM = 'coco_lm' |
|
|
| DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, |
| DSET_TYPE_T5, DSET_TYPE_BERT_CN_WWM, |
| DSET_TYPE_BART, DSET_TYPE_COCOLM] |
|
|
|
|
| def get_datasets_weights_and_num_samples(data_prefix, |
| train_valid_test_num_samples): |
|
|
| |
| |
| assert len(data_prefix) % 2 == 0 |
| num_datasets = len(data_prefix) // 2 |
| weights = [0] * num_datasets |
| prefixes = [0] * num_datasets |
| for i in range(num_datasets): |
| weights[i] = float(data_prefix[2 * i]) |
| prefixes[i] = (data_prefix[2 * i + 1]).strip() |
| |
| weight_sum = 0.0 |
| for weight in weights: |
| weight_sum += weight |
| assert weight_sum > 0.0 |
| weights = [weight / weight_sum for weight in weights] |
|
|
| |
| |
| |
| datasets_train_valid_test_num_samples = [] |
| for weight in weights: |
| datasets_train_valid_test_num_samples.append( |
| [int(math.ceil(val * weight * 1.005)) |
| for val in train_valid_test_num_samples]) |
|
|
| return prefixes, weights, datasets_train_valid_test_num_samples |
|
|
|
|
| def compile_helper(): |
| """Compile helper function ar runtime. Make sure this |
| is invoked on a single process.""" |
| import os |
| import subprocess |
| path = os.path.abspath(os.path.dirname(__file__)) |
| ret = subprocess.run(['make', '-C', path]) |
| if ret.returncode != 0: |
| print("Making C++ dataset helpers module failed, exiting.") |
| import sys |
| sys.exit(1) |
|
|
|
|
| def get_a_and_b_segments(sample, np_rng): |
| """Divide sample into a and b segments.""" |
|
|
| |
| n_sentences = len(sample) |
| |
| assert n_sentences > 1, 'make sure each sample has at least two sentences.' |
|
|
| |
| |
| a_end = 1 |
| if n_sentences >= 3: |
| |
| a_end = np_rng.randint(1, n_sentences) |
| tokens_a = [] |
| for j in range(a_end): |
| tokens_a.extend(sample[j]) |
|
|
| |
| tokens_b = [] |
| for j in range(a_end, n_sentences): |
| tokens_b.extend(sample[j]) |
|
|
| |
| is_next_random = False |
| if np_rng.random() < 0.5: |
| is_next_random = True |
| tokens_a, tokens_b = tokens_b, tokens_a |
|
|
| return tokens_a, tokens_b, is_next_random |
|
|
|
|
| def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): |
| """Truncates a pair of sequences to a maximum sequence length.""" |
| |
| assert len_a > 0 |
| if len_a + len_b <= max_num_tokens: |
| return False |
| while len_a + len_b > max_num_tokens: |
| if len_a > len_b: |
| len_a -= 1 |
| tokens = tokens_a |
| else: |
| len_b -= 1 |
| tokens = tokens_b |
| if np_rng.random() < 0.5: |
| del tokens[0] |
| else: |
| tokens.pop() |
| return True |
|
|
|
|
| def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): |
| """Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" |
|
|
| tokens = [] |
| tokentypes = [] |
| |
| tokens.append(cls_id) |
| tokentypes.append(0) |
| |
| for token in tokens_a: |
| tokens.append(token) |
| tokentypes.append(0) |
| |
| tokens.append(sep_id) |
| tokentypes.append(0) |
| |
| for token in tokens_b: |
| tokens.append(token) |
| tokentypes.append(1) |
| if tokens_b: |
| |
| tokens.append(sep_id) |
| tokentypes.append(1) |
|
|
| return tokens, tokentypes |
|
|
|
|
| MaskedLmInstance = collections.namedtuple("MaskedLmInstance", |
| ["index", "label"]) |
|
|
|
|
| def is_start_piece(piece): |
| """Check if the current word piece is the starting piece (BERT).""" |
| |
| |
| |
| |
| return not piece.startswith("##") |
|
|
|
|
| def create_masked_lm_predictions(tokens, |
| vocab_id_list, vocab_id_to_token_dict, |
| masked_lm_prob, |
| cls_id, sep_id, mask_id, |
| max_predictions_per_seq, |
| np_rng, |
| tokenizer, |
| max_ngrams=3, |
| do_whole_word_mask=True, |
| favor_longer_ngram=False, |
| do_permutation=False, |
| geometric_dist=False, |
| masking_style="bert", |
| zh_tokenizer=None): |
| """Creates the predictions for the masked LM objective. |
| Note: Tokens here are vocab ids and not text tokens.""" |
|
|
| cand_indexes = [] |
| |
| |
| |
| token_boundary = [0] * len(tokens) |
|
|
| |
| if zh_tokenizer is None: |
| for (i, token) in enumerate(tokens): |
| if token == cls_id or token == sep_id: |
| token_boundary[i] = 1 |
| continue |
| |
| |
| |
| |
| |
| |
| if (do_whole_word_mask and len(cand_indexes) >= 1 and |
| not is_start_piece(vocab_id_to_token_dict[token])): |
| cand_indexes[-1].append(i) |
| else: |
| cand_indexes.append([i]) |
| if is_start_piece(vocab_id_to_token_dict[token]): |
| token_boundary[i] = 1 |
| else: |
| |
| |
| raw_tokens = [] |
| for t in tokens: |
| if t != cls_id and t != sep_id: |
| raw_tokens.append(t) |
| raw_tokens = [vocab_id_to_token_dict[i] for i in raw_tokens] |
| |
| word_list = set(zh_tokenizer(''.join(raw_tokens), HMM=True)) |
| word_length_dict = {} |
| for w in word_list: |
| if len(w) < 1: |
| continue |
| if w[0] not in word_length_dict: |
| word_length_dict[w[0]] = len(w) |
| elif word_length_dict[w[0]] < len(w): |
| word_length_dict[w[0]] = len(w) |
| i = 0 |
| |
| while i < len(tokens): |
| token_id = tokens[i] |
| token = vocab_id_to_token_dict[token_id] |
| if len(token) == 0 or token_id == cls_id or token_id == sep_id: |
| token_boundary[i] = 1 |
| i += 1 |
| continue |
| word_max_length = 1 |
| if token[0] in word_length_dict: |
| word_max_length = word_length_dict[token[0]] |
| j = 0 |
| word = '' |
| word_end = i+1 |
| |
| old_style = False |
| while word_end < len(tokens) and vocab_id_to_token_dict[tokens[word_end]].startswith('##'): |
| old_style = True |
| word_end += 1 |
| if not old_style: |
| while j < word_max_length and i+j < len(tokens): |
| cur_token = tokens[i+j] |
| word += vocab_id_to_token_dict[cur_token] |
| j += 1 |
| if word in word_list: |
| word_end = i+j |
| cand_indexes.append([p for p in range(i, word_end)]) |
| token_boundary[i] = 1 |
| i = word_end |
|
|
| output_tokens = list(tokens) |
| |
| if masking_style == 'bert-cn-wwm': |
| |
| |
| new_token_ids = [] |
| for token_id in output_tokens: |
| token = tokenizer.convert_ids_to_tokens([token_id])[0] |
| if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0: |
| token = token[2:] |
| new_token_id = tokenizer.convert_tokens_to_ids([token])[ |
| 0] |
| new_token_ids.append(new_token_id) |
| output_tokens = new_token_ids |
|
|
| masked_lm_positions = [] |
| masked_lm_labels = [] |
|
|
| if masked_lm_prob == 0: |
| return (output_tokens, masked_lm_positions, |
| masked_lm_labels, token_boundary) |
|
|
| num_to_predict = min(max_predictions_per_seq, |
| max(1, int(round(len(tokens) * masked_lm_prob)))) |
|
|
| ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) |
| if not geometric_dist: |
| |
| |
| pvals = 1. / np.arange(1, max_ngrams + 1) |
| pvals /= pvals.sum(keepdims=True) |
| if favor_longer_ngram: |
| pvals = pvals[::-1] |
| |
| ngram_indexes = [] |
| for idx in range(len(cand_indexes)): |
| ngram_index = [] |
| for n in ngrams: |
| ngram_index.append(cand_indexes[idx:idx + n]) |
| ngram_indexes.append(ngram_index) |
|
|
| np_rng.shuffle(ngram_indexes) |
|
|
| (masked_lms, masked_spans) = ([], []) |
| covered_indexes = set() |
| for cand_index_set in ngram_indexes: |
| if len(masked_lms) >= num_to_predict: |
| break |
| if not cand_index_set: |
| continue |
| |
| |
| for index_set in cand_index_set[0]: |
| for index in index_set: |
| if index in covered_indexes: |
| continue |
|
|
| if not geometric_dist: |
| n = np_rng.choice(ngrams[:len(cand_index_set)], |
| p=pvals[:len(cand_index_set)] / |
| pvals[:len(cand_index_set)].sum(keepdims=True)) |
| else: |
| |
| |
| |
| n = min(np_rng.geometric(0.2), max_ngrams) |
|
|
| index_set = sum(cand_index_set[n - 1], []) |
| n -= 1 |
| |
| |
| |
| while len(masked_lms) + len(index_set) > num_to_predict: |
| if n == 0: |
| break |
| index_set = sum(cand_index_set[n - 1], []) |
| n -= 1 |
| |
| |
| if len(masked_lms) + len(index_set) > num_to_predict: |
| continue |
| is_any_index_covered = False |
| for index in index_set: |
| if index in covered_indexes: |
| is_any_index_covered = True |
| break |
| if is_any_index_covered: |
| continue |
| for index in index_set: |
| covered_indexes.add(index) |
| masked_token = None |
| if masking_style == "bert": |
| |
| if np_rng.random() < 0.8: |
| masked_token = mask_id |
| else: |
| |
| if np_rng.random() < 0.5: |
| masked_token = tokens[index] |
| |
| else: |
| masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] |
| elif masking_style == 'bert-cn-wwm': |
| |
| if np_rng.random() < 0.8: |
| masked_token = mask_id |
| else: |
| |
| if np_rng.random() < 0.5: |
| |
| token_id = tokens[index] |
| token = tokenizer.convert_ids_to_tokens([token_id])[ |
| 0] |
| if len(re.findall('##[\u4E00-\u9FA5]', token)) > 0: |
| token = token[2:] |
| new_token_id = tokenizer.convert_tokens_to_ids([token])[ |
| 0] |
| masked_token = new_token_id |
| |
| else: |
| masked_token = vocab_id_list[np_rng.randint( |
| 0, len(vocab_id_list))] |
| elif masking_style == "t5": |
| masked_token = mask_id |
| else: |
| raise ValueError("invalid value of masking style") |
|
|
| output_tokens[index] = masked_token |
| masked_lms.append(MaskedLmInstance( |
| index=index, label=tokens[index])) |
|
|
| masked_spans.append(MaskedLmInstance( |
| index=index_set, |
| label=[tokens[index] for index in index_set])) |
|
|
| assert len(masked_lms) <= num_to_predict |
| np_rng.shuffle(ngram_indexes) |
|
|
| select_indexes = set() |
| if do_permutation: |
| for cand_index_set in ngram_indexes: |
| if len(select_indexes) >= num_to_predict: |
| break |
| if not cand_index_set: |
| continue |
| |
| |
| for index_set in cand_index_set[0]: |
| for index in index_set: |
| if index in covered_indexes or index in select_indexes: |
| continue |
|
|
| n = np.random.choice(ngrams[:len(cand_index_set)], |
| p=pvals[:len(cand_index_set)] / |
| pvals[:len(cand_index_set)].sum(keepdims=True)) |
| index_set = sum(cand_index_set[n - 1], []) |
| n -= 1 |
|
|
| while len(select_indexes) + len(index_set) > num_to_predict: |
| if n == 0: |
| break |
| index_set = sum(cand_index_set[n - 1], []) |
| n -= 1 |
| |
| |
| if len(select_indexes) + len(index_set) > num_to_predict: |
| continue |
| is_any_index_covered = False |
| for index in index_set: |
| if index in covered_indexes or index in select_indexes: |
| is_any_index_covered = True |
| break |
| if is_any_index_covered: |
| continue |
| for index in index_set: |
| select_indexes.add(index) |
| assert len(select_indexes) <= num_to_predict |
|
|
| select_indexes = sorted(select_indexes) |
| permute_indexes = list(select_indexes) |
| np_rng.shuffle(permute_indexes) |
| orig_token = list(output_tokens) |
|
|
| for src_i, tgt_i in zip(select_indexes, permute_indexes): |
| output_tokens[src_i] = orig_token[tgt_i] |
| masked_lms.append(MaskedLmInstance( |
| index=src_i, label=orig_token[src_i])) |
|
|
| masked_lms = sorted(masked_lms, key=lambda x: x.index) |
| |
| masked_spans = sorted(masked_spans, key=lambda x: x.index[0]) |
|
|
| for p in masked_lms: |
| masked_lm_positions.append(p.index) |
| masked_lm_labels.append(p.label) |
| return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans) |
|
|
|
|
| def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, |
| masked_labels, pad_id, max_seq_length): |
| """Pad sequences and convert them to numpy.""" |
|
|
| |
| num_tokens = len(tokens) |
| padding_length = max_seq_length - num_tokens |
| assert padding_length >= 0 |
| assert len(tokentypes) == num_tokens |
| assert len(masked_positions) == len(masked_labels) |
|
|
| |
| filler = [pad_id] * padding_length |
| tokens_np = np.array(tokens + filler, dtype=np.int64) |
| tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) |
|
|
| |
| padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, |
| dtype=np.int64) |
|
|
| |
| labels = [-1] * max_seq_length |
| loss_mask = [0] * max_seq_length |
| for i in range(len(masked_positions)): |
| assert masked_positions[i] < num_tokens |
| labels[masked_positions[i]] = masked_labels[i] |
| loss_mask[masked_positions[i]] = 1 |
| labels_np = np.array(labels, dtype=np.int64) |
| loss_mask_np = np.array(loss_mask, dtype=np.int64) |
|
|
| return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np |
|
|
|
|
| def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, |
| train_valid_test_num_samples, |
| max_seq_length, |
| masked_lm_prob, short_seq_prob, seed, |
| tokenizer, |
| skip_warmup, binary_head=False, |
| max_seq_length_dec=None, |
| dataset_type='standard_bert', |
| zh_tokenizer=None, |
| span=None): |
|
|
| if len(data_prefix) == 1: |
| return _build_train_valid_test_datasets(data_prefix[0], |
| data_impl, splits_string, |
| train_valid_test_num_samples, |
| max_seq_length, masked_lm_prob, |
| short_seq_prob, seed, |
| skip_warmup, |
| binary_head, |
| max_seq_length_dec, |
| tokenizer, |
| dataset_type=dataset_type, |
| zh_tokenizer=zh_tokenizer, |
| span=span) |
| |
| |
| output = get_datasets_weights_and_num_samples(data_prefix, |
| train_valid_test_num_samples) |
| prefixes, weights, datasets_train_valid_test_num_samples = output |
|
|
| |
| train_datasets = [] |
| valid_datasets = [] |
| test_datasets = [] |
| for i in range(len(prefixes)): |
| train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( |
| prefixes[i], data_impl, splits_string, |
| datasets_train_valid_test_num_samples[i], |
| max_seq_length, masked_lm_prob, short_seq_prob, |
| seed, skip_warmup, binary_head, max_seq_length_dec, |
| tokenizer, dataset_type=dataset_type, zh_tokenizer=zh_tokenizer) |
| if train_ds: |
| train_datasets.append(train_ds) |
| if valid_ds: |
| valid_datasets.append(valid_ds) |
| if test_ds: |
| test_datasets.append(test_ds) |
|
|
| |
| blending_train_dataset = None |
| if train_datasets: |
| blending_train_dataset = BlendableDataset(train_datasets, weights) |
| blending_valid_dataset = None |
| if valid_datasets: |
| blending_valid_dataset = BlendableDataset(valid_datasets, weights) |
| blending_test_dataset = None |
| if test_datasets: |
| blending_test_dataset = BlendableDataset(test_datasets, weights) |
|
|
| return (blending_train_dataset, blending_valid_dataset, |
| blending_test_dataset) |
|
|
|
|
| def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, |
| train_valid_test_num_samples, |
| max_seq_length, |
| masked_lm_prob, short_seq_prob, seed, |
| skip_warmup, binary_head, |
| max_seq_length_dec, |
| tokenizer, |
| dataset_type='standard_bert', |
| zh_tokenizer=None, |
| span=None): |
|
|
| if dataset_type not in DSET_TYPES: |
| raise ValueError("Invalid dataset_type: ", dataset_type) |
|
|
| |
| indexed_dataset = get_indexed_dataset_(data_prefix, |
| data_impl, |
| skip_warmup) |
|
|
| |
| |
| |
| total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 |
| splits = get_train_valid_test_split_(splits_string, total_num_of_documents) |
|
|
| |
| print_rank_0(' > dataset split:') |
|
|
| def print_split_stats(name, index): |
| print_rank_0(' {}:'.format(name)) |
| print_rank_0(' document indices in [{}, {}) total of {} ' |
| 'documents'.format(splits[index], splits[index + 1], |
| splits[index + 1] - splits[index])) |
| start_index = indexed_dataset.doc_idx[splits[index]] |
| end_index = indexed_dataset.doc_idx[splits[index + 1]] |
| print_rank_0(' sentence indices in [{}, {}) total of {} ' |
| 'sentences'.format(start_index, end_index, |
| end_index - start_index)) |
| print_split_stats('train', 0) |
| print_split_stats('validation', 1) |
| print_split_stats('test', 2) |
|
|
| def build_dataset(index, name): |
| from fengshen.data.megatron_dataloader.bert_dataset import BertDataset |
| from fengshen.data.megatron_dataloader.bart_dataset import BartDataset |
| from fengshen.data.megatron_dataloader.cocolm_dataset import COCOLMDataset |
| dataset = None |
| if splits[index + 1] > splits[index]: |
| |
| doc_idx_ptr = indexed_dataset.get_doc_idx() |
| |
| start_index = splits[index] |
| |
| end_index = splits[index + 1] + 1 |
| |
| indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) |
| |
| kwargs = dict( |
| name=name, |
| data_prefix=data_prefix, |
| num_epochs=None, |
| max_num_samples=train_valid_test_num_samples[index], |
| max_seq_length=max_seq_length, |
| seed=seed, |
| ) |
|
|
| if dataset_type == DSET_TYPE_BERT or dataset_type == DSET_TYPE_BERT_CN_WWM: |
| dataset = BertDataset( |
| indexed_dataset=indexed_dataset, |
| masked_lm_prob=masked_lm_prob, |
| short_seq_prob=short_seq_prob, |
| binary_head=binary_head, |
| |
| tokenizer=tokenizer, |
| masking_style='bert' if dataset_type == DSET_TYPE_BERT else 'bert-cn-wwm', |
| **kwargs |
| ) |
| elif dataset_type == DSET_TYPE_BART: |
| dataset = BartDataset( |
| indexed_dataset=indexed_dataset, |
| masked_lm_prob=masked_lm_prob, |
| short_seq_prob=short_seq_prob, |
| tokenizer=tokenizer, |
| zh_tokenizer=zh_tokenizer, |
| **kwargs |
| ) |
| elif dataset_type == DSET_TYPE_COCOLM: |
| dataset = COCOLMDataset( |
| indexed_dataset=indexed_dataset, |
| masked_lm_prob=masked_lm_prob, |
| short_seq_prob=short_seq_prob, |
| tokenizer=tokenizer, |
| masking_style='bert', |
| span=span, |
| **kwargs |
| ) |
| else: |
| raise NotImplementedError( |
| "Dataset type not fully implemented.") |
|
|
| |
| indexed_dataset.set_doc_idx(doc_idx_ptr) |
| |
| assert indexed_dataset.doc_idx[0] == 0 |
| assert indexed_dataset.doc_idx.shape[0] == \ |
| (total_num_of_documents + 1) |
| return dataset |
|
|
| train_dataset = build_dataset(0, 'train') |
| valid_dataset = build_dataset(1, 'valid') |
| test_dataset = build_dataset(2, 'test') |
|
|
| return (train_dataset, valid_dataset, test_dataset) |
|
|
|
|
| def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): |
|
|
| print_rank_0(' > building dataset index ...') |
|
|
| start_time = time.time() |
| indexed_dataset = make_indexed_dataset(data_prefix, |
| data_impl, |
| skip_warmup) |
| assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] |
| print_rank_0(' > finished creating indexed dataset in {:4f} ' |
| 'seconds'.format(time.time() - start_time)) |
|
|
| print_rank_0(' > indexed dataset stats:') |
| print_rank_0(' number of documents: {}'.format( |
| indexed_dataset.doc_idx.shape[0] - 1)) |
| print_rank_0(' number of sentences: {}'.format( |
| indexed_dataset.sizes.shape[0])) |
|
|
| return indexed_dataset |
|
|
|
|
| def get_train_valid_test_split_(splits_string, size): |
| """ Get dataset splits from comma or '/' separated string list.""" |
|
|
| splits = [] |
| if splits_string.find(',') != -1: |
| splits = [float(s) for s in splits_string.split(',')] |
| elif splits_string.find('/') != -1: |
| splits = [float(s) for s in splits_string.split('/')] |
| else: |
| splits = [float(splits_string)] |
| while len(splits) < 3: |
| splits.append(0.) |
| splits = splits[:3] |
| splits_sum = sum(splits) |
| assert splits_sum > 0.0 |
| splits = [split / splits_sum for split in splits] |
| splits_index = [0] |
| for index, split in enumerate(splits): |
| splits_index.append(splits_index[index] + |
| int(round(split * float(size)))) |
| diff = splits_index[-1] - size |
| for index in range(1, len(splits_index)): |
| splits_index[index] -= diff |
| assert len(splits_index) == 4 |
| assert splits_index[-1] == size |
| return splits_index |
|
|
|
|
| def get_samples_mapping(indexed_dataset, |
| data_prefix, |
| num_epochs, |
| max_num_samples, |
| max_seq_length, |
| short_seq_prob, |
| seed, |
| name, |
| binary_head): |
| """Get a list that maps a sample index to a starting |
| sentence index, end sentence index, and length""" |
|
|
| if not num_epochs: |
| if not max_num_samples: |
| raise ValueError("Need to specify either max_num_samples " |
| "or num_epochs") |
| num_epochs = np.iinfo(np.int32).max - 1 |
| if not max_num_samples: |
| max_num_samples = np.iinfo(np.int64).max - 1 |
|
|
| |
| indexmap_filename = data_prefix |
| indexmap_filename += '_{}_indexmap'.format(name) |
| if num_epochs != (np.iinfo(np.int32).max - 1): |
| indexmap_filename += '_{}ep'.format(num_epochs) |
| if max_num_samples != (np.iinfo(np.int64).max - 1): |
| indexmap_filename += '_{}mns'.format(max_num_samples) |
| indexmap_filename += '_{}msl'.format(max_seq_length) |
| indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) |
| indexmap_filename += '_{}s'.format(seed) |
| indexmap_filename += '.npy' |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| print_rank_0(' > loading indexed mapping from {}'.format( |
| indexmap_filename)) |
| start_time = time.time() |
| samples_mapping = np.load( |
| indexmap_filename, allow_pickle=True, mmap_mode='r') |
| print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( |
| time.time() - start_time)) |
| print_rank_0(' total number of samples: {}'.format( |
| samples_mapping.shape[0])) |
|
|
| return samples_mapping |
|
|