| # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Any, Optional, Tuple | |
| import torch | |
| from torch import distributed | |
| def print_if_rank0(*args): | |
| if distributed.get_rank() == 0: | |
| print(*args) | |
| class AllGatherGrad(torch.autograd.Function): | |
| # stolen from pytorch lightning | |
| def forward( | |
| ctx: Any, | |
| tensor: torch.Tensor, | |
| group: Optional["torch.distributed.ProcessGroup"] = None, | |
| ) -> torch.Tensor: | |
| ctx.group = group | |
| gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] | |
| torch.distributed.all_gather(gathered_tensor, tensor, group=group) | |
| gathered_tensor = torch.stack(gathered_tensor, dim=0) | |
| return gathered_tensor | |
| def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: | |
| grad_output = torch.cat(grad_output) | |
| torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) | |
| return grad_output[torch.distributed.get_rank()], None | |