Spaces:
Sleeping
Sleeping
| # Adapted from https://pytorch.org/docs/stable/_modules/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.html | |
| # We divide by world_size first before converting to fp16, so it's safer. | |
| from typing import Any, Callable | |
| import torch | |
| import torch.distributed as dist | |
| def fp16_compress_hook( | |
| process_group: dist.ProcessGroup, bucket: dist.GradBucket | |
| ) -> torch.futures.Future[torch.Tensor]: | |
| """ | |
| This DDP communication hook implements a simple gradient compression | |
| approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``) | |
| and then divides it by the process group size. | |
| It allreduces those ``float16`` gradient tensors. Once compressed gradient | |
| tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). | |
| Example:: | |
| >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) | |
| """ | |
| group_to_use = process_group if process_group is not None else dist.group.WORLD | |
| world_size = group_to_use.size() | |
| # Divide first before converting to fp16 | |
| # Use out argument to fuse the division and the conversion. | |
| compressed_tensor = torch.div(bucket.buffer(), world_size, | |
| out=torch.empty_like(bucket.buffer(), dtype=torch.float16)) | |
| fut = dist.all_reduce( | |
| compressed_tensor, group=group_to_use, async_op=True | |
| ).get_future() | |
| def decompress(fut): | |
| decompressed_tensor = bucket.buffer() | |
| # Decompress in place to reduce the peak memory. | |
| # See: https://github.com/pytorch/pytorch/issues/45968 | |
| decompressed_tensor.copy_(fut.value()[0]) | |
| return decompressed_tensor | |
| # TODO: maybe have a backoff strategy: check if the buffer has inf / NaN, in that case | |
| # resend with fp32? | |
| return fut.then(decompress) | |