""" ISDNet building blocks: STDC-like modules and Laplacian pyramid """ import os import math import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init class ConvX(nn.Module): """Basic conv-bn-relu block.""" def __init__(self, in_planes, out_planes, kernel=3, stride=1): super().__init__() self.conv = nn.Conv2d( in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel // 2, bias=False ) self.bn = nn.SyncBatchNorm(out_planes) self.relu = nn.ReLU(inplace=True) def forward(self, x): return self.relu(self.bn(self.conv(x))) class AddBottleneck(nn.Module): """STDC AddBottleneck: residual addition fusion.""" def __init__(self, in_planes, out_planes, block_num=3, stride=1): super().__init__() self.conv_list = nn.ModuleList() self.stride = stride if stride == 2: self.avd_layer = nn.Sequential( nn.Conv2d(out_planes // 2, out_planes // 2, 3, 2, 1, groups=out_planes // 2, bias=False), nn.SyncBatchNorm(out_planes // 2) ) self.skip = nn.Sequential( nn.Conv2d(in_planes, in_planes, 3, 2, 1, groups=in_planes, bias=False), nn.SyncBatchNorm(in_planes), nn.Conv2d(in_planes, out_planes, 1, bias=False), nn.SyncBatchNorm(out_planes) ) stride = 1 for idx in range(block_num): if idx == 0: self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1)) elif idx == 1 and block_num == 2: self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride)) elif idx == 1: self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride)) elif idx < block_num - 1: self.conv_list.append( ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx + 1))) ) else: self.conv_list.append( ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx))) ) def forward(self, x): out_list, out = [], x for idx, conv in enumerate(self.conv_list): if idx == 0 and self.stride == 2: out = self.avd_layer(conv(out)) else: out = conv(out) out_list.append(out) if self.stride == 2: return torch.cat(out_list, dim=1) + self.skip(x) return torch.cat(out_list, dim=1) + x class CatBottleneck(nn.Module): """STDC CatBottleneck: concatenation fusion.""" def __init__(self, in_planes, out_planes, block_num=3, stride=1): super().__init__() self.conv_list = nn.ModuleList() self.stride = stride if stride == 2: self.avd_layer = nn.Sequential( nn.Conv2d(out_planes // 2, out_planes // 2, 3, 2, 1, groups=out_planes // 2, bias=False), nn.SyncBatchNorm(out_planes // 2) ) self.skip = nn.AvgPool2d(3, 2, 1) stride = 1 for idx in range(block_num): if idx == 0: self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1)) elif idx == 1 and block_num == 2: self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride)) elif idx == 1: self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride)) elif idx < block_num - 1: self.conv_list.append( ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx + 1))) ) else: self.conv_list.append( ConvX(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx))) ) def forward(self, x): out_list = [] out1 = self.conv_list[0](x) for idx, conv in enumerate(self.conv_list[1:]): if idx == 0 and self.stride == 2: out = conv(self.avd_layer(out1)) elif idx == 0: out = conv(out1) else: out = conv(out) out_list.append(out) if self.stride == 2: out_list.insert(0, self.skip(out1)) else: out_list.insert(0, out1) return torch.cat(out_list, dim=1) class ShallowNet(nn.Module): """ STDC-like shallow network for high-resolution feature extraction. Args: base: Base channel number in_channels: Input channels (3 for RGB, 6 for pyramid concat) layers: Number of blocks per stage block_num: Number of convs per block type: 'cat' for CatBottleneck, 'add' for AddBottleneck pretrain_model: Path to pretrained STDC weights """ def __init__(self, base=64, in_channels=3, layers=[2, 2], block_num=4, type="cat", pretrain_model=''): super().__init__() block = CatBottleneck if type == "cat" else AddBottleneck self.in_channels = in_channels features = [ ConvX(in_channels, base // 2, 3, 2), ConvX(base // 2, base, 3, 2) ] for i, layer in enumerate(layers): for j in range(layer): if i == 0 and j == 0: features.append(block(base, base * 4, block_num, 2)) elif j == 0: features.append( block(base * int(math.pow(2, i + 1)), base * int(math.pow(2, i + 2)), block_num, 2) ) else: features.append( block(base * int(math.pow(2, i + 2)), base * int(math.pow(2, i + 2)), block_num, 1) ) self.features = nn.Sequential(*features) self.x2 = nn.Sequential(self.features[:1]) self.x4 = nn.Sequential(self.features[1:2]) self.x8 = nn.Sequential(self.features[2:4]) self.x16 = nn.Sequential(self.features[4:6]) if pretrain_model and os.path.exists(pretrain_model): print(f'Loading pretrain model {pretrain_model}') sd = torch.load(pretrain_model, weights_only=False)["state_dict"] ssd = self.state_dict() for k, v in sd.items(): if k == 'features.0.conv.weight' and in_channels != 3: v = torch.cat([v, v], dim=1) if k in ssd: ssd.update({k: v}) self.load_state_dict(ssd, strict=False) else: for m in self.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, mode='fan_out') elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): init.constant_(m.weight, 1) init.constant_(m.bias, 0) def forward(self, x): x2 = self.x2(x) x4 = self.x4(x2) x8 = self.x8(x4) x16 = self.x16(x8) return x8, x16 class Lap_Pyramid_Conv(nn.Module): """ Laplacian pyramid decomposition. Extracts high-frequency details at multiple scales. """ def __init__(self, num_high=3, gauss_chl=3): super().__init__() self.num_high = num_high self.gauss_chl = gauss_chl k = torch.tensor([ [1., 4., 6., 4., 1], [4., 16., 24., 16., 4.], [6., 24., 36., 24., 6.], [4., 16., 24., 16., 4.], [1., 4., 6., 4., 1.] ]) / 256. self.register_buffer('kernel', k.repeat(gauss_chl, 1, 1, 1)) def conv_gauss(self, img, k): return F.conv2d(F.pad(img, (2, 2, 2, 2), mode='reflect'), k, groups=img.shape[1]) def downsample(self, x): return x[:, :, ::2, ::2] def upsample(self, x): cc = torch.cat([x, torch.zeros_like(x)], dim=3) cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3]) cc = cc.permute(0, 1, 3, 2) cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device)], dim=3) cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2) return self.conv_gauss(cc.permute(0, 1, 3, 2), 4 * self.kernel) def pyramid_decom(self, img): """Decompose image into Laplacian pyramid (high-frequency residuals).""" current = img pyr = [] for _ in range(self.num_high): down = self.downsample(self.conv_gauss(current, self.kernel)) up = self.upsample(down) if up.shape[2:] != current.shape[2:]: up = F.interpolate(up, current.shape[2:]) pyr.append(current - up) current = down return pyr