File size: 5,428 Bytes
0e83290 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
"""
/*****************************************************************************/
BatchNorm2dSync with multi-gpu
code referenced from : https://github.com/mapillary/inplace_abn
/*****************************************************************************/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch.cuda.comm as comm
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ._csrc import _backend
def _count_samples(x):
count = 1
for i, s in enumerate(x.size()):
if i != 1:
count *= s
return count
class BatchNorm2dSyncFunc(Function):
@staticmethod
def forward(ctx, x, weight, bias, running_mean, running_var,
extra, compute_stats=True, momentum=0.1, eps=1e-05):
def _parse_extra(ctx, extra):
ctx.is_master = extra["is_master"]
if ctx.is_master:
ctx.master_queue = extra["master_queue"]
ctx.worker_queues = extra["worker_queues"]
ctx.worker_ids = extra["worker_ids"]
else:
ctx.master_queue = extra["master_queue"]
ctx.worker_queue = extra["worker_queue"]
# Save context
if extra is not None:
_parse_extra(ctx, extra)
ctx.compute_stats = compute_stats
ctx.momentum = momentum
ctx.eps = eps
ctx.affine = weight is not None and bias is not None
if ctx.compute_stats:
N = _count_samples(x) * (ctx.master_queue.maxsize + 1)
assert N > 1
# 1. compute sum(x) and sum(x^2)
xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach())
if ctx.is_master:
xsums, xsqsums = [xsum], [xsqsum]
# master : gatther all sum(x) and sum(x^2) from slaves
for _ in range(ctx.master_queue.maxsize):
xsum_w, xsqsum_w = ctx.master_queue.get()
ctx.master_queue.task_done()
xsums.append(xsum_w)
xsqsums.append(xsqsum_w)
xsum = comm.reduce_add(xsums)
xsqsum = comm.reduce_add(xsqsums)
mean = xsum / N
sumvar = xsqsum - xsum * mean
var = sumvar / N
uvar = sumvar / (N - 1)
# master : broadcast global mean, variance to all slaves
tensors = comm.broadcast_coalesced(
(mean, uvar, var), [mean.get_device()] + ctx.worker_ids)
for ts, queue in zip(tensors[1:], ctx.worker_queues):
queue.put(ts)
else:
# slave : send sum(x) and sum(x^2) to master
ctx.master_queue.put((xsum, xsqsum))
# slave : get global mean and variance
mean, uvar, var = ctx.worker_queue.get()
ctx.worker_queue.task_done()
# Update running stats
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar)
ctx.N = N
ctx.save_for_backward(x, weight, bias, mean, var)
else:
mean, var = running_mean, running_var
# do batch norm forward
z = _backend.syncbn_forward(x, weight, bias, mean, var,
ctx.affine, ctx.eps)
return z
@staticmethod
@once_differentiable
def backward(ctx, dz):
x, weight, bias, mean, var = ctx.saved_tensors
dz = dz.contiguous()
# 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i})
sum_dz, sum_dz_xhat = _backend.syncbn_backward_xhat(
dz, x, mean, var, ctx.eps)
if ctx.is_master:
sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat]
# master : gatther from slaves
for _ in range(ctx.master_queue.maxsize):
sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get()
ctx.master_queue.task_done()
sum_dzs.append(sum_dz_w)
sum_dz_xhats.append(sum_dz_xhat_w)
# master : compute global stats
sum_dz = comm.reduce_add(sum_dzs)
sum_dz_xhat = comm.reduce_add(sum_dz_xhats)
sum_dz /= ctx.N
sum_dz_xhat /= ctx.N
# master : broadcast global stats
tensors = comm.broadcast_coalesced(
(sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids)
for ts, queue in zip(tensors[1:], ctx.worker_queues):
queue.put(ts)
else:
# slave : send to master
ctx.master_queue.put((sum_dz, sum_dz_xhat))
# slave : get global stats
sum_dz, sum_dz_xhat = ctx.worker_queue.get()
ctx.worker_queue.task_done()
# do batch norm backward
dx, dweight, dbias = _backend.syncbn_backward(
dz, x, weight, bias, mean, var, sum_dz, sum_dz_xhat,
ctx.affine, ctx.eps)
return dx, dweight, dbias, \
None, None, None, None, None, None
batchnorm2d_sync = BatchNorm2dSyncFunc.apply
__all__ = ["batchnorm2d_sync"]
|