|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
|
|
|
|
|
|
def _allreduce_fut(process_group: dist.ProcessGroup, |
|
|
tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]: |
|
|
'Averages the input gradient tensor by allreduce and returns a future.' |
|
|
group_to_use = process_group if process_group is not None else dist.group.WORLD |
|
|
|
|
|
|
|
|
tensor.div_(group_to_use.size()) |
|
|
|
|
|
return (dist.all_reduce( |
|
|
tensor, group=group_to_use, |
|
|
async_op=True).get_future().then(lambda fut: fut.value()[0])) |
|
|
|
|
|
|
|
|
def allreduce_hook( |
|
|
process_group: dist.ProcessGroup, |
|
|
bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: |
|
|
""" |
|
|
This DDP communication hook just calls ``allreduce`` using ``GradBucket`` |
|
|
tensors. Once gradient tensors are aggregated across all workers, its ``then`` |
|
|
callback takes the mean and returns the result. If user registers this hook, |
|
|
DDP results is expected to be same as the case where no hook was registered. |
|
|
Hence, this won't change behavior of DDP and user can use this as a reference |
|
|
or modify this hook to log useful information or any other purposes while |
|
|
unaffecting DDP behavior. |
|
|
|
|
|
Example:: |
|
|
>>> ddp_model.register_comm_hook(process_group, allreduce_hook) |
|
|
""" |
|
|
return _allreduce_fut(process_group, bucket.buffer()) |
|
|
|
|
|
|
|
|
def fp16_compress_hook( |
|
|
process_group: dist.ProcessGroup, |
|
|
bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: |
|
|
""" |
|
|
This DDP communication hook implements a simple gradient compression |
|
|
approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``) |
|
|
and then divides it by the process group size. |
|
|
It allreduces those ``float16`` gradient tensors. Once compressed gradient |
|
|
tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). |
|
|
|
|
|
Example:: |
|
|
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) |
|
|
""" |
|
|
group_to_use = process_group if process_group is not None else dist.group.WORLD |
|
|
world_size = group_to_use.size() |
|
|
|
|
|
compressed_tensor = bucket.buffer().to(torch.float16).div_(world_size) |
|
|
|
|
|
fut = dist.all_reduce(compressed_tensor, group=group_to_use, |
|
|
async_op=True).get_future() |
|
|
|
|
|
def decompress(fut): |
|
|
decompressed_tensor = bucket.buffer() |
|
|
|
|
|
|
|
|
decompressed_tensor.copy_(fut.value()[0]) |
|
|
return decompressed_tensor |
|
|
|
|
|
return fut.then(decompress) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bf16_compress_hook( |
|
|
process_group: dist.ProcessGroup, |
|
|
bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: |
|
|
""" |
|
|
Warning: This API is experimental, and it requires NCCL version later than 2.9.6. |
|
|
|
|
|
This DDP communication hook implements a simple gradient compression |
|
|
approach that casts ``GradBucket`` tensor to half-precision |
|
|
`Brain floating point format <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_ (``torch.bfloat16``) |
|
|
and then divides it by the process group size. |
|
|
It allreduces those ``bfloat16`` gradient tensors. Once compressed gradient |
|
|
tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). |
|
|
|
|
|
Example:: |
|
|
>>> ddp_model.register_comm_hook(process_group, bf16_compress_hook) |
|
|
""" |
|
|
group_to_use = process_group if process_group is not None else dist.group.WORLD |
|
|
world_size = group_to_use.size() |
|
|
|
|
|
compressed_tensor = bucket.buffer().to(torch.bfloat16).div_(world_size) |
|
|
|
|
|
fut = dist.all_reduce(compressed_tensor, group=group_to_use, |
|
|
async_op=True).get_future() |
|
|
|
|
|
def decompress(fut): |
|
|
decompressed_tensor = bucket.buffer() |
|
|
|
|
|
|
|
|
decompressed_tensor.copy_(fut.value()[0]) |
|
|
return decompressed_tensor |
|
|
|
|
|
return fut.then(decompress) |
|
|
|
|
|
|
|
|
def fp16_compress_wrapper( |
|
|
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]] |
|
|
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: |
|
|
""" |
|
|
This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision |
|
|
floating point format (``torch.float16``), and casts the resulting tensor of the given hook back to |
|
|
the input data type, such as ``float32``. |
|
|
|
|
|
Therefore, ``fp16_compress_hook`` is equivalent to ``fp16_compress_wrapper(allreduce_hook)``. |
|
|
|
|
|
Example:: |
|
|
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) |
|
|
>>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook)) |
|
|
""" |
|
|
|
|
|
def fp16_compress_wrapper_hook( |
|
|
hook_state, |
|
|
bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: |
|
|
|
|
|
bucket.set_buffer(bucket.buffer().to(torch.float16)) |
|
|
|
|
|
fut = hook(hook_state, bucket) |
|
|
|
|
|
def decompress(fut): |
|
|
decompressed_tensor = bucket.buffer() |
|
|
|
|
|
|
|
|
decompressed_tensor.copy_(fut.value()) |
|
|
return decompressed_tensor |
|
|
|
|
|
|
|
|
return fut.then(decompress) |
|
|
|
|
|
return fp16_compress_wrapper_hook |
|
|
|
|
|
|
|
|
def bf16_compress_wrapper( |
|
|
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]] |
|
|
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: |
|
|
""" |
|
|
Warning: This API is experimental, and it requires NCCL version later than 2.9.6. |
|
|
|
|
|
This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision |
|
|
`Brain floating point format <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format> `_ (``torch.bfloat16``), |
|
|
and casts the resulting tensor of the given hook back to the input data type, such as ``float32``. |
|
|
|
|
|
Therefore, ``bf16_compress_hook`` is equivalent to ``bf16_compress_wrapper(allreduce_hook)``. |
|
|
|
|
|
Example:: |
|
|
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) |
|
|
>>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook)) |
|
|
""" |
|
|
|
|
|
def bf16_compress_wrapper_hook( |
|
|
hook_state, |
|
|
bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: |
|
|
|
|
|
bucket.set_buffer(bucket.buffer().to(torch.bfloat16)) |
|
|
|
|
|
fut = hook(hook_state, bucket) |
|
|
|
|
|
def decompress(fut): |
|
|
decompressed_tensor = bucket.buffer() |
|
|
|
|
|
|
|
|
decompressed_tensor.copy_(fut.value()) |
|
|
return decompressed_tensor |
|
|
|
|
|
|
|
|
return fut.then(decompress) |
|
|
|
|
|
return bf16_compress_wrapper_hook |
|
|
|