| | |
| | |
| |
|
| | """ |
| | @Author : Peike Li |
| | @Contact : peike.li@yahoo.com |
| | @File : ocnet.py |
| | @Time : 8/4/19 3:36 PM |
| | @Desc : |
| | @License : This source code is licensed under the license found in the |
| | LICENSE file in the root directory of this source tree. |
| | """ |
| |
|
| | import functools |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.autograd import Variable |
| | from torch.nn import functional as F |
| |
|
| | from modules import InPlaceABNSync |
| | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') |
| |
|
| |
|
| | class _SelfAttentionBlock(nn.Module): |
| | ''' |
| | The basic implementation for self-attention block/non-local block |
| | Input: |
| | N X C X H X W |
| | Parameters: |
| | in_channels : the dimension of the input feature map |
| | key_channels : the dimension after the key/query transform |
| | value_channels : the dimension after the value transform |
| | scale : choose the scale to downsample the input feature maps (save memory cost) |
| | Return: |
| | N X C X H X W |
| | position-aware context features.(w/o concate or add with the input) |
| | ''' |
| |
|
| | def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1): |
| | super(_SelfAttentionBlock, self).__init__() |
| | self.scale = scale |
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| | self.key_channels = key_channels |
| | self.value_channels = value_channels |
| | if out_channels == None: |
| | self.out_channels = in_channels |
| | self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) |
| | self.f_key = nn.Sequential( |
| | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, |
| | kernel_size=1, stride=1, padding=0), |
| | InPlaceABNSync(self.key_channels), |
| | ) |
| | self.f_query = self.f_key |
| | self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels, |
| | kernel_size=1, stride=1, padding=0) |
| | self.W = nn.Conv2d(in_channels=self.value_channels, out_channels=self.out_channels, |
| | kernel_size=1, stride=1, padding=0) |
| | nn.init.constant(self.W.weight, 0) |
| | nn.init.constant(self.W.bias, 0) |
| |
|
| | def forward(self, x): |
| | batch_size, h, w = x.size(0), x.size(2), x.size(3) |
| | if self.scale > 1: |
| | x = self.pool(x) |
| |
|
| | value = self.f_value(x).view(batch_size, self.value_channels, -1) |
| | value = value.permute(0, 2, 1) |
| | query = self.f_query(x).view(batch_size, self.key_channels, -1) |
| | query = query.permute(0, 2, 1) |
| | key = self.f_key(x).view(batch_size, self.key_channels, -1) |
| |
|
| | sim_map = torch.matmul(query, key) |
| | sim_map = (self.key_channels ** -.5) * sim_map |
| | sim_map = F.softmax(sim_map, dim=-1) |
| |
|
| | context = torch.matmul(sim_map, value) |
| | context = context.permute(0, 2, 1).contiguous() |
| | context = context.view(batch_size, self.value_channels, *x.size()[2:]) |
| | context = self.W(context) |
| | if self.scale > 1: |
| | context = F.upsample(input=context, size=(h, w), mode='bilinear', align_corners=True) |
| | return context |
| |
|
| |
|
| | class SelfAttentionBlock2D(_SelfAttentionBlock): |
| | def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1): |
| | super(SelfAttentionBlock2D, self).__init__(in_channels, |
| | key_channels, |
| | value_channels, |
| | out_channels, |
| | scale) |
| |
|
| |
|
| | class BaseOC_Module(nn.Module): |
| | """ |
| | Implementation of the BaseOC module |
| | Parameters: |
| | in_features / out_features: the channels of the input / output feature maps. |
| | dropout: we choose 0.05 as the default value. |
| | size: you can apply multiple sizes. Here we only use one size. |
| | Return: |
| | features fused with Object context information. |
| | """ |
| |
|
| | def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1])): |
| | super(BaseOC_Module, self).__init__() |
| | self.stages = [] |
| | self.stages = nn.ModuleList( |
| | [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes]) |
| | self.conv_bn_dropout = nn.Sequential( |
| | nn.Conv2d(2 * in_channels, out_channels, kernel_size=1, padding=0), |
| | InPlaceABNSync(out_channels), |
| | nn.Dropout2d(dropout) |
| | ) |
| |
|
| | def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size): |
| | return SelfAttentionBlock2D(in_channels, |
| | key_channels, |
| | value_channels, |
| | output_channels, |
| | size) |
| |
|
| | def forward(self, feats): |
| | priors = [stage(feats) for stage in self.stages] |
| | context = priors[0] |
| | for i in range(1, len(priors)): |
| | context += priors[i] |
| | output = self.conv_bn_dropout(torch.cat([context, feats], 1)) |
| | return output |
| |
|
| |
|
| | class BaseOC_Context_Module(nn.Module): |
| | """ |
| | Output only the context features. |
| | Parameters: |
| | in_features / out_features: the channels of the input / output feature maps. |
| | dropout: specify the dropout ratio |
| | fusion: We provide two different fusion method, "concat" or "add" |
| | size: we find that directly learn the attention weights on even 1/8 feature maps is hard. |
| | Return: |
| | features after "concat" or "add" |
| | """ |
| |
|
| | def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1])): |
| | super(BaseOC_Context_Module, self).__init__() |
| | self.stages = [] |
| | self.stages = nn.ModuleList( |
| | [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes]) |
| | self.conv_bn_dropout = nn.Sequential( |
| | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0), |
| | InPlaceABNSync(out_channels), |
| | ) |
| |
|
| | def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size): |
| | return SelfAttentionBlock2D(in_channels, |
| | key_channels, |
| | value_channels, |
| | output_channels, |
| | size) |
| |
|
| | def forward(self, feats): |
| | priors = [stage(feats) for stage in self.stages] |
| | context = priors[0] |
| | for i in range(1, len(priors)): |
| | context += priors[i] |
| | output = self.conv_bn_dropout(context) |
| | return output |
| |
|
| |
|
| | class ASP_OC_Module(nn.Module): |
| | def __init__(self, features, out_features=256, dilations=(12, 24, 36)): |
| | super(ASP_OC_Module, self).__init__() |
| | self.context = nn.Sequential(nn.Conv2d(features, out_features, kernel_size=3, padding=1, dilation=1, bias=True), |
| | InPlaceABNSync(out_features), |
| | BaseOC_Context_Module(in_channels=out_features, out_channels=out_features, |
| | key_channels=out_features // 2, value_channels=out_features, |
| | dropout=0, sizes=([2]))) |
| | self.conv2 = nn.Sequential(nn.Conv2d(features, out_features, kernel_size=1, padding=0, dilation=1, bias=False), |
| | InPlaceABNSync(out_features)) |
| | self.conv3 = nn.Sequential( |
| | nn.Conv2d(features, out_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False), |
| | InPlaceABNSync(out_features)) |
| | self.conv4 = nn.Sequential( |
| | nn.Conv2d(features, out_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False), |
| | InPlaceABNSync(out_features)) |
| | self.conv5 = nn.Sequential( |
| | nn.Conv2d(features, out_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False), |
| | InPlaceABNSync(out_features)) |
| |
|
| | self.conv_bn_dropout = nn.Sequential( |
| | nn.Conv2d(out_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False), |
| | InPlaceABNSync(out_features), |
| | nn.Dropout2d(0.1) |
| | ) |
| |
|
| | def _cat_each(self, feat1, feat2, feat3, feat4, feat5): |
| | assert (len(feat1) == len(feat2)) |
| | z = [] |
| | for i in range(len(feat1)): |
| | z.append(torch.cat((feat1[i], feat2[i], feat3[i], feat4[i], feat5[i]), 1)) |
| | return z |
| |
|
| | def forward(self, x): |
| | if isinstance(x, Variable): |
| | _, _, h, w = x.size() |
| | elif isinstance(x, tuple) or isinstance(x, list): |
| | _, _, h, w = x[0].size() |
| | else: |
| | raise RuntimeError('unknown input type') |
| |
|
| | feat1 = self.context(x) |
| | feat2 = self.conv2(x) |
| | feat3 = self.conv3(x) |
| | feat4 = self.conv4(x) |
| | feat5 = self.conv5(x) |
| |
|
| | if isinstance(x, Variable): |
| | out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1) |
| | elif isinstance(x, tuple) or isinstance(x, list): |
| | out = self._cat_each(feat1, feat2, feat3, feat4, feat5) |
| | else: |
| | raise RuntimeError('unknown input type') |
| | output = self.conv_bn_dropout(out) |
| | return output |
| |
|