| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from copy import deepcopy |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch.optim import AdamW |
| | from torch.optim.lr_scheduler import LambdaLR |
| | from torch.utils.data import DataLoader |
| |
|
| | from accelerate.accelerator import Accelerator, DataLoaderConfiguration, GradientAccumulationPlugin |
| | from accelerate.state import GradientState |
| | from accelerate.test_utils import RegressionDataset, RegressionModel |
| | from accelerate.utils import DistributedType, set_seed |
| |
|
| |
|
| | def check_model_parameters(model_a, model_b, did_step, iteration, **kwargs): |
| | for param, grad_param in zip(model_a.parameters(), model_b.parameters()): |
| | if not param.requires_grad: |
| | continue |
| | if not did_step: |
| | |
| | assert torch.allclose(param.grad, grad_param.grad, **kwargs) is False, ( |
| | f"Gradients in sync when they should not be at iteration {iteration}:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})" |
| | ) |
| | else: |
| | |
| | assert torch.allclose(param.grad, grad_param.grad, **kwargs) is True, ( |
| | f"Gradients not in sync when they should be at iteration {iteration}:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})" |
| | ) |
| |
|
| |
|
| | def step_model(model, input, target, accelerator, do_backward=True): |
| | model.train() |
| | output = model(input) |
| | loss = F.mse_loss(output, target.to(output.device)) |
| | if not do_backward: |
| | loss /= accelerator.gradient_accumulation_steps |
| | loss.backward() |
| | else: |
| | accelerator.backward(loss) |
| |
|
| |
|
| | def get_training_setup(accelerator, sched=False): |
| | "Returns everything needed to perform basic training" |
| | set_seed(42) |
| | model = RegressionModel() |
| | ddp_model = deepcopy(model) |
| | dset = RegressionDataset(length=80) |
| | dataloader = DataLoader(dset, batch_size=16) |
| | model.to(accelerator.device) |
| | if sched: |
| | opt = AdamW(params=model.parameters(), lr=1e-3) |
| | ddp_opt = AdamW(params=ddp_model.parameters(), lr=1e-3) |
| | sched = LambdaLR(opt, lr_lambda=lambda epoch: epoch**0.65) |
| | ddp_sched = LambdaLR(ddp_opt, lr_lambda=lambda epoch: epoch**0.65) |
| | |
| | if sched: |
| | ddp_model, ddp_opt, ddp_sched, dataloader = accelerator.prepare(ddp_model, ddp_opt, ddp_sched, dataloader) |
| | else: |
| | ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader) |
| | if sched: |
| | return (model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched) |
| | return model, ddp_model, dataloader |
| |
|
| |
|
| | def test_noop_sync(accelerator): |
| | |
| | model, ddp_model, dataloader = get_training_setup(accelerator) |
| | |
| | ddp_input, ddp_target = next(iter(dataloader)).values() |
| | for iteration in range(3): |
| | |
| | input, target = accelerator.gather((ddp_input, ddp_target)) |
| | input, target = input.to(accelerator.device), target.to(accelerator.device) |
| | |
| | step_model(model, input, target, accelerator) |
| | |
| | if iteration % 2 == 0: |
| | |
| | with accelerator.no_sync(ddp_model): |
| | step_model(ddp_model, ddp_input, ddp_target, accelerator) |
| | else: |
| | |
| | step_model(ddp_model, ddp_input, ddp_target, accelerator) |
| |
|
| | |
| | check_model_parameters(model, ddp_model, True, iteration) |
| | for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): |
| | if not param.requires_grad: |
| | continue |
| | assert torch.allclose(param.grad, ddp_param.grad), ( |
| | f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" |
| | ) |
| |
|
| | |
| | torch.manual_seed(1337 + iteration) |
| | ddp_input = ddp_input[torch.randperm(len(ddp_input))] |
| |
|
| |
|
| | def test_distributed_sync(accelerator): |
| | |
| | model, ddp_model, dataloader = get_training_setup(accelerator) |
| | |
| | ddp_input, ddp_target = next(iter(dataloader)).values() |
| | for iteration in range(3): |
| | |
| | input, target = accelerator.gather((ddp_input, ddp_target)) |
| | input, target = input.to(accelerator.device), target.to(accelerator.device) |
| | |
| | step_model(model, input, target, accelerator) |
| | |
| | if iteration % 2 == 0: |
| | |
| | with accelerator.no_sync(ddp_model): |
| | step_model(ddp_model, ddp_input, ddp_target, accelerator) |
| | else: |
| | |
| | step_model(ddp_model, ddp_input, ddp_target, accelerator) |
| |
|
| | |
| | for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): |
| | if not param.requires_grad: |
| | continue |
| | if iteration % 2 == 0: |
| | |
| | assert torch.allclose(param.grad, ddp_param.grad) is False, ( |
| | f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" |
| | ) |
| | else: |
| | |
| | assert torch.allclose(param.grad, ddp_param.grad) is True, ( |
| | f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" |
| | ) |
| |
|
| | |
| | torch.manual_seed(1337 + iteration) |
| | ddp_input = ddp_input[torch.randperm(len(ddp_input))] |
| |
|
| |
|
| | def test_distributed_sync_multiple_fwd(accelerator): |
| | |
| | model, ddp_model, dataloader = get_training_setup(accelerator) |
| | |
| | losses = [] |
| | num_iterations = 3 |
| | for iteration in range(num_iterations): |
| | ddp_input, ddp_target = next(iter(dataloader)).values() |
| |
|
| | |
| | input, target = accelerator.gather((ddp_input, ddp_target)) |
| | input, target = input.to(accelerator.device), target.to(accelerator.device) |
| |
|
| | |
| | step_model(model, input, target, accelerator) |
| |
|
| | |
| | with accelerator.no_sync(ddp_model): |
| | ddp_output = ddp_model(ddp_input) |
| | loss = F.mse_loss(ddp_output, ddp_target.to(ddp_output.device)) |
| | losses.append(loss) |
| |
|
| | |
| | for iteration in range(num_iterations): |
| | loss = losses[iteration] |
| |
|
| | if iteration < num_iterations - 1: |
| | |
| | accelerator.backward(loss) |
| |
|
| | |
| | for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): |
| | if not param.requires_grad: |
| | continue |
| | |
| | assert torch.allclose(param.grad, ddp_param.grad) is False, ( |
| | f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" |
| | ) |
| |
|
| | else: |
| | |
| | with accelerator.trigger_sync_in_backward(ddp_model): |
| | accelerator.backward(loss) |
| |
|
| | |
| | for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): |
| | if not param.requires_grad: |
| | continue |
| | |
| | assert torch.allclose(param.grad, ddp_param.grad) is True, ( |
| | f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" |
| | ) |
| |
|
| |
|
| | def test_gradient_accumulation(split_batches=False, dispatch_batches=False, sync_each_batch=False): |
| | gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch) |
| | dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches) |
| | accelerator = Accelerator( |
| | dataloader_config=dataloader_config, |
| | gradient_accumulation_plugin=gradient_accumulation_plugin, |
| | ) |
| | |
| | model, ddp_model, dataloader = get_training_setup(accelerator) |
| | for iteration, batch in enumerate(dataloader): |
| | ddp_input, ddp_target = batch.values() |
| | |
| | input, target = accelerator.gather((ddp_input, ddp_target)) |
| | input, target = input.to(accelerator.device), target.to(accelerator.device) |
| | |
| | step_model(model, input, target, accelerator, False) |
| | |
| | with accelerator.accumulate(ddp_model): |
| | step_model(ddp_model, ddp_input, ddp_target, accelerator) |
| |
|
| | |
| | for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): |
| | if not param.requires_grad: |
| | continue |
| | if ((iteration + 1) % 2 == 0) or (iteration == len(dataloader) - 1) or sync_each_batch: |
| | |
| | assert torch.allclose(param.grad, ddp_param.grad) is True, ( |
| | f"Gradients not in sync when they should be at iteration {iteration}:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" |
| | ) |
| | else: |
| | |
| | assert torch.allclose(param.grad, ddp_param.grad) is False, ( |
| | f"Gradients in sync when they should not be at iteration {iteration}:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" |
| | ) |
| |
|
| | |
| | torch.manual_seed(1337 + iteration) |
| | ddp_input = ddp_input[torch.randperm(len(ddp_input))] |
| | GradientState._reset_state() |
| |
|
| |
|
| | def test_gradient_accumulation_with_opt_and_scheduler( |
| | split_batches=False, dispatch_batches=False, sync_each_batch=False |
| | ): |
| | gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch) |
| | dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches) |
| | accelerator = Accelerator( |
| | dataloader_config=dataloader_config, |
| | gradient_accumulation_plugin=gradient_accumulation_plugin, |
| | ) |
| | |
| | model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True) |
| | for iteration, batch in enumerate(dataloader): |
| | ddp_input, ddp_target = batch.values() |
| | |
| | input, target = accelerator.gather((ddp_input, ddp_target)) |
| | input, target = input.to(accelerator.device), target.to(accelerator.device) |
| | |
| | model.train() |
| | ddp_model.train() |
| | step_model(model, input, target, accelerator, False) |
| | opt.step() |
| |
|
| | if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)): |
| | if split_batches: |
| | sched.step() |
| | else: |
| | for _ in range(accelerator.num_processes): |
| | sched.step() |
| |
|
| | |
| | with accelerator.accumulate(ddp_model): |
| | step_model(ddp_model, ddp_input, ddp_target, accelerator) |
| | ddp_opt.step() |
| | ddp_sched.step() |
| |
|
| | |
| | assert opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"], ( |
| | f"Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]['lr']}\nDDP opt: {ddp_opt.param_groups[0]['lr']}\n" |
| | ) |
| | did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader)) |
| | if accelerator.num_processes > 1: |
| | check_model_parameters( |
| | model, |
| | ddp_model, |
| | did_step or sync_each_batch, |
| | iteration, |
| | rtol=1e-3, |
| | ) |
| |
|
| | if did_step: |
| | opt.zero_grad() |
| | ddp_opt.zero_grad() |
| |
|
| | |
| | torch.manual_seed(1337 + iteration) |
| | GradientState._reset_state() |
| |
|
| |
|
| | def test_dataloader_break(): |
| | accelerator = Accelerator() |
| | first_dset = RegressionDataset(length=80) |
| | first_dataloader = DataLoader(first_dset, batch_size=16) |
| | second_dset = RegressionDataset(length=96) |
| | second_dataloader = DataLoader(second_dset, batch_size=16) |
| | first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader) |
| |
|
| | assert accelerator.gradient_state.active_dataloader is None |
| | for iteration, _ in enumerate(first_dataloader): |
| | assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader) |
| | if iteration < len(first_dataloader) - 1: |
| | assert not accelerator.gradient_state.end_of_dataloader |
| | if iteration == 1: |
| | for batch_num, _ in enumerate(second_dataloader): |
| | assert id(accelerator.gradient_state.active_dataloader) == id(second_dataloader) |
| | if batch_num < len(second_dataloader) - 1: |
| | assert not accelerator.gradient_state.end_of_dataloader |
| | else: |
| | assert accelerator.gradient_state.end_of_dataloader |
| | else: |
| | assert accelerator.gradient_state.end_of_dataloader |
| | assert accelerator.gradient_state.active_dataloader is None |
| |
|
| |
|
| | def main(): |
| | accelerator = Accelerator() |
| | state = accelerator.state |
| | if state.local_process_index == 0: |
| | print("**Test `accumulate` gradient accumulation with dataloader break**") |
| | if state.distributed_type != DistributedType.XLA: |
| | test_dataloader_break() |
| | if state.distributed_type == DistributedType.NO: |
| | if state.local_process_index == 0: |
| | print("**Test NOOP `no_sync` context manager**") |
| | test_noop_sync(accelerator) |
| | if state.distributed_type in ( |
| | DistributedType.MULTI_GPU, |
| | DistributedType.MULTI_NPU, |
| | DistributedType.MULTI_MLU, |
| | DistributedType.MULTI_SDAA, |
| | DistributedType.MULTI_MUSA, |
| | DistributedType.MULTI_CPU, |
| | DistributedType.MULTI_HPU, |
| | ): |
| | if state.local_process_index == 0: |
| | print("**Test Distributed `no_sync` context manager**") |
| | test_distributed_sync(accelerator) |
| | if state.local_process_index == 0: |
| | print("**Test Distributed `no_sync` context manager with multiple forwards**") |
| | test_distributed_sync_multiple_fwd(accelerator) |
| | if state.distributed_type in ( |
| | DistributedType.MULTI_GPU, |
| | DistributedType.MULTI_NPU, |
| | DistributedType.MULTI_MLU, |
| | DistributedType.MULTI_SDAA, |
| | DistributedType.MULTI_MUSA, |
| | DistributedType.MULTI_HPU, |
| | ): |
| | for split_batch in [True, False]: |
| | for dispatch_batches in [True, False]: |
| | for sync_each_batch in [True, False]: |
| | if state.local_process_index == 0: |
| | print( |
| | "**Test `accumulate` gradient accumulation, ", |
| | f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**", |
| | ) |
| | test_gradient_accumulation(split_batch, dispatch_batches, sync_each_batch) |
| |
|
| | |
| | if state.local_process_index == 0: |
| | print( |
| | "**Test `accumulate` gradient accumulation with optimizer and scheduler, ", |
| | "`split_batches=False`, `dispatch_batches=False`, `sync_each_batch=False`**", |
| | ) |
| | test_gradient_accumulation_with_opt_and_scheduler() |
| | if state.distributed_type in ( |
| | DistributedType.MULTI_GPU, |
| | DistributedType.MULTI_NPU, |
| | DistributedType.MULTI_MLU, |
| | DistributedType.MULTI_SDAA, |
| | DistributedType.MULTI_MUSA, |
| | DistributedType.MULTI_HPU, |
| | ): |
| | for split_batch in [True, False]: |
| | for dispatch_batches in [True, False]: |
| | for sync_each_batch in [True, False]: |
| | if not split_batch and not dispatch_batches and not sync_each_batch: |
| | continue |
| | if state.local_process_index == 0: |
| | print( |
| | "**Test `accumulate` gradient accumulation with optimizer and scheduler, ", |
| | f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**", |
| | ) |
| | test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch) |
| | state.destroy_process_group() |
| |
|
| |
|
| | def _mp_fn(index): |
| | |
| | main() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|