import torch import torch.nn as nn import torch.nn.functional as F import segmentation_models_pytorch as smp class SegformerBranch(nn.Module): def __init__(self, in_channels=4, classes=4): super(SegformerBranch, self).__init__() self.segformer = smp.Segformer( encoder_name="mobilenet_v2", encoder_weights=None, in_channels=in_channels, classes=classes, ) def forward(self, x): return self.segformer(x) class UNetBranch(nn.Module): def __init__(self, in_channels=4, classes=4, benchmark=False): super(UNetBranch, self).__init__() self.unet = smp.Unet( encoder_name="mobilenet_v2", encoder_weights=None, in_channels=in_channels, classes=classes, ) self.benchmark = benchmark def forward(self, x): results = self.unet(x) if self.benchmark: results = torch.sigmoid(results) return results class UNetPlusPlusBranch(nn.Module): def __init__(self, in_channels=4, classes=4, benchmark=False): super(UNetPlusPlusBranch, self).__init__() self.unet_pp = smp.UnetPlusPlus( encoder_name="mobilenet_v2", encoder_weights=None, in_channels=in_channels, classes=classes ) self.benchmark = benchmark def forward(self, x): results = self.unet_pp(x) if self.benchmark: results = torch.sigmoid(results) return results class DeepLabV3Branch(nn.Module): def __init__(self, in_channels=4, classes=4): super(DeepLabV3Branch, self).__init__() self.deeplabv3 = smp.DeepLabV3( encoder_name="mobilenet_v2", encoder_weights=None, in_channels=in_channels, classes=classes, ) def forward(self, x): return self.deeplabv3(x) class PixelWiseNet(nn.Module): def __init__(self, in_channels=4, out_channels=4, base_channels=32): super(PixelWiseNet, self).__init__() self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(base_channels) self.conv2 = nn.Conv2d(base_channels, base_channels, kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(base_channels) self.conv3 = nn.Conv2d(base_channels, out_channels, kernel_size=1, bias=False) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = self.conv3(x) return x class CombinedNet(nn.Module): def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False): super(CombinedNet, self).__init__() self.seg_branch = SegformerBranch(in_channels=in_channels, classes=classes) self.pixel_branch = PixelWiseNet(in_channels=in_channels, out_channels=classes, base_channels=base_channels) self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False) self.benchmark = benchmark def forward(self, x): seg_out = self.seg_branch(x) pixel_out = self.pixel_branch(x) fused = seg_out + pixel_out out = self.fusion_conv(fused) if self.benchmark: out = torch.sigmoid(out) return out class CombinedNet3(nn.Module): def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False): super(CombinedNet3, self).__init__() self.seg_branch = UNetPlusPlusBranch(in_channels=in_channels, classes=classes) self.pixel_branch = PixelWiseNet( in_channels=in_channels, out_channels=classes, base_channels=base_channels, ) self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False) self.benchmark = benchmark def forward(self, x): seg_out = self.seg_branch(x) pixel_out = self.pixel_branch(x) fused = seg_out + pixel_out out = self.fusion_conv(fused) if self.benchmark: out = torch.sigmoid(out) return out class CombinedNet4(nn.Module): def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False): super(CombinedNet4, self).__init__() self.seg_branch = DeepLabV3Branch(in_channels=in_channels, classes=classes) self.pixel_branch = PixelWiseNet( in_channels=in_channels, out_channels=classes, base_channels=base_channels, ) self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False) self.benchmark = benchmark def forward(self, x): seg_out = self.seg_branch(x) pixel_out = self.pixel_branch(x) fused = seg_out + pixel_out out = self.fusion_conv(fused) if self.benchmark: out= torch.sigmoid(out) return out