Spaces:
Sleeping
Sleeping
File size: 6,443 Bytes
33a707b fbefebb d68b77d 93f7cd5 d68b77d 93f7cd5 d68b77d 33a707b 5fdd273 33a707b 93f7cd5 33a707b 472dac7 33a707b fbefebb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | import torch
import torch.nn as nn
from .height_head import resize
class ConvModule(nn.Module):
"""Conv + Norm + Activation with same submodule names as mmcv.ConvModule."""
def __init__(self, in_channels, out_channels, kernel_size, padding=0,
norm_layer=None, act_layer=None):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
padding=padding, bias=(norm_layer is None))
self.bn = norm_layer(out_channels) if norm_layer is not None else nn.Identity()
self.activate = act_layer() if act_layer is not None else nn.Identity()
def forward(self, x):
return self.activate(self.bn(self.conv(x)))
class PPM(nn.ModuleList):
def __init__(self, pool_scales, in_channels, channels, align_corners):
super(PPM, self).__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
ConvModule(in_channels, channels, 1,
norm_layer=nn.SyncBatchNorm,
act_layer=nn.ReLU),
))
def forward(self, x):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
upsampled_ppm_out = resize(
ppm_out,
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs
class Decoder(nn.Module):
def __init__(self,
in_channel=320,
short_cut_channels=[512, 256, 128],
num_deconv_filters=2,
decover_filter=[128, 128],
psp_channel=16,
pool_scales=[1, 2, 3, 6]):
super(Decoder, self).__init__()
self.in_channel = in_channel
self.num_deconv_filters = num_deconv_filters
self.short_cut_channels = short_cut_channels
self.decover_filter = decover_filter
self.pool_scales = pool_scales
self.upper_module_dict = nn.ModuleDict()
# define network layers
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, in_channel, 3, stride=1, padding=1),
nn.GroupNorm(16, in_channel),
nn.ReLU())
if self.short_cut_channels is not None:
self.connect_conv_dict = nn.ModuleDict()
self.connect_conv_dict['connect_conv1'] = nn.Sequential(
nn.Conv2d(short_cut_channels[0], in_channel, 3, stride=1, padding=1),
nn.GroupNorm(16, in_channel),
nn.ReLU())
self.connect_conv_dict['adapt_merge1'] = nn.Conv2d(in_channel*2, in_channel, 1)
self.connect_conv_dict['connect_conv2'] = nn.Sequential(
nn.Conv2d(short_cut_channels[1], decover_filter[0], 3, stride=1, padding=1),
nn.GroupNorm(16, decover_filter[0]),
nn.ReLU())
self.connect_conv_dict['adapt_merge2'] = nn.Conv2d(decover_filter[0]*2, decover_filter[0], 1)
self.connect_conv_dict['connect_conv3'] = nn.Sequential(
nn.Conv2d(short_cut_channels[2], decover_filter[1], 3, stride=1, padding=1),
nn.GroupNorm(16, decover_filter[1]),
nn.ReLU())
self.connect_conv_dict['adapt_merge3'] = nn.Conv2d(decover_filter[1]*2, decover_filter[1], 1)
for i in range(num_deconv_filters):
self._make_deconv_layer(
f'deconv{i}', in_channel, decover_filter[i])
in_channel = decover_filter[i]
self.psp_channel = psp_channel
if psp_channel > -1:
self.psp_modules = PPM(
pool_scales=pool_scales,
in_channels=in_channel,
channels=psp_channel,
align_corners=False)
def _make_deconv_layer(self, name, in_channel, out_channel):
"""Make deconv layers."""
layers = []
layers.append(
nn.ConvTranspose2d(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=2,
stride=2,
padding=0,
output_padding=0,
bias=False))
layers.append(nn.BatchNorm2d(out_channel))
layers.append(nn.ReLU(inplace=True))
self.upper_module_dict[name] = nn.Sequential(*layers)
def forward(self, x, res_list=None):
x = self.conv1(x) # 32*32
if res_list is not None:
res = self.connect_conv_dict['connect_conv1'](res_list[0])
x = self.connect_conv_dict['adapt_merge1'](torch.cat([x, res], dim=1))
x = self.upper_module_dict['deconv0'](x) # 64*64
if res_list is not None:
res = self.connect_conv_dict['connect_conv2'](res_list[1])
x = self.connect_conv_dict['adapt_merge2'](torch.cat([x, res], dim=1))
x = self.upper_module_dict['deconv1'](x) # 128*128
if res_list is not None:
res = self.connect_conv_dict['connect_conv3'](res_list[2])
x = self.connect_conv_dict['adapt_merge3'](torch.cat([x, res], dim=1))
if self.psp_channel > -1:
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
return torch.cat(psp_outs, dim=1)
else:
return x
if __name__ == '__main__':
# model = Decoder(in_channel=320)
# input_data = torch.randn(1, 320, 32, 32)
# res = [torch.randn(1, 512, 32, 32), torch.randn(1, 256, 64, 64), torch.randn(1, 128, 128, 128)]
# output = model(input_data, res)
# print(output.shape)
model = Decoder(in_channel=320, short_cut_channels=None)
# input_data = torch.randn(1, 320, 32, 32)
# output = model(input_data, None)
# flops, params = get_model_complexity_info(model, (320, 32, 32))
# print(f"参数量: {params}")
# print(f"计算量: {flops}")
# print("-" * 30)
# print(output.shape)
# model = Decoder(in_channel=320, short_cut_channels=None, psp_channel=-1)
# input_data = torch.randn(2, 320, 32, 32)
# output = model(input_data, None)
# print(output.shape) |