import torch import torch.nn.functional as F import warnings from torch import nn as nn upsample = lambda x, size: F.interpolate(x, size, mode='bilinear', align_corners=False) batchnorm_momentum = 0.01 / 2 def get_n_params(parameters): pp = 0 for p in parameters: nn = 1 for s in list(p.size()): nn = nn * s pp += nn return pp class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): super(SeparableConv2d, self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias) self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) def forward(self, x): x = self.conv1(x) x = self.pointwise(x) return x class _BNReluConv(nn.Sequential): def __init__(self, num_maps_in, num_maps_out, k=3, batch_norm=True, bn_momentum=0.1, bias=False, dilation=1, drop_rate=.0, separable=False): super(_BNReluConv, self).__init__() if batch_norm: self.add_module('norm', nn.BatchNorm2d(num_maps_in, momentum=bn_momentum)) self.add_module('relu', nn.ReLU(inplace=batch_norm is True)) padding = k // 2 conv_class = SeparableConv2d if separable else nn.Conv2d warnings.warn(f'Using conv type {k}x{k}: {conv_class}') self.add_module('conv', conv_class(num_maps_in, num_maps_out, kernel_size=k, padding=padding, bias=bias, dilation=dilation)) if drop_rate > 0: warnings.warn(f'Using dropout with p: {drop_rate}') self.add_module('dropout', nn.Dropout2d(drop_rate, inplace=True)) class _Upsample(nn.Module): def __init__(self, num_maps_in, skip_maps_in, num_maps_out, use_bn=True, k=3, use_skip=True, only_skip=False, detach_skip=False, fixed_size=None, separable=False, bneck_starts_with_bn=True): super(_Upsample, self).__init__() print(f'Upsample layer: in = {num_maps_in}, skip = {skip_maps_in}, out = {num_maps_out}') self.bottleneck = _BNReluConv(skip_maps_in, num_maps_in, k=1, batch_norm=use_bn and bneck_starts_with_bn) self.blend_conv = _BNReluConv(num_maps_in, num_maps_out, k=k, batch_norm=use_bn, separable=separable) self.use_skip = use_skip self.only_skip = only_skip self.detach_skip = detach_skip warnings.warn(f'\tUsing skips: {self.use_skip} (only skips: {self.only_skip})', UserWarning) self.upsampling_method = upsample if fixed_size is not None: self.upsampling_method = lambda x, size: F.interpolate(x, mode='nearest', size=fixed_size) warnings.warn(f'Fixed upsample size', UserWarning) def forward(self, x, skip): skip = self.bottleneck.forward(skip) if self.detach_skip: skip = skip.detach() skip_size = skip.size()[2:4] x = self.upsampling_method(x, skip_size) if self.use_skip: x = x + skip x = self.blend_conv.forward(x) return x class _UpsampleBlend(nn.Module): def __init__(self, num_features, use_bn=True, use_skip=True, detach_skip=False, fixed_size=None, k=3, separable=False): super(_UpsampleBlend, self).__init__() self.blend_conv = _BNReluConv(num_features, num_features, k=k, batch_norm=use_bn, separable=separable) self.use_skip = use_skip self.detach_skip = detach_skip warnings.warn(f'Using skip connections: {self.use_skip}', UserWarning) self.upsampling_method = upsample if fixed_size is not None: self.upsampling_method = lambda x, size: F.interpolate(x, mode='nearest', size=fixed_size) warnings.warn(f'Fixed upsample size', UserWarning) def forward(self, x, skip): if self.detach_skip: warnings.warn(f'Detaching skip connection {skip.shape[2:4]}', UserWarning) skip = skip.detach() skip_size = skip.size()[-2:] x = self.upsampling_method(x, skip_size) if self.use_skip: x = x + skip x = self.blend_conv.forward(x) return x class SpatialPyramidPooling(nn.Module): def __init__(self, num_maps_in, num_levels, bt_size=512, level_size=128, out_size=128, grids=(6, 3, 2, 1), square_grid=False, bn_momentum=0.1, use_bn=True, drop_rate=.0, fixed_size=None, starts_with_bn=True): super(SpatialPyramidPooling, self).__init__() self.fixed_size = fixed_size self.grids = grids if self.fixed_size: ref = min(self.fixed_size) self.grids = list(filter(lambda x: x <= ref, self.grids)) self.square_grid = square_grid self.upsampling_method = upsample if self.fixed_size is not None: self.upsampling_method = lambda x, size: F.interpolate(x, mode='nearest', size=fixed_size) warnings.warn(f'Fixed upsample size', UserWarning) self.spp = nn.Sequential() self.spp.add_module('spp_bn', _BNReluConv(num_maps_in, bt_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn and starts_with_bn)) num_features = bt_size final_size = num_features for i in range(num_levels): final_size += level_size self.spp.add_module('spp' + str(i), _BNReluConv(num_features, level_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn, drop_rate=drop_rate)) self.spp.add_module('spp_fuse', _BNReluConv(final_size, out_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn)) def forward(self, x): levels = [] target_size = self.fixed_size if self.fixed_size is not None else x.size()[2:4] ar = target_size[1] / target_size[0] x = self.spp[0].forward(x) levels.append(x) num = len(self.spp) - 1 for i in range(1, num): if not self.square_grid: grid_size = (self.grids[i - 1], max(1, round(ar * self.grids[i - 1]))) x_pooled = F.adaptive_avg_pool2d(x, grid_size) else: x_pooled = F.adaptive_avg_pool2d(x, self.grids[i - 1]) level = self.spp[i].forward(x_pooled) level = self.upsampling_method(level, target_size) levels.append(level) x = torch.cat(levels, 1) x = self.spp[-1].forward(x) return x class Identity(nn.Module): def __init__(self, *args, **kwargs): super(Identity, self).__init__() def forward(self, input): return input