File size: 12,150 Bytes
9294bc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""
Custom replacement for `torch.nn.functional.convNd` and `torch.nn.functional.conv_transposeNd`
that supports arbitrarily high order gradients with zero performance penalty.
Modified from https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/conv2d_gradfix.py
"""

import contextlib
import warnings
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Conv2d, Conv3d

# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
# pylint: disable=protected-access

# ----------------------------------------------------------------------------

enabled = False  # Enable the custom op by setting this to true.
weight_gradients_disabled = (
    False  # Forcefully disable computation of gradients with respect to the weights.
)


@contextlib.contextmanager
def no_weight_gradients():
    global weight_gradients_disabled
    old = weight_gradients_disabled
    weight_gradients_disabled = True
    yield
    weight_gradients_disabled = old


# ----------------------------------------------------------------------------
class GradFixConv2d(Conv2d):
    def __init__(self, *args, use_gradfix: bool = False, **kwargs):
        self.use_gradfix = use_gradfix
        super().__init__(*args, **kwargs)

    def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
        conv_fn = F.conv2d if not self.use_gradfix else convNd
        if self.padding_mode != "zeros":
            return conv_fn(
                F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                weight,
                bias,
                self.stride,
                (0, 0),
                self.dilation,
                self.groups,
            )
        return conv_fn(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)

    def forward(
        self, input: Tensor, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None
    ) -> Tensor:
        weight = self.weight if weight is None else weight
        bias = self.bias if bias is None else bias
        return self._conv_forward(input, weight, bias)


class GradFixConv3d(Conv3d):
    def __init__(self, *args, use_gradfix: bool = False, **kwargs):
        self.use_gradfix = use_gradfix
        super().__init__(*args, **kwargs)

    def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
        conv_fn = F.conv3d if not self.use_gradfix else convNd
        if self.padding_mode != "zeros":
            return conv_fn(
                F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                weight,
                bias,
                self.stride,
                (0, 0, 0),
                self.dilation,
                self.groups,
            )
        return conv_fn(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)

    def forward(
        self, input: Tensor, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None
    ) -> Tensor:
        weight = self.weight if weight is None else weight
        bias = self.bias if bias is None else bias
        return self._conv_forward(input, weight, bias)


# ----------------------------------------------------------------------------


def convNd(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    N = weight.ndim - 2
    if _should_use_custom_op(input):
        return _conv_gradfix(
            transpose=False,
            weight_shape=weight.shape,
            stride=stride,
            padding=padding,
            output_padding=0,
            dilation=dilation,
            groups=groups,
        ).apply(input, weight, bias)
    return getattr(torch.nn.functional, f"conv{N}d")(
        input=input,
        weight=weight,
        bias=bias,
        stride=stride,
        padding=padding,
        dilation=dilation,
        groups=groups,
    )


def conv_transposeNd(
    input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1
):
    N = weight.ndim - 2
    if _should_use_custom_op(input):
        return _conv_gradfix(
            transpose=True,
            weight_shape=weight.shape,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            groups=groups,
            dilation=dilation,
        ).apply(input, weight, bias)
    return getattr(torch.nn.functional, f"conv_transpose{N}d")(
        input=input,
        weight=weight,
        bias=bias,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
        groups=groups,
        dilation=dilation,
    )


# ----------------------------------------------------------------------------


def _should_use_custom_op(input):
    assert isinstance(input, torch.Tensor)
    if (not enabled) or (not torch.backends.cudnn.enabled):
        return False
    if input.device.type != "cuda":
        return False
    if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8.", "1.9"]):
        return True
    if torch.__version__.startswith("2"):
        return True
    warnings.warn(
        f"conv2d_gradfix not supported on PyTorch {torch.__version__}. "
        f"Falling back to torch.nn.functional.conv2d()."
    )
    return False


def _tuple_of_ints(xs, ndim):
    xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
    assert len(xs) == ndim
    assert all(isinstance(x, int) for x in xs)
    return xs


# ----------------------------------------------------------------------------

_conv_gradfix_cache = dict()


def _conv_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
    ndim = len(weight_shape) - 2
    # Parse arguments.
    weight_shape = tuple(weight_shape)
    stride = _tuple_of_ints(stride, ndim)
    padding = _tuple_of_ints(padding, ndim)
    output_padding = _tuple_of_ints(output_padding, ndim)
    dilation = _tuple_of_ints(dilation, ndim)

    # Lookup from cache.
    key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
    if key in _conv_gradfix_cache:
        return _conv_gradfix_cache[key]

    # Validate arguments.
    assert groups >= 1
    assert all(stride[i] >= 1 for i in range(ndim))
    assert all(padding[i] >= 0 for i in range(ndim))
    assert all(dilation[i] >= 0 for i in range(ndim))
    if not transpose:
        assert all(output_padding[i] == 0 for i in range(ndim))
    else:  # transpose
        assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))

    # Helpers.
    common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)

    def calc_output_padding(input_shape, output_shape):
        if transpose:
            return [
                0,
            ] * ndim
        return [
            input_shape[i + 2]
            - (output_shape[i + 2] - 1) * stride[i]
            - (1 - 2 * padding[i])
            - dilation[i] * (weight_shape[i + 2] - 1)
            for i in range(ndim)
        ]

    # Forward & backward.
    class ConvNd(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, weight, bias):
            """
            input size: [B, C, ...]
            weight size:
                -> Conv:        [C_out, C_in // groups, ...]
                -> Transpose:   [C_in, C_out // groups, ...]
            """
            assert weight.shape == weight_shape
            ctx.save_for_backward(input, weight)

            # General case => cuDNN.
            if transpose:
                return getattr(torch.nn.functional, f"conv_transpose{ndim}d")(
                    input=input,
                    weight=weight.to(input.dtype),
                    bias=bias,
                    output_padding=output_padding,
                    **common_kwargs,
                )
            return getattr(torch.nn.functional, f"conv{ndim}d")(
                input=input, weight=weight.to(input.dtype), bias=bias, **common_kwargs
            )

        @staticmethod
        def backward(ctx, grad_output):
            input, weight = ctx.saved_tensors
            grad_input = None
            grad_weight = None
            grad_bias = None

            if ctx.needs_input_grad[0]:  # Input
                p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
                op = _conv_gradfix(
                    transpose=(not transpose),
                    weight_shape=weight_shape,
                    output_padding=p,
                    **common_kwargs,
                )
                grad_input = op.apply(grad_output, weight, None)
                assert grad_input.shape == input.shape

            if ctx.needs_input_grad[1] and not weight_gradients_disabled:  # Weight
                grad_weight = ConvNdGradWeight.apply(grad_output, input)
                assert grad_weight.shape == weight_shape

            if ctx.needs_input_grad[2]:  # Bias
                grad_bias = grad_output.transpose(0, 1).flatten(1).sum(1)

            return grad_input, grad_weight, grad_bias

    # Gradient with respect to the weights.
    class ConvNdGradWeight(torch.autograd.Function):
        @staticmethod
        def forward(ctx, grad_output, input):
            flags = [
                torch.backends.cudnn.benchmark,
                torch.backends.cudnn.deterministic,
                torch.backends.cudnn.allow_tf32,
            ]
            if torch.__version__.startswith("1"):
                op = torch._C._jit_get_operation(
                    "aten::cudnn_convolution_backward_weight"
                    if not transpose
                    else "aten::cudnn_convolution_transpose_backward_weight"
                )
                grad_weight = op(
                    weight_shape,
                    grad_output,
                    input.to(grad_output.dtype),
                    padding,
                    stride,
                    dilation,
                    groups,
                    *flags,
                )
            elif torch.__version__.startswith("2"):
                # https://github.com/pytorch/pytorch/issues/74437
                op, _ = torch._C._jit_get_operation("aten::convolution_backward")
                dummy_weight = torch.tensor(
                    0.0, dtype=grad_output.dtype, device=input.device
                ).expand(weight_shape)
                grad_weight = op(
                    grad_output,
                    input.to(grad_output.dtype),
                    dummy_weight,
                    None,
                    stride,
                    padding,
                    dilation,
                    transpose,
                    (0,) * ndim,
                    groups,
                    [False, True, False],
                )[1]
            else:
                raise NotImplementedError
            assert grad_weight.shape == weight_shape
            ctx.save_for_backward(grad_output, input)
            return grad_weight

        @staticmethod
        def backward(ctx, grad2_grad_weight):
            grad_output, input = ctx.saved_tensors
            grad2_grad_output = None
            grad2_input = None

            if ctx.needs_input_grad[0]:  # Grad of Weight
                grad2_grad_output = ConvNd.apply(input, grad2_grad_weight, None)
                assert grad2_grad_output.shape == grad_output.shape

            if ctx.needs_input_grad[1]:  # Input
                p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
                op = _conv_gradfix(
                    transpose=(not transpose),
                    weight_shape=weight_shape,
                    output_padding=p,
                    **common_kwargs,
                )
                grad2_input = op.apply(grad_output, grad2_grad_weight, None)
                assert grad2_input.shape == input.shape

            return grad2_grad_output, grad2_input

    _conv_gradfix_cache[key] = ConvNd
    return ConvNd


# ----------------------------------------------------------------------------