import torch import torch.nn as nn from .height_head import resize class ConvModule(nn.Module): """Conv + Norm + Activation with same submodule names as mmcv.ConvModule.""" def __init__(self, in_channels, out_channels, kernel_size, padding=0, norm_layer=None, act_layer=None): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=(norm_layer is None)) self.bn = norm_layer(out_channels) if norm_layer is not None else nn.Identity() self.activate = act_layer() if act_layer is not None else nn.Identity() def forward(self, x): return self.activate(self.bn(self.conv(x))) class PPM(nn.ModuleList): def __init__(self, pool_scales, in_channels, channels, align_corners): super(PPM, self).__init__() self.pool_scales = pool_scales self.align_corners = align_corners self.in_channels = in_channels self.channels = channels for pool_scale in pool_scales: self.append( nn.Sequential( nn.AdaptiveAvgPool2d(pool_scale), ConvModule(in_channels, channels, 1, norm_layer=nn.SyncBatchNorm, act_layer=nn.ReLU), )) def forward(self, x): """Forward function.""" ppm_outs = [] for ppm in self: ppm_out = ppm(x) upsampled_ppm_out = resize( ppm_out, size=x.size()[2:], mode='bilinear', align_corners=self.align_corners) ppm_outs.append(upsampled_ppm_out) return ppm_outs class Decoder(nn.Module): def __init__(self, in_channel=320, short_cut_channels=[512, 256, 128], num_deconv_filters=2, decover_filter=[128, 128], psp_channel=16, pool_scales=[1, 2, 3, 6]): super(Decoder, self).__init__() self.in_channel = in_channel self.num_deconv_filters = num_deconv_filters self.short_cut_channels = short_cut_channels self.decover_filter = decover_filter self.pool_scales = pool_scales self.upper_module_dict = nn.ModuleDict() # define network layers self.conv1 = nn.Sequential( nn.Conv2d(in_channel, in_channel, 3, stride=1, padding=1), nn.GroupNorm(16, in_channel), nn.ReLU()) if self.short_cut_channels is not None: self.connect_conv_dict = nn.ModuleDict() self.connect_conv_dict['connect_conv1'] = nn.Sequential( nn.Conv2d(short_cut_channels[0], in_channel, 3, stride=1, padding=1), nn.GroupNorm(16, in_channel), nn.ReLU()) self.connect_conv_dict['adapt_merge1'] = nn.Conv2d(in_channel*2, in_channel, 1) self.connect_conv_dict['connect_conv2'] = nn.Sequential( nn.Conv2d(short_cut_channels[1], decover_filter[0], 3, stride=1, padding=1), nn.GroupNorm(16, decover_filter[0]), nn.ReLU()) self.connect_conv_dict['adapt_merge2'] = nn.Conv2d(decover_filter[0]*2, decover_filter[0], 1) self.connect_conv_dict['connect_conv3'] = nn.Sequential( nn.Conv2d(short_cut_channels[2], decover_filter[1], 3, stride=1, padding=1), nn.GroupNorm(16, decover_filter[1]), nn.ReLU()) self.connect_conv_dict['adapt_merge3'] = nn.Conv2d(decover_filter[1]*2, decover_filter[1], 1) for i in range(num_deconv_filters): self._make_deconv_layer( f'deconv{i}', in_channel, decover_filter[i]) in_channel = decover_filter[i] self.psp_channel = psp_channel if psp_channel > -1: self.psp_modules = PPM( pool_scales=pool_scales, in_channels=in_channel, channels=psp_channel, align_corners=False) def _make_deconv_layer(self, name, in_channel, out_channel): """Make deconv layers.""" layers = [] layers.append( nn.ConvTranspose2d( in_channels=in_channel, out_channels=out_channel, kernel_size=2, stride=2, padding=0, output_padding=0, bias=False)) layers.append(nn.BatchNorm2d(out_channel)) layers.append(nn.ReLU(inplace=True)) self.upper_module_dict[name] = nn.Sequential(*layers) def forward(self, x, res_list=None): x = self.conv1(x) # 32*32 if res_list is not None: res = self.connect_conv_dict['connect_conv1'](res_list[0]) x = self.connect_conv_dict['adapt_merge1'](torch.cat([x, res], dim=1)) x = self.upper_module_dict['deconv0'](x) # 64*64 if res_list is not None: res = self.connect_conv_dict['connect_conv2'](res_list[1]) x = self.connect_conv_dict['adapt_merge2'](torch.cat([x, res], dim=1)) x = self.upper_module_dict['deconv1'](x) # 128*128 if res_list is not None: res = self.connect_conv_dict['connect_conv3'](res_list[2]) x = self.connect_conv_dict['adapt_merge3'](torch.cat([x, res], dim=1)) if self.psp_channel > -1: psp_outs = [x] psp_outs.extend(self.psp_modules(x)) return torch.cat(psp_outs, dim=1) else: return x if __name__ == '__main__': # model = Decoder(in_channel=320) # input_data = torch.randn(1, 320, 32, 32) # res = [torch.randn(1, 512, 32, 32), torch.randn(1, 256, 64, 64), torch.randn(1, 128, 128, 128)] # output = model(input_data, res) # print(output.shape) model = Decoder(in_channel=320, short_cut_channels=None) # input_data = torch.randn(1, 320, 32, 32) # output = model(input_data, None) # flops, params = get_model_complexity_info(model, (320, 32, 32)) # print(f"参数量: {params}") # print(f"计算量: {flops}") # print("-" * 30) # print(output.shape) # model = Decoder(in_channel=320, short_cut_channels=None, psp_channel=-1) # input_data = torch.randn(2, 320, 32, 32) # output = model(input_data, None) # print(output.shape)