| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import torch |
| |
|
| | from accelerate import Accelerator, DistributedType |
| |
|
| |
|
| | class LocalSGD: |
| | """ |
| | A helper class to support local SGD on top of Accelerator. It simply runs a given number of updates independently |
| | on each device, and averages model weights every K synchronization step. |
| | |
| | It should be used only in the multi-GPU (or multi-CPU) setup without extensions such as DeepSpeed. In particular, |
| | this is a simple implementation that cannot support scenarios such as model parallelism. |
| | |
| | |
| | Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes |
| | back to at least: |
| | |
| | Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint |
| | arXiv:1606.07365.](https://huggingface.co/papers/1606.07365) |
| | |
| | We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of). |
| | |
| | Stich, Sebastian Urban. ["Local SGD Converges Fast and Communicates Little." ICLR 2019-International Conference on |
| | Learning Representations. No. CONF. 2019.](https://huggingface.co/papers/1805.09767) |
| | |
| | """ |
| |
|
| | def __enter__(self): |
| | if self.enabled: |
| | self.model_sync_obj = self.model.no_sync() |
| | self.model_sync_obj.__enter__() |
| |
|
| | return self |
| |
|
| | def __exit__(self, type, value, tb): |
| | if self.enabled: |
| | |
| | self._sync_and_avg_model_params() |
| | self.model_sync_obj.__exit__(type, value, tb) |
| |
|
| | def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_steps: int, enabled: bool = True): |
| | """ |
| | Constructor. |
| | |
| | Args: |
| | model (`torch.nn.Module): |
| | The model whose parameters we need to average. |
| | accelerator (`Accelerator`): |
| | Accelerator object. |
| | local_sgd_steps (`int`): |
| | A number of local SGD steps (before model parameters are synchronized). |
| | enabled (`bool): |
| | Local SGD is disabled if this parameter set to `False`. |
| | """ |
| | if accelerator.distributed_type not in [ |
| | DistributedType.NO, |
| | DistributedType.MULTI_CPU, |
| | DistributedType.MULTI_GPU, |
| | DistributedType.MULTI_XPU, |
| | DistributedType.MULTI_MLU, |
| | DistributedType.MULTI_HPU, |
| | DistributedType.MULTI_SDAA, |
| | DistributedType.MULTI_MUSA, |
| | DistributedType.MULTI_NPU, |
| | ]: |
| | raise NotImplementedError("LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)") |
| | self.enabled = enabled and accelerator.distributed_type != DistributedType.NO |
| | self.num_steps = 0 |
| | if self.enabled: |
| | self.accelerator = accelerator |
| | self.model = model |
| | self.local_sgd_steps = local_sgd_steps |
| |
|
| | def step(self): |
| | """ |
| | This function makes a "step" and synchronizes model parameters if necessary. |
| | """ |
| | self.num_steps += 1 |
| | if not self.enabled: |
| | return |
| |
|
| | if self.num_steps % self.local_sgd_steps == 0: |
| | self._sync_and_avg_model_params() |
| |
|
| | def _sync_and_avg_model_params(self): |
| | """ |
| | Synchronize + Average model parameters across all GPUs |
| | """ |
| |
|
| | self.accelerator.wait_for_everyone() |
| | with self.accelerator.autocast(): |
| | for param in self.model.parameters(): |
| | param.data = self.accelerator.reduce(param.data, reduction="mean") |
| |
|