| | |
| |
|
| | import torch.cuda as cuda |
| | import torch.nn as nn |
| | import torch |
| | import collections |
| | from torch.nn.parallel._functions import Gather |
| |
|
| |
|
| | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] |
| |
|
| |
|
| | def async_copy_to(obj, dev, main_stream=None): |
| | if torch.is_tensor(obj): |
| | v = obj.cuda(dev, non_blocking=True) |
| | if main_stream is not None: |
| | v.data.record_stream(main_stream) |
| | return v |
| | elif isinstance(obj, collections.Mapping): |
| | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} |
| | elif isinstance(obj, collections.Sequence): |
| | return [async_copy_to(o, dev, main_stream) for o in obj] |
| | else: |
| | return obj |
| |
|
| |
|
| | def dict_gather(outputs, target_device, dim=0): |
| | """ |
| | Gathers variables from different GPUs on a specified device |
| | (-1 means the CPU), with dictionary support. |
| | """ |
| | def gather_map(outputs): |
| | out = outputs[0] |
| | if torch.is_tensor(out): |
| | |
| | if out.dim() == 0: |
| | outputs = [o.unsqueeze(0) for o in outputs] |
| | return Gather.apply(target_device, dim, *outputs) |
| | elif out is None: |
| | return None |
| | elif isinstance(out, collections.Mapping): |
| | return {k: gather_map([o[k] for o in outputs]) for k in out} |
| | elif isinstance(out, collections.Sequence): |
| | return type(out)(map(gather_map, zip(*outputs))) |
| | return gather_map(outputs) |
| |
|
| |
|
| | class DictGatherDataParallel(nn.DataParallel): |
| | def gather(self, outputs, output_device): |
| | return dict_gather(outputs, output_device, dim=self.dim) |
| |
|
| |
|
| | class UserScatteredDataParallel(DictGatherDataParallel): |
| | def scatter(self, inputs, kwargs, device_ids): |
| | assert len(inputs) == 1 |
| | inputs = inputs[0] |
| | inputs = _async_copy_stream(inputs, device_ids) |
| | inputs = [[i] for i in inputs] |
| | assert len(kwargs) == 0 |
| | kwargs = [{} for _ in range(len(inputs))] |
| |
|
| | return inputs, kwargs |
| |
|
| |
|
| | def user_scattered_collate(batch): |
| | return batch |
| |
|
| |
|
| | def _async_copy(inputs, device_ids): |
| | nr_devs = len(device_ids) |
| | assert type(inputs) in (tuple, list) |
| | assert len(inputs) == nr_devs |
| |
|
| | outputs = [] |
| | for i, dev in zip(inputs, device_ids): |
| | with cuda.device(dev): |
| | outputs.append(async_copy_to(i, dev)) |
| |
|
| | return tuple(outputs) |
| |
|
| |
|
| | def _async_copy_stream(inputs, device_ids): |
| | nr_devs = len(device_ids) |
| | assert type(inputs) in (tuple, list) |
| | assert len(inputs) == nr_devs |
| |
|
| | outputs = [] |
| | streams = [_get_stream(d) for d in device_ids] |
| | for i, dev, stream in zip(inputs, device_ids, streams): |
| | with cuda.device(dev): |
| | main_stream = cuda.current_stream() |
| | with cuda.stream(stream): |
| | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) |
| | main_stream.wait_stream(stream) |
| |
|
| | return outputs |
| |
|
| |
|
| | """Adapted from: torch/nn/parallel/_functions.py""" |
| | |
| | _streams = None |
| |
|
| |
|
| | def _get_stream(device): |
| | """Gets a background stream for copying between CPU and GPU""" |
| | global _streams |
| | if device == -1: |
| | return None |
| | if _streams is None: |
| | _streams = [None] * cuda.device_count() |
| | if _streams[device] is None: _streams[device] = cuda.Stream(device) |
| | return _streams[device] |
| |
|