Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from .height_head import resize | |
| class ConvModule(nn.Module): | |
| """Conv + Norm + Activation with same submodule names as mmcv.ConvModule.""" | |
| def __init__(self, in_channels, out_channels, kernel_size, padding=0, | |
| norm_layer=None, act_layer=None): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, | |
| padding=padding, bias=(norm_layer is None)) | |
| self.bn = norm_layer(out_channels) if norm_layer is not None else nn.Identity() | |
| self.activate = act_layer() if act_layer is not None else nn.Identity() | |
| def forward(self, x): | |
| return self.activate(self.bn(self.conv(x))) | |
| class PPM(nn.ModuleList): | |
| def __init__(self, pool_scales, in_channels, channels, align_corners): | |
| super(PPM, self).__init__() | |
| self.pool_scales = pool_scales | |
| self.align_corners = align_corners | |
| self.in_channels = in_channels | |
| self.channels = channels | |
| for pool_scale in pool_scales: | |
| self.append( | |
| nn.Sequential( | |
| nn.AdaptiveAvgPool2d(pool_scale), | |
| ConvModule(in_channels, channels, 1, | |
| norm_layer=nn.SyncBatchNorm, | |
| act_layer=nn.ReLU), | |
| )) | |
| def forward(self, x): | |
| """Forward function.""" | |
| ppm_outs = [] | |
| for ppm in self: | |
| ppm_out = ppm(x) | |
| upsampled_ppm_out = resize( | |
| ppm_out, | |
| size=x.size()[2:], | |
| mode='bilinear', | |
| align_corners=self.align_corners) | |
| ppm_outs.append(upsampled_ppm_out) | |
| return ppm_outs | |
| class Decoder(nn.Module): | |
| def __init__(self, | |
| in_channel=320, | |
| short_cut_channels=[512, 256, 128], | |
| num_deconv_filters=2, | |
| decover_filter=[128, 128], | |
| psp_channel=16, | |
| pool_scales=[1, 2, 3, 6]): | |
| super(Decoder, self).__init__() | |
| self.in_channel = in_channel | |
| self.num_deconv_filters = num_deconv_filters | |
| self.short_cut_channels = short_cut_channels | |
| self.decover_filter = decover_filter | |
| self.pool_scales = pool_scales | |
| self.upper_module_dict = nn.ModuleDict() | |
| # define network layers | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(in_channel, in_channel, 3, stride=1, padding=1), | |
| nn.GroupNorm(16, in_channel), | |
| nn.ReLU()) | |
| if self.short_cut_channels is not None: | |
| self.connect_conv_dict = nn.ModuleDict() | |
| self.connect_conv_dict['connect_conv1'] = nn.Sequential( | |
| nn.Conv2d(short_cut_channels[0], in_channel, 3, stride=1, padding=1), | |
| nn.GroupNorm(16, in_channel), | |
| nn.ReLU()) | |
| self.connect_conv_dict['adapt_merge1'] = nn.Conv2d(in_channel*2, in_channel, 1) | |
| self.connect_conv_dict['connect_conv2'] = nn.Sequential( | |
| nn.Conv2d(short_cut_channels[1], decover_filter[0], 3, stride=1, padding=1), | |
| nn.GroupNorm(16, decover_filter[0]), | |
| nn.ReLU()) | |
| self.connect_conv_dict['adapt_merge2'] = nn.Conv2d(decover_filter[0]*2, decover_filter[0], 1) | |
| self.connect_conv_dict['connect_conv3'] = nn.Sequential( | |
| nn.Conv2d(short_cut_channels[2], decover_filter[1], 3, stride=1, padding=1), | |
| nn.GroupNorm(16, decover_filter[1]), | |
| nn.ReLU()) | |
| self.connect_conv_dict['adapt_merge3'] = nn.Conv2d(decover_filter[1]*2, decover_filter[1], 1) | |
| for i in range(num_deconv_filters): | |
| self._make_deconv_layer( | |
| f'deconv{i}', in_channel, decover_filter[i]) | |
| in_channel = decover_filter[i] | |
| self.psp_channel = psp_channel | |
| if psp_channel > -1: | |
| self.psp_modules = PPM( | |
| pool_scales=pool_scales, | |
| in_channels=in_channel, | |
| channels=psp_channel, | |
| align_corners=False) | |
| def _make_deconv_layer(self, name, in_channel, out_channel): | |
| """Make deconv layers.""" | |
| layers = [] | |
| layers.append( | |
| nn.ConvTranspose2d( | |
| in_channels=in_channel, | |
| out_channels=out_channel, | |
| kernel_size=2, | |
| stride=2, | |
| padding=0, | |
| output_padding=0, | |
| bias=False)) | |
| layers.append(nn.BatchNorm2d(out_channel)) | |
| layers.append(nn.ReLU(inplace=True)) | |
| self.upper_module_dict[name] = nn.Sequential(*layers) | |
| def forward(self, x, res_list=None): | |
| x = self.conv1(x) # 32*32 | |
| if res_list is not None: | |
| res = self.connect_conv_dict['connect_conv1'](res_list[0]) | |
| x = self.connect_conv_dict['adapt_merge1'](torch.cat([x, res], dim=1)) | |
| x = self.upper_module_dict['deconv0'](x) # 64*64 | |
| if res_list is not None: | |
| res = self.connect_conv_dict['connect_conv2'](res_list[1]) | |
| x = self.connect_conv_dict['adapt_merge2'](torch.cat([x, res], dim=1)) | |
| x = self.upper_module_dict['deconv1'](x) # 128*128 | |
| if res_list is not None: | |
| res = self.connect_conv_dict['connect_conv3'](res_list[2]) | |
| x = self.connect_conv_dict['adapt_merge3'](torch.cat([x, res], dim=1)) | |
| if self.psp_channel > -1: | |
| psp_outs = [x] | |
| psp_outs.extend(self.psp_modules(x)) | |
| return torch.cat(psp_outs, dim=1) | |
| else: | |
| return x | |
| if __name__ == '__main__': | |
| # model = Decoder(in_channel=320) | |
| # input_data = torch.randn(1, 320, 32, 32) | |
| # res = [torch.randn(1, 512, 32, 32), torch.randn(1, 256, 64, 64), torch.randn(1, 128, 128, 128)] | |
| # output = model(input_data, res) | |
| # print(output.shape) | |
| model = Decoder(in_channel=320, short_cut_channels=None) | |
| # input_data = torch.randn(1, 320, 32, 32) | |
| # output = model(input_data, None) | |
| # flops, params = get_model_complexity_info(model, (320, 32, 32)) | |
| # print(f"参数量: {params}") | |
| # print(f"计算量: {flops}") | |
| # print("-" * 30) | |
| # print(output.shape) | |
| # model = Decoder(in_channel=320, short_cut_channels=None, psp_channel=-1) | |
| # input_data = torch.randn(2, 320, 32, 32) | |
| # output = model(input_data, None) | |
| # print(output.shape) |