| |
| import torch |
| from torch.nn.parallel._functions import _get_stream |
|
|
|
|
| def scatter(input, devices, streams=None): |
| """Scatters tensor across multiple GPUs.""" |
| if streams is None: |
| streams = [None] * len(devices) |
|
|
| if isinstance(input, list): |
| chunk_size = (len(input) - 1) // len(devices) + 1 |
| outputs = [ |
| scatter(input[i], [devices[i // chunk_size]], |
| [streams[i // chunk_size]]) for i in range(len(input)) |
| ] |
| return outputs |
| elif isinstance(input, torch.Tensor): |
| output = input.contiguous() |
| |
| stream = streams[0] if output.numel() > 0 else None |
| if devices != [-1]: |
| with torch.cuda.device(devices[0]), torch.cuda.stream(stream): |
| output = output.cuda(devices[0], non_blocking=True) |
| else: |
| |
| |
| output = output.unsqueeze(0) |
| return output |
| else: |
| raise Exception(f'Unknown type {type(input)}.') |
|
|
|
|
| def synchronize_stream(output, devices, streams): |
| if isinstance(output, list): |
| chunk_size = len(output) // len(devices) |
| for i in range(len(devices)): |
| for j in range(chunk_size): |
| synchronize_stream(output[i * chunk_size + j], [devices[i]], |
| [streams[i]]) |
| elif isinstance(output, torch.Tensor): |
| if output.numel() != 0: |
| with torch.cuda.device(devices[0]): |
| main_stream = torch.cuda.current_stream() |
| main_stream.wait_stream(streams[0]) |
| output.record_stream(main_stream) |
| else: |
| raise Exception(f'Unknown type {type(output)}.') |
|
|
|
|
| def get_input_device(input): |
| if isinstance(input, list): |
| for item in input: |
| input_device = get_input_device(item) |
| if input_device != -1: |
| return input_device |
| return -1 |
| elif isinstance(input, torch.Tensor): |
| return input.get_device() if input.is_cuda else -1 |
| else: |
| raise Exception(f'Unknown type {type(input)}.') |
|
|
|
|
| class Scatter: |
|
|
| @staticmethod |
| def forward(target_gpus, input): |
| input_device = get_input_device(input) |
| streams = None |
| if input_device == -1 and target_gpus != [-1]: |
| |
| streams = [_get_stream(device) for device in target_gpus] |
|
|
| outputs = scatter(input, target_gpus, streams) |
| |
| if streams is not None: |
| synchronize_stream(outputs, target_gpus, streams) |
|
|
| return tuple(outputs) |
|
|