| | |
| | import torch |
| | from torch.nn.parallel._functions import _get_stream |
| |
|
| |
|
| | def scatter(input, devices, streams=None): |
| | """Scatters tensor across multiple GPUs.""" |
| | if streams is None: |
| | streams = [None] * len(devices) |
| |
|
| | if isinstance(input, list): |
| | chunk_size = (len(input) - 1) // len(devices) + 1 |
| | outputs = [ |
| | scatter(input[i], [devices[i // chunk_size]], |
| | [streams[i // chunk_size]]) for i in range(len(input)) |
| | ] |
| | return outputs |
| | elif isinstance(input, torch.Tensor): |
| | output = input.contiguous() |
| | |
| | stream = streams[0] if output.numel() > 0 else None |
| | if devices != [-1]: |
| | with torch.cuda.device(devices[0]), torch.cuda.stream(stream): |
| | output = output.cuda(devices[0], non_blocking=True) |
| | else: |
| | |
| | |
| | output = output.unsqueeze(0) |
| | return output |
| | else: |
| | raise Exception(f'Unknown type {type(input)}.') |
| |
|
| |
|
| | def synchronize_stream(output, devices, streams): |
| | if isinstance(output, list): |
| | chunk_size = len(output) // len(devices) |
| | for i in range(len(devices)): |
| | for j in range(chunk_size): |
| | synchronize_stream(output[i * chunk_size + j], [devices[i]], |
| | [streams[i]]) |
| | elif isinstance(output, torch.Tensor): |
| | if output.numel() != 0: |
| | with torch.cuda.device(devices[0]): |
| | main_stream = torch.cuda.current_stream() |
| | main_stream.wait_stream(streams[0]) |
| | output.record_stream(main_stream) |
| | else: |
| | raise Exception(f'Unknown type {type(output)}.') |
| |
|
| |
|
| | def get_input_device(input): |
| | if isinstance(input, list): |
| | for item in input: |
| | input_device = get_input_device(item) |
| | if input_device != -1: |
| | return input_device |
| | return -1 |
| | elif isinstance(input, torch.Tensor): |
| | return input.get_device() if input.is_cuda else -1 |
| | else: |
| | raise Exception(f'Unknown type {type(input)}.') |
| |
|
| |
|
| | class Scatter: |
| |
|
| | @staticmethod |
| | def forward(target_gpus, input): |
| | input_device = get_input_device(input) |
| | streams = None |
| | if input_device == -1 and target_gpus != [-1]: |
| | |
| | streams = [_get_stream(device) for device in target_gpus] |
| |
|
| | outputs = scatter(input, target_gpus, streams) |
| | |
| | if streams is not None: |
| | synchronize_stream(outputs, target_gpus, streams) |
| |
|
| | return tuple(outputs) |
| |
|