| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import torch |
|
|
| from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs, PartialState |
| from accelerate.utils import is_hpu_available |
|
|
|
|
| class MockModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| torch.manual_seed(0) |
| self.p = torch.nn.Parameter(torch.randn(40, 20)) |
|
|
| def forward(self, x, rank): |
| return self.p * (x ** (1 + rank)) |
|
|
|
|
| def _run_and_get_grads(model, rank): |
| torch.manual_seed(2024) |
| input = torch.randn(40, 20) |
| output = model(input, rank) |
| output.mean().backward() |
| param = next(model.parameters()) |
| return param.grad |
|
|
|
|
| def test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option): |
| ddp_kwargs = DistributedDataParallelKwargs( |
| comm_hook=comm_hook, |
| comm_wrapper=comm_wrapper, |
| comm_state_option=comm_state_option, |
| ) |
| accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) |
|
|
| model = accelerator.prepare(MockModel()) |
| hook_grads = _run_and_get_grads(model, accelerator.local_process_index) |
|
|
| reference_model = torch.nn.parallel.DistributedDataParallel( |
| MockModel().to(accelerator.device), |
| device_ids=[accelerator.local_process_index], |
| output_device=accelerator.local_process_index, |
| ) |
| reference_grads = _run_and_get_grads(reference_model, accelerator.local_process_index) |
|
|
| torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-2, atol=1e-2) |
|
|
|
|
| def main(): |
| for comm_hook, comm_wrapper, comm_state_option in [ |
| (DDPCommunicationHookType.NO, DDPCommunicationHookType.NO, {}), |
| (DDPCommunicationHookType.FP16, DDPCommunicationHookType.NO, {}), |
| (DDPCommunicationHookType.BF16, DDPCommunicationHookType.NO, {}), |
| (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {}), |
| (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.FP16, {}), |
| (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.BF16, {}), |
| (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.NO, {"matrix_approximation_rank": 2}), |
| (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.NO, {}), |
| (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.FP16, {}), |
| (DDPCommunicationHookType.BATCHED_POWER_SGD, DDPCommunicationHookType.BF16, {}), |
| ]: |
| if is_hpu_available(): |
| HPU_UNSUPPORTED_COMM_HOOKS = {DDPCommunicationHookType.FP16, DDPCommunicationHookType.BF16} |
| if comm_hook in HPU_UNSUPPORTED_COMM_HOOKS or comm_wrapper in HPU_UNSUPPORTED_COMM_HOOKS: |
| print(f"Skipping test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper} on HPU") |
| continue |
|
|
| print(f"Test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper}") |
| test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option) |
| PartialState().destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|