HeightAdaptor / networks /decoder.py
PubAccount's picture
Update networks/decoder.py
93f7cd5 verified
import torch
import torch.nn as nn
from .height_head import resize
class ConvModule(nn.Module):
"""Conv + Norm + Activation with same submodule names as mmcv.ConvModule."""
def __init__(self, in_channels, out_channels, kernel_size, padding=0,
norm_layer=None, act_layer=None):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
padding=padding, bias=(norm_layer is None))
self.bn = norm_layer(out_channels) if norm_layer is not None else nn.Identity()
self.activate = act_layer() if act_layer is not None else nn.Identity()
def forward(self, x):
return self.activate(self.bn(self.conv(x)))
class PPM(nn.ModuleList):
def __init__(self, pool_scales, in_channels, channels, align_corners):
super(PPM, self).__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
ConvModule(in_channels, channels, 1,
norm_layer=nn.SyncBatchNorm,
act_layer=nn.ReLU),
))
def forward(self, x):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
upsampled_ppm_out = resize(
ppm_out,
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs
class Decoder(nn.Module):
def __init__(self,
in_channel=320,
short_cut_channels=[512, 256, 128],
num_deconv_filters=2,
decover_filter=[128, 128],
psp_channel=16,
pool_scales=[1, 2, 3, 6]):
super(Decoder, self).__init__()
self.in_channel = in_channel
self.num_deconv_filters = num_deconv_filters
self.short_cut_channels = short_cut_channels
self.decover_filter = decover_filter
self.pool_scales = pool_scales
self.upper_module_dict = nn.ModuleDict()
# define network layers
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, in_channel, 3, stride=1, padding=1),
nn.GroupNorm(16, in_channel),
nn.ReLU())
if self.short_cut_channels is not None:
self.connect_conv_dict = nn.ModuleDict()
self.connect_conv_dict['connect_conv1'] = nn.Sequential(
nn.Conv2d(short_cut_channels[0], in_channel, 3, stride=1, padding=1),
nn.GroupNorm(16, in_channel),
nn.ReLU())
self.connect_conv_dict['adapt_merge1'] = nn.Conv2d(in_channel*2, in_channel, 1)
self.connect_conv_dict['connect_conv2'] = nn.Sequential(
nn.Conv2d(short_cut_channels[1], decover_filter[0], 3, stride=1, padding=1),
nn.GroupNorm(16, decover_filter[0]),
nn.ReLU())
self.connect_conv_dict['adapt_merge2'] = nn.Conv2d(decover_filter[0]*2, decover_filter[0], 1)
self.connect_conv_dict['connect_conv3'] = nn.Sequential(
nn.Conv2d(short_cut_channels[2], decover_filter[1], 3, stride=1, padding=1),
nn.GroupNorm(16, decover_filter[1]),
nn.ReLU())
self.connect_conv_dict['adapt_merge3'] = nn.Conv2d(decover_filter[1]*2, decover_filter[1], 1)
for i in range(num_deconv_filters):
self._make_deconv_layer(
f'deconv{i}', in_channel, decover_filter[i])
in_channel = decover_filter[i]
self.psp_channel = psp_channel
if psp_channel > -1:
self.psp_modules = PPM(
pool_scales=pool_scales,
in_channels=in_channel,
channels=psp_channel,
align_corners=False)
def _make_deconv_layer(self, name, in_channel, out_channel):
"""Make deconv layers."""
layers = []
layers.append(
nn.ConvTranspose2d(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=2,
stride=2,
padding=0,
output_padding=0,
bias=False))
layers.append(nn.BatchNorm2d(out_channel))
layers.append(nn.ReLU(inplace=True))
self.upper_module_dict[name] = nn.Sequential(*layers)
def forward(self, x, res_list=None):
x = self.conv1(x) # 32*32
if res_list is not None:
res = self.connect_conv_dict['connect_conv1'](res_list[0])
x = self.connect_conv_dict['adapt_merge1'](torch.cat([x, res], dim=1))
x = self.upper_module_dict['deconv0'](x) # 64*64
if res_list is not None:
res = self.connect_conv_dict['connect_conv2'](res_list[1])
x = self.connect_conv_dict['adapt_merge2'](torch.cat([x, res], dim=1))
x = self.upper_module_dict['deconv1'](x) # 128*128
if res_list is not None:
res = self.connect_conv_dict['connect_conv3'](res_list[2])
x = self.connect_conv_dict['adapt_merge3'](torch.cat([x, res], dim=1))
if self.psp_channel > -1:
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
return torch.cat(psp_outs, dim=1)
else:
return x
if __name__ == '__main__':
# model = Decoder(in_channel=320)
# input_data = torch.randn(1, 320, 32, 32)
# res = [torch.randn(1, 512, 32, 32), torch.randn(1, 256, 64, 64), torch.randn(1, 128, 128, 128)]
# output = model(input_data, res)
# print(output.shape)
model = Decoder(in_channel=320, short_cut_channels=None)
# input_data = torch.randn(1, 320, 32, 32)
# output = model(input_data, None)
# flops, params = get_model_complexity_info(model, (320, 32, 32))
# print(f"参数量: {params}")
# print(f"计算量: {flops}")
# print("-" * 30)
# print(output.shape)
# model = Decoder(in_channel=320, short_cut_channels=None, psp_channel=-1)
# input_data = torch.randn(2, 320, 32, 32)
# output = model(input_data, None)
# print(output.shape)