| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as functional
|
|
|
| from models._util import try_index
|
| from .bn import ABN
|
|
|
|
|
| class DeeplabV3(nn.Module):
|
| def __init__(self,
|
| in_channels,
|
| out_channels,
|
| hidden_channels=256,
|
| dilations=(12, 24, 36),
|
| norm_act=ABN,
|
| pooling_size=None):
|
| super(DeeplabV3, self).__init__()
|
| self.pooling_size = pooling_size
|
|
|
| self.map_convs = nn.ModuleList([
|
| nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
|
| nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]),
|
| nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]),
|
| nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2])
|
| ])
|
| self.map_bn = norm_act(hidden_channels * 4)
|
|
|
| self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
|
| self.global_pooling_bn = norm_act(hidden_channels)
|
|
|
| self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False)
|
| self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
|
| self.red_bn = norm_act(out_channels)
|
|
|
| self.reset_parameters(self.map_bn.activation, self.map_bn.slope)
|
|
|
| def reset_parameters(self, activation, slope):
|
| gain = nn.init.calculate_gain(activation, slope)
|
| for m in self.modules():
|
| if isinstance(m, nn.Conv2d):
|
| nn.init.xavier_normal_(m.weight.data, gain)
|
| if hasattr(m, "bias") and m.bias is not None:
|
| nn.init.constant_(m.bias, 0)
|
| elif isinstance(m, ABN):
|
| if hasattr(m, "weight") and m.weight is not None:
|
| nn.init.constant_(m.weight, 1)
|
| if hasattr(m, "bias") and m.bias is not None:
|
| nn.init.constant_(m.bias, 0)
|
|
|
| def forward(self, x):
|
|
|
| out = torch.cat([m(x) for m in self.map_convs], dim=1)
|
| out = self.map_bn(out)
|
| out = self.red_conv(out)
|
|
|
|
|
| pool = self._global_pooling(x)
|
| pool = self.global_pooling_conv(pool)
|
| pool = self.global_pooling_bn(pool)
|
| pool = self.pool_red_conv(pool)
|
| if self.training or self.pooling_size is None:
|
| pool = pool.repeat(1, 1, x.size(2), x.size(3))
|
|
|
| out += pool
|
| out = self.red_bn(out)
|
| return out
|
|
|
| def _global_pooling(self, x):
|
| if self.training or self.pooling_size is None:
|
| pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
|
| pool = pool.view(x.size(0), x.size(1), 1, 1)
|
| else:
|
| pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]),
|
| min(try_index(self.pooling_size, 1), x.shape[3]))
|
| padding = (
|
| (pooling_size[1] - 1) // 2,
|
| (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1,
|
| (pooling_size[0] - 1) // 2,
|
| (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1
|
| )
|
|
|
| pool = functional.avg_pool2d(x, pooling_size, stride=1)
|
| pool = functional.pad(pool, pad=padding, mode="replicate")
|
| return pool
|
|
|