# Gradient Synchronization PyTorch's distributed module operates by communicating back and forth between all of the GPUs in your system. This communication takes time, and ensuring all processes know the states of each other happens at particular triggerpoints when using the `ddp` module. These triggerpoints are added to the PyTorch model, specifically their `forward()` and `backward()` methods. This happens when the model is wrapped with `DistributedDataParallel`: ```python import torch.nn as nn from torch.nn.parallel import DistributedDataParallel model = nn.Linear(10, 10) ddp_model = DistributedDataParallel(model) ``` In 🤗 Accelerate this conversion happens automatically when calling [`~Accelerator.prepare`] and passing in your model. ```diff + from accelerate import Accelerator + accelerator = Accelerator() import torch.nn as nn - from torch.nn.parallel import DistributedDataParallel model = nn.Linear(10,10) + model = accelerator.prepare(model) ``` ## The slowdown in gradient accumulation You now understand that PyTorch adds hooks to the `forward` and `backward` method of your PyTorch model when training in a distributed setup. But how does this risk slowing down your code? In DDP (distributed data parallel), the specific order in which processes are performed and ran are expected at specific points and these must also occur at roughly the same time before moving on. The most direct example is when you update all of the parameters in a model through `.backward()`. All instances of the model need to have updated their gradients, collated, and updated again before moving onto the next batch of data. But when performing gradient accumulation, you accumulate `n` losses and skip `.backward()` until `n` batches have been reached. This can cause a significant slowdown since all the processes need to communicate with them more times than needed. How can you avoid this overhead? ## Solving the slowdown problem Since you are skipping these batches, their gradients do not need to be synchronized until the point where `.backward()` is actually called. PyTorch cannot automagically tell when you need to do this, but they do provide a tool to help through the [`no_sync`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.no_sync) context manager that is added to your model after converting it to DDP. Under this context manager, PyTorch will skip synchronizing the gradients when `.backward()` is called, and the first call to `.backward()` outside this context manager will trigger the synchronization. See an example below: ```python ddp_model, dataloader = accelerator.prepare(model, dataloader) for index, batch in enumerate(dataloader): inputs, targets = batch # Trigger gradient synchronization on the last batch if index != (len(dataloader) - 1): with ddp_model.no_sync(): # Gradients only accumulate outputs = ddp_model(inputs) loss = loss_func(outputs) accelerator.backward(loss) else: # Gradients finally sync outputs = ddp_model(inputs) loss = loss_func(outputs) accelerator.backward(loss) ``` In 🤗 Accelerate to make this an API that can be called no matter the training device (though it may not do anything if you are not in a distributed system!), `ddp_model.no_sync` gets replaced with [`~Accelerator.no_sync`] and operates the same way: ```diff ddp_model, dataloader = accelerator.prepare(model, dataloader) for index, batch in enumerate(dataloader): inputs, targets = batch # Trigger gradient synchronization on the last batch if index != (len(dataloader)-1): - with ddp_model.no_sync(): + with accelerator.no_sync(model): # Gradients only accumulate outputs = ddp_model(inputs) loss = loss_func(outputs, targets) accelerator.backward(loss) else: # Gradients finally sync outputs = ddp_model(inputs) loss = loss_func(outputs) accelerator.backward(loss) ``` As you may expect, the [`~Accelerator.accumulate`] function wraps around this conditional check by keeping track of the current batch number, leaving you with the final gradient accumulation API: ```python ddp_model, dataloader = accelerator.prepare(model, dataloader) for batch in dataloader: with accelerator.accumulate(model): optimizer.zero_grad() inputs, targets = batch outputs = model(inputs) loss = loss_function(outputs, targets) accelerator.backward(loss) ``` As a result, you should either use *`accelerator.accumulate` or `accelerator.no_sync`* when it comes to API choice.