File size: 5,557 Bytes
e828767 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import torch
from torch.testing._internal import common_utils
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from apex.transformer.pipeline_parallel.utils import _split_batch_into_microbatch as split_batch_into_microbatch
class MyIterableDataset(Dataset):
def __init__(self, start, end):
super().__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
self.samples = list(range(self.start, self.end))
def __iter__(self):
return iter(range(self.start, self.end))
def __getitem__(self, index):
return self.samples[index]
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []
# Samples 8 tensors in total.
# First sample 4 tensors twice, then sample 2 tensors fourth.
class TestBatchSamplerBehavior(common_utils.TestCase):
def tearDown(self) -> None:
torch.cuda.empty_cache()
super().tearDown()
def test_batch_sampler_behavior(self):
dataset = MyIterableDataset(0, 100)
for num_workers in (1, 2, 4):
torch.manual_seed(42)
loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, 4, 0, 1), num_workers=num_workers)
samples = []
for i, batch in enumerate(loader):
samples.append(batch)
if i == 2 - 1:
break
torch.manual_seed(42)
loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, 2, 0, 1), num_workers=num_workers)
samples2 = []
for i, batch in enumerate(loader):
samples2.append(batch)
if i == 4 - 1:
break
self.assertEqual(torch.cat(samples), torch.cat(samples2), msg=f"num_workers={num_workers}")
def test_split_batch(self):
class MyIterableDataset(Dataset):
def __init__(self, start, end):
super().__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
self.samples = list(range(self.start, self.end))
def __len__(self):
return self.end - self.start
def __iter__(self):
return iter(range(self.start, self.end))
def __getitem__(self, index):
return (torch.tensor([index, index]), torch.tensor([index // 2, index // 2]))
dataset = MyIterableDataset(0, 100)
torch.manual_seed(42)
global_batch_size = 16
loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, global_batch_size, 0, 1), num_workers=2)
batch = next(iter(loader))
for _micro_batch_size in (1, 2, 4, 8):
microbatches = list(split_batch_into_microbatch(
batch,
_micro_batch_size=_micro_batch_size,
_global_batch_size=global_batch_size,
))
self.assertEqual(len(microbatches), global_batch_size // _micro_batch_size)
self.assertEqual(len(microbatches[0][0]), _micro_batch_size)
if __name__ == "__main__":
common_utils.run_tests()
|