| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """GPT style dataset.""" |
|
|
| import os |
| import time |
|
|
| import numpy as np |
| import torch |
|
|
| from megatron import mpu, print_rank_0 |
| from megatron.data.blendable_dataset import BlendableDataset |
| from megatron.data.dataset_utils import get_datasets_weights_and_num_samples |
| from megatron.data.dataset_utils import get_train_valid_test_split_ |
| from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset |
|
|
|
|
| def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, |
| train_valid_test_num_samples, |
| seq_length, seed, skip_warmup): |
| """Build train, valid, and test datasets.""" |
|
|
| |
| if len(data_prefix) == 1: |
| return _build_train_valid_test_datasets(data_prefix[0], |
| data_impl, splits_string, |
| train_valid_test_num_samples, |
| seq_length, seed, skip_warmup) |
|
|
| |
| |
| output = get_datasets_weights_and_num_samples(data_prefix, |
| train_valid_test_num_samples) |
| prefixes, weights, datasets_train_valid_test_num_samples = output |
|
|
| |
| train_datasets = [] |
| valid_datasets = [] |
| test_datasets = [] |
| for i in range(len(prefixes)): |
| train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( |
| prefixes[i], data_impl, splits_string, |
| datasets_train_valid_test_num_samples[i], |
| seq_length, seed, skip_warmup) |
| if train_ds: |
| train_datasets.append(train_ds) |
| if valid_ds: |
| valid_datasets.append(valid_ds) |
| if test_ds: |
| test_datasets.append(test_ds) |
|
|
| |
| blending_train_dataset = None |
| if train_datasets: |
| blending_train_dataset = BlendableDataset(train_datasets, weights) |
| blending_valid_dataset = None |
| if valid_datasets: |
| blending_valid_dataset = BlendableDataset(valid_datasets, weights) |
| blending_test_dataset = None |
| if test_datasets: |
| blending_test_dataset = BlendableDataset(test_datasets, weights) |
|
|
| return (blending_train_dataset, blending_valid_dataset, |
| blending_test_dataset) |
|
|
|
|
| def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, |
| train_valid_test_num_samples, |
| seq_length, seed, skip_warmup): |
| """Build train, valid, and test datasets.""" |
|
|
| |
| indexed_dataset = get_indexed_dataset_(data_prefix, |
| data_impl, |
| skip_warmup) |
|
|
| total_num_of_documents = indexed_dataset.sizes.shape[0] |
| splits = get_train_valid_test_split_(splits_string, total_num_of_documents) |
|
|
| |
| print_rank_0(' > dataset split:') |
|
|
| def print_split_stats(name, index): |
| print_rank_0(' {}:'.format(name)) |
| print_rank_0(' document indices in [{}, {}) total of {} ' |
| 'documents'.format(splits[index], splits[index + 1], |
| splits[index + 1] - splits[index])) |
| print_split_stats('train', 0) |
| print_split_stats('validation', 1) |
| print_split_stats('test', 2) |
|
|
| def build_dataset(index, name): |
| dataset = None |
| if splits[index + 1] > splits[index]: |
| documents = np.arange(start=splits[index], stop=splits[index + 1], |
| step=1, dtype=np.int32) |
| dataset = GPTDataset(name, data_prefix, |
| documents, indexed_dataset, |
| train_valid_test_num_samples[index], |
| seq_length, seed) |
| return dataset |
|
|
| train_dataset = build_dataset(0, 'train') |
| valid_dataset = build_dataset(1, 'valid') |
| test_dataset = build_dataset(2, 'test') |
|
|
| return (train_dataset, valid_dataset, test_dataset) |
|
|
|
|
| def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): |
| """Build indexed dataset.""" |
| print_rank_0(' > building dataset index ...') |
|
|
| start_time = time.time() |
| indexed_dataset = make_indexed_dataset(data_prefix, |
| data_impl, |
| skip_warmup) |
| print_rank_0(' > finished creating indexed dataset in {:4f} ' |
| 'seconds'.format(time.time() - start_time)) |
| print_rank_0(' number of documents: {}'.format( |
| indexed_dataset.sizes.shape[0])) |
|
|
| return indexed_dataset |
|
|
|
|
| class GPTDataset(torch.utils.data.Dataset): |
|
|
| def __init__(self, name, data_prefix, documents, indexed_dataset, |
| num_samples, seq_length, seed): |
|
|
| self.name = name |
| self.indexed_dataset = indexed_dataset |
|
|
| |
| assert np.min(documents) >= 0 |
| assert np.max(documents) < indexed_dataset.sizes.shape[0] |
|
|
| |
| self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( |
| self.name, data_prefix, documents, self.indexed_dataset.sizes, |
| num_samples, seq_length, seed) |
|
|
| def __len__(self): |
| |
| |
| return self.sample_idx.shape[0] - 1 |
|
|
| def __getitem__(self, idx): |
| |
| idx = self.shuffle_idx[idx] |
| |
| doc_index_f = self.sample_idx[idx][0] |
| doc_index_l = self.sample_idx[idx + 1][0] |
| offset_f = self.sample_idx[idx][1] |
| offset_l = self.sample_idx[idx + 1][1] |
| |
| if doc_index_f == doc_index_l: |
| sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], |
| offset=offset_f, |
| length=offset_l - offset_f + 1) |
| else: |
| |
| sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], |
| offset=offset_f)] |
| |
| for i in range(doc_index_f + 1, doc_index_l): |
| sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) |
| |
| sample_list.append(self.indexed_dataset.get( |
| self.doc_idx[doc_index_l], |
| length=offset_l + 1)) |
| sample = np.concatenate(sample_list) |
|
|
| return {'text': np.array(sample, dtype=np.int64)} |
|
|
|
|
| def _build_index_mappings(name, data_prefix, documents, sizes, |
| num_samples, seq_length, seed): |
| """Build doc-idx, sample-idx, and shuffle-idx. |
| doc-idx: is an array (ordered) of documents to be used in training. |
| sample-idx: is the start document index and document offset for each |
| training sample. |
| shuffle-idx: maps the sample index into a random index into sample-idx. |
| """ |
| |
| tokens_per_epoch = _num_tokens(documents, sizes) |
| num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) |
| |
| np_rng = np.random.RandomState(seed=seed) |
|
|
| |
| _filename = data_prefix |
| _filename += '_{}_indexmap'.format(name) |
| _filename += '_{}ns'.format(num_samples) |
| _filename += '_{}sl'.format(seq_length) |
| _filename += '_{}s'.format(seed) |
| doc_idx_filename = _filename + '_doc_idx.npy' |
| sample_idx_filename = _filename + '_sample_idx.npy' |
| shuffle_idx_filename = _filename + '_shuffle_idx.npy' |
|
|
| |
| if torch.distributed.get_rank() == 0: |
| if (not os.path.isfile(doc_idx_filename)) or \ |
| (not os.path.isfile(sample_idx_filename)) or \ |
| (not os.path.isfile(shuffle_idx_filename)): |
|
|
| print_rank_0(' > WARNING: could not find index map files, building ' |
| 'the indices on rank 0 ...') |
|
|
| |
| |
|
|
| |
| |
| if num_epochs == 1: |
| separate_last_epoch = False |
| print(' > only one epoch required, setting ' |
| 'separate_last_epoch to False', flush=True) |
|
|
| else: |
| |
| num_samples_from_epochs_minus_one = ( |
| (num_epochs - 1) * tokens_per_epoch - 1) // seq_length |
| last_epoch_num_samples = num_samples - \ |
| num_samples_from_epochs_minus_one |
| assert last_epoch_num_samples >= 0, \ |
| 'last epoch number of samples should be non-negative.' |
| num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length |
| assert last_epoch_num_samples < (num_samples_per_epoch + 1), \ |
| 'last epoch number of samples exceeded max value.' |
| |
| |
| |
| |
| separate_last_epoch = (last_epoch_num_samples < |
| int(0.80 * num_samples_per_epoch)) |
| if separate_last_epoch: |
| string = ' > last epoch number of samples ({}) is smaller '\ |
| 'than 80% of number of samples per epoch ({}), '\ |
| 'setting separate_last_epoch to True' |
| else: |
| string = ' > last epoch number of samples ({}) is larger '\ |
| 'than 80% of number of samples per epoch ({}), '\ |
| 'setting separate_last_epoch to False' |
| print(string.format(last_epoch_num_samples, |
| num_samples_per_epoch), flush=True) |
|
|
| |
| start_time = time.time() |
| doc_idx = _build_doc_idx(documents, num_epochs, np_rng, |
| separate_last_epoch) |
| np.save(doc_idx_filename, doc_idx, allow_pickle=True) |
| print_rank_0(' > elasped time to build and save doc-idx mapping ' |
| '(seconds): {:4f}'.format(time.time() - start_time)) |
| |
| start_time = time.time() |
| |
| |
| from megatron.data import helpers |
| assert doc_idx.dtype == np.int32 |
| assert sizes.dtype == np.int32 |
| sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, |
| num_epochs, tokens_per_epoch) |
| |
| |
| np.save(sample_idx_filename, sample_idx, allow_pickle=True) |
| print_rank_0(' > elasped time to build and save sample-idx mapping ' |
| '(seconds): {:4f}'.format(time.time() - start_time)) |
| |
| start_time = time.time() |
| |
| |
| if separate_last_epoch: |
| num_samples_ = num_samples_from_epochs_minus_one |
| else: |
| num_samples_ = sample_idx.shape[0] - 1 |
| shuffle_idx = _build_shuffle_idx(num_samples_, |
| sample_idx.shape[0] - 1, np_rng) |
| np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) |
| print_rank_0(' > elasped time to build and save shuffle-idx mapping' |
| ' (seconds): {:4f}'.format(time.time() - start_time)) |
|
|
| |
| |
| |
| counts = torch.cuda.LongTensor([1]) |
| torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) |
| torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) |
| assert counts[0].item() == ( |
| torch.distributed.get_world_size() // |
| torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) |
|
|
| |
| start_time = time.time() |
| print_rank_0(' > loading doc-idx mapping from {}'.format( |
| doc_idx_filename)) |
| doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r') |
| print_rank_0(' > loading sample-idx mapping from {}'.format( |
| sample_idx_filename)) |
| sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') |
| print_rank_0(' > loading shuffle-idx mapping from {}'.format( |
| shuffle_idx_filename)) |
| shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') |
| print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( |
| time.time() - start_time)) |
| print_rank_0(' total number of samples: {}'.format( |
| sample_idx.shape[0])) |
| print_rank_0(' total number of epochs: {}'.format(num_epochs)) |
|
|
| return doc_idx, sample_idx, shuffle_idx |
|
|
|
|
| def _num_tokens(documents, sizes): |
| """Total number of tokens in the dataset.""" |
| return np.sum(sizes[documents]) |
|
|
|
|
| def _num_epochs(tokens_per_epoch, seq_length, num_samples): |
| """Based on number of samples and sequence lenght, calculate how many |
| epochs will be needed.""" |
| num_epochs = 0 |
| total_tokens = 0 |
| while True: |
| num_epochs += 1 |
| total_tokens += tokens_per_epoch |
| |
| |
| |
| if ((total_tokens - 1) // seq_length) >= num_samples: |
| return num_epochs |
|
|
|
|
| def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): |
| """Build an array with length = number-of-epochs * number-of-dcuments. |
| Each index is mapped to a corresponding document.""" |
| if not separate_last_epoch or num_epochs == 1: |
| doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1] |
| doc_idx[:] = documents |
| doc_idx = doc_idx.reshape(-1) |
| doc_idx = doc_idx.astype(np.int32) |
| np_rng.shuffle(doc_idx) |
| return doc_idx |
|
|
| doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False) |
| doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) |
| return np.concatenate((doc_idx_first, doc_idx_last)) |
|
|
|
|
| def _build_sample_idx(sizes, doc_idx, seq_length, |
| num_epochs, tokens_per_epoch): |
| """Sample index mapping is a 2D array with sizes |
| [number-of-samples + 1, 2] where [..., 0] contains |
| the index into `doc_idx` and [..., 1] is the |
| starting offset in that document.""" |
|
|
| |
| num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length |
| sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32) |
|
|
| |
| sample_index = 0 |
| |
| doc_idx_index = 0 |
| |
| doc_offset = 0 |
| |
| sample_idx[sample_index][0] = doc_idx_index |
| sample_idx[sample_index][1] = doc_offset |
| sample_index += 1 |
| while sample_index <= num_samples: |
| |
| remaining_seq_length = seq_length + 1 |
| while remaining_seq_length != 0: |
| |
| doc_id = doc_idx[doc_idx_index] |
| doc_length = sizes[doc_id] - doc_offset |
| |
| remaining_seq_length -= doc_length |
| |
| |
| |
| |
| if remaining_seq_length <= 0: |
| doc_offset += (remaining_seq_length + doc_length - 1) |
| remaining_seq_length = 0 |
| else: |
| |
| doc_idx_index += 1 |
| doc_offset = 0 |
| |
| sample_idx[sample_index][0] = doc_idx_index |
| sample_idx[sample_index][1] = doc_offset |
| sample_index += 1 |
|
|
| return sample_idx |
|
|
|
|
| def _build_shuffle_idx(num_samples, total_size, np_rng): |
| """Build the range [0, size) and shuffle.""" |
| print(' > building shuffle index with split [0, {}) and [{}, {}) ' |
| '...'.format(num_samples, num_samples, total_size), flush=True) |
| |
| dtype_ = np.uint32 |
| if total_size >= (np.iinfo(np.uint32).max - 1): |
| dtype_ = np.int64 |
|
|
| shuffle_idx_first = np.arange(start=0, stop=num_samples, |
| step=1, dtype=dtype_) |
| np_rng.shuffle(shuffle_idx_first) |
| if num_samples == total_size: |
| return shuffle_idx_first |
|
|
| shuffle_idx_last = np.arange(start=num_samples, stop=total_size, |
| step=1, dtype=dtype_) |
| np_rng.shuffle(shuffle_idx_last) |
|
|
| return np.concatenate((shuffle_idx_first, shuffle_idx_last)) |
|
|