Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # @Time : 2024/7/21 下午3:51 | |
| # @Author : xiaoshun | |
| # @Email : 3038523973@qq.com | |
| # @File : mcdnet.py | |
| # @Software: PyCharm | |
| import cv2 | |
| import image_dehazer | |
| import numpy as np | |
| # 论文地址:https://www.sciencedirect.com/science/article/pii/S1569843224001742?via%3Dihub | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class _DPFF(nn.Module): | |
| def __init__(self, in_channels) -> None: | |
| super(_DPFF, self).__init__() | |
| self.cbr1 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False) | |
| self.cbr2 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False) | |
| # self.sigmoid = nn.Sigmoid() | |
| self.cbr3 = nn.Conv2d(in_channels, in_channels, 1, 1, bias=False) | |
| self.cbr4 = nn.Conv2d(in_channels * 2, in_channels, 1, 1, bias=False) | |
| def forward(self, feature1, feature2): | |
| d1 = torch.abs(feature1 - feature2) | |
| d2 = self.cbr1(torch.cat([feature1, feature2], dim=1)) | |
| d = torch.cat([d1, d2], dim=1) | |
| d = self.cbr2(d) | |
| # d = self.sigmoid(d) | |
| v1, v2 = self.cbr3(feature1), self.cbr3(feature2) | |
| v1, v2 = v1 * d, v2 * d | |
| features = torch.cat([v1, v2], dim=1) | |
| features = self.cbr4(features) | |
| return features | |
| class DPFF(nn.Module): | |
| def __init__(self, layer_channels) -> None: | |
| super(DPFF, self).__init__() | |
| self.cfes = nn.ModuleList() | |
| for layer_channel in layer_channels: | |
| self.cfes.append(_DPFF(layer_channel)) | |
| def forward(self, features1, features2): | |
| outputs = [] | |
| for feature1, feature2, cfe in zip(features1, features2, self.cfes): | |
| outputs.append(cfe(feature1, feature2)) | |
| return outputs | |
| class DirectDPFF(nn.Module): | |
| def __init__(self, layer_channels) -> None: | |
| super(DirectDPFF, self).__init__() | |
| self.fusions = nn.ModuleList( | |
| [nn.Conv2d(layer_channel * 2, layer_channel, 1, 1) for layer_channel in layer_channels] | |
| ) | |
| def forward(self, features1, features2): | |
| outputs = [] | |
| for feature1, feature2, fusion in zip(features1, features2, self.fusions): | |
| feature = torch.cat([feature1, feature2], dim=1) | |
| outputs.append(fusion(feature)) | |
| return outputs | |
| class ConvBlock(nn.Module): | |
| def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, | |
| bn=False, activation=True, maxpool=True): | |
| super(ConvBlock, self).__init__() | |
| self.module = [] | |
| if maxpool: | |
| down = nn.Sequential( | |
| *[ | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(input_size, output_size, 1, 1, 0, bias=bias) | |
| ] | |
| ) | |
| else: | |
| down = nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias) | |
| self.module.append(down) | |
| if bn: | |
| self.module.append(nn.BatchNorm2d(output_size)) | |
| if activation: | |
| self.module.append(nn.PReLU()) | |
| self.module = nn.Sequential(*self.module) | |
| def forward(self, x): | |
| out = self.module(x) | |
| return out | |
| class DeconvBlock(nn.Module): | |
| def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, | |
| bn=False, activation=True, bilinear=True): | |
| super(DeconvBlock, self).__init__() | |
| self.module = [] | |
| if bilinear: | |
| deconv = nn.Sequential( | |
| *[ | |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |
| nn.Conv2d(input_size, output_size, 1, 1, 0, bias=bias) | |
| ] | |
| ) | |
| else: | |
| deconv = nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias) | |
| self.module.append(deconv) | |
| if bn: | |
| self.module.append(nn.BatchNorm2d(output_size)) | |
| if activation: | |
| self.module.append(nn.PReLU()) | |
| self.module = nn.Sequential(*self.module) | |
| def forward(self, x): | |
| out = self.module(x) | |
| return out | |
| class FusionBlock(torch.nn.Module): | |
| def __init__(self, num_filter, num_ft, kernel_size=4, stride=2, padding=1, bias=True, maxpool=False, | |
| bilinear=False): | |
| super(FusionBlock, self).__init__() | |
| self.num_ft = num_ft | |
| self.up_convs = nn.ModuleList() | |
| self.down_convs = nn.ModuleList() | |
| for i in range(self.num_ft): | |
| self.up_convs.append( | |
| DeconvBlock(num_filter // (2 ** i), num_filter // (2 ** (i + 1)), kernel_size, stride, padding, | |
| bias=bias, bilinear=bilinear) | |
| ) | |
| self.down_convs.append( | |
| ConvBlock(num_filter // (2 ** (i + 1)), num_filter // (2 ** i), kernel_size, stride, padding, bias=bias, | |
| maxpool=maxpool) | |
| ) | |
| def forward(self, ft_l, ft_h_list): | |
| ft_fusion = ft_l | |
| for i in range(len(ft_h_list)): | |
| ft = ft_fusion | |
| for j in range(self.num_ft - i): | |
| ft = self.up_convs[j](ft) | |
| ft = ft - ft_h_list[i] | |
| for j in range(self.num_ft - i): | |
| ft = self.down_convs[self.num_ft - i - j - 1](ft) | |
| ft_fusion = ft_fusion + ft | |
| return ft_fusion | |
| class ConvLayer(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True): | |
| super(ConvLayer, self).__init__() | |
| reflection_padding = kernel_size // 2 | |
| self.reflection_pad = nn.ReflectionPad2d(reflection_padding) | |
| self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) | |
| def forward(self, x): | |
| out = self.reflection_pad(x) | |
| out = self.conv2d(out) | |
| return out | |
| class UpsampleConvLayer(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride): | |
| super(UpsampleConvLayer, self).__init__() | |
| self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride) | |
| def forward(self, x): | |
| out = self.conv2d(x) | |
| return out | |
| class AddRelu(nn.Module): | |
| """It is for adding two feed forwards to the output of the two following conv layers in expanding path | |
| """ | |
| def __init__(self) -> None: | |
| super(AddRelu, self).__init__() | |
| self.relu = nn.PReLU() | |
| def forward(self, input_tensor1, input_tensor2, input_tensor3): | |
| x = input_tensor1 + input_tensor2 + input_tensor3 | |
| return self.relu(x) | |
| class BasicBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, mid_channels=None): | |
| super(BasicBlock, self).__init__() | |
| if not mid_channels: | |
| mid_channels = out_channels | |
| self.conv1 = ConvLayer(in_channels, mid_channels, kernel_size=3, stride=1) | |
| self.bn1 = nn.BatchNorm2d(mid_channels, momentum=0.1) | |
| self.relu = nn.PReLU() | |
| self.conv2 = ConvLayer(mid_channels, out_channels, kernel_size=3, stride=1) | |
| self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.1) | |
| self.conv3 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1) | |
| def forward(self, x): | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| residual = self.conv3(x) | |
| out = out + residual | |
| out = self.relu(out) | |
| return out | |
| class Bottleneck(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(Bottleneck, self).__init__() | |
| self.conv1 = ConvLayer(in_channels, out_channels, kernel_size=3, stride=1) | |
| self.bn1 = nn.BatchNorm2d(out_channels, momentum=0.1) | |
| self.conv2 = ConvLayer(out_channels, out_channels, kernel_size=3, stride=1) | |
| self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.1) | |
| self.conv3 = ConvLayer(out_channels, out_channels, kernel_size=3, stride=1) | |
| self.bn3 = nn.BatchNorm2d(out_channels, momentum=0.1) | |
| self.conv4 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1) | |
| self.relu = nn.PReLU() | |
| def forward(self, x): | |
| out = self.conv1(x) | |
| out = self.bn1(out) | |
| out = self.relu(out) | |
| out = self.conv2(out) | |
| out = self.bn2(out) | |
| out = self.relu(out) | |
| out = self.conv3(out) | |
| out = self.bn3(out) | |
| residual = self.conv4(x) | |
| out = out + residual | |
| out = self.relu(out) | |
| return out | |
| class PPM(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(PPM, self).__init__() | |
| self.pool_sizes = [1, 2, 3, 6] # subregion size in each level | |
| self.num_levels = len(self.pool_sizes) # number of pyramid levels | |
| self.conv_layers = nn.ModuleList() | |
| for i in range(self.num_levels): | |
| self.conv_layers.append(nn.Sequential( | |
| nn.AdaptiveAvgPool2d(output_size=self.pool_sizes[i]), | |
| nn.Conv2d(in_channels, in_channels // self.num_levels, kernel_size=1), | |
| nn.BatchNorm2d(in_channels // self.num_levels), | |
| nn.ReLU(inplace=True) | |
| )) | |
| self.out_conv = nn.Conv2d(in_channels * 2, out_channels, kernel_size=1, stride=1) | |
| def forward(self, x): | |
| input_size = x.size()[2:] # get input size | |
| output = [x] | |
| # pyramid pooling | |
| for i in range(self.num_levels): | |
| out = self.conv_layers[i](x) | |
| out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=True) | |
| output.append(out) | |
| # concatenate features from different levels | |
| output = torch.cat(output, dim=1) | |
| output = self.out_conv(output) | |
| return output | |
| class MCDNet(nn.Module): | |
| def __init__(self, in_channels=4, num_classes=4, maxpool=False, bilinear=False) -> None: | |
| super(MCDNet, self).__init__() | |
| level = 1 | |
| # encoder | |
| self.conv_input = ConvLayer(in_channels, 32 * level, kernel_size=3, stride=2) | |
| self.dense0 = BasicBlock(32 * level, 32 * level) | |
| self.conv2x = ConvLayer(32 * level, 64 * level, kernel_size=3, stride=2) | |
| self.dense1 = BasicBlock(64 * level, 64 * level) | |
| self.conv4x = ConvLayer(64 * level, 128 * level, kernel_size=3, stride=2) | |
| self.dense2 = BasicBlock(128 * level, 128 * level) | |
| self.conv8x = ConvLayer(128 * level, 256 * level, kernel_size=3, stride=2) | |
| self.dense3 = BasicBlock(256 * level, 256 * level) | |
| self.conv16x = ConvLayer(256 * level, 512 * level, kernel_size=3, stride=2) | |
| self.dense4 = PPM(512 * level, 512 * level) | |
| # dpff | |
| self.dpffm = DPFF([32, 64, 128, 256, 512]) | |
| # decoder | |
| self.convd16x = UpsampleConvLayer(512 * level, 256 * level, kernel_size=3, stride=2) | |
| self.fusion4 = FusionBlock(256 * level, 3, maxpool=maxpool, bilinear=bilinear) | |
| self.dense_4 = Bottleneck(512 * level, 256 * level) | |
| self.add_block4 = AddRelu() | |
| self.convd8x = UpsampleConvLayer(256 * level, 128 * level, kernel_size=3, stride=2) | |
| self.fusion3 = FusionBlock(128 * level, 2, maxpool=maxpool, bilinear=bilinear) | |
| self.dense_3 = Bottleneck(256 * level, 128 * level) | |
| self.add_block3 = AddRelu() | |
| self.convd4x = UpsampleConvLayer(128 * level, 64 * level, kernel_size=3, stride=2) | |
| self.fusion2 = FusionBlock(64 * level, 1, maxpool=maxpool, bilinear=bilinear) | |
| self.dense_2 = Bottleneck(128 * level, 64 * level) | |
| self.add_block2 = AddRelu() | |
| self.convd2x = UpsampleConvLayer(64 * level, 32 * level, kernel_size=3, stride=2) | |
| self.dense_1 = Bottleneck(64 * level, 32 * level) | |
| self.add_block1 = AddRelu() | |
| self.head = UpsampleConvLayer(32 * level, num_classes, kernel_size=3, stride=2) | |
| self.apply(self._weights_init) | |
| def _weights_init(self, m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_normal_(m.weight) | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| elif isinstance(m, nn.BatchNorm2d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| def get_lr_data(self, x: torch.Tensor) -> torch.Tensor: | |
| images = x.cpu().permute(0, 2, 3, 1).numpy() | |
| batch_size = images.shape[0] | |
| lr = [] | |
| for i in range(batch_size): | |
| lr_image = cv2.cvtColor(images[i], cv2.COLOR_RGB2BGR) | |
| lr_image = image_dehazer.remove_haze(lr_image, showHazeTransmissionMap=False)[0] | |
| lr_image = cv2.cvtColor(lr_image, cv2.COLOR_BGR2RGB) | |
| max_pix = np.max(lr_image) | |
| min_pix = np.min(lr_image) | |
| lr_image = (lr_image - min_pix) / (max_pix - min_pix) | |
| lr_image = np.clip(lr_image, 0, 1) | |
| lr_tensor = torch.from_numpy(lr_image).permute(2, 0, 1).float() | |
| lr.append(lr_tensor) | |
| return torch.stack(lr, dim=0).to(x.device) | |
| def forward(self, x1): | |
| x2 = self.get_lr_data(x1) | |
| # encoder1 | |
| res1x_1 = self.conv_input(x1) | |
| res1x_1 = self.dense0(res1x_1) | |
| res2x_1 = self.conv2x(res1x_1) | |
| res2x_1 = self.dense1(res2x_1) | |
| res4x_1 = self.conv4x(res2x_1) | |
| res4x_1 = self.dense2(res4x_1) | |
| res8x_1 = self.conv8x(res4x_1) | |
| res8x_1 = self.dense3(res8x_1) | |
| res16x_1 = self.conv16x(res8x_1) | |
| res16x_1 = self.dense4(res16x_1) | |
| # encoder2 | |
| res1x_2 = self.conv_input(x2) | |
| res1x_2 = self.dense0(res1x_2) | |
| res2x_2 = self.conv2x(res1x_2) | |
| res2x_2 = self.dense1(res2x_2) | |
| res4x_2 = self.conv4x(res2x_2) | |
| res4x_2 = self.dense2(res4x_2) | |
| res8x_2 = self.conv8x(res4x_2) | |
| res8x_2 = self.dense3(res8x_2) | |
| res16x_2 = self.conv16x(res8x_2) | |
| res16x_2 = self.dense4(res16x_2) | |
| # dual-perspective feature fusion | |
| res1x, res2x, res4x, res8x, res16x = self.dpffm( | |
| [res1x_1, res2x_1, res4x_1, res8x_1, res16x_1], | |
| [res1x_2, res2x_2, res4x_2, res8x_2, res16x_2] | |
| ) | |
| # decoder | |
| res8x1 = self.convd16x(res16x) | |
| res8x1 = F.interpolate(res8x1, res8x.size()[2:], mode='bilinear') | |
| res8x2 = self.fusion4(res8x, [res1x, res2x, res4x]) | |
| res8x2 = torch.cat([res8x1, res8x2], dim=1) | |
| res8x2 = self.dense_4(res8x2) | |
| res8x2 = self.add_block4(res8x1, res8x, res8x2) | |
| res4x1 = self.convd8x(res8x2) | |
| res4x1 = F.interpolate(res4x1, res4x.size()[2:], mode='bilinear') | |
| res4x2 = self.fusion3(res4x, [res1x, res2x]) | |
| res4x2 = torch.cat([res4x1, res4x2], dim=1) | |
| res4x2 = self.dense_3(res4x2) | |
| res4x2 = self.add_block3(res4x1, res4x, res4x2) | |
| res2x1 = self.convd4x(res4x2) | |
| res2x1 = F.interpolate(res2x1, res2x.size()[2:], mode='bilinear') | |
| res2x2 = self.fusion2(res2x, [res1x]) | |
| res2x2 = torch.cat([res2x1, res2x2], dim=1) | |
| res2x2 = self.dense_2(res2x2) | |
| res2x2 = self.add_block2(res2x1, res2x, res2x2) | |
| res1x1 = self.convd2x(res2x2) | |
| res1x1 = F.interpolate(res1x1, res1x.size()[2:], mode='bilinear') | |
| res1x2 = torch.cat([res1x1, res1x], dim=1) | |
| res1x2 = self.dense_1(res1x2) | |
| res1x2 = self.add_block1(res1x1, res1x, res1x2) | |
| out = self.head(res1x2) | |
| out = F.interpolate(out, x1.size()[2:], mode='bilinear') | |
| return out | |
| def lr_lambda(epoch): | |
| return (1 - epoch / 50) ** 0.9 | |
| if __name__ == "__main__": | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # device = 'cpu' | |
| model = MCDNet(in_channels=3, num_classes=7).to(device) | |
| fake_img = torch.randn(size=(2, 3, 256, 256)).to(device) | |
| out = model(fake_img).detach().cpu() | |
| print(out.shape) | |
| # torch.Size([2, 7, 256, 256]) | |