#!/usr/bin/env python # Copyright 2021 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pickle import tempfile import warnings from unittest.mock import Mock import torch from torch.utils.data import ( BatchSampler, DataLoader, Dataset, IterableDataset, RandomSampler, TensorDataset, default_collate, ) from accelerate.accelerator import Accelerator, DataLoaderConfiguration from accelerate.utils.dataclasses import DistributedType NUM_ELEMENTS = 22 NUM_WORKERS = 4 BATCH_SIZE = 4 class DummyDataset(Dataset): def __len__(self): return NUM_ELEMENTS def __getitem__(self, index): squeeze = False if isinstance(index, int): index = [index] squeeze = True elif isinstance(index, slice): index = list(range(*index.indices(self.size))) else: index = list(index) batch = [{"index": i, "label": i % 2, "random_augmentation": torch.rand(1).item()} for i in index] if squeeze: batch = batch[0] return batch class DummyIterableDataset(IterableDataset): def __init__(self, data): self.data = data def __iter__(self): yield from self.data def create_accelerator(even_batches=True): dataloader_config = DataLoaderConfiguration(even_batches=even_batches) accelerator = Accelerator(dataloader_config=dataloader_config) assert accelerator.num_processes == 2, "this script expects that two GPUs are available" return accelerator def create_dataloader( accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False ): """ Create a simple DataLoader to use during the test cases """ values = torch.as_tensor(range(dataset_size)) if shuffle: values = values[torch.randperm(values.size(0))] if iterable: dataset = DummyIterableDataset(values) else: dataset = TensorDataset(torch.as_tensor(range(dataset_size))) dl = DataLoader(dataset, batch_size=batch_size) dl = accelerator.prepare(dl) return dl def verify_dataloader_batch_sizes( accelerator: Accelerator, dataset_size: int, batch_size: int, process_0_expected_batch_sizes: list[int], process_1_expected_batch_sizes: list[int], ): """ A helper function for verifying the batch sizes coming from a prepared dataloader in each process """ dl = create_dataloader(accelerator=accelerator, dataset_size=dataset_size, batch_size=batch_size) batch_sizes = [len(batch[0]) for batch in dl] if accelerator.process_index == 0: assert batch_sizes == process_0_expected_batch_sizes elif accelerator.process_index == 1: assert batch_sizes == process_1_expected_batch_sizes def test_default_ensures_even_batch_sizes(): accelerator = create_accelerator() # without padding, we would expect a different number of batches verify_dataloader_batch_sizes( accelerator, dataset_size=3, batch_size=1, process_0_expected_batch_sizes=[1, 1], process_1_expected_batch_sizes=[1, 1], ) # without padding, we would expect the same number of batches, but different sizes verify_dataloader_batch_sizes( accelerator, dataset_size=7, batch_size=2, process_0_expected_batch_sizes=[2, 2], process_1_expected_batch_sizes=[2, 2], ) def test_can_disable_even_batches(): accelerator = create_accelerator(even_batches=False) verify_dataloader_batch_sizes( accelerator, dataset_size=3, batch_size=1, process_0_expected_batch_sizes=[1, 1], process_1_expected_batch_sizes=[1], ) verify_dataloader_batch_sizes( accelerator, dataset_size=7, batch_size=2, process_0_expected_batch_sizes=[2, 2], process_1_expected_batch_sizes=[2, 1], ) def test_can_join_uneven_inputs(): accelerator = create_accelerator(even_batches=False) model = torch.nn.Linear(1, 1) ddp_model = accelerator.prepare(model) dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) batch_idxs = [] with accelerator.join_uneven_inputs([ddp_model]): for batch_idx, batch in enumerate(dl): output = ddp_model(batch[0].float()) loss = output.sum() loss.backward() batch_idxs.append(batch_idx) accelerator.wait_for_everyone() if accelerator.process_index == 0: assert batch_idxs == [0, 1] elif accelerator.process_index == 1: assert batch_idxs == [0] def test_join_raises_warning_for_non_ddp_distributed(accelerator): with warnings.catch_warnings(record=True) as w: with accelerator.join_uneven_inputs([Mock()]): pass assert issubclass(w[-1].category, UserWarning) assert "only supported for multi-GPU" in str(w[-1].message) def test_join_can_override_even_batches(): default_even_batches = True overridden_even_batches = False accelerator = create_accelerator(even_batches=default_even_batches) model = torch.nn.Linear(1, 1) ddp_model = accelerator.prepare(model) train_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) valid_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches): train_dl_overridden_value = train_dl.batch_sampler.even_batches valid_dl_overridden_value = valid_dl.batch_sampler.even_batches assert train_dl_overridden_value == overridden_even_batches assert valid_dl_overridden_value == overridden_even_batches assert train_dl.batch_sampler.even_batches == default_even_batches assert valid_dl.batch_sampler.even_batches == default_even_batches def test_join_can_override_for_mixed_type_dataloaders(): default_even_batches = True overridden_even_batches = False accelerator = create_accelerator(even_batches=default_even_batches) model = torch.nn.Linear(1, 1) ddp_model = accelerator.prepare(model) create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True) batch_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) with warnings.catch_warnings(): warnings.filterwarnings("ignore") try: with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches): batch_dl_overridden_value = batch_dl.batch_sampler.even_batches except AttributeError: # ensure attribute error is not raised when processing iterable dl raise AssertionError assert batch_dl_overridden_value == overridden_even_batches assert batch_dl.batch_sampler.even_batches == default_even_batches def test_join_raises_warning_for_iterable_when_overriding_even_batches(): accelerator = create_accelerator() model = torch.nn.Linear(1, 1) ddp_model = accelerator.prepare(model) create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True) with warnings.catch_warnings(record=True) as w: with accelerator.join_uneven_inputs([ddp_model], even_batches=False): pass assert issubclass(w[-1].category, UserWarning) assert "only supported for map-style datasets" in str(w[-1].message) def test_pickle_accelerator(): accelerator = create_accelerator() data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4) _ = accelerator.prepare(data_loader) pickled_accelerator = pickle.dumps(accelerator) unpickled_accelerator = pickle.loads(pickled_accelerator) # TODO: Maybe this should be implemented as __eq__ for AcceleratorState? assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__ def test_data_loader(data_loader, accelerator): # Prepare the DataLoader data_loader = accelerator.prepare(data_loader) all_examples = [] for i, batch in enumerate(data_loader): index, _ = accelerator.gather_for_metrics((batch["index"], batch["label"])) all_examples.extend(index.detach().cpu().numpy().tolist()) # Sort the examples sorted_all_examples = sorted(all_examples) # Check if all elements are present in the sorted list of iterated samples assert len(set(sorted_all_examples)) == NUM_ELEMENTS, ( "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes." ) def test_stateful_dataloader(accelerator): """ Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then resumed from the saved state. The result should be the same as the rest of the data that iterated over after saving. """ old_dataloader_config = accelerator.dataloader_config try: accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) prepared_dl = create_dataloader( accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True ) untrained_batches = [] # Calculate what step that will be total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes) last_batch_num = total_batches - 1 for step, batch in enumerate(prepared_dl): # Step just before if step == last_batch_num - 1: state_dict = prepared_dl.state_dict() if step >= last_batch_num: # Otherwise grab the "unseen" batches untrained_batches.append(batch) not_skipped_batches = accelerator.gather(untrained_batches) prepared_dl.load_state_dict(state_dict) resumed_batches = [] for batch in prepared_dl: resumed_batches.append(batch) resumed_batches = accelerator.gather(resumed_batches) for b1, b2 in zip(not_skipped_batches, resumed_batches): for v1, v2 in zip(b1, b2): assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal" finally: accelerator.dataloader_config = old_dataloader_config def test_stateful_dataloader_save_state(accelerator): """ Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`, and then resumed from the saved state. The result should be the same as the rest of the data that iterated over after saving. """ old_dataloader_config = accelerator.dataloader_config try: with tempfile.TemporaryDirectory() as tmpdir: accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) prepared_dl = create_dataloader( accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True ) untrained_batches = [] # Calculate what step that will be total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes) last_batch_num = total_batches - 1 for step, batch in enumerate(prepared_dl): # Step just before if step == last_batch_num - 1: accelerator.save_state(tmpdir) if step >= last_batch_num: # Otherwise grab the "unseen" batches untrained_batches.append(batch) not_skipped_batches = accelerator.gather(untrained_batches) accelerator.load_state(tmpdir) resumed_batches = [] for batch in prepared_dl: resumed_batches.append(batch) resumed_batches = accelerator.gather(resumed_batches) for b1, b2 in zip(not_skipped_batches, resumed_batches): for v1, v2 in zip(b1, b2): assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal" finally: accelerator.dataloader_config = old_dataloader_config def main(): accelerator = create_accelerator() torch.manual_seed(accelerator.process_index) accelerator.print("Test that even_batches variable ensures uniform batches across processes") test_default_ensures_even_batch_sizes() accelerator.print("Run tests with even_batches disabled") test_can_disable_even_batches() accelerator.print("Test joining uneven inputs") test_can_join_uneven_inputs() accelerator.print("Test overriding even_batches when joining uneven inputs") test_join_can_override_even_batches() accelerator.print("Test overriding even_batches for mixed dataloader types") test_join_can_override_for_mixed_type_dataloaders() accelerator.print("Test overriding even_batches raises a warning for iterable dataloaders") test_join_raises_warning_for_iterable_when_overriding_even_batches() accelerator.print("Test join with non DDP distributed raises warning") original_state = accelerator.state.distributed_type accelerator.state.distributed_type = DistributedType.FSDP test_join_raises_warning_for_non_ddp_distributed(accelerator) accelerator.state.distributed_type = original_state accelerator.print("Test pickling an accelerator") test_pickle_accelerator() dataset = DummyDataset() accelerator.print("Test DataLoader with shuffle=False") loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) test_data_loader(loader, accelerator) accelerator.print("Test DataLoader with shuffle=True") loader = DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) test_data_loader(loader, accelerator) accelerator.print("Test DataLoader with batch_sampler") sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False) loader = DataLoader(dataset, batch_sampler=sampler, num_workers=NUM_WORKERS) test_data_loader(loader, accelerator) accelerator.print("Test DataLoader with sampler as an instance of `BatchSampler`") sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False) loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS) test_data_loader(loader, accelerator) test_stateful_dataloader(accelerator) test_stateful_dataloader_save_state(accelerator) accelerator.end_training() if __name__ == "__main__": main()