| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import pickle |
| import tempfile |
| import warnings |
| from unittest.mock import Mock |
|
|
| import torch |
| from torch.utils.data import ( |
| BatchSampler, |
| DataLoader, |
| Dataset, |
| IterableDataset, |
| RandomSampler, |
| TensorDataset, |
| default_collate, |
| ) |
|
|
| from accelerate.accelerator import Accelerator, DataLoaderConfiguration |
| from accelerate.utils.dataclasses import DistributedType |
|
|
|
|
| NUM_ELEMENTS = 22 |
| NUM_WORKERS = 4 |
| BATCH_SIZE = 4 |
|
|
|
|
| class DummyDataset(Dataset): |
| def __len__(self): |
| return NUM_ELEMENTS |
|
|
| def __getitem__(self, index): |
| squeeze = False |
|
|
| if isinstance(index, int): |
| index = [index] |
| squeeze = True |
| elif isinstance(index, slice): |
| index = list(range(*index.indices(self.size))) |
| else: |
| index = list(index) |
|
|
| batch = [{"index": i, "label": i % 2, "random_augmentation": torch.rand(1).item()} for i in index] |
|
|
| if squeeze: |
| batch = batch[0] |
|
|
| return batch |
|
|
|
|
| class DummyIterableDataset(IterableDataset): |
| def __init__(self, data): |
| self.data = data |
|
|
| def __iter__(self): |
| yield from self.data |
|
|
|
|
| def create_accelerator(even_batches=True): |
| dataloader_config = DataLoaderConfiguration(even_batches=even_batches) |
| accelerator = Accelerator(dataloader_config=dataloader_config) |
| assert accelerator.num_processes == 2, "this script expects that two GPUs are available" |
| return accelerator |
|
|
|
|
| def create_dataloader( |
| accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False |
| ): |
| """ |
| Create a simple DataLoader to use during the test cases |
| """ |
| values = torch.as_tensor(range(dataset_size)) |
| if shuffle: |
| values = values[torch.randperm(values.size(0))] |
| if iterable: |
| dataset = DummyIterableDataset(values) |
| else: |
| dataset = TensorDataset(torch.as_tensor(range(dataset_size))) |
|
|
| dl = DataLoader(dataset, batch_size=batch_size) |
| dl = accelerator.prepare(dl) |
|
|
| return dl |
|
|
|
|
| def verify_dataloader_batch_sizes( |
| accelerator: Accelerator, |
| dataset_size: int, |
| batch_size: int, |
| process_0_expected_batch_sizes: list[int], |
| process_1_expected_batch_sizes: list[int], |
| ): |
| """ |
| A helper function for verifying the batch sizes coming from a prepared dataloader in each process |
| """ |
| dl = create_dataloader(accelerator=accelerator, dataset_size=dataset_size, batch_size=batch_size) |
|
|
| batch_sizes = [len(batch[0]) for batch in dl] |
|
|
| if accelerator.process_index == 0: |
| assert batch_sizes == process_0_expected_batch_sizes |
| elif accelerator.process_index == 1: |
| assert batch_sizes == process_1_expected_batch_sizes |
|
|
|
|
| def test_default_ensures_even_batch_sizes(): |
| accelerator = create_accelerator() |
|
|
| |
| verify_dataloader_batch_sizes( |
| accelerator, |
| dataset_size=3, |
| batch_size=1, |
| process_0_expected_batch_sizes=[1, 1], |
| process_1_expected_batch_sizes=[1, 1], |
| ) |
|
|
| |
| verify_dataloader_batch_sizes( |
| accelerator, |
| dataset_size=7, |
| batch_size=2, |
| process_0_expected_batch_sizes=[2, 2], |
| process_1_expected_batch_sizes=[2, 2], |
| ) |
|
|
|
|
| def test_can_disable_even_batches(): |
| accelerator = create_accelerator(even_batches=False) |
|
|
| verify_dataloader_batch_sizes( |
| accelerator, |
| dataset_size=3, |
| batch_size=1, |
| process_0_expected_batch_sizes=[1, 1], |
| process_1_expected_batch_sizes=[1], |
| ) |
|
|
| verify_dataloader_batch_sizes( |
| accelerator, |
| dataset_size=7, |
| batch_size=2, |
| process_0_expected_batch_sizes=[2, 2], |
| process_1_expected_batch_sizes=[2, 1], |
| ) |
|
|
|
|
| def test_can_join_uneven_inputs(): |
| accelerator = create_accelerator(even_batches=False) |
|
|
| model = torch.nn.Linear(1, 1) |
| ddp_model = accelerator.prepare(model) |
|
|
| dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) |
|
|
| batch_idxs = [] |
| with accelerator.join_uneven_inputs([ddp_model]): |
| for batch_idx, batch in enumerate(dl): |
| output = ddp_model(batch[0].float()) |
| loss = output.sum() |
| loss.backward() |
| batch_idxs.append(batch_idx) |
|
|
| accelerator.wait_for_everyone() |
|
|
| if accelerator.process_index == 0: |
| assert batch_idxs == [0, 1] |
| elif accelerator.process_index == 1: |
| assert batch_idxs == [0] |
|
|
|
|
| def test_join_raises_warning_for_non_ddp_distributed(accelerator): |
| with warnings.catch_warnings(record=True) as w: |
| with accelerator.join_uneven_inputs([Mock()]): |
| pass |
|
|
| assert issubclass(w[-1].category, UserWarning) |
| assert "only supported for multi-GPU" in str(w[-1].message) |
|
|
|
|
| def test_join_can_override_even_batches(): |
| default_even_batches = True |
| overridden_even_batches = False |
| accelerator = create_accelerator(even_batches=default_even_batches) |
| model = torch.nn.Linear(1, 1) |
| ddp_model = accelerator.prepare(model) |
| train_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) |
| valid_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) |
|
|
| with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches): |
| train_dl_overridden_value = train_dl.batch_sampler.even_batches |
| valid_dl_overridden_value = valid_dl.batch_sampler.even_batches |
|
|
| assert train_dl_overridden_value == overridden_even_batches |
| assert valid_dl_overridden_value == overridden_even_batches |
| assert train_dl.batch_sampler.even_batches == default_even_batches |
| assert valid_dl.batch_sampler.even_batches == default_even_batches |
|
|
|
|
| def test_join_can_override_for_mixed_type_dataloaders(): |
| default_even_batches = True |
| overridden_even_batches = False |
| accelerator = create_accelerator(even_batches=default_even_batches) |
| model = torch.nn.Linear(1, 1) |
| ddp_model = accelerator.prepare(model) |
| create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True) |
| batch_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) |
|
|
| with warnings.catch_warnings(): |
| warnings.filterwarnings("ignore") |
| try: |
| with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches): |
| batch_dl_overridden_value = batch_dl.batch_sampler.even_batches |
| except AttributeError: |
| |
| raise AssertionError |
|
|
| assert batch_dl_overridden_value == overridden_even_batches |
| assert batch_dl.batch_sampler.even_batches == default_even_batches |
|
|
|
|
| def test_join_raises_warning_for_iterable_when_overriding_even_batches(): |
| accelerator = create_accelerator() |
| model = torch.nn.Linear(1, 1) |
| ddp_model = accelerator.prepare(model) |
| create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True) |
|
|
| with warnings.catch_warnings(record=True) as w: |
| with accelerator.join_uneven_inputs([ddp_model], even_batches=False): |
| pass |
|
|
| assert issubclass(w[-1].category, UserWarning) |
| assert "only supported for map-style datasets" in str(w[-1].message) |
|
|
|
|
| def test_pickle_accelerator(): |
| accelerator = create_accelerator() |
| data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4) |
| _ = accelerator.prepare(data_loader) |
| pickled_accelerator = pickle.dumps(accelerator) |
| unpickled_accelerator = pickle.loads(pickled_accelerator) |
| |
| assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__ |
|
|
|
|
| def test_data_loader(data_loader, accelerator): |
| |
| data_loader = accelerator.prepare(data_loader) |
|
|
| all_examples = [] |
| for i, batch in enumerate(data_loader): |
| index, _ = accelerator.gather_for_metrics((batch["index"], batch["label"])) |
| all_examples.extend(index.detach().cpu().numpy().tolist()) |
|
|
| |
| sorted_all_examples = sorted(all_examples) |
|
|
| |
| assert len(set(sorted_all_examples)) == NUM_ELEMENTS, ( |
| "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes." |
| ) |
|
|
|
|
| def _test_stateful_dataloader_resume(accelerator, iterable): |
| """ |
| Helper: iterate a stateful dataloader, save state after a few batches using `load_state_dict`, |
| resume from the saved state, and verify the resumed batches match what was originally unseen. |
| |
| Saves early (after 3 batches) so many batches remain, exposing any off-by-one in state restoration. |
| Tested with both iterable and map-style datasets to cover different state_dict code paths. |
| """ |
| old_dataloader_config = accelerator.dataloader_config |
| try: |
| accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) |
| prepared_dl = create_dataloader( |
| accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=iterable, shuffle=True |
| ) |
| untrained_batches = [] |
| save_step = 2 |
| for step, batch in enumerate(prepared_dl): |
| if step == save_step: |
| state_dict = prepared_dl.state_dict() |
| if step > save_step: |
| untrained_batches.append(batch) |
| not_skipped_batches = accelerator.gather(untrained_batches) |
| prepared_dl.load_state_dict(state_dict) |
| resumed_batches = [] |
| for batch in prepared_dl: |
| resumed_batches.append(batch) |
| resumed_batches = accelerator.gather(resumed_batches) |
| assert len(not_skipped_batches) == len(resumed_batches), ( |
| f"Expected {len(not_skipped_batches)} batches after resume, got {len(resumed_batches)}" |
| ) |
| for b1, b2 in zip(not_skipped_batches, resumed_batches): |
| for v1, v2 in zip(b1, b2): |
| assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal" |
| finally: |
| accelerator.dataloader_config = old_dataloader_config |
|
|
|
|
| def test_stateful_dataloader(accelerator): |
| """ |
| Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then |
| resumed from the saved state. |
| |
| The result should be the same as the rest of the data that iterated over after saving. |
| """ |
| _test_stateful_dataloader_resume(accelerator, iterable=True) |
| _test_stateful_dataloader_resume(accelerator, iterable=False) |
|
|
|
|
| def _test_stateful_dataloader_save_state_resume(accelerator, iterable): |
| """ |
| Helper: iterate a stateful dataloader, save state after a few batches using `Accelerator.save_state`, |
| resume, and verify the resumed batches match what was originally unseen. |
| """ |
| old_dataloader_config = accelerator.dataloader_config |
| try: |
| with tempfile.TemporaryDirectory() as tmpdir: |
| accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) |
| prepared_dl = create_dataloader( |
| accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=iterable, shuffle=True |
| ) |
| untrained_batches = [] |
| save_step = 2 |
| for step, batch in enumerate(prepared_dl): |
| if step == save_step: |
| accelerator.save_state(tmpdir) |
| if step > save_step: |
| untrained_batches.append(batch) |
| not_skipped_batches = accelerator.gather(untrained_batches) |
| accelerator.load_state(tmpdir) |
| resumed_batches = [] |
| for batch in prepared_dl: |
| resumed_batches.append(batch) |
| resumed_batches = accelerator.gather(resumed_batches) |
| assert len(not_skipped_batches) == len(resumed_batches), ( |
| f"Expected {len(not_skipped_batches)} batches after resume, got {len(resumed_batches)}" |
| ) |
| for b1, b2 in zip(not_skipped_batches, resumed_batches): |
| for v1, v2 in zip(b1, b2): |
| assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal" |
| finally: |
| accelerator.dataloader_config = old_dataloader_config |
|
|
|
|
| def test_stateful_dataloader_save_state(accelerator): |
| """ |
| Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`, |
| and then resumed from the saved state. |
| |
| The result should be the same as the rest of the data that iterated over after saving. |
| """ |
| _test_stateful_dataloader_save_state_resume(accelerator, iterable=True) |
| _test_stateful_dataloader_save_state_resume(accelerator, iterable=False) |
|
|
|
|
| def main(): |
| accelerator = create_accelerator() |
| torch.manual_seed(accelerator.process_index) |
|
|
| accelerator.print("Test that even_batches variable ensures uniform batches across processes") |
| test_default_ensures_even_batch_sizes() |
|
|
| accelerator.print("Run tests with even_batches disabled") |
| test_can_disable_even_batches() |
|
|
| accelerator.print("Test joining uneven inputs") |
| test_can_join_uneven_inputs() |
|
|
| accelerator.print("Test overriding even_batches when joining uneven inputs") |
| test_join_can_override_even_batches() |
|
|
| accelerator.print("Test overriding even_batches for mixed dataloader types") |
| test_join_can_override_for_mixed_type_dataloaders() |
|
|
| accelerator.print("Test overriding even_batches raises a warning for iterable dataloaders") |
| test_join_raises_warning_for_iterable_when_overriding_even_batches() |
|
|
| accelerator.print("Test join with non DDP distributed raises warning") |
| original_state = accelerator.state.distributed_type |
| accelerator.state.distributed_type = DistributedType.FSDP |
| test_join_raises_warning_for_non_ddp_distributed(accelerator) |
| accelerator.state.distributed_type = original_state |
|
|
| accelerator.print("Test pickling an accelerator") |
| test_pickle_accelerator() |
|
|
| dataset = DummyDataset() |
|
|
| accelerator.print("Test DataLoader with shuffle=False") |
| loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) |
| test_data_loader(loader, accelerator) |
|
|
| accelerator.print("Test DataLoader with shuffle=True") |
| loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) |
| test_data_loader(loader, accelerator) |
|
|
| accelerator.print("Test DataLoader with batch_sampler") |
| sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False) |
| loader = DataLoader(dataset, batch_sampler=sampler, num_workers=NUM_WORKERS) |
| test_data_loader(loader, accelerator) |
|
|
| accelerator.print("Test DataLoader with sampler as an instance of `BatchSampler`") |
| sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False) |
| loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS) |
| test_data_loader(loader, accelerator) |
| test_stateful_dataloader(accelerator) |
| test_stateful_dataloader_save_state(accelerator) |
|
|
| accelerator.end_training() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|