Tianyinus's picture
init submit
edcf5ee verified
import torch
import torch.nn as nn
import torchvision
resnet = torchvision.models.resnet.resnet50(pretrained=True)
from .munet_transformer import transmunet
import cv2
import numpy as np
class ConvBlock(nn.Module):
"""
Helper module that consists of a Conv -> BN -> ReLU
"""
def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.with_nonlinearity = with_nonlinearity
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.with_nonlinearity:
x = self.relu(x)
return x
class Bridge(nn.Module):
"""
This is the middle layer of the UNet which just consists of some
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.bridge = nn.Sequential(
ConvBlock(in_channels, out_channels),
ConvBlock(out_channels, out_channels)
)
def forward(self, x):
return self.bridge(x)
class UpBlockForUNetWithResNet50(nn.Module):
"""
Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock
"""
def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None,
upsampling_method="conv_transpose"):
super().__init__()
if up_conv_in_channels == None:
up_conv_in_channels = in_channels
if up_conv_out_channels == None:
up_conv_out_channels = out_channels
if upsampling_method == "conv_transpose":
self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
elif upsampling_method == "bilinear":
self.upsample = nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
)
self.conv_block_1 = ConvBlock(in_channels, out_channels)
self.conv_block_2 = ConvBlock(out_channels, out_channels)
def forward(self, up_x, down_x):
"""
:param up_x: this is the output from the previous up block
:param down_x: this is the output from the down block
:return: upsampled feature map
"""
x = self.upsample(up_x)
x = torch.cat([x, down_x], 1)
x = self.conv_block_1(x)
x = self.conv_block_2(x)
return x
class SE_Block(nn.Module):
def __init__(self, c, r=16):
super().__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(c, c // r, bias=False),
nn.ReLU(inplace=True),
nn.Linear(c // r, c, bias=False),
nn.Sigmoid()
)
def forward(self, x):
bs, c, _, _ = x.shape
y = self.squeeze(x).view(bs, c)
y = self.excitation(y).view(bs, c, 1, 1)
x = x * y.expand_as(x)
return y
class TransMUNet(nn.Module):
DEPTH = 6
def __init__(self, n_classes=2,
patch_size: int = 16,
emb_size: int = 512,
img_size: int = 256,
n_channels=3,
depth: int = 4,
n_regions: int = (256 // 16) ** 2,
output_ch: int = 1,
bilinear=True):
super().__init__()
self.n_classes = n_classes
self.transformer = transmunet(in_channels=n_channels,
patch_size=patch_size,
emb_size=emb_size,
img_size=img_size,
depth=depth,
n_regions=n_regions)
resnet = torchvision.models.resnet.resnet50(pretrained=True)
down_blocks = []
up_blocks = []
self.input_block = nn.Sequential(*list(resnet.children()))[:3]
self.input_pool = list(resnet.children())[3]
for bottleneck in list(resnet.children()):
if isinstance(bottleneck, nn.Sequential):
down_blocks.append(bottleneck)
self.down_blocks = nn.ModuleList(down_blocks)
self.bridge = Bridge(2048, 2048)
up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024))
up_blocks.append(UpBlockForUNetWithResNet50(1024, 512))
up_blocks.append(UpBlockForUNetWithResNet50(512, 256))
up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128,
up_conv_in_channels=256, up_conv_out_channels=128))
up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64,
up_conv_in_channels=128, up_conv_out_channels=64))
self.up_blocks = nn.ModuleList(up_blocks)
self.out = nn.Conv2d(128, n_classes, kernel_size=1, stride=1)
self.boundary = nn.Sequential(nn.Conv2d(64, 32, kernel_size=1, stride=1),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, bias=False),
nn.Sigmoid())
self.se = SE_Block(c=64)
def forward(self, x, with_additional=False):
[global_contexual, regional_distribution, region_coeff] = self.transformer(x)
pre_pools = dict()
pre_pools[f"layer_0"] = x
x = self.input_block(x)
pre_pools[f"layer_1"] = x
x = self.input_pool(x)
for i, block in enumerate(self.down_blocks, 2):
x = block(x)
if i == (TransMUNet.DEPTH - 1):
continue
pre_pools[f"layer_{i}"] = x
x = self.bridge(x)
for i, block in enumerate(self.up_blocks, 1):
key = f"layer_{TransMUNet.DEPTH - 1 - i}"
x = block(x, pre_pools[key])
B_out = self.boundary(x)
B = B_out.repeat_interleave(int(x.shape[1]), dim=1)
x = self.se(x)
x = x + B
att = regional_distribution.repeat_interleave(int(x.shape[1]), dim=1)
x = x * att
x = torch.cat((x, global_contexual), dim=1)
x = self.out(x)
# print(x.shape)
del pre_pools
x = torch.sigmoid(x)
# print('x shape: ', x.shape)
if with_additional:
return x, B_out, region_coeff
else:
return x