| |
|
|
| from copy import deepcopy |
| import fvcore.nn.weight_init as weight_init |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from .batch_norm import get_norm |
| from .blocks import DepthwiseSeparableConv2d |
| from .wrappers import Conv2d |
|
|
|
|
| class ASPP(nn.Module): |
| """ |
| Atrous Spatial Pyramid Pooling (ASPP). |
| """ |
|
|
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| dilations, |
| *, |
| norm, |
| activation, |
| pool_kernel_size=None, |
| dropout: float = 0.0, |
| use_depthwise_separable_conv=False, |
| ): |
| """ |
| Args: |
| in_channels (int): number of input channels for ASPP. |
| out_channels (int): number of output channels. |
| dilations (list): a list of 3 dilations in ASPP. |
| norm (str or callable): normalization for all conv layers. |
| See :func:`layers.get_norm` for supported format. norm is |
| applied to all conv layers except the conv following |
| global average pooling. |
| activation (callable): activation function. |
| pool_kernel_size (tuple, list): the average pooling size (kh, kw) |
| for image pooling layer in ASPP. If set to None, it always |
| performs global average pooling. If not None, it must be |
| divisible by the shape of inputs in forward(). It is recommended |
| to use a fixed input feature size in training, and set this |
| option to match this size, so that it performs global average |
| pooling in training, and the size of the pooling window stays |
| consistent in inference. |
| dropout (float): apply dropout on the output of ASPP. It is used in |
| the official DeepLab implementation with a rate of 0.1: |
| https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/model.py#L532 # noqa |
| use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d |
| for 3x3 convs in ASPP, proposed in :paper:`DeepLabV3+`. |
| """ |
| super(ASPP, self).__init__() |
| assert len(dilations) == 3, "ASPP expects 3 dilations, got {}".format(len(dilations)) |
| self.pool_kernel_size = pool_kernel_size |
| self.dropout = dropout |
| use_bias = norm == "" |
| self.convs = nn.ModuleList() |
| |
| self.convs.append( |
| Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size=1, |
| bias=use_bias, |
| norm=get_norm(norm, out_channels), |
| activation=deepcopy(activation), |
| ) |
| ) |
| weight_init.c2_xavier_fill(self.convs[-1]) |
| |
| for dilation in dilations: |
| if use_depthwise_separable_conv: |
| self.convs.append( |
| DepthwiseSeparableConv2d( |
| in_channels, |
| out_channels, |
| kernel_size=3, |
| padding=dilation, |
| dilation=dilation, |
| norm1=norm, |
| activation1=deepcopy(activation), |
| norm2=norm, |
| activation2=deepcopy(activation), |
| ) |
| ) |
| else: |
| self.convs.append( |
| Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size=3, |
| padding=dilation, |
| dilation=dilation, |
| bias=use_bias, |
| norm=get_norm(norm, out_channels), |
| activation=deepcopy(activation), |
| ) |
| ) |
| weight_init.c2_xavier_fill(self.convs[-1]) |
| |
| |
| |
| if pool_kernel_size is None: |
| image_pooling = nn.Sequential( |
| nn.AdaptiveAvgPool2d(1), |
| Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), |
| ) |
| else: |
| image_pooling = nn.Sequential( |
| nn.AvgPool2d(kernel_size=pool_kernel_size, stride=1), |
| Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), |
| ) |
| weight_init.c2_xavier_fill(image_pooling[1]) |
| self.convs.append(image_pooling) |
|
|
| self.project = Conv2d( |
| 5 * out_channels, |
| out_channels, |
| kernel_size=1, |
| bias=use_bias, |
| norm=get_norm(norm, out_channels), |
| activation=deepcopy(activation), |
| ) |
| weight_init.c2_xavier_fill(self.project) |
|
|
| def forward(self, x): |
| size = x.shape[-2:] |
| if self.pool_kernel_size is not None: |
| if size[0] % self.pool_kernel_size[0] or size[1] % self.pool_kernel_size[1]: |
| raise ValueError( |
| "`pool_kernel_size` must be divisible by the shape of inputs. " |
| "Input size: {} `pool_kernel_size`: {}".format(size, self.pool_kernel_size) |
| ) |
| res = [] |
| for conv in self.convs: |
| res.append(conv(x)) |
| res[-1] = F.interpolate(res[-1], size=size, mode="bilinear", align_corners=False) |
| res = torch.cat(res, dim=1) |
| res = self.project(res) |
| res = F.dropout(res, self.dropout, training=self.training) if self.dropout > 0 else res |
| return res |
|
|