dxcanh commited on
Commit
2866035
Β·
verified Β·
1 Parent(s): f255ec3

Upload deform_conv.py

Browse files
Files changed (1) hide show
  1. basicsr/ops/dcn/deform_conv.py +379 -0
basicsr/ops/dcn/deform_conv.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import torch
4
+ from torch import nn as nn
5
+ from torch.autograd import Function
6
+ from torch.autograd.function import once_differentiable
7
+ from torch.nn import functional as F
8
+ from torch.nn.modules.utils import _pair, _single
9
+
10
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
11
+ if BASICSR_JIT == 'True':
12
+ from torch.utils.cpp_extension import load
13
+ module_path = os.path.dirname(__file__)
14
+ deform_conv_ext = load(
15
+ 'deform_conv',
16
+ sources=[
17
+ os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
18
+ os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
19
+ os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
20
+ ],
21
+ )
22
+ else:
23
+ try:
24
+ from . import deform_conv_ext
25
+ except ImportError:
26
+ pass
27
+ # avoid annoying print output
28
+ # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
29
+ # '1. compile with BASICSR_EXT=True. or\n '
30
+ # '2. set BASICSR_JIT=True during running')
31
+
32
+
33
+ class DeformConvFunction(Function):
34
+
35
+ @staticmethod
36
+ def forward(ctx,
37
+ input,
38
+ offset,
39
+ weight,
40
+ stride=1,
41
+ padding=0,
42
+ dilation=1,
43
+ groups=1,
44
+ deformable_groups=1,
45
+ im2col_step=64):
46
+ if input is not None and input.dim() != 4:
47
+ raise ValueError(f'Expected 4D tensor as input, got {input.dim()}D tensor instead.')
48
+ ctx.stride = _pair(stride)
49
+ ctx.padding = _pair(padding)
50
+ ctx.dilation = _pair(dilation)
51
+ ctx.groups = groups
52
+ ctx.deformable_groups = deformable_groups
53
+ ctx.im2col_step = im2col_step
54
+
55
+ ctx.save_for_backward(input, offset, weight)
56
+
57
+ output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
58
+
59
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
60
+
61
+ if not input.is_cuda:
62
+ raise NotImplementedError
63
+ else:
64
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
65
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
66
+ deform_conv_ext.deform_conv_forward(input, weight,
67
+ offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
68
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
69
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
70
+ ctx.deformable_groups, cur_im2col_step)
71
+ return output
72
+
73
+ @staticmethod
74
+ @once_differentiable
75
+ def backward(ctx, grad_output):
76
+ input, offset, weight = ctx.saved_tensors
77
+
78
+ grad_input = grad_offset = grad_weight = None
79
+
80
+ if not grad_output.is_cuda:
81
+ raise NotImplementedError
82
+ else:
83
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
84
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
85
+
86
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
87
+ grad_input = torch.zeros_like(input)
88
+ grad_offset = torch.zeros_like(offset)
89
+ deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
90
+ grad_offset, weight, ctx.bufs_[0], weight.size(3),
91
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
92
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
93
+ ctx.deformable_groups, cur_im2col_step)
94
+
95
+ if ctx.needs_input_grad[2]:
96
+ grad_weight = torch.zeros_like(weight)
97
+ deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
98
+ ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
99
+ weight.size(2), ctx.stride[1], ctx.stride[0],
100
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
101
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
102
+ cur_im2col_step)
103
+
104
+ return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
105
+
106
+ @staticmethod
107
+ def _output_size(input, weight, padding, dilation, stride):
108
+ channels = weight.size(0)
109
+ output_size = (input.size(0), channels)
110
+ for d in range(input.dim() - 2):
111
+ in_size = input.size(d + 2)
112
+ pad = padding[d]
113
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
114
+ stride_ = stride[d]
115
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
116
+ if not all(map(lambda s: s > 0, output_size)):
117
+ raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})')
118
+ return output_size
119
+
120
+
121
+ class ModulatedDeformConvFunction(Function):
122
+
123
+ @staticmethod
124
+ def forward(ctx,
125
+ input,
126
+ offset,
127
+ mask,
128
+ weight,
129
+ bias=None,
130
+ stride=1,
131
+ padding=0,
132
+ dilation=1,
133
+ groups=1,
134
+ deformable_groups=1):
135
+ ctx.stride = stride
136
+ ctx.padding = padding
137
+ ctx.dilation = dilation
138
+ ctx.groups = groups
139
+ ctx.deformable_groups = deformable_groups
140
+ ctx.with_bias = bias is not None
141
+ if not ctx.with_bias:
142
+ bias = input.new_empty(1) # fake tensor
143
+ if not input.is_cuda:
144
+ raise NotImplementedError
145
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad:
146
+ ctx.save_for_backward(input, offset, mask, weight, bias)
147
+ output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
148
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
149
+ deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
150
+ ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
151
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
152
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
153
+ return output
154
+
155
+ @staticmethod
156
+ @once_differentiable
157
+ def backward(ctx, grad_output):
158
+ if not grad_output.is_cuda:
159
+ raise NotImplementedError
160
+ input, offset, mask, weight, bias = ctx.saved_tensors
161
+ grad_input = torch.zeros_like(input)
162
+ grad_offset = torch.zeros_like(offset)
163
+ grad_mask = torch.zeros_like(mask)
164
+ grad_weight = torch.zeros_like(weight)
165
+ grad_bias = torch.zeros_like(bias)
166
+ deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
167
+ grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
168
+ grad_output, weight.shape[2], weight.shape[3], ctx.stride,
169
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
170
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
171
+ if not ctx.with_bias:
172
+ grad_bias = None
173
+
174
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
175
+
176
+ @staticmethod
177
+ def _infer_shape(ctx, input, weight):
178
+ n = input.size(0)
179
+ channels_out = weight.size(0)
180
+ height, width = input.shape[2:4]
181
+ kernel_h, kernel_w = weight.shape[2:4]
182
+ height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
183
+ width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
184
+ return n, channels_out, height_out, width_out
185
+
186
+
187
+ deform_conv = DeformConvFunction.apply
188
+ modulated_deform_conv = ModulatedDeformConvFunction.apply
189
+
190
+
191
+ class DeformConv(nn.Module):
192
+
193
+ def __init__(self,
194
+ in_channels,
195
+ out_channels,
196
+ kernel_size,
197
+ stride=1,
198
+ padding=0,
199
+ dilation=1,
200
+ groups=1,
201
+ deformable_groups=1,
202
+ bias=False):
203
+ super(DeformConv, self).__init__()
204
+
205
+ assert not bias
206
+ assert in_channels % groups == 0, f'in_channels {in_channels} is not divisible by groups {groups}'
207
+ assert out_channels % groups == 0, f'out_channels {out_channels} is not divisible by groups {groups}'
208
+
209
+ self.in_channels = in_channels
210
+ self.out_channels = out_channels
211
+ self.kernel_size = _pair(kernel_size)
212
+ self.stride = _pair(stride)
213
+ self.padding = _pair(padding)
214
+ self.dilation = _pair(dilation)
215
+ self.groups = groups
216
+ self.deformable_groups = deformable_groups
217
+ # enable compatibility with nn.Conv2d
218
+ self.transposed = False
219
+ self.output_padding = _single(0)
220
+
221
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
222
+
223
+ self.reset_parameters()
224
+
225
+ def reset_parameters(self):
226
+ n = self.in_channels
227
+ for k in self.kernel_size:
228
+ n *= k
229
+ stdv = 1. / math.sqrt(n)
230
+ self.weight.data.uniform_(-stdv, stdv)
231
+
232
+ def forward(self, x, offset):
233
+ # To fix an assert error in deform_conv_cuda.cpp:128
234
+ # input image is smaller than kernel
235
+ input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
236
+ if input_pad:
237
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
238
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
239
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
240
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
241
+ out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
242
+ self.deformable_groups)
243
+ if input_pad:
244
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
245
+ return out
246
+
247
+
248
+ class DeformConvPack(DeformConv):
249
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
250
+
251
+ Args:
252
+ in_channels (int): Same as nn.Conv2d.
253
+ out_channels (int): Same as nn.Conv2d.
254
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
255
+ stride (int or tuple[int]): Same as nn.Conv2d.
256
+ padding (int or tuple[int]): Same as nn.Conv2d.
257
+ dilation (int or tuple[int]): Same as nn.Conv2d.
258
+ groups (int): Same as nn.Conv2d.
259
+ bias (bool or str): If specified as `auto`, it will be decided by the
260
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
261
+ False.
262
+ """
263
+
264
+ _version = 2
265
+
266
+ def __init__(self, *args, **kwargs):
267
+ super(DeformConvPack, self).__init__(*args, **kwargs)
268
+
269
+ self.conv_offset = nn.Conv2d(
270
+ self.in_channels,
271
+ self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
272
+ kernel_size=self.kernel_size,
273
+ stride=_pair(self.stride),
274
+ padding=_pair(self.padding),
275
+ dilation=_pair(self.dilation),
276
+ bias=True)
277
+ self.init_offset()
278
+
279
+ def init_offset(self):
280
+ self.conv_offset.weight.data.zero_()
281
+ self.conv_offset.bias.data.zero_()
282
+
283
+ def forward(self, x):
284
+ offset = self.conv_offset(x)
285
+ return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
286
+ self.deformable_groups)
287
+
288
+
289
+ class ModulatedDeformConv(nn.Module):
290
+
291
+ def __init__(self,
292
+ in_channels,
293
+ out_channels,
294
+ kernel_size,
295
+ stride=1,
296
+ padding=0,
297
+ dilation=1,
298
+ groups=1,
299
+ deformable_groups=1,
300
+ bias=True):
301
+ super(ModulatedDeformConv, self).__init__()
302
+ self.in_channels = in_channels
303
+ self.out_channels = out_channels
304
+ self.kernel_size = _pair(kernel_size)
305
+ self.stride = stride
306
+ self.padding = padding
307
+ self.dilation = dilation
308
+ self.groups = groups
309
+ self.deformable_groups = deformable_groups
310
+ self.with_bias = bias
311
+ # enable compatibility with nn.Conv2d
312
+ self.transposed = False
313
+ self.output_padding = _single(0)
314
+
315
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
316
+ if bias:
317
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
318
+ else:
319
+ self.register_parameter('bias', None)
320
+ self.init_weights()
321
+
322
+ def init_weights(self):
323
+ n = self.in_channels
324
+ for k in self.kernel_size:
325
+ n *= k
326
+ stdv = 1. / math.sqrt(n)
327
+ self.weight.data.uniform_(-stdv, stdv)
328
+ if self.bias is not None:
329
+ self.bias.data.zero_()
330
+
331
+ def forward(self, x, offset, mask):
332
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
333
+ self.groups, self.deformable_groups)
334
+
335
+
336
+ class ModulatedDeformConvPack(ModulatedDeformConv):
337
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
338
+
339
+ Args:
340
+ in_channels (int): Same as nn.Conv2d.
341
+ out_channels (int): Same as nn.Conv2d.
342
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
343
+ stride (int or tuple[int]): Same as nn.Conv2d.
344
+ padding (int or tuple[int]): Same as nn.Conv2d.
345
+ dilation (int or tuple[int]): Same as nn.Conv2d.
346
+ groups (int): Same as nn.Conv2d.
347
+ bias (bool or str): If specified as `auto`, it will be decided by the
348
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
349
+ False.
350
+ """
351
+
352
+ _version = 2
353
+
354
+ def __init__(self, *args, **kwargs):
355
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
356
+
357
+ self.conv_offset = nn.Conv2d(
358
+ self.in_channels,
359
+ self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
360
+ kernel_size=self.kernel_size,
361
+ stride=_pair(self.stride),
362
+ padding=_pair(self.padding),
363
+ dilation=_pair(self.dilation),
364
+ bias=True)
365
+ self.init_weights()
366
+
367
+ def init_weights(self):
368
+ super(ModulatedDeformConvPack, self).init_weights()
369
+ if hasattr(self, 'conv_offset'):
370
+ self.conv_offset.weight.data.zero_()
371
+ self.conv_offset.bias.data.zero_()
372
+
373
+ def forward(self, x):
374
+ out = self.conv_offset(x)
375
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
376
+ offset = torch.cat((o1, o2), dim=1)
377
+ mask = torch.sigmoid(mask)
378
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
379
+ self.groups, self.deformable_groups)