apex / tests /L0 /run_transformer /test_batch_sampler.py
camenduru's picture
thanks to NVIDIA ❤
e828767
import torch
from torch.testing._internal import common_utils
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from apex.transformer.pipeline_parallel.utils import _split_batch_into_microbatch as split_batch_into_microbatch
class MyIterableDataset(Dataset):
def __init__(self, start, end):
super().__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
self.samples = list(range(self.start, self.end))
def __iter__(self):
return iter(range(self.start, self.end))
def __getitem__(self, index):
return self.samples[index]
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []
# Samples 8 tensors in total.
# First sample 4 tensors twice, then sample 2 tensors fourth.
class TestBatchSamplerBehavior(common_utils.TestCase):
def tearDown(self) -> None:
torch.cuda.empty_cache()
super().tearDown()
def test_batch_sampler_behavior(self):
dataset = MyIterableDataset(0, 100)
for num_workers in (1, 2, 4):
torch.manual_seed(42)
loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, 4, 0, 1), num_workers=num_workers)
samples = []
for i, batch in enumerate(loader):
samples.append(batch)
if i == 2 - 1:
break
torch.manual_seed(42)
loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, 2, 0, 1), num_workers=num_workers)
samples2 = []
for i, batch in enumerate(loader):
samples2.append(batch)
if i == 4 - 1:
break
self.assertEqual(torch.cat(samples), torch.cat(samples2), msg=f"num_workers={num_workers}")
def test_split_batch(self):
class MyIterableDataset(Dataset):
def __init__(self, start, end):
super().__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
self.samples = list(range(self.start, self.end))
def __len__(self):
return self.end - self.start
def __iter__(self):
return iter(range(self.start, self.end))
def __getitem__(self, index):
return (torch.tensor([index, index]), torch.tensor([index // 2, index // 2]))
dataset = MyIterableDataset(0, 100)
torch.manual_seed(42)
global_batch_size = 16
loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, global_batch_size, 0, 1), num_workers=2)
batch = next(iter(loader))
for _micro_batch_size in (1, 2, 4, 8):
microbatches = list(split_batch_into_microbatch(
batch,
_micro_batch_size=_micro_batch_size,
_global_batch_size=global_batch_size,
))
self.assertEqual(len(microbatches), global_batch_size // _micro_batch_size)
self.assertEqual(len(microbatches[0][0]), _micro_batch_size)
if __name__ == "__main__":
common_utils.run_tests()