Mhara's picture
Upload folder using huggingface_hub
dae5c90 verified
Raw
History Blame Contribute Delete
7.06 kB
import torch
import torch.nn.functional as F
import warnings
from torch import nn as nn
upsample = lambda x, size: F.interpolate(x, size, mode='bilinear', align_corners=False)
batchnorm_momentum = 0.01 / 2
def get_n_params(parameters):
pp = 0
for p in parameters:
nn = 1
for s in list(p.size()):
nn = nn * s
pp += nn
return pp
class SeparableConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
super(SeparableConv2d, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels,
bias=bias)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
def forward(self, x):
x = self.conv1(x)
x = self.pointwise(x)
return x
class _BNReluConv(nn.Sequential):
def __init__(self, num_maps_in, num_maps_out, k=3, batch_norm=True, bn_momentum=0.1, bias=False, dilation=1,
drop_rate=.0, separable=False):
super(_BNReluConv, self).__init__()
if batch_norm:
self.add_module('norm', nn.BatchNorm2d(num_maps_in, momentum=bn_momentum))
self.add_module('relu', nn.ReLU(inplace=batch_norm is True))
padding = k // 2
conv_class = SeparableConv2d if separable else nn.Conv2d
warnings.warn(f'Using conv type {k}x{k}: {conv_class}')
self.add_module('conv', conv_class(num_maps_in, num_maps_out, kernel_size=k, padding=padding, bias=bias,
dilation=dilation))
if drop_rate > 0:
warnings.warn(f'Using dropout with p: {drop_rate}')
self.add_module('dropout', nn.Dropout2d(drop_rate, inplace=True))
class _Upsample(nn.Module):
def __init__(self, num_maps_in, skip_maps_in, num_maps_out, use_bn=True, k=3, use_skip=True, only_skip=False,
detach_skip=False, fixed_size=None, separable=False, bneck_starts_with_bn=True):
super(_Upsample, self).__init__()
print(f'Upsample layer: in = {num_maps_in}, skip = {skip_maps_in}, out = {num_maps_out}')
self.bottleneck = _BNReluConv(skip_maps_in, num_maps_in, k=1, batch_norm=use_bn and bneck_starts_with_bn)
self.blend_conv = _BNReluConv(num_maps_in, num_maps_out, k=k, batch_norm=use_bn, separable=separable)
self.use_skip = use_skip
self.only_skip = only_skip
self.detach_skip = detach_skip
warnings.warn(f'\tUsing skips: {self.use_skip} (only skips: {self.only_skip})', UserWarning)
self.upsampling_method = upsample
if fixed_size is not None:
self.upsampling_method = lambda x, size: F.interpolate(x, mode='nearest', size=fixed_size)
warnings.warn(f'Fixed upsample size', UserWarning)
def forward(self, x, skip):
skip = self.bottleneck.forward(skip)
if self.detach_skip:
skip = skip.detach()
skip_size = skip.size()[2:4]
x = self.upsampling_method(x, skip_size)
if self.use_skip:
x = x + skip
x = self.blend_conv.forward(x)
return x
class _UpsampleBlend(nn.Module):
def __init__(self, num_features, use_bn=True, use_skip=True, detach_skip=False, fixed_size=None, k=3,
separable=False):
super(_UpsampleBlend, self).__init__()
self.blend_conv = _BNReluConv(num_features, num_features, k=k, batch_norm=use_bn, separable=separable)
self.use_skip = use_skip
self.detach_skip = detach_skip
warnings.warn(f'Using skip connections: {self.use_skip}', UserWarning)
self.upsampling_method = upsample
if fixed_size is not None:
self.upsampling_method = lambda x, size: F.interpolate(x, mode='nearest', size=fixed_size)
warnings.warn(f'Fixed upsample size', UserWarning)
def forward(self, x, skip):
if self.detach_skip:
warnings.warn(f'Detaching skip connection {skip.shape[2:4]}', UserWarning)
skip = skip.detach()
skip_size = skip.size()[-2:]
x = self.upsampling_method(x, skip_size)
if self.use_skip:
x = x + skip
x = self.blend_conv.forward(x)
return x
class SpatialPyramidPooling(nn.Module):
def __init__(self, num_maps_in, num_levels, bt_size=512, level_size=128, out_size=128,
grids=(6, 3, 2, 1), square_grid=False, bn_momentum=0.1, use_bn=True, drop_rate=.0,
fixed_size=None, starts_with_bn=True):
super(SpatialPyramidPooling, self).__init__()
self.fixed_size = fixed_size
self.grids = grids
if self.fixed_size:
ref = min(self.fixed_size)
self.grids = list(filter(lambda x: x <= ref, self.grids))
self.square_grid = square_grid
self.upsampling_method = upsample
if self.fixed_size is not None:
self.upsampling_method = lambda x, size: F.interpolate(x, mode='nearest', size=fixed_size)
warnings.warn(f'Fixed upsample size', UserWarning)
self.spp = nn.Sequential()
self.spp.add_module('spp_bn', _BNReluConv(num_maps_in, bt_size, k=1, bn_momentum=bn_momentum,
batch_norm=use_bn and starts_with_bn))
num_features = bt_size
final_size = num_features
for i in range(num_levels):
final_size += level_size
self.spp.add_module('spp' + str(i),
_BNReluConv(num_features, level_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn,
drop_rate=drop_rate))
self.spp.add_module('spp_fuse',
_BNReluConv(final_size, out_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn))
def forward(self, x):
levels = []
target_size = self.fixed_size if self.fixed_size is not None else x.size()[2:4]
ar = target_size[1] / target_size[0]
x = self.spp[0].forward(x)
levels.append(x)
num = len(self.spp) - 1
for i in range(1, num):
if not self.square_grid:
grid_size = (self.grids[i - 1], max(1, round(ar * self.grids[i - 1])))
x_pooled = F.adaptive_avg_pool2d(x, grid_size)
else:
x_pooled = F.adaptive_avg_pool2d(x, self.grids[i - 1])
level = self.spp[i].forward(x_pooled)
level = self.upsampling_method(level, target_size)
levels.append(level)
x = torch.cat(levels, 1)
x = self.spp[-1].forward(x)
return x
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, input):
return input