import torch import torch.nn as nn import torchvision resnet = torchvision.models.resnet.resnet50(pretrained=True) from .munet_transformer import transmunet import cv2 import numpy as np class ConvBlock(nn.Module): """ Helper module that consists of a Conv -> BN -> ReLU """ def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.with_nonlinearity = with_nonlinearity def forward(self, x): x = self.conv(x) x = self.bn(x) if self.with_nonlinearity: x = self.relu(x) return x class Bridge(nn.Module): """ This is the middle layer of the UNet which just consists of some """ def __init__(self, in_channels, out_channels): super().__init__() self.bridge = nn.Sequential( ConvBlock(in_channels, out_channels), ConvBlock(out_channels, out_channels) ) def forward(self, x): return self.bridge(x) class UpBlockForUNetWithResNet50(nn.Module): """ Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock """ def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None, upsampling_method="conv_transpose"): super().__init__() if up_conv_in_channels == None: up_conv_in_channels = in_channels if up_conv_out_channels == None: up_conv_out_channels = out_channels if upsampling_method == "conv_transpose": self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2) elif upsampling_method == "bilinear": self.upsample = nn.Sequential( nn.Upsample(mode='bilinear', scale_factor=2), nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) ) self.conv_block_1 = ConvBlock(in_channels, out_channels) self.conv_block_2 = ConvBlock(out_channels, out_channels) def forward(self, up_x, down_x): """ :param up_x: this is the output from the previous up block :param down_x: this is the output from the down block :return: upsampled feature map """ x = self.upsample(up_x) x = torch.cat([x, down_x], 1) x = self.conv_block_1(x) x = self.conv_block_2(x) return x class SE_Block(nn.Module): def __init__(self, c, r=16): super().__init__() self.squeeze = nn.AdaptiveAvgPool2d(1) self.excitation = nn.Sequential( nn.Linear(c, c // r, bias=False), nn.ReLU(inplace=True), nn.Linear(c // r, c, bias=False), nn.Sigmoid() ) def forward(self, x): bs, c, _, _ = x.shape y = self.squeeze(x).view(bs, c) y = self.excitation(y).view(bs, c, 1, 1) x = x * y.expand_as(x) return y class TransMUNet(nn.Module): DEPTH = 6 def __init__(self, n_classes=2, patch_size: int = 16, emb_size: int = 512, img_size: int = 256, n_channels=3, depth: int = 4, n_regions: int = (256 // 16) ** 2, output_ch: int = 1, bilinear=True): super().__init__() self.n_classes = n_classes self.transformer = transmunet(in_channels=n_channels, patch_size=patch_size, emb_size=emb_size, img_size=img_size, depth=depth, n_regions=n_regions) resnet = torchvision.models.resnet.resnet50(pretrained=True) down_blocks = [] up_blocks = [] self.input_block = nn.Sequential(*list(resnet.children()))[:3] self.input_pool = list(resnet.children())[3] for bottleneck in list(resnet.children()): if isinstance(bottleneck, nn.Sequential): down_blocks.append(bottleneck) self.down_blocks = nn.ModuleList(down_blocks) self.bridge = Bridge(2048, 2048) up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024)) up_blocks.append(UpBlockForUNetWithResNet50(1024, 512)) up_blocks.append(UpBlockForUNetWithResNet50(512, 256)) up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128, up_conv_in_channels=256, up_conv_out_channels=128)) up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64, up_conv_in_channels=128, up_conv_out_channels=64)) self.up_blocks = nn.ModuleList(up_blocks) self.out = nn.Conv2d(128, n_classes, kernel_size=1, stride=1) self.boundary = nn.Sequential(nn.Conv2d(64, 32, kernel_size=1, stride=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(32, 1, kernel_size=1, stride=1, bias=False), nn.Sigmoid()) self.se = SE_Block(c=64) def forward(self, x, with_additional=False): [global_contexual, regional_distribution, region_coeff] = self.transformer(x) pre_pools = dict() pre_pools[f"layer_0"] = x x = self.input_block(x) pre_pools[f"layer_1"] = x x = self.input_pool(x) for i, block in enumerate(self.down_blocks, 2): x = block(x) if i == (TransMUNet.DEPTH - 1): continue pre_pools[f"layer_{i}"] = x x = self.bridge(x) for i, block in enumerate(self.up_blocks, 1): key = f"layer_{TransMUNet.DEPTH - 1 - i}" x = block(x, pre_pools[key]) B_out = self.boundary(x) B = B_out.repeat_interleave(int(x.shape[1]), dim=1) x = self.se(x) x = x + B att = regional_distribution.repeat_interleave(int(x.shape[1]), dim=1) x = x * att x = torch.cat((x, global_contexual), dim=1) x = self.out(x) # print(x.shape) del pre_pools x = torch.sigmoid(x) # print('x shape: ', x.shape) if with_additional: return x, B_out, region_coeff else: return x