|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Encoding Data Parallel"""
|
| import threading
|
| import functools
|
| import torch
|
| from torch.autograd import Variable, Function
|
| import torch.cuda.comm as comm
|
| from torch.nn.parallel.data_parallel import DataParallel
|
| from torch.nn.parallel.parallel_apply import get_a_var
|
| from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
|
|
| torch_ver = torch.__version__[:3]
|
|
|
| __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 'patch_replication_callback']
|
|
|
| def allreduce(*inputs):
|
| """Cross GPU all reduce autograd operation for calculate mean and
|
| variance in SyncBN.
|
| """
|
| return AllReduce.apply(*inputs)
|
|
|
| class AllReduce(Function):
|
| @staticmethod
|
| def forward(ctx, num_inputs, *inputs):
|
| ctx.num_inputs = num_inputs
|
| ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
|
| inputs = [inputs[i:i + num_inputs]
|
| for i in range(0, len(inputs), num_inputs)]
|
|
|
| inputs = sorted(inputs, key=lambda i: i[0].get_device())
|
| results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
|
| outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
|
| return tuple([t for tensors in outputs for t in tensors])
|
|
|
| @staticmethod
|
| def backward(ctx, *inputs):
|
| inputs = [i.data for i in inputs]
|
| inputs = [inputs[i:i + ctx.num_inputs]
|
| for i in range(0, len(inputs), ctx.num_inputs)]
|
| results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
|
| outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
|
| return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
|
|
|
| class Reduce(Function):
|
| @staticmethod
|
| def forward(ctx, *inputs):
|
| ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
|
| inputs = sorted(inputs, key=lambda i: i.get_device())
|
| return comm.reduce_add(inputs)
|
|
|
| @staticmethod
|
| def backward(ctx, gradOutput):
|
| return Broadcast.apply(ctx.target_gpus, gradOutput)
|
|
|
|
|
| class DataParallelModel(DataParallel):
|
| """Implements data parallelism at the module level.
|
|
|
| This container parallelizes the application of the given module by
|
| splitting the input across the specified devices by chunking in the
|
| batch dimension.
|
| In the forward pass, the module is replicated on each device,
|
| and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
|
| Note that the outputs are not gathered, please use compatible
|
| :class:`encoding.parallel.DataParallelCriterion`.
|
|
|
| The batch size should be larger than the number of GPUs used. It should
|
| also be an integer multiple of the number of GPUs so that each chunk is
|
| the same size (so that each GPU processes the same number of samples).
|
|
|
| Args:
|
| module: module to be parallelized
|
| device_ids: CUDA devices (default: all devices)
|
|
|
| Reference:
|
| Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
|
| Amit Agrawal. “Context Encoding for Semantic Segmentation.
|
| *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
|
|
|
| Example::
|
|
|
| >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
|
| >>> y = net(x)
|
| """
|
| def gather(self, outputs, output_device):
|
| return outputs
|
|
|
| def replicate(self, module, device_ids):
|
| modules = super(DataParallelModel, self).replicate(module, device_ids)
|
| return modules
|
|
|
|
|
| class DataParallelCriterion(DataParallel):
|
| """
|
| Calculate loss in multiple-GPUs, which balance the memory usage for
|
| Semantic Segmentation.
|
|
|
| The targets are splitted across the specified devices by chunking in
|
| the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
|
|
|
| Reference:
|
| Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
|
| Amit Agrawal. “Context Encoding for Semantic Segmentation.
|
| *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
|
|
|
| Example::
|
|
|
| >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
|
| >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
|
| >>> y = net(x)
|
| >>> loss = criterion(y, target)
|
| """
|
| def forward(self, inputs, *targets, **kwargs):
|
|
|
|
|
| if not self.device_ids:
|
| return self.module(inputs, *targets, **kwargs)
|
| targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
|
| if len(self.device_ids) == 1:
|
| return self.module(inputs, *targets[0], **kwargs[0])
|
| replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
| outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
|
| return Reduce.apply(*outputs) / len(outputs)
|
|
|
|
|
| def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
|
| assert len(modules) == len(inputs)
|
| assert len(targets) == len(inputs)
|
| if kwargs_tup:
|
| assert len(modules) == len(kwargs_tup)
|
| else:
|
| kwargs_tup = ({},) * len(modules)
|
| if devices is not None:
|
| assert len(modules) == len(devices)
|
| else:
|
| devices = [None] * len(modules)
|
|
|
| lock = threading.Lock()
|
| results = {}
|
| if torch_ver != "0.3":
|
| grad_enabled = torch.is_grad_enabled()
|
|
|
| def _worker(i, module, input, target, kwargs, device=None):
|
| if torch_ver != "0.3":
|
| torch.set_grad_enabled(grad_enabled)
|
| if device is None:
|
| device = get_a_var(input).get_device()
|
| try:
|
| if not isinstance(input, tuple):
|
| input = (input,)
|
| with torch.cuda.device(device):
|
| output = module(*(input + target), **kwargs)
|
| with lock:
|
| results[i] = output
|
| except Exception as e:
|
| with lock:
|
| results[i] = e
|
|
|
| if len(modules) > 1:
|
| threads = [threading.Thread(target=_worker,
|
| args=(i, module, input, target,
|
| kwargs, device),)
|
| for i, (module, input, target, kwargs, device) in
|
| enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
|
|
|
| for thread in threads:
|
| thread.start()
|
| for thread in threads:
|
| thread.join()
|
| else:
|
| _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
|
|
|
| outputs = []
|
| for i in range(len(inputs)):
|
| output = results[i]
|
| if isinstance(output, Exception):
|
| raise output
|
| outputs.append(output)
|
| return outputs
|
|
|