Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from torch.autograd import Function | |
| from torch.autograd.function import once_differentiable | |
| from torch.nn.modules.utils import _pair | |
| from ..utils import ext_loader | |
| ext_module = ext_loader.load_ext( | |
| '_ext', ['masked_im2col_forward', 'masked_col2im_forward']) | |
| class MaskedConv2dFunction(Function): | |
| def symbolic(g, features, mask, weight, bias, padding, stride=1): | |
| return g.op( | |
| 'mmcv::MMCVMaskedConv2d', | |
| features, | |
| mask, | |
| weight, | |
| bias, | |
| padding_i=padding, | |
| stride_i=stride) | |
| def forward(ctx, | |
| features: torch.Tensor, | |
| mask: torch.Tensor, | |
| weight: torch.nn.Parameter, | |
| bias: torch.nn.Parameter, | |
| padding: int = 0, | |
| stride: int = 1) -> torch.Tensor: | |
| assert mask.dim() == 3 and mask.size(0) == 1 | |
| assert features.dim() == 4 and features.size(0) == 1 | |
| assert features.size()[2:] == mask.size()[1:] | |
| pad_h, pad_w = _pair(padding) | |
| stride_h, stride_w = _pair(stride) | |
| if stride_h != 1 or stride_w != 1: | |
| raise ValueError( | |
| 'Stride could not only be 1 in masked_conv2d currently.') | |
| out_channel, in_channel, kernel_h, kernel_w = weight.size() | |
| if features.device.type == 'npu': | |
| import torch_npu | |
| output = torch_npu.npu_conv2d( | |
| features, | |
| weight, | |
| bias, | |
| stride=(stride_h, stride_w), | |
| padding=(pad_h, pad_w), | |
| dilation=(1, 1), | |
| groups=1) | |
| if mask.size()[1:] != output.size()[2:]: | |
| raise ValueError( | |
| 'The mask is inconsistent with the shape of output_conv.') | |
| mask = mask > 0 | |
| mask = mask.type(output.dtype) | |
| output = output * mask | |
| return output | |
| batch_size = features.size(0) | |
| out_h = int( | |
| math.floor( | |
| torch.true_divide((features.size(2) + 2 * pad_h - | |
| (kernel_h - 1) - 1), stride_h) + 1)) | |
| out_w = int( | |
| math.floor( | |
| torch.true_divide((features.size(3) + 2 * pad_w - | |
| (kernel_w - 1) - 1), stride_w) + 1)) | |
| mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False) | |
| output = features.new_zeros(batch_size, out_channel, out_h, out_w) | |
| if mask_inds.numel() > 0: | |
| mask_h_idx = mask_inds[:, 0].contiguous() | |
| mask_w_idx = mask_inds[:, 1].contiguous() | |
| data_col = features.new_zeros(in_channel * kernel_h * kernel_w, | |
| mask_inds.size(0)) | |
| ext_module.masked_im2col_forward( | |
| features, | |
| mask_h_idx, | |
| mask_w_idx, | |
| data_col, | |
| kernel_h=kernel_h, | |
| kernel_w=kernel_w, | |
| pad_h=pad_h, | |
| pad_w=pad_w) | |
| masked_output = torch.addmm(1, bias[:, None], 1, | |
| weight.view(out_channel, -1), data_col) | |
| ext_module.masked_col2im_forward( | |
| masked_output, | |
| mask_h_idx, | |
| mask_w_idx, | |
| output, | |
| height=out_h, | |
| width=out_w, | |
| channels=out_channel) | |
| return output | |
| def backward(ctx, grad_output: torch.Tensor) -> tuple: | |
| return (None, ) * 5 | |
| masked_conv2d = MaskedConv2dFunction.apply | |
| class MaskedConv2d(nn.Conv2d): | |
| """A MaskedConv2d which inherits the official Conv2d. | |
| The masked forward doesn't implement the backward function and only | |
| supports the stride parameter to be 1 currently. | |
| """ | |
| def __init__(self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: Union[int, Tuple[int, ...]], | |
| stride: int = 1, | |
| padding: int = 0, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| bias: bool = True): | |
| super().__init__(in_channels, out_channels, kernel_size, stride, | |
| padding, dilation, groups, bias) | |
| def forward(self, | |
| input: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| if mask is None: # fallback to the normal Conv2d | |
| return super().forward(input) | |
| else: | |
| return masked_conv2d(input, mask, self.weight, self.bias, | |
| self.padding) | |