|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torchvision
|
|
|
|
|
|
resnet = torchvision.models.resnet.resnet50(pretrained=True)
|
|
|
from .munet_transformer import transmunet
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
class ConvBlock(nn.Module):
|
|
|
"""
|
|
|
Helper module that consists of a Conv -> BN -> ReLU
|
|
|
"""
|
|
|
|
|
|
def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
|
|
|
super().__init__()
|
|
|
self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
|
|
|
self.bn = nn.BatchNorm2d(out_channels)
|
|
|
self.relu = nn.ReLU()
|
|
|
self.with_nonlinearity = with_nonlinearity
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.conv(x)
|
|
|
x = self.bn(x)
|
|
|
if self.with_nonlinearity:
|
|
|
x = self.relu(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class Bridge(nn.Module):
|
|
|
"""
|
|
|
This is the middle layer of the UNet which just consists of some
|
|
|
"""
|
|
|
|
|
|
def __init__(self, in_channels, out_channels):
|
|
|
super().__init__()
|
|
|
self.bridge = nn.Sequential(
|
|
|
ConvBlock(in_channels, out_channels),
|
|
|
ConvBlock(out_channels, out_channels)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.bridge(x)
|
|
|
|
|
|
|
|
|
class UpBlockForUNetWithResNet50(nn.Module):
|
|
|
"""
|
|
|
Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock
|
|
|
"""
|
|
|
|
|
|
def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
|
|
|
upsampling_method="conv_transpose"):
|
|
|
super().__init__()
|
|
|
|
|
|
if up_conv_in_channels == None:
|
|
|
up_conv_in_channels = in_channels
|
|
|
if up_conv_out_channels == None:
|
|
|
up_conv_out_channels = out_channels
|
|
|
|
|
|
if upsampling_method == "conv_transpose":
|
|
|
self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
|
|
|
elif upsampling_method == "bilinear":
|
|
|
self.upsample = nn.Sequential(
|
|
|
nn.Upsample(mode='bilinear', scale_factor=2),
|
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
|
|
|
)
|
|
|
self.conv_block_1 = ConvBlock(in_channels, out_channels)
|
|
|
self.conv_block_2 = ConvBlock(out_channels, out_channels)
|
|
|
|
|
|
def forward(self, up_x, down_x):
|
|
|
"""
|
|
|
|
|
|
:param up_x: this is the output from the previous up block
|
|
|
:param down_x: this is the output from the down block
|
|
|
:return: upsampled feature map
|
|
|
"""
|
|
|
x = self.upsample(up_x)
|
|
|
x = torch.cat([x, down_x], 1)
|
|
|
x = self.conv_block_1(x)
|
|
|
x = self.conv_block_2(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class SE_Block(nn.Module):
|
|
|
def __init__(self, c, r=16):
|
|
|
super().__init__()
|
|
|
self.squeeze = nn.AdaptiveAvgPool2d(1)
|
|
|
self.excitation = nn.Sequential(
|
|
|
nn.Linear(c, c // r, bias=False),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Linear(c // r, c, bias=False),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
bs, c, _, _ = x.shape
|
|
|
y = self.squeeze(x).view(bs, c)
|
|
|
y = self.excitation(y).view(bs, c, 1, 1)
|
|
|
x = x * y.expand_as(x)
|
|
|
return y
|
|
|
|
|
|
|
|
|
class TransMUNet(nn.Module):
|
|
|
DEPTH = 6
|
|
|
|
|
|
def __init__(self, n_classes=2,
|
|
|
patch_size: int = 16,
|
|
|
emb_size: int = 512,
|
|
|
img_size: int = 256,
|
|
|
n_channels=3,
|
|
|
depth: int = 4,
|
|
|
n_regions: int = (256 // 16) ** 2,
|
|
|
output_ch: int = 1,
|
|
|
bilinear=True):
|
|
|
super().__init__()
|
|
|
self.n_classes = n_classes
|
|
|
self.transformer = transmunet(in_channels=n_channels,
|
|
|
patch_size=patch_size,
|
|
|
emb_size=emb_size,
|
|
|
img_size=img_size,
|
|
|
depth=depth,
|
|
|
n_regions=n_regions)
|
|
|
resnet = torchvision.models.resnet.resnet50(pretrained=True)
|
|
|
down_blocks = []
|
|
|
up_blocks = []
|
|
|
self.input_block = nn.Sequential(*list(resnet.children()))[:3]
|
|
|
self.input_pool = list(resnet.children())[3]
|
|
|
for bottleneck in list(resnet.children()):
|
|
|
if isinstance(bottleneck, nn.Sequential):
|
|
|
down_blocks.append(bottleneck)
|
|
|
self.down_blocks = nn.ModuleList(down_blocks)
|
|
|
self.bridge = Bridge(2048, 2048)
|
|
|
up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024))
|
|
|
up_blocks.append(UpBlockForUNetWithResNet50(1024, 512))
|
|
|
up_blocks.append(UpBlockForUNetWithResNet50(512, 256))
|
|
|
up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128,
|
|
|
up_conv_in_channels=256, up_conv_out_channels=128))
|
|
|
up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64,
|
|
|
up_conv_in_channels=128, up_conv_out_channels=64))
|
|
|
|
|
|
self.up_blocks = nn.ModuleList(up_blocks)
|
|
|
|
|
|
self.out = nn.Conv2d(128, n_classes, kernel_size=1, stride=1)
|
|
|
|
|
|
self.boundary = nn.Sequential(nn.Conv2d(64, 32, kernel_size=1, stride=1),
|
|
|
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(32, 1, kernel_size=1, stride=1, bias=False),
|
|
|
nn.Sigmoid())
|
|
|
|
|
|
self.se = SE_Block(c=64)
|
|
|
|
|
|
def forward(self, x, with_additional=False):
|
|
|
[global_contexual, regional_distribution, region_coeff] = self.transformer(x)
|
|
|
|
|
|
pre_pools = dict()
|
|
|
pre_pools[f"layer_0"] = x
|
|
|
x = self.input_block(x)
|
|
|
pre_pools[f"layer_1"] = x
|
|
|
x = self.input_pool(x)
|
|
|
|
|
|
for i, block in enumerate(self.down_blocks, 2):
|
|
|
x = block(x)
|
|
|
if i == (TransMUNet.DEPTH - 1):
|
|
|
continue
|
|
|
pre_pools[f"layer_{i}"] = x
|
|
|
|
|
|
x = self.bridge(x)
|
|
|
|
|
|
for i, block in enumerate(self.up_blocks, 1):
|
|
|
key = f"layer_{TransMUNet.DEPTH - 1 - i}"
|
|
|
x = block(x, pre_pools[key])
|
|
|
|
|
|
B_out = self.boundary(x)
|
|
|
B = B_out.repeat_interleave(int(x.shape[1]), dim=1)
|
|
|
x = self.se(x)
|
|
|
x = x + B
|
|
|
att = regional_distribution.repeat_interleave(int(x.shape[1]), dim=1)
|
|
|
x = x * att
|
|
|
x = torch.cat((x, global_contexual), dim=1)
|
|
|
x = self.out(x)
|
|
|
|
|
|
del pre_pools
|
|
|
x = torch.sigmoid(x)
|
|
|
|
|
|
if with_additional:
|
|
|
return x, B_out, region_coeff
|
|
|
else:
|
|
|
return x
|
|
|
|