| | |
| | |
| | |
| | from typing import Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | __all__ = ["ReparamLargeKernelConv"] |
| |
|
| |
|
| | class ReparamLargeKernelConv(nn.Module): |
| | """Building Block of RepLKNet |
| | |
| | This class defines overparameterized large kernel conv block |
| | introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_ |
| | |
| | Reference: https://github.com/DingXiaoH/RepLKNet-pytorch |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | kernel_size: int, |
| | stride: int, |
| | groups: int, |
| | small_kernel: int, |
| | inference_mode: bool = False, |
| | activation: nn.Module = nn.GELU(), |
| | ) -> None: |
| | """Construct a ReparamLargeKernelConv module. |
| | |
| | Args: |
| | in_channels: Number of input channels. |
| | out_channels: Number of output channels. |
| | kernel_size: Kernel size of the large kernel conv branch. |
| | stride: Stride size. Default: 1 |
| | groups: Group number. Default: 1 |
| | small_kernel: Kernel size of small kernel conv branch. |
| | inference_mode: If True, instantiates model in inference mode. Default: ``False`` |
| | activation: Activation module. Default: ``nn.GELU`` |
| | """ |
| | super(ReparamLargeKernelConv, self).__init__() |
| |
|
| | self.stride = stride |
| | self.groups = groups |
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| | self.activation = activation |
| |
|
| | self.kernel_size = kernel_size |
| | self.small_kernel = small_kernel |
| | self.padding = kernel_size // 2 |
| | if inference_mode: |
| | self.lkb_reparam = nn.Conv2d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | padding=self.padding, |
| | dilation=1, |
| | groups=groups, |
| | bias=True, |
| | ) |
| | else: |
| | self.lkb_origin = self._conv_bn( |
| | kernel_size=kernel_size, padding=self.padding |
| | ) |
| | if small_kernel is not None: |
| | assert ( |
| | small_kernel <= kernel_size |
| | ), "The kernel size for re-param cannot be larger than the large kernel!" |
| | self.small_conv = self._conv_bn( |
| | kernel_size=small_kernel, padding=small_kernel // 2 |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """Apply forward pass.""" |
| | if hasattr(self, "lkb_reparam"): |
| | out = self.lkb_reparam(x) |
| | else: |
| | out = self.lkb_origin(x) |
| | if hasattr(self, "small_conv"): |
| | out += self.small_conv(x) |
| |
|
| | self.activation(out) |
| | return out |
| |
|
| | def get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Method to obtain re-parameterized kernel and bias. |
| | Reference: https://github.com/DingXiaoH/RepLKNet-pytorch |
| | |
| | Returns: |
| | Tuple of (kernel, bias) after fusing branches. |
| | """ |
| | eq_k, eq_b = self._fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn) |
| | if hasattr(self, "small_conv"): |
| | small_k, small_b = self._fuse_bn(self.small_conv.conv, self.small_conv.bn) |
| | eq_b += small_b |
| | eq_k += nn.functional.pad( |
| | small_k, [(self.kernel_size - self.small_kernel) // 2] * 4 |
| | ) |
| | return eq_k, eq_b |
| |
|
| | def reparameterize(self) -> None: |
| | """ |
| | Following works like `RepVGG: Making VGG-style ConvNets Great Again` - |
| | https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched |
| | architecture used at training time to obtain a plain CNN-like structure |
| | for inference. |
| | """ |
| | eq_k, eq_b = self.get_kernel_bias() |
| | self.lkb_reparam = nn.Conv2d( |
| | in_channels=self.in_channels, |
| | out_channels=self.out_channels, |
| | kernel_size=self.kernel_size, |
| | stride=self.stride, |
| | padding=self.padding, |
| | dilation=self.lkb_origin.conv.dilation, |
| | groups=self.groups, |
| | bias=True, |
| | ) |
| |
|
| | self.lkb_reparam.weight.data = eq_k |
| | self.lkb_reparam.bias.data = eq_b |
| | self.__delattr__("lkb_origin") |
| | if hasattr(self, "small_conv"): |
| | self.__delattr__("small_conv") |
| |
|
| | @staticmethod |
| | def _fuse_bn( |
| | conv: torch.Tensor, bn: nn.BatchNorm2d |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Method to fuse batchnorm layer with conv layer. |
| | |
| | Args: |
| | conv: Convolutional kernel weights. |
| | bn: Batchnorm 2d layer. |
| | |
| | Returns: |
| | Tuple of (kernel, bias) after fusing batchnorm. |
| | """ |
| | kernel = conv.weight |
| | running_mean = bn.running_mean |
| | running_var = bn.running_var |
| | gamma = bn.weight |
| | beta = bn.bias |
| | eps = bn.eps |
| | std = (running_var + eps).sqrt() |
| | t = (gamma / std).reshape(-1, 1, 1, 1) |
| | return kernel * t, beta - running_mean * gamma / std |
| |
|
| | def _conv_bn(self, kernel_size: int, padding: int = 0) -> nn.Sequential: |
| | """Helper method to construct conv-batchnorm layers. |
| | |
| | Args: |
| | kernel_size: Size of the convolution kernel. |
| | padding: Zero-padding size. |
| | |
| | Returns: |
| | A nn.Sequential Conv-BN module. |
| | """ |
| | mod_list = nn.Sequential() |
| | mod_list.add_module( |
| | "conv", |
| | nn.Conv2d( |
| | in_channels=self.in_channels, |
| | out_channels=self.out_channels, |
| | kernel_size=kernel_size, |
| | stride=self.stride, |
| | padding=padding, |
| | groups=self.groups, |
| | bias=False, |
| | ), |
| | ) |
| | mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) |
| | return mod_list |
| |
|