| |
| import math |
|
|
| import torch |
| import torch.nn as nn |
| from torch.autograd import Function |
| from torch.autograd.function import once_differentiable |
| from torch.nn.modules.utils import _pair |
|
|
| from ..utils import ext_loader |
|
|
| ext_module = ext_loader.load_ext( |
| '_ext', ['masked_im2col_forward', 'masked_col2im_forward']) |
|
|
|
|
| class MaskedConv2dFunction(Function): |
|
|
| @staticmethod |
| def symbolic(g, features, mask, weight, bias, padding, stride): |
| return g.op( |
| 'mmcv::MMCVMaskedConv2d', |
| features, |
| mask, |
| weight, |
| bias, |
| padding_i=padding, |
| stride_i=stride) |
|
|
| @staticmethod |
| def forward(ctx, features, mask, weight, bias, padding=0, stride=1): |
| assert mask.dim() == 3 and mask.size(0) == 1 |
| assert features.dim() == 4 and features.size(0) == 1 |
| assert features.size()[2:] == mask.size()[1:] |
| pad_h, pad_w = _pair(padding) |
| stride_h, stride_w = _pair(stride) |
| if stride_h != 1 or stride_w != 1: |
| raise ValueError( |
| 'Stride could not only be 1 in masked_conv2d currently.') |
| out_channel, in_channel, kernel_h, kernel_w = weight.size() |
|
|
| batch_size = features.size(0) |
| out_h = int( |
| math.floor((features.size(2) + 2 * pad_h - |
| (kernel_h - 1) - 1) / stride_h + 1)) |
| out_w = int( |
| math.floor((features.size(3) + 2 * pad_w - |
| (kernel_h - 1) - 1) / stride_w + 1)) |
| mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False) |
| output = features.new_zeros(batch_size, out_channel, out_h, out_w) |
| if mask_inds.numel() > 0: |
| mask_h_idx = mask_inds[:, 0].contiguous() |
| mask_w_idx = mask_inds[:, 1].contiguous() |
| data_col = features.new_zeros(in_channel * kernel_h * kernel_w, |
| mask_inds.size(0)) |
| ext_module.masked_im2col_forward( |
| features, |
| mask_h_idx, |
| mask_w_idx, |
| data_col, |
| kernel_h=kernel_h, |
| kernel_w=kernel_w, |
| pad_h=pad_h, |
| pad_w=pad_w) |
|
|
| masked_output = torch.addmm(1, bias[:, None], 1, |
| weight.view(out_channel, -1), data_col) |
| ext_module.masked_col2im_forward( |
| masked_output, |
| mask_h_idx, |
| mask_w_idx, |
| output, |
| height=out_h, |
| width=out_w, |
| channels=out_channel) |
| return output |
|
|
| @staticmethod |
| @once_differentiable |
| def backward(ctx, grad_output): |
| return (None, ) * 5 |
|
|
|
|
| masked_conv2d = MaskedConv2dFunction.apply |
|
|
|
|
| class MaskedConv2d(nn.Conv2d): |
| """A MaskedConv2d which inherits the official Conv2d. |
| |
| The masked forward doesn't implement the backward function and only |
| supports the stride parameter to be 1 currently. |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=1, |
| padding=0, |
| dilation=1, |
| groups=1, |
| bias=True): |
| super(MaskedConv2d, |
| self).__init__(in_channels, out_channels, kernel_size, stride, |
| padding, dilation, groups, bias) |
|
|
| def forward(self, input, mask=None): |
| if mask is None: |
| return super(MaskedConv2d, self).forward(input) |
| else: |
| return masked_conv2d(input, mask, self.weight, self.bias, |
| self.padding) |
|
|