ISDNet-pytorch / models /modules.py
Antoine1091's picture
Upload folder using huggingface_hub
a585f5a verified
"""
ISDNet building blocks: STDC-like modules and Laplacian pyramid
"""
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
class ConvX(nn.Module):
"""Basic conv-bn-relu block."""
def __init__(self, in_planes, out_planes, kernel=3, stride=1):
super().__init__()
self.conv = nn.Conv2d(
in_planes, out_planes,
kernel_size=kernel, stride=stride,
padding=kernel // 2, bias=False
)
self.bn = nn.SyncBatchNorm(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class AddBottleneck(nn.Module):
"""STDC AddBottleneck: residual addition fusion."""
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
super().__init__()
self.conv_list = nn.ModuleList()
self.stride = stride
if stride == 2:
self.avd_layer = nn.Sequential(
nn.Conv2d(out_planes // 2, out_planes // 2, 3, 2, 1,
groups=out_planes // 2, bias=False),
nn.SyncBatchNorm(out_planes // 2)
)
self.skip = nn.Sequential(
nn.Conv2d(in_planes, in_planes, 3, 2, 1, groups=in_planes, bias=False),
nn.SyncBatchNorm(in_planes),
nn.Conv2d(in_planes, out_planes, 1, bias=False),
nn.SyncBatchNorm(out_planes)
)
stride = 1
for idx in range(block_num):
if idx == 0:
self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))
elif idx == 1 and block_num == 2:
self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride))
elif idx == 1:
self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride))
elif idx < block_num - 1:
self.conv_list.append(
ConvX(out_planes // int(math.pow(2, idx)),
out_planes // int(math.pow(2, idx + 1)))
)
else:
self.conv_list.append(
ConvX(out_planes // int(math.pow(2, idx)),
out_planes // int(math.pow(2, idx)))
)
def forward(self, x):
out_list, out = [], x
for idx, conv in enumerate(self.conv_list):
if idx == 0 and self.stride == 2:
out = self.avd_layer(conv(out))
else:
out = conv(out)
out_list.append(out)
if self.stride == 2:
return torch.cat(out_list, dim=1) + self.skip(x)
return torch.cat(out_list, dim=1) + x
class CatBottleneck(nn.Module):
"""STDC CatBottleneck: concatenation fusion."""
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
super().__init__()
self.conv_list = nn.ModuleList()
self.stride = stride
if stride == 2:
self.avd_layer = nn.Sequential(
nn.Conv2d(out_planes // 2, out_planes // 2, 3, 2, 1,
groups=out_planes // 2, bias=False),
nn.SyncBatchNorm(out_planes // 2)
)
self.skip = nn.AvgPool2d(3, 2, 1)
stride = 1
for idx in range(block_num):
if idx == 0:
self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))
elif idx == 1 and block_num == 2:
self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride))
elif idx == 1:
self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride))
elif idx < block_num - 1:
self.conv_list.append(
ConvX(out_planes // int(math.pow(2, idx)),
out_planes // int(math.pow(2, idx + 1)))
)
else:
self.conv_list.append(
ConvX(out_planes // int(math.pow(2, idx)),
out_planes // int(math.pow(2, idx)))
)
def forward(self, x):
out_list = []
out1 = self.conv_list[0](x)
for idx, conv in enumerate(self.conv_list[1:]):
if idx == 0 and self.stride == 2:
out = conv(self.avd_layer(out1))
elif idx == 0:
out = conv(out1)
else:
out = conv(out)
out_list.append(out)
if self.stride == 2:
out_list.insert(0, self.skip(out1))
else:
out_list.insert(0, out1)
return torch.cat(out_list, dim=1)
class ShallowNet(nn.Module):
"""
STDC-like shallow network for high-resolution feature extraction.
Args:
base: Base channel number
in_channels: Input channels (3 for RGB, 6 for pyramid concat)
layers: Number of blocks per stage
block_num: Number of convs per block
type: 'cat' for CatBottleneck, 'add' for AddBottleneck
pretrain_model: Path to pretrained STDC weights
"""
def __init__(self, base=64, in_channels=3, layers=[2, 2], block_num=4,
type="cat", pretrain_model=''):
super().__init__()
block = CatBottleneck if type == "cat" else AddBottleneck
self.in_channels = in_channels
features = [
ConvX(in_channels, base // 2, 3, 2),
ConvX(base // 2, base, 3, 2)
]
for i, layer in enumerate(layers):
for j in range(layer):
if i == 0 and j == 0:
features.append(block(base, base * 4, block_num, 2))
elif j == 0:
features.append(
block(base * int(math.pow(2, i + 1)),
base * int(math.pow(2, i + 2)), block_num, 2)
)
else:
features.append(
block(base * int(math.pow(2, i + 2)),
base * int(math.pow(2, i + 2)), block_num, 1)
)
self.features = nn.Sequential(*features)
self.x2 = nn.Sequential(self.features[:1])
self.x4 = nn.Sequential(self.features[1:2])
self.x8 = nn.Sequential(self.features[2:4])
self.x16 = nn.Sequential(self.features[4:6])
if pretrain_model and os.path.exists(pretrain_model):
print(f'Loading pretrain model {pretrain_model}')
sd = torch.load(pretrain_model, weights_only=False)["state_dict"]
ssd = self.state_dict()
for k, v in sd.items():
if k == 'features.0.conv.weight' and in_channels != 3:
v = torch.cat([v, v], dim=1)
if k in ssd:
ssd.update({k: v})
self.load_state_dict(ssd, strict=False)
else:
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
def forward(self, x):
x2 = self.x2(x)
x4 = self.x4(x2)
x8 = self.x8(x4)
x16 = self.x16(x8)
return x8, x16
class Lap_Pyramid_Conv(nn.Module):
"""
Laplacian pyramid decomposition.
Extracts high-frequency details at multiple scales.
"""
def __init__(self, num_high=3, gauss_chl=3):
super().__init__()
self.num_high = num_high
self.gauss_chl = gauss_chl
k = torch.tensor([
[1., 4., 6., 4., 1],
[4., 16., 24., 16., 4.],
[6., 24., 36., 24., 6.],
[4., 16., 24., 16., 4.],
[1., 4., 6., 4., 1.]
]) / 256.
self.register_buffer('kernel', k.repeat(gauss_chl, 1, 1, 1))
def conv_gauss(self, img, k):
return F.conv2d(F.pad(img, (2, 2, 2, 2), mode='reflect'), k, groups=img.shape[1])
def downsample(self, x):
return x[:, :, ::2, ::2]
def upsample(self, x):
cc = torch.cat([x, torch.zeros_like(x)], dim=3)
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
cc = cc.permute(0, 1, 3, 2)
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3],
x.shape[2] * 2, device=x.device)], dim=3)
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
return self.conv_gauss(cc.permute(0, 1, 3, 2), 4 * self.kernel)
def pyramid_decom(self, img):
"""Decompose image into Laplacian pyramid (high-frequency residuals)."""
current = img
pyr = []
for _ in range(self.num_high):
down = self.downsample(self.conv_gauss(current, self.kernel))
up = self.upsample(down)
if up.shape[2:] != current.shape[2:]:
up = F.interpolate(up, current.shape[2:])
pyr.append(current - up)
current = down
return pyr