|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from scripts.maskscratches import Downsample |
|
|
from scripts.maskscratches import DataParallelWithCallback |
|
|
|
|
|
|
|
|
class UNet(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels=3, |
|
|
out_channels=3, |
|
|
depth=5, |
|
|
conv_num=2, |
|
|
wf=6, |
|
|
padding=True, |
|
|
batch_norm=True, |
|
|
up_mode="upsample", |
|
|
with_tanh=False, |
|
|
sync_bn=True, |
|
|
antialiasing=True, |
|
|
): |
|
|
""" |
|
|
Implementation of |
|
|
U-Net: Convolutional Networks for Biomedical Image Segmentation |
|
|
(Ronneberger et al., 2015) |
|
|
https://arxiv.org/abs/1505.04597 |
|
|
Using the default arguments will yield the exact version used |
|
|
in the original paper |
|
|
Args: |
|
|
in_channels (int): number of input channels |
|
|
out_channels (int): number of output channels |
|
|
depth (int): depth of the network |
|
|
wf (int): number of filters in the first layer is 2**wf |
|
|
padding (bool): if True, apply padding such that the input shape |
|
|
is the same as the output. |
|
|
This may introduce artifacts |
|
|
batch_norm (bool): Use BatchNorm after layers with an |
|
|
activation function |
|
|
up_mode (str): one of 'upconv' or 'upsample'. |
|
|
'upconv' will use transposed convolutions for |
|
|
learned upsampling. |
|
|
'upsample' will use bilinear upsampling. |
|
|
""" |
|
|
|
|
|
super().__init__() |
|
|
assert up_mode in ("upconv", "upsample") |
|
|
self.padding = padding |
|
|
self.depth = depth - 1 |
|
|
prev_channels = in_channels |
|
|
|
|
|
self.first = nn.Sequential( |
|
|
*[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)] |
|
|
) |
|
|
prev_channels = 2 ** wf |
|
|
|
|
|
self.down_path = nn.ModuleList() |
|
|
self.down_sample = nn.ModuleList() |
|
|
for i in range(depth): |
|
|
if antialiasing and depth > 0: |
|
|
self.down_sample.append( |
|
|
nn.Sequential( |
|
|
*[ |
|
|
nn.ReflectionPad2d(1), |
|
|
nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0), |
|
|
nn.BatchNorm2d(prev_channels), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
Downsample(channels=prev_channels, stride=2), |
|
|
] |
|
|
) |
|
|
) |
|
|
else: |
|
|
self.down_sample.append( |
|
|
nn.Sequential( |
|
|
*[ |
|
|
nn.ReflectionPad2d(1), |
|
|
nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0), |
|
|
nn.BatchNorm2d(prev_channels), |
|
|
nn.LeakyReLU(0.2, True), |
|
|
] |
|
|
) |
|
|
) |
|
|
self.down_path.append( |
|
|
UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm) |
|
|
) |
|
|
prev_channels = 2 ** (wf + i + 1) |
|
|
|
|
|
self.up_path = nn.ModuleList() |
|
|
for i in reversed(range(depth)): |
|
|
self.up_path.append( |
|
|
UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm) |
|
|
) |
|
|
prev_channels = 2 ** (wf + i) |
|
|
|
|
|
if with_tanh: |
|
|
self.last = nn.Sequential( |
|
|
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()] |
|
|
) |
|
|
else: |
|
|
self.last = nn.Sequential( |
|
|
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)] |
|
|
) |
|
|
|
|
|
if sync_bn: |
|
|
self = DataParallelWithCallback(self) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.first(x) |
|
|
|
|
|
blocks = [] |
|
|
for i, down_block in enumerate(self.down_path): |
|
|
blocks.append(x) |
|
|
x = self.down_sample[i](x) |
|
|
x = down_block(x) |
|
|
|
|
|
for i, up in enumerate(self.up_path): |
|
|
x = up(x, blocks[-i - 1]) |
|
|
|
|
|
return self.last(x) |
|
|
|
|
|
|
|
|
class UNetConvBlock(nn.Module): |
|
|
def __init__(self, conv_num, in_size, out_size, padding, batch_norm): |
|
|
super(UNetConvBlock, self).__init__() |
|
|
block = [] |
|
|
|
|
|
for _ in range(conv_num): |
|
|
block.append(nn.ReflectionPad2d(padding=int(padding))) |
|
|
block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0)) |
|
|
if batch_norm: |
|
|
block.append(nn.BatchNorm2d(out_size)) |
|
|
block.append(nn.LeakyReLU(0.2, True)) |
|
|
in_size = out_size |
|
|
|
|
|
self.block = nn.Sequential(*block) |
|
|
|
|
|
def forward(self, x): |
|
|
out = self.block(x) |
|
|
return out |
|
|
|
|
|
|
|
|
class UNetUpBlock(nn.Module): |
|
|
def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm): |
|
|
super(UNetUpBlock, self).__init__() |
|
|
if up_mode == "upconv": |
|
|
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) |
|
|
elif up_mode == "upsample": |
|
|
self.up = nn.Sequential( |
|
|
nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False), |
|
|
nn.ReflectionPad2d(1), |
|
|
nn.Conv2d(in_size, out_size, kernel_size=3, padding=0), |
|
|
) |
|
|
|
|
|
self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm) |
|
|
|
|
|
def center_crop(self, layer, target_size): |
|
|
_, _, layer_height, layer_width = layer.size() |
|
|
diff_y = (layer_height - target_size[0]) // 2 |
|
|
diff_x = (layer_width - target_size[1]) // 2 |
|
|
return layer[:, :, diff_y: (diff_y + target_size[0]), diff_x: (diff_x + target_size[1])] |
|
|
|
|
|
def forward(self, x, bridge): |
|
|
up = self.up(x) |
|
|
crop1 = self.center_crop(bridge, up.shape[2:]) |
|
|
out = torch.cat([up, crop1], 1) |
|
|
out = self.conv_block(out) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|