| """ | |
| ISDNet building blocks: STDC-like modules and Laplacian pyramid | |
| """ | |
| import os | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import init | |
| class ConvX(nn.Module): | |
| """Basic conv-bn-relu block.""" | |
| def __init__(self, in_planes, out_planes, kernel=3, stride=1): | |
| super().__init__() | |
| self.conv = nn.Conv2d( | |
| in_planes, out_planes, | |
| kernel_size=kernel, stride=stride, | |
| padding=kernel // 2, bias=False | |
| ) | |
| self.bn = nn.SyncBatchNorm(out_planes) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| return self.relu(self.bn(self.conv(x))) | |
| class AddBottleneck(nn.Module): | |
| """STDC AddBottleneck: residual addition fusion.""" | |
| def __init__(self, in_planes, out_planes, block_num=3, stride=1): | |
| super().__init__() | |
| self.conv_list = nn.ModuleList() | |
| self.stride = stride | |
| if stride == 2: | |
| self.avd_layer = nn.Sequential( | |
| nn.Conv2d(out_planes // 2, out_planes // 2, 3, 2, 1, | |
| groups=out_planes // 2, bias=False), | |
| nn.SyncBatchNorm(out_planes // 2) | |
| ) | |
| self.skip = nn.Sequential( | |
| nn.Conv2d(in_planes, in_planes, 3, 2, 1, groups=in_planes, bias=False), | |
| nn.SyncBatchNorm(in_planes), | |
| nn.Conv2d(in_planes, out_planes, 1, bias=False), | |
| nn.SyncBatchNorm(out_planes) | |
| ) | |
| stride = 1 | |
| for idx in range(block_num): | |
| if idx == 0: | |
| self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1)) | |
| elif idx == 1 and block_num == 2: | |
| self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride)) | |
| elif idx == 1: | |
| self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride)) | |
| elif idx < block_num - 1: | |
| self.conv_list.append( | |
| ConvX(out_planes // int(math.pow(2, idx)), | |
| out_planes // int(math.pow(2, idx + 1))) | |
| ) | |
| else: | |
| self.conv_list.append( | |
| ConvX(out_planes // int(math.pow(2, idx)), | |
| out_planes // int(math.pow(2, idx))) | |
| ) | |
| def forward(self, x): | |
| out_list, out = [], x | |
| for idx, conv in enumerate(self.conv_list): | |
| if idx == 0 and self.stride == 2: | |
| out = self.avd_layer(conv(out)) | |
| else: | |
| out = conv(out) | |
| out_list.append(out) | |
| if self.stride == 2: | |
| return torch.cat(out_list, dim=1) + self.skip(x) | |
| return torch.cat(out_list, dim=1) + x | |
| class CatBottleneck(nn.Module): | |
| """STDC CatBottleneck: concatenation fusion.""" | |
| def __init__(self, in_planes, out_planes, block_num=3, stride=1): | |
| super().__init__() | |
| self.conv_list = nn.ModuleList() | |
| self.stride = stride | |
| if stride == 2: | |
| self.avd_layer = nn.Sequential( | |
| nn.Conv2d(out_planes // 2, out_planes // 2, 3, 2, 1, | |
| groups=out_planes // 2, bias=False), | |
| nn.SyncBatchNorm(out_planes // 2) | |
| ) | |
| self.skip = nn.AvgPool2d(3, 2, 1) | |
| stride = 1 | |
| for idx in range(block_num): | |
| if idx == 0: | |
| self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1)) | |
| elif idx == 1 and block_num == 2: | |
| self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride)) | |
| elif idx == 1: | |
| self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride)) | |
| elif idx < block_num - 1: | |
| self.conv_list.append( | |
| ConvX(out_planes // int(math.pow(2, idx)), | |
| out_planes // int(math.pow(2, idx + 1))) | |
| ) | |
| else: | |
| self.conv_list.append( | |
| ConvX(out_planes // int(math.pow(2, idx)), | |
| out_planes // int(math.pow(2, idx))) | |
| ) | |
| def forward(self, x): | |
| out_list = [] | |
| out1 = self.conv_list[0](x) | |
| for idx, conv in enumerate(self.conv_list[1:]): | |
| if idx == 0 and self.stride == 2: | |
| out = conv(self.avd_layer(out1)) | |
| elif idx == 0: | |
| out = conv(out1) | |
| else: | |
| out = conv(out) | |
| out_list.append(out) | |
| if self.stride == 2: | |
| out_list.insert(0, self.skip(out1)) | |
| else: | |
| out_list.insert(0, out1) | |
| return torch.cat(out_list, dim=1) | |
| class ShallowNet(nn.Module): | |
| """ | |
| STDC-like shallow network for high-resolution feature extraction. | |
| Args: | |
| base: Base channel number | |
| in_channels: Input channels (3 for RGB, 6 for pyramid concat) | |
| layers: Number of blocks per stage | |
| block_num: Number of convs per block | |
| type: 'cat' for CatBottleneck, 'add' for AddBottleneck | |
| pretrain_model: Path to pretrained STDC weights | |
| """ | |
| def __init__(self, base=64, in_channels=3, layers=[2, 2], block_num=4, | |
| type="cat", pretrain_model=''): | |
| super().__init__() | |
| block = CatBottleneck if type == "cat" else AddBottleneck | |
| self.in_channels = in_channels | |
| features = [ | |
| ConvX(in_channels, base // 2, 3, 2), | |
| ConvX(base // 2, base, 3, 2) | |
| ] | |
| for i, layer in enumerate(layers): | |
| for j in range(layer): | |
| if i == 0 and j == 0: | |
| features.append(block(base, base * 4, block_num, 2)) | |
| elif j == 0: | |
| features.append( | |
| block(base * int(math.pow(2, i + 1)), | |
| base * int(math.pow(2, i + 2)), block_num, 2) | |
| ) | |
| else: | |
| features.append( | |
| block(base * int(math.pow(2, i + 2)), | |
| base * int(math.pow(2, i + 2)), block_num, 1) | |
| ) | |
| self.features = nn.Sequential(*features) | |
| self.x2 = nn.Sequential(self.features[:1]) | |
| self.x4 = nn.Sequential(self.features[1:2]) | |
| self.x8 = nn.Sequential(self.features[2:4]) | |
| self.x16 = nn.Sequential(self.features[4:6]) | |
| if pretrain_model and os.path.exists(pretrain_model): | |
| print(f'Loading pretrain model {pretrain_model}') | |
| sd = torch.load(pretrain_model, weights_only=False)["state_dict"] | |
| ssd = self.state_dict() | |
| for k, v in sd.items(): | |
| if k == 'features.0.conv.weight' and in_channels != 3: | |
| v = torch.cat([v, v], dim=1) | |
| if k in ssd: | |
| ssd.update({k: v}) | |
| self.load_state_dict(ssd, strict=False) | |
| else: | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| init.kaiming_normal_(m.weight, mode='fan_out') | |
| elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): | |
| init.constant_(m.weight, 1) | |
| init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| x2 = self.x2(x) | |
| x4 = self.x4(x2) | |
| x8 = self.x8(x4) | |
| x16 = self.x16(x8) | |
| return x8, x16 | |
| class Lap_Pyramid_Conv(nn.Module): | |
| """ | |
| Laplacian pyramid decomposition. | |
| Extracts high-frequency details at multiple scales. | |
| """ | |
| def __init__(self, num_high=3, gauss_chl=3): | |
| super().__init__() | |
| self.num_high = num_high | |
| self.gauss_chl = gauss_chl | |
| k = torch.tensor([ | |
| [1., 4., 6., 4., 1], | |
| [4., 16., 24., 16., 4.], | |
| [6., 24., 36., 24., 6.], | |
| [4., 16., 24., 16., 4.], | |
| [1., 4., 6., 4., 1.] | |
| ]) / 256. | |
| self.register_buffer('kernel', k.repeat(gauss_chl, 1, 1, 1)) | |
| def conv_gauss(self, img, k): | |
| return F.conv2d(F.pad(img, (2, 2, 2, 2), mode='reflect'), k, groups=img.shape[1]) | |
| def downsample(self, x): | |
| return x[:, :, ::2, ::2] | |
| def upsample(self, x): | |
| cc = torch.cat([x, torch.zeros_like(x)], dim=3) | |
| cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3]) | |
| cc = cc.permute(0, 1, 3, 2) | |
| cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], | |
| x.shape[2] * 2, device=x.device)], dim=3) | |
| cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2) | |
| return self.conv_gauss(cc.permute(0, 1, 3, 2), 4 * self.kernel) | |
| def pyramid_decom(self, img): | |
| """Decompose image into Laplacian pyramid (high-frequency residuals).""" | |
| current = img | |
| pyr = [] | |
| for _ in range(self.num_high): | |
| down = self.downsample(self.conv_gauss(current, self.kernel)) | |
| up = self.upsample(down) | |
| if up.shape[2:] != current.shape[2:]: | |
| up = F.interpolate(up, current.shape[2:]) | |
| pyr.append(current - up) | |
| current = down | |
| return pyr | |