| | import warnings |
| | from itertools import chain |
| | from typing import Optional |
| |
|
| | import torch |
| | from torch._utils import _get_device_index |
| | from torch.autograd import Function |
| | from torch.nn.parallel import comm |
| |
|
| |
|
| | class Broadcast(Function): |
| | @staticmethod |
| | def forward(ctx, target_gpus, *inputs): |
| | assert all(i.device.type != "cpu" for i in inputs), ( |
| | "Broadcast function not implemented for CPU tensors" |
| | ) |
| | target_gpus = [_get_device_index(x, True) for x in target_gpus] |
| | ctx.target_gpus = target_gpus |
| | if len(inputs) == 0: |
| | return () |
| | ctx.num_inputs = len(inputs) |
| | ctx.input_device = inputs[0].get_device() |
| | outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus) |
| | non_differentiables = [] |
| | for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]): |
| | if not input_requires_grad: |
| | non_differentiables.extend(output[idx] for output in outputs) |
| | ctx.mark_non_differentiable(*non_differentiables) |
| | return tuple(chain.from_iterable(outputs)) |
| |
|
| | @staticmethod |
| | def backward(ctx, *grad_outputs): |
| | return (None,) + ReduceAddCoalesced.apply( |
| | ctx.input_device, ctx.num_inputs, *grad_outputs |
| | ) |
| |
|
| |
|
| | class ReduceAddCoalesced(Function): |
| | @staticmethod |
| | def forward(ctx, destination, num_inputs, *grads): |
| | ctx.target_gpus = [ |
| | grads[i].get_device() for i in range(0, len(grads), num_inputs) |
| | ] |
| |
|
| | grads_ = [grads[i : i + num_inputs] for i in range(0, len(grads), num_inputs)] |
| | return comm.reduce_add_coalesced(grads_, destination) |
| |
|
| | @staticmethod |
| | def backward(ctx, *grad_outputs): |
| | return ( |
| | None, |
| | None, |
| | ) + Broadcast.apply(ctx.target_gpus, *grad_outputs) |
| |
|
| |
|
| | class Gather(Function): |
| | @staticmethod |
| | def forward(ctx, target_device, dim, *inputs): |
| | assert all(i.device.type != "cpu" for i in inputs), ( |
| | "Gather function not implemented for CPU tensors" |
| | ) |
| | if target_device == "cpu": |
| | ctx.target_device = "cpu" |
| | else: |
| | target_device = _get_device_index(target_device, True) |
| | ctx.target_device = target_device |
| | ctx.dim = dim |
| | ctx.input_gpus = tuple(i.get_device() for i in inputs) |
| | if all(t.dim() == 0 for t in inputs) and dim == 0: |
| | inputs = tuple(t.view(1) for t in inputs) |
| | warnings.warn( |
| | "Was asked to gather along dimension 0, but all " |
| | "input tensors were scalars; will instead unsqueeze " |
| | "and return a vector." |
| | ) |
| | ctx.unsqueezed_scalar = True |
| | else: |
| | ctx.unsqueezed_scalar = False |
| | ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs) |
| | return comm.gather(inputs, ctx.dim, ctx.target_device) |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | scattered_grads = Scatter.apply( |
| | ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output |
| | ) |
| | if ctx.unsqueezed_scalar: |
| | scattered_grads = tuple(g[0] for g in scattered_grads) |
| | return (None, None) + scattered_grads |
| |
|
| |
|
| | class Scatter(Function): |
| | @staticmethod |
| | def forward(ctx, target_gpus, chunk_sizes, dim, input): |
| | target_gpus = [_get_device_index(x, True) for x in target_gpus] |
| | ctx.dim = dim |
| | ctx.input_device = input.get_device() if input.device.type != "cpu" else -1 |
| | streams = None |
| | if torch.accelerator.is_available() and ctx.input_device == -1: |
| | |
| | streams = [_get_stream(torch.device(device)) for device in target_gpus] |
| | outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams) |
| | |
| | if streams is not None: |
| | for i, output in enumerate(outputs): |
| | with torch.accelerator.device_index(target_gpus[i]): |
| | main_stream = torch.accelerator.current_stream() |
| | main_stream.wait_stream(streams[i]) |
| | output.record_stream(main_stream) |
| | return outputs |
| |
|
| | @staticmethod |
| | def backward(ctx, *grad_output): |
| | return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output) |
| |
|
| |
|
| | |
| | _streams: Optional[list[Optional[torch.Stream]]] = None |
| |
|
| |
|
| | def _get_stream(device: torch.device): |
| | """Get a background stream for copying between CPU and target device.""" |
| | global _streams |
| | if device.type == "cpu" or not torch.accelerator.is_available(): |
| | return None |
| | assert torch.accelerator.current_accelerator().type == device.type |
| | if _streams is None: |
| | _streams = [None] * torch.accelerator.device_count() |
| | if _streams[device.index] is None: |
| | _streams[device.index] = torch.Stream(device.index) |
| | return _streams[device.index] |
| |
|