| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Dataloaders.""" |
| |
|
| |
|
| | import random |
| | import torch |
| | import numpy as np |
| | from torch.utils.data import Dataset |
| | from megatron import get_args |
| | from megatron import mpu |
| |
|
| |
|
| | def build_pretraining_data_loader(dataset, consumed_samples): |
| | """Buld dataloader given an input dataset.""" |
| |
|
| | if dataset is None: |
| | return None |
| | args = get_args() |
| |
|
| | |
| | if args.dataloader_type == 'single': |
| | batch_sampler = MegatronPretrainingSampler( |
| | total_samples=len(dataset), |
| | consumed_samples=consumed_samples, |
| | micro_batch_size=args.micro_batch_size, |
| | data_parallel_rank=mpu.get_data_parallel_rank(), |
| | data_parallel_size=mpu.get_data_parallel_world_size()) |
| | elif args.dataloader_type == 'cyclic': |
| | batch_sampler = MegatronPretrainingRandomSampler( |
| | dataset, |
| | total_samples=len(dataset), |
| | consumed_samples=consumed_samples, |
| | micro_batch_size=args.micro_batch_size, |
| | data_parallel_rank=mpu.get_data_parallel_rank(), |
| | data_parallel_size=mpu.get_data_parallel_world_size(), |
| | data_sharding=args.data_sharding) |
| | else: |
| | raise Exception('{} dataloader type is not supported.'.format( |
| | args.dataloader_type)) |
| |
|
| | |
| | return torch.utils.data.DataLoader(dataset, |
| | batch_sampler=batch_sampler, |
| | num_workers=args.num_workers, |
| | pin_memory=True) |
| |
|
| | class MegatronPretrainingSampler: |
| |
|
| | def __init__(self, total_samples, consumed_samples, micro_batch_size, |
| | data_parallel_rank, data_parallel_size, drop_last=True): |
| | |
| | self.total_samples = total_samples |
| | self.consumed_samples = consumed_samples |
| | self.micro_batch_size = micro_batch_size |
| | self.data_parallel_rank = data_parallel_rank |
| | self.micro_batch_times_data_parallel_size = \ |
| | self.micro_batch_size * data_parallel_size |
| | self.drop_last = drop_last |
| |
|
| | |
| | assert self.total_samples > 0, \ |
| | 'no sample to consume: {}'.format(self.total_samples) |
| | assert self.consumed_samples < self.total_samples, \ |
| | 'no samples left to consume: {}, {}'.format(self.consumed_samples, |
| | self.total_samples) |
| | assert self.micro_batch_size > 0 |
| | assert data_parallel_size > 0 |
| | assert self.data_parallel_rank < data_parallel_size, \ |
| | 'data_parallel_rank should be smaller than data size: {}, ' \ |
| | '{}'.format(self.data_parallel_rank, data_parallel_size) |
| |
|
| | def __len__(self): |
| | return self.total_samples |
| |
|
| | def get_start_end_idx(self): |
| | start_idx = self.data_parallel_rank * self.micro_batch_size |
| | end_idx = start_idx + self.micro_batch_size |
| | return start_idx, end_idx |
| |
|
| | def __iter__(self): |
| | batch = [] |
| | |
| | for idx in range(self.consumed_samples, self.total_samples): |
| | batch.append(idx) |
| | if len(batch) == self.micro_batch_times_data_parallel_size: |
| | start_idx, end_idx = self.get_start_end_idx() |
| | yield batch[start_idx:end_idx] |
| | batch = [] |
| |
|
| | |
| | if len(batch) > 0 and not self.drop_last: |
| | start_idx, end_idx = self.get_start_end_idx() |
| | yield batch[start_idx:end_idx] |
| |
|
| |
|
| | class RandomSeedDataset(Dataset): |
| |
|
| | def __init__(self, dataset): |
| | args = get_args() |
| | self.base_seed = args.seed |
| | self.curr_seed = args.seed |
| | self.dataset = dataset |
| |
|
| | def __len__(self): |
| | return len(self.dataset) |
| |
|
| | def set_epoch(self, epoch): |
| | self.curr_seed = self.base_seed + epoch |
| |
|
| | def __getitem__(self, idx): |
| | seed = idx + self.curr_seed |
| | torch.manual_seed(seed) |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | return self.dataset[idx] |
| |
|
| |
|
| | class MegatronPretrainingRandomSampler: |
| |
|
| | def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size, |
| | data_parallel_rank, data_parallel_size, data_sharding): |
| | |
| | self.dataset = dataset |
| | self.total_samples = total_samples |
| | self.consumed_samples = consumed_samples |
| | self.micro_batch_size = micro_batch_size |
| | self.data_parallel_rank = data_parallel_rank |
| | self.data_parallel_size = data_parallel_size |
| | self.data_sharding = data_sharding |
| | self.micro_batch_times_data_parallel_size = \ |
| | self.micro_batch_size * data_parallel_size |
| | self.last_batch_size = \ |
| | self.total_samples % self.micro_batch_times_data_parallel_size |
| |
|
| | |
| | assert self.total_samples > 0, \ |
| | 'no sample to consume: {}'.format(self.total_samples) |
| | assert self.micro_batch_size > 0 |
| | assert data_parallel_size > 0 |
| | assert self.data_parallel_rank < data_parallel_size, \ |
| | 'data_parallel_rank should be smaller than data size: {}, ' \ |
| | '{}'.format(self.data_parallel_rank, data_parallel_size) |
| |
|
| | def __len__(self): |
| | return self.total_samples |
| |
|
| | def __iter__(self): |
| | active_total_samples = self.total_samples - self.last_batch_size |
| | self.epoch = self.consumed_samples // active_total_samples |
| | current_epoch_samples = self.consumed_samples % active_total_samples |
| | assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 |
| |
|
| | if isinstance(self.dataset, RandomSeedDataset): |
| | self.dataset.set_epoch(self.epoch) |
| |
|
| | |
| | if self.data_sharding: |
| | bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ |
| | * self.micro_batch_size |
| | bucket_offset = current_epoch_samples // self.data_parallel_size |
| | start_idx = self.data_parallel_rank * bucket_size |
| | |
| | g = torch.Generator() |
| | g.manual_seed(self.epoch) |
| | random_idx = torch.randperm(bucket_size, generator=g).tolist() |
| | idx_range = [start_idx + x for x in random_idx[bucket_offset:]] |
| | else: |
| | full_bucket_size = (self.total_samples // self.micro_batch_size) \ |
| | * self.micro_batch_size |
| | full_bucket_offset = current_epoch_samples |
| | g = torch.Generator() |
| | g.manual_seed(self.epoch) |
| | idx_range_total = \ |
| | torch.randperm(full_bucket_size, generator=g).tolist() |
| | idx_range_active = idx_range_total[full_bucket_offset:] |
| | idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size] |
| |
|
| | batch = [] |
| | |
| | for idx in idx_range: |
| | batch.append(idx) |
| | if len(batch) == self.micro_batch_size: |
| | self.consumed_samples += self.micro_batch_times_data_parallel_size |
| | yield batch |
| | batch = [] |
| |
|