| | import pdb |
| | from os import path |
| | import torch |
| | import torch.distributed as dist |
| | import torch.autograd as autograd |
| | import torch.cuda.comm as comm |
| | from torch.autograd.function import once_differentiable |
| | from torch.utils.cpp_extension import load |
| |
|
| | _src_path = path.join(path.dirname(path.abspath(__file__)), "src") |
| | _backend = load(name="inplace_abn", |
| | extra_cflags=["-O3"], |
| | sources=[path.join(_src_path, f) for f in [ |
| | "inplace_abn.cpp", |
| | "inplace_abn_cpu.cpp", |
| | "inplace_abn_cuda.cu", |
| | "inplace_abn_cuda_half.cu" |
| | ]], |
| | extra_cuda_cflags=["--expt-extended-lambda"]) |
| |
|
| | |
| | ACT_RELU = "relu" |
| | ACT_LEAKY_RELU = "leaky_relu" |
| | ACT_ELU = "elu" |
| | ACT_NONE = "none" |
| |
|
| |
|
| | def _check(fn, *args, **kwargs): |
| | success = fn(*args, **kwargs) |
| | if not success: |
| | raise RuntimeError("CUDA Error encountered in {}".format(fn)) |
| |
|
| |
|
| | def _broadcast_shape(x): |
| | out_size = [] |
| | for i, s in enumerate(x.size()): |
| | if i != 1: |
| | out_size.append(1) |
| | else: |
| | out_size.append(s) |
| | return out_size |
| |
|
| |
|
| | def _reduce(x): |
| | if len(x.size()) == 2: |
| | return x.sum(dim=0) |
| | else: |
| | n, c = x.size()[0:2] |
| | return x.contiguous().view((n, c, -1)).sum(2).sum(0) |
| |
|
| |
|
| | def _count_samples(x): |
| | count = 1 |
| | for i, s in enumerate(x.size()): |
| | if i != 1: |
| | count *= s |
| | return count |
| |
|
| |
|
| | def _act_forward(ctx, x): |
| | if ctx.activation == ACT_LEAKY_RELU: |
| | _backend.leaky_relu_forward(x, ctx.slope) |
| | elif ctx.activation == ACT_ELU: |
| | _backend.elu_forward(x) |
| | elif ctx.activation == ACT_NONE: |
| | pass |
| |
|
| |
|
| | def _act_backward(ctx, x, dx): |
| | if ctx.activation == ACT_LEAKY_RELU: |
| | _backend.leaky_relu_backward(x, dx, ctx.slope) |
| | elif ctx.activation == ACT_ELU: |
| | _backend.elu_backward(x, dx) |
| | elif ctx.activation == ACT_NONE: |
| | pass |
| |
|
| |
|
| | class InPlaceABN(autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x, weight, bias, running_mean, running_var, |
| | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): |
| | |
| | ctx.training = training |
| | ctx.momentum = momentum |
| | ctx.eps = eps |
| | ctx.activation = activation |
| | ctx.slope = slope |
| | ctx.affine = weight is not None and bias is not None |
| |
|
| | |
| | count = _count_samples(x) |
| | x = x.contiguous() |
| | weight = weight.contiguous() if ctx.affine else x.new_empty(0) |
| | bias = bias.contiguous() if ctx.affine else x.new_empty(0) |
| |
|
| | if ctx.training: |
| | mean, var = _backend.mean_var(x) |
| |
|
| | |
| | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) |
| | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) |
| |
|
| | |
| | ctx.mark_dirty(x, running_mean, running_var) |
| | else: |
| | mean, var = running_mean.contiguous(), running_var.contiguous() |
| | ctx.mark_dirty(x) |
| |
|
| | |
| | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) |
| | _act_forward(ctx, x) |
| |
|
| | |
| | ctx.var = var |
| | ctx.save_for_backward(x, var, weight, bias) |
| | ctx.mark_non_differentiable(running_mean, running_var) |
| | return x, running_mean, running_var |
| |
|
| | @staticmethod |
| | @once_differentiable |
| | def backward(ctx, dz, _drunning_mean, _drunning_var): |
| | z, var, weight, bias = ctx.saved_tensors |
| | dz = dz.contiguous() |
| |
|
| | |
| | _act_backward(ctx, z, dz) |
| |
|
| | if ctx.training: |
| | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) |
| | else: |
| | |
| | edz = dz.new_zeros(dz.size(1)) |
| | eydz = dz.new_zeros(dz.size(1)) |
| |
|
| | dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) |
| | |
| | dweight = eydz if ctx.affine else None |
| | if dweight is not None: |
| | dweight[weight < 0] *= -1 |
| | dbias = edz if ctx.affine else None |
| |
|
| | return dx, dweight, dbias, None, None, None, None, None, None, None |
| |
|
| |
|
| | class InPlaceABNSync(autograd.Function): |
| | @classmethod |
| | def forward(cls, ctx, x, weight, bias, running_mean, running_var, |
| | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True): |
| | |
| | ctx.training = training |
| | ctx.momentum = momentum |
| | ctx.eps = eps |
| | ctx.activation = activation |
| | ctx.slope = slope |
| | ctx.affine = weight is not None and bias is not None |
| |
|
| | |
| | ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1 |
| |
|
| | |
| | batch_size = x.new_tensor([x.shape[0]], dtype=torch.long) |
| |
|
| | x = x.contiguous() |
| | weight = weight.contiguous() if ctx.affine else x.new_empty(0) |
| | bias = bias.contiguous() if ctx.affine else x.new_empty(0) |
| |
|
| | if ctx.training: |
| | mean, var = _backend.mean_var(x) |
| | if ctx.world_size > 1: |
| | |
| | if equal_batches: |
| | batch_size *= ctx.world_size |
| | else: |
| | dist.all_reduce(batch_size, dist.ReduceOp.SUM) |
| |
|
| | ctx.factor = x.shape[0] / float(batch_size.item()) |
| |
|
| | mean_all = mean.clone() * ctx.factor |
| | dist.all_reduce(mean_all, dist.ReduceOp.SUM) |
| |
|
| | var_all = (var + (mean - mean_all) ** 2) * ctx.factor |
| | dist.all_reduce(var_all, dist.ReduceOp.SUM) |
| |
|
| | mean = mean_all |
| | var = var_all |
| |
|
| | |
| | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) |
| | count = batch_size.item() * x.view(x.shape[0], x.shape[1], -1).shape[-1] |
| | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1))) |
| |
|
| | |
| | ctx.mark_dirty(x, running_mean, running_var) |
| | else: |
| | mean, var = running_mean.contiguous(), running_var.contiguous() |
| | ctx.mark_dirty(x) |
| |
|
| | |
| | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) |
| | _act_forward(ctx, x) |
| |
|
| | |
| | ctx.var = var |
| | ctx.save_for_backward(x, var, weight, bias) |
| | ctx.mark_non_differentiable(running_mean, running_var) |
| | return x, running_mean, running_var |
| |
|
| | @staticmethod |
| | @once_differentiable |
| | def backward(ctx, dz, _drunning_mean, _drunning_var): |
| | z, var, weight, bias = ctx.saved_tensors |
| | dz = dz.contiguous() |
| |
|
| | |
| | _act_backward(ctx, z, dz) |
| |
|
| | if ctx.training: |
| | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) |
| | edz_local = edz.clone() |
| | eydz_local = eydz.clone() |
| |
|
| | if ctx.world_size > 1: |
| | edz *= ctx.factor |
| | dist.all_reduce(edz, dist.ReduceOp.SUM) |
| |
|
| | eydz *= ctx.factor |
| | dist.all_reduce(eydz, dist.ReduceOp.SUM) |
| | else: |
| | edz_local = edz = dz.new_zeros(dz.size(1)) |
| | eydz_local = eydz = dz.new_zeros(dz.size(1)) |
| |
|
| | dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) |
| | |
| | dweight = eydz_local if ctx.affine else None |
| | if dweight is not None: |
| | dweight[weight < 0] *= -1 |
| | dbias = edz_local if ctx.affine else None |
| |
|
| | return dx, dweight, dbias, None, None, None, None, None, None, None |
| |
|
| |
|
| | inplace_abn = InPlaceABN.apply |
| | inplace_abn_sync = InPlaceABNSync.apply |
| |
|
| | __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"] |
| |
|