| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Custom replacement for `torch.nn.functional.conv2d` that supports |
| | arbitrarily high order gradients with zero performance penalty.""" |
| |
|
| | import contextlib |
| | import warnings |
| |
|
| | import torch |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | enabled = False |
| | weight_gradients_disabled = False |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def no_weight_gradients(): |
| | global weight_gradients_disabled |
| | old = weight_gradients_disabled |
| | weight_gradients_disabled = True |
| | yield |
| | weight_gradients_disabled = old |
| |
|
| |
|
| | |
| |
|
| |
|
| | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): |
| | if _should_use_custom_op(input): |
| | return _conv2d_gradfix( |
| | transpose=False, |
| | weight_shape=weight.shape, |
| | stride=stride, |
| | padding=padding, |
| | output_padding=0, |
| | dilation=dilation, |
| | groups=groups |
| | ).apply(input, weight, bias) |
| | return torch.nn.functional.conv2d( |
| | input=input, |
| | weight=weight, |
| | bias=bias, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups |
| | ) |
| |
|
| |
|
| | def conv_transpose2d( |
| | input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1 |
| | ): |
| | if _should_use_custom_op(input): |
| | return _conv2d_gradfix( |
| | transpose=True, |
| | weight_shape=weight.shape, |
| | stride=stride, |
| | padding=padding, |
| | output_padding=output_padding, |
| | groups=groups, |
| | dilation=dilation |
| | ).apply(input, weight, bias) |
| | return torch.nn.functional.conv_transpose2d( |
| | input=input, |
| | weight=weight, |
| | bias=bias, |
| | stride=stride, |
| | padding=padding, |
| | output_padding=output_padding, |
| | groups=groups, |
| | dilation=dilation |
| | ) |
| |
|
| |
|
| | |
| |
|
| |
|
| | def _should_use_custom_op(input): |
| | assert isinstance(input, torch.Tensor) |
| | if (not enabled) or (not torch.backends.cudnn.enabled): |
| | return False |
| | if input.device.type != 'cuda': |
| | return False |
| | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): |
| | return True |
| | warnings.warn( |
| | f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().' |
| | ) |
| | return False |
| |
|
| |
|
| | def _tuple_of_ints(xs, ndim): |
| | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs, ) * ndim |
| | assert len(xs) == ndim |
| | assert all(isinstance(x, int) for x in xs) |
| | return xs |
| |
|
| |
|
| | |
| |
|
| | _conv2d_gradfix_cache = dict() |
| |
|
| |
|
| | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): |
| | |
| | ndim = 2 |
| | weight_shape = tuple(weight_shape) |
| | stride = _tuple_of_ints(stride, ndim) |
| | padding = _tuple_of_ints(padding, ndim) |
| | output_padding = _tuple_of_ints(output_padding, ndim) |
| | dilation = _tuple_of_ints(dilation, ndim) |
| |
|
| | |
| | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) |
| | if key in _conv2d_gradfix_cache: |
| | return _conv2d_gradfix_cache[key] |
| |
|
| | |
| | assert groups >= 1 |
| | assert len(weight_shape) == ndim + 2 |
| | assert all(stride[i] >= 1 for i in range(ndim)) |
| | assert all(padding[i] >= 0 for i in range(ndim)) |
| | assert all(dilation[i] >= 0 for i in range(ndim)) |
| | if not transpose: |
| | assert all(output_padding[i] == 0 for i in range(ndim)) |
| | else: |
| | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) |
| |
|
| | |
| | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) |
| |
|
| | def calc_output_padding(input_shape, output_shape): |
| | if transpose: |
| | return [0, 0] |
| | return [ |
| | input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] - (1 - 2 * padding[i]) - |
| | dilation[i] * (weight_shape[i + 2] - 1) for i in range(ndim) |
| | ] |
| |
|
| | |
| | class Conv2d(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, input, weight, bias): |
| | assert weight.shape == weight_shape |
| | if not transpose: |
| | output = torch.nn.functional.conv2d( |
| | input=input, weight=weight, bias=bias, **common_kwargs |
| | ) |
| | else: |
| | output = torch.nn.functional.conv_transpose2d( |
| | input=input, |
| | weight=weight, |
| | bias=bias, |
| | output_padding=output_padding, |
| | **common_kwargs |
| | ) |
| | ctx.save_for_backward(input, weight) |
| | return output |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | input, weight = ctx.saved_tensors |
| | grad_input = None |
| | grad_weight = None |
| | grad_bias = None |
| |
|
| | if ctx.needs_input_grad[0]: |
| | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) |
| | grad_input = _conv2d_gradfix( |
| | transpose=(not transpose), |
| | weight_shape=weight_shape, |
| | output_padding=p, |
| | **common_kwargs |
| | ).apply(grad_output, weight, None) |
| | assert grad_input.shape == input.shape |
| |
|
| | if ctx.needs_input_grad[1] and not weight_gradients_disabled: |
| | grad_weight = Conv2dGradWeight.apply(grad_output, input) |
| | assert grad_weight.shape == weight_shape |
| |
|
| | if ctx.needs_input_grad[2]: |
| | grad_bias = grad_output.sum([0, 2, 3]) |
| |
|
| | return grad_input, grad_weight, grad_bias |
| |
|
| | |
| | class Conv2dGradWeight(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, grad_output, input): |
| | op = torch._C._jit_get_operation( |
| | 'aten::cudnn_convolution_backward_weight' |
| | if not transpose else 'aten::cudnn_convolution_transpose_backward_weight' |
| | ) |
| | flags = [ |
| | torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, |
| | torch.backends.cudnn.allow_tf32 |
| | ] |
| | grad_weight = op( |
| | weight_shape, grad_output, input, padding, stride, dilation, groups, *flags |
| | ) |
| | assert grad_weight.shape == weight_shape |
| | ctx.save_for_backward(grad_output, input) |
| | return grad_weight |
| |
|
| | @staticmethod |
| | def backward(ctx, grad2_grad_weight): |
| | grad_output, input = ctx.saved_tensors |
| | grad2_grad_output = None |
| | grad2_input = None |
| |
|
| | if ctx.needs_input_grad[0]: |
| | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) |
| | assert grad2_grad_output.shape == grad_output.shape |
| |
|
| | if ctx.needs_input_grad[1]: |
| | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) |
| | grad2_input = _conv2d_gradfix( |
| | transpose=(not transpose), |
| | weight_shape=weight_shape, |
| | output_padding=p, |
| | **common_kwargs |
| | ).apply(grad_output, grad2_grad_weight, None) |
| | assert grad2_input.shape == input.shape |
| |
|
| | return grad2_grad_output, grad2_input |
| |
|
| | _conv2d_gradfix_cache[key] = Conv2d |
| | return Conv2d |
| |
|
| |
|
| | |
| |
|