| | |
| |
|
| | from copy import deepcopy |
| | import fvcore.nn.weight_init as weight_init |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| | from .batch_norm import get_norm |
| | from .blocks import DepthwiseSeparableConv2d |
| | from .wrappers import Conv2d |
| |
|
| |
|
| | class ASPP(nn.Module): |
| | """ |
| | Atrous Spatial Pyramid Pooling (ASPP). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | dilations, |
| | *, |
| | norm, |
| | activation, |
| | pool_kernel_size=None, |
| | dropout: float = 0.0, |
| | use_depthwise_separable_conv=False, |
| | ): |
| | """ |
| | Args: |
| | in_channels (int): number of input channels for ASPP. |
| | out_channels (int): number of output channels. |
| | dilations (list): a list of 3 dilations in ASPP. |
| | norm (str or callable): normalization for all conv layers. |
| | See :func:`layers.get_norm` for supported format. norm is |
| | applied to all conv layers except the conv following |
| | global average pooling. |
| | activation (callable): activation function. |
| | pool_kernel_size (tuple, list): the average pooling size (kh, kw) |
| | for image pooling layer in ASPP. If set to None, it always |
| | performs global average pooling. If not None, it must be |
| | divisible by the shape of inputs in forward(). It is recommended |
| | to use a fixed input feature size in training, and set this |
| | option to match this size, so that it performs global average |
| | pooling in training, and the size of the pooling window stays |
| | consistent in inference. |
| | dropout (float): apply dropout on the output of ASPP. It is used in |
| | the official DeepLab implementation with a rate of 0.1: |
| | https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/model.py#L532 # noqa |
| | use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d |
| | for 3x3 convs in ASPP, proposed in :paper:`DeepLabV3+`. |
| | """ |
| | super(ASPP, self).__init__() |
| | assert len(dilations) == 3, "ASPP expects 3 dilations, got {}".format(len(dilations)) |
| | self.pool_kernel_size = pool_kernel_size |
| | self.dropout = dropout |
| | use_bias = norm == "" |
| | self.convs = nn.ModuleList() |
| | |
| | self.convs.append( |
| | Conv2d( |
| | in_channels, |
| | out_channels, |
| | kernel_size=1, |
| | bias=use_bias, |
| | norm=get_norm(norm, out_channels), |
| | activation=deepcopy(activation), |
| | ) |
| | ) |
| | weight_init.c2_xavier_fill(self.convs[-1]) |
| | |
| | for dilation in dilations: |
| | if use_depthwise_separable_conv: |
| | self.convs.append( |
| | DepthwiseSeparableConv2d( |
| | in_channels, |
| | out_channels, |
| | kernel_size=3, |
| | padding=dilation, |
| | dilation=dilation, |
| | norm1=norm, |
| | activation1=deepcopy(activation), |
| | norm2=norm, |
| | activation2=deepcopy(activation), |
| | ) |
| | ) |
| | else: |
| | self.convs.append( |
| | Conv2d( |
| | in_channels, |
| | out_channels, |
| | kernel_size=3, |
| | padding=dilation, |
| | dilation=dilation, |
| | bias=use_bias, |
| | norm=get_norm(norm, out_channels), |
| | activation=deepcopy(activation), |
| | ) |
| | ) |
| | weight_init.c2_xavier_fill(self.convs[-1]) |
| | |
| | |
| | |
| | if pool_kernel_size is None: |
| | image_pooling = nn.Sequential( |
| | nn.AdaptiveAvgPool2d(1), |
| | Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), |
| | ) |
| | else: |
| | image_pooling = nn.Sequential( |
| | nn.AvgPool2d(kernel_size=pool_kernel_size, stride=1), |
| | Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), |
| | ) |
| | weight_init.c2_xavier_fill(image_pooling[1]) |
| | self.convs.append(image_pooling) |
| |
|
| | self.project = Conv2d( |
| | 5 * out_channels, |
| | out_channels, |
| | kernel_size=1, |
| | bias=use_bias, |
| | norm=get_norm(norm, out_channels), |
| | activation=deepcopy(activation), |
| | ) |
| | weight_init.c2_xavier_fill(self.project) |
| |
|
| | def forward(self, x): |
| | size = x.shape[-2:] |
| | if self.pool_kernel_size is not None: |
| | if size[0] % self.pool_kernel_size[0] or size[1] % self.pool_kernel_size[1]: |
| | raise ValueError( |
| | "`pool_kernel_size` must be divisible by the shape of inputs. " |
| | "Input size: {} `pool_kernel_size`: {}".format(size, self.pool_kernel_size) |
| | ) |
| | res = [] |
| | for conv in self.convs: |
| | res.append(conv(x)) |
| | res[-1] = F.interpolate(res[-1], size=size, mode="bilinear", align_corners=False) |
| | res = torch.cat(res, dim=1) |
| | res = self.project(res) |
| | res = F.dropout(res, self.dropout, training=self.training) if self.dropout > 0 else res |
| | return res |
| |
|