File size: 7,056 Bytes
dae5c90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | import torch
import torch.nn.functional as F
import warnings
from torch import nn as nn
upsample = lambda x, size: F.interpolate(x, size, mode='bilinear', align_corners=False)
batchnorm_momentum = 0.01 / 2
def get_n_params(parameters):
pp = 0
for p in parameters:
nn = 1
for s in list(p.size()):
nn = nn * s
pp += nn
return pp
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
super(SeparableConv2d, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels,
bias=bias)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
def forward(self, x):
x = self.conv1(x)
x = self.pointwise(x)
return x
class _BNReluConv(nn.Sequential):
def __init__(self, num_maps_in, num_maps_out, k=3, batch_norm=True, bn_momentum=0.1, bias=False, dilation=1,
drop_rate=.0, separable=False):
super(_BNReluConv, self).__init__()
if batch_norm:
self.add_module('norm', nn.BatchNorm2d(num_maps_in, momentum=bn_momentum))
self.add_module('relu', nn.ReLU(inplace=batch_norm is True))
padding = k // 2
conv_class = SeparableConv2d if separable else nn.Conv2d
warnings.warn(f'Using conv type {k}x{k}: {conv_class}')
self.add_module('conv', conv_class(num_maps_in, num_maps_out, kernel_size=k, padding=padding, bias=bias,
dilation=dilation))
if drop_rate > 0:
warnings.warn(f'Using dropout with p: {drop_rate}')
self.add_module('dropout', nn.Dropout2d(drop_rate, inplace=True))
class _Upsample(nn.Module):
def __init__(self, num_maps_in, skip_maps_in, num_maps_out, use_bn=True, k=3, use_skip=True, only_skip=False,
detach_skip=False, fixed_size=None, separable=False, bneck_starts_with_bn=True):
super(_Upsample, self).__init__()
print(f'Upsample layer: in = {num_maps_in}, skip = {skip_maps_in}, out = {num_maps_out}')
self.bottleneck = _BNReluConv(skip_maps_in, num_maps_in, k=1, batch_norm=use_bn and bneck_starts_with_bn)
self.blend_conv = _BNReluConv(num_maps_in, num_maps_out, k=k, batch_norm=use_bn, separable=separable)
self.use_skip = use_skip
self.only_skip = only_skip
self.detach_skip = detach_skip
warnings.warn(f'\tUsing skips: {self.use_skip} (only skips: {self.only_skip})', UserWarning)
self.upsampling_method = upsample
if fixed_size is not None:
self.upsampling_method = lambda x, size: F.interpolate(x, mode='nearest', size=fixed_size)
warnings.warn(f'Fixed upsample size', UserWarning)
def forward(self, x, skip):
skip = self.bottleneck.forward(skip)
if self.detach_skip:
skip = skip.detach()
skip_size = skip.size()[2:4]
x = self.upsampling_method(x, skip_size)
if self.use_skip:
x = x + skip
x = self.blend_conv.forward(x)
return x
class _UpsampleBlend(nn.Module):
def __init__(self, num_features, use_bn=True, use_skip=True, detach_skip=False, fixed_size=None, k=3,
separable=False):
super(_UpsampleBlend, self).__init__()
self.blend_conv = _BNReluConv(num_features, num_features, k=k, batch_norm=use_bn, separable=separable)
self.use_skip = use_skip
self.detach_skip = detach_skip
warnings.warn(f'Using skip connections: {self.use_skip}', UserWarning)
self.upsampling_method = upsample
if fixed_size is not None:
self.upsampling_method = lambda x, size: F.interpolate(x, mode='nearest', size=fixed_size)
warnings.warn(f'Fixed upsample size', UserWarning)
def forward(self, x, skip):
if self.detach_skip:
warnings.warn(f'Detaching skip connection {skip.shape[2:4]}', UserWarning)
skip = skip.detach()
skip_size = skip.size()[-2:]
x = self.upsampling_method(x, skip_size)
if self.use_skip:
x = x + skip
x = self.blend_conv.forward(x)
return x
class SpatialPyramidPooling(nn.Module):
def __init__(self, num_maps_in, num_levels, bt_size=512, level_size=128, out_size=128,
grids=(6, 3, 2, 1), square_grid=False, bn_momentum=0.1, use_bn=True, drop_rate=.0,
fixed_size=None, starts_with_bn=True):
super(SpatialPyramidPooling, self).__init__()
self.fixed_size = fixed_size
self.grids = grids
if self.fixed_size:
ref = min(self.fixed_size)
self.grids = list(filter(lambda x: x <= ref, self.grids))
self.square_grid = square_grid
self.upsampling_method = upsample
if self.fixed_size is not None:
self.upsampling_method = lambda x, size: F.interpolate(x, mode='nearest', size=fixed_size)
warnings.warn(f'Fixed upsample size', UserWarning)
self.spp = nn.Sequential()
self.spp.add_module('spp_bn', _BNReluConv(num_maps_in, bt_size, k=1, bn_momentum=bn_momentum,
batch_norm=use_bn and starts_with_bn))
num_features = bt_size
final_size = num_features
for i in range(num_levels):
final_size += level_size
self.spp.add_module('spp' + str(i),
_BNReluConv(num_features, level_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn,
drop_rate=drop_rate))
self.spp.add_module('spp_fuse',
_BNReluConv(final_size, out_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn))
def forward(self, x):
levels = []
target_size = self.fixed_size if self.fixed_size is not None else x.size()[2:4]
ar = target_size[1] / target_size[0]
x = self.spp[0].forward(x)
levels.append(x)
num = len(self.spp) - 1
for i in range(1, num):
if not self.square_grid:
grid_size = (self.grids[i - 1], max(1, round(ar * self.grids[i - 1])))
x_pooled = F.adaptive_avg_pool2d(x, grid_size)
else:
x_pooled = F.adaptive_avg_pool2d(x, self.grids[i - 1])
level = self.spp[i].forward(x_pooled)
level = self.upsampling_method(level, target_size)
levels.append(level)
x = torch.cat(levels, 1)
x = self.spp[-1].forward(x)
return x
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, input):
return input |