| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import segmentation_models_pytorch as smp | |
| class SegformerBranch(nn.Module): | |
| def __init__(self, in_channels=4, classes=4): | |
| super(SegformerBranch, self).__init__() | |
| self.segformer = smp.Segformer( | |
| encoder_name="mobilenet_v2", | |
| encoder_weights=None, | |
| in_channels=in_channels, | |
| classes=classes, | |
| ) | |
| def forward(self, x): | |
| return self.segformer(x) | |
| class UNetBranch(nn.Module): | |
| def __init__(self, in_channels=4, classes=4, benchmark=False): | |
| super(UNetBranch, self).__init__() | |
| self.unet = smp.Unet( | |
| encoder_name="mobilenet_v2", | |
| encoder_weights=None, | |
| in_channels=in_channels, | |
| classes=classes, | |
| ) | |
| self.benchmark = benchmark | |
| def forward(self, x): | |
| results = self.unet(x) | |
| if self.benchmark: | |
| results = torch.sigmoid(results) | |
| return results | |
| class UNetPlusPlusBranch(nn.Module): | |
| def __init__(self, in_channels=4, classes=4, benchmark=False): | |
| super(UNetPlusPlusBranch, self).__init__() | |
| self.unet_pp = smp.UnetPlusPlus( | |
| encoder_name="mobilenet_v2", | |
| encoder_weights=None, | |
| in_channels=in_channels, | |
| classes=classes | |
| ) | |
| self.benchmark = benchmark | |
| def forward(self, x): | |
| results = self.unet_pp(x) | |
| if self.benchmark: | |
| results = torch.sigmoid(results) | |
| return results | |
| class DeepLabV3Branch(nn.Module): | |
| def __init__(self, in_channels=4, classes=4): | |
| super(DeepLabV3Branch, self).__init__() | |
| self.deeplabv3 = smp.DeepLabV3( | |
| encoder_name="mobilenet_v2", | |
| encoder_weights=None, | |
| in_channels=in_channels, | |
| classes=classes, | |
| ) | |
| def forward(self, x): | |
| return self.deeplabv3(x) | |
| class PixelWiseNet(nn.Module): | |
| def __init__(self, in_channels=4, out_channels=4, base_channels=32): | |
| super(PixelWiseNet, self).__init__() | |
| self.conv1 = nn.Conv2d(in_channels, base_channels, kernel_size=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(base_channels) | |
| self.conv2 = nn.Conv2d(base_channels, base_channels, kernel_size=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(base_channels) | |
| self.conv3 = nn.Conv2d(base_channels, out_channels, kernel_size=1, bias=False) | |
| def forward(self, x): | |
| x = F.relu(self.bn1(self.conv1(x))) | |
| x = F.relu(self.bn2(self.conv2(x))) | |
| x = self.conv3(x) | |
| return x | |
| class CombinedNet(nn.Module): | |
| def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False): | |
| super(CombinedNet, self).__init__() | |
| self.seg_branch = SegformerBranch(in_channels=in_channels, classes=classes) | |
| self.pixel_branch = PixelWiseNet(in_channels=in_channels, out_channels=classes, base_channels=base_channels) | |
| self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False) | |
| self.benchmark = benchmark | |
| def forward(self, x): | |
| seg_out = self.seg_branch(x) | |
| pixel_out = self.pixel_branch(x) | |
| fused = seg_out + pixel_out | |
| out = self.fusion_conv(fused) | |
| if self.benchmark: | |
| out = torch.sigmoid(out) | |
| return out | |
| class CombinedNet3(nn.Module): | |
| def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False): | |
| super(CombinedNet3, self).__init__() | |
| self.seg_branch = UNetPlusPlusBranch(in_channels=in_channels, classes=classes) | |
| self.pixel_branch = PixelWiseNet( | |
| in_channels=in_channels, | |
| out_channels=classes, | |
| base_channels=base_channels, | |
| ) | |
| self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False) | |
| self.benchmark = benchmark | |
| def forward(self, x): | |
| seg_out = self.seg_branch(x) | |
| pixel_out = self.pixel_branch(x) | |
| fused = seg_out + pixel_out | |
| out = self.fusion_conv(fused) | |
| if self.benchmark: | |
| out = torch.sigmoid(out) | |
| return out | |
| class CombinedNet4(nn.Module): | |
| def __init__(self, in_channels=4, classes=4, base_channels=32, benchmark=False): | |
| super(CombinedNet4, self).__init__() | |
| self.seg_branch = DeepLabV3Branch(in_channels=in_channels, classes=classes) | |
| self.pixel_branch = PixelWiseNet( | |
| in_channels=in_channels, | |
| out_channels=classes, | |
| base_channels=base_channels, | |
| ) | |
| self.fusion_conv = nn.Conv2d(classes, classes, kernel_size=1, bias=False) | |
| self.benchmark = benchmark | |
| def forward(self, x): | |
| seg_out = self.seg_branch(x) | |
| pixel_out = self.pixel_branch(x) | |
| fused = seg_out + pixel_out | |
| out = self.fusion_conv(fused) | |
| if self.benchmark: | |
| out= torch.sigmoid(out) | |
| return out | |