|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from torchvision.transforms.functional import rgb_to_grayscale |
|
|
import numpy as np |
|
|
|
|
|
class DoubleConv(nn.Module): |
|
|
"""(convolution => [BN] => ReLU) * 2""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels, mid_channels=None): |
|
|
super().__init__() |
|
|
if not mid_channels: |
|
|
mid_channels = out_channels |
|
|
self.double_conv = nn.Sequential( |
|
|
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.double_conv(x) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Conv2d): |
|
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
|
m.weight.data.normal_(0, math.sqrt(2. / n)) |
|
|
elif isinstance(m, nn.BatchNorm2d): |
|
|
m.weight.data.fill_(1) |
|
|
m.bias.data.zero_() |
|
|
|
|
|
|
|
|
class Down(nn.Module): |
|
|
"""Downscaling with maxpool then double conv""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels): |
|
|
super().__init__() |
|
|
self.maxpool_conv = nn.Sequential( |
|
|
nn.MaxPool2d(2), |
|
|
DoubleConv(in_channels, out_channels) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.maxpool_conv(x) |
|
|
|
|
|
|
|
|
class Up(nn.Module): |
|
|
"""Upscaling then double conv""" |
|
|
|
|
|
def __init__(self, in_channels, out_channels, bilinear=True): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
if bilinear: |
|
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
|
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) |
|
|
else: |
|
|
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) |
|
|
self.conv = DoubleConv(in_channels, out_channels) |
|
|
|
|
|
def forward(self, x1, x2): |
|
|
x1 = self.up(x1) |
|
|
|
|
|
diffY = x2.size()[2] - x1.size()[2] |
|
|
diffX = x2.size()[3] - x1.size()[3] |
|
|
|
|
|
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, |
|
|
diffY // 2, diffY - diffY // 2]) |
|
|
|
|
|
|
|
|
|
|
|
x = torch.cat([x2, x1], dim=1) |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
|
|
|
class SpatialGate(nn.Module): |
|
|
def __init__(self, in_channels): |
|
|
super(SpatialGate, self).__init__() |
|
|
self.spatial = nn.Conv2d(in_channels, 1, kernel_size=1) |
|
|
self.sigmoid = nn.Sigmoid() |
|
|
|
|
|
def forward(self, x): |
|
|
x_out = self.spatial(x) |
|
|
scale = self.sigmoid(x_out) |
|
|
return scale * x |
|
|
|
|
|
|
|
|
|
|
|
class SobelOperator(nn.Module): |
|
|
def __init__(self): |
|
|
super(SobelOperator, self).__init__() |
|
|
self.conv_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False) |
|
|
self.conv_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False) |
|
|
self.conv_x.weight[0].data[:, :, :] = torch.FloatTensor([[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]) |
|
|
self.conv_y.weight[0].data[:, :, :] = torch.FloatTensor([[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]) |
|
|
|
|
|
def forward(self, x): |
|
|
G_x = self.conv_x(x) |
|
|
G_y = self.conv_y(x) |
|
|
grad_mag = torch.sqrt(torch.pow(G_x, 2) + torch.pow(G_y, 2)) |
|
|
return grad_mag |
|
|
|
|
|
|
|
|
class offset_estimator(nn.Sequential): |
|
|
def __init__(self, kernel_size, fwhm, in_channel, mid_channel, out_channel) -> None: |
|
|
super().__init__() |
|
|
model = [] |
|
|
assert len(kernel_size) == len(fwhm), "length error" |
|
|
for i in range(len(kernel_size)): |
|
|
if i == 0: |
|
|
gaussian_weight = torch.FloatTensor(gaussian_2d(kernel_size[i], fwhm=fwhm[i])) |
|
|
gauss_filter = nn.Conv2d(in_channel, mid_channel, kernel_size[i], padding=(kernel_size[i] - 1) // 2, |
|
|
bias=False) |
|
|
gauss_filter.weight[0].data[:, :, :] = gaussian_weight |
|
|
model += [gauss_filter, nn.ReLU(inplace=True)] |
|
|
elif i == len(kernel_size) - 1: |
|
|
gaussian_weight = torch.FloatTensor(gaussian_2d(kernel_size[i], fwhm=fwhm[i])) |
|
|
gauss_filter = nn.Conv2d(mid_channel, out_channel, kernel_size[i], padding=(kernel_size[i] - 1) // 2, |
|
|
bias=False) |
|
|
gauss_filter.weight[0].data[:, :, :] = gaussian_weight |
|
|
model += [gauss_filter, nn.ReLU(inplace=True)] |
|
|
else: |
|
|
gaussian_weight = torch.FloatTensor(gaussian_2d(kernel_size[i], fwhm=fwhm[i])) |
|
|
gauss_filter = nn.Conv2d(mid_channel, mid_channel, kernel_size[i], padding=(kernel_size[i] - 1) // 2, |
|
|
bias=False) |
|
|
gauss_filter.weight[0].data[:, :, :] = gaussian_weight |
|
|
model += [gauss_filter, nn.ReLU(inplace=True)] |
|
|
self.model = nn.Sequential(*model) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.model(x) |
|
|
|
|
|
|
|
|
|
|
|
def logsumexp_2d(tensor): |
|
|
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) |
|
|
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) |
|
|
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() |
|
|
return outputs |
|
|
|
|
|
|
|
|
class Flatten(nn.Module): |
|
|
def forward(self, x): |
|
|
return x.view(x.size(0), -1) |
|
|
|
|
|
|
|
|
class ChannelGate(nn.Module): |
|
|
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): |
|
|
super(ChannelGate, self).__init__() |
|
|
self.gate_channels = gate_channels |
|
|
self.mlp = nn.Sequential( |
|
|
Flatten(), |
|
|
nn.Linear(gate_channels, gate_channels // reduction_ratio), |
|
|
nn.ReLU(), |
|
|
nn.Linear(gate_channels // reduction_ratio, gate_channels) |
|
|
) |
|
|
self.pool_types = pool_types |
|
|
|
|
|
def forward(self, x): |
|
|
channel_att_sum = None |
|
|
for pool_type in self.pool_types: |
|
|
if pool_type == 'avg': |
|
|
avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) |
|
|
channel_att_raw = self.mlp(avg_pool) |
|
|
elif pool_type == 'max': |
|
|
max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) |
|
|
channel_att_raw = self.mlp(max_pool) |
|
|
elif pool_type == 'lp': |
|
|
lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) |
|
|
channel_att_raw = self.mlp(lp_pool) |
|
|
elif pool_type == 'lse': |
|
|
|
|
|
lse_pool = logsumexp_2d(x) |
|
|
channel_att_raw = self.mlp(lse_pool) |
|
|
|
|
|
if channel_att_sum is None: |
|
|
channel_att_sum = channel_att_raw |
|
|
else: |
|
|
channel_att_sum = channel_att_sum + channel_att_raw |
|
|
|
|
|
scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) |
|
|
return x * scale |
|
|
|
|
|
|
|
|
|
|
|
def LBP(image): |
|
|
radius = 2 |
|
|
n_points = 8 * radius |
|
|
method = 'uniform' |
|
|
gray_img = rgb_to_grayscale(image) |
|
|
gray_img = gray_img.squeeze(1) |
|
|
lbf_feature = np.zeros((gray_img.shape[0], gray_img.shape[1], gray_img.shape[2])) |
|
|
for i in range(gray_img.shape[0]): |
|
|
lbf_feature[i] = feature.local_binary_pattern(gray_img[i], n_points, radius, method) |
|
|
return torch.FloatTensor(lbf_feature).unsqueeze(1) |
|
|
|
|
|
|
|
|
class Discriminator(nn.Module): |
|
|
def __init__(self, in_channel): |
|
|
super().__init__() |
|
|
self.in_channel = in_channel |
|
|
|
|
|
def discriminator_block(in_filters, out_filters): |
|
|
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=False)] |
|
|
return layers |
|
|
|
|
|
self.model = nn.Sequential( |
|
|
*discriminator_block(self.in_channel, 4), |
|
|
*discriminator_block(4, 4), |
|
|
*discriminator_block(4, 4), |
|
|
*discriminator_block(4, 4), |
|
|
nn.ZeroPad2d((1, 0, 1, 0)), |
|
|
nn.Conv2d(4, 1, 4, padding=1, bias=False) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.model(x) |
|
|
|
|
|
|
|
|
class Discriminator_new(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
def discriminator_block(in_filters, out_filters, first_block=False): |
|
|
layers = [] |
|
|
layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1)) |
|
|
layers.append(nn.LeakyReLU(0.2, inplace=True)) |
|
|
layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1)) |
|
|
layers.append(nn.LeakyReLU(0.2, inplace=True)) |
|
|
return layers |
|
|
|
|
|
layers = [] |
|
|
in_filters = 3 |
|
|
for i, out_filters in enumerate([4, 6, 8, 10]): |
|
|
layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0))) |
|
|
in_filters = out_filters |
|
|
|
|
|
layers.append(nn.ZeroPad2d((1, 0, 1, 0))) |
|
|
layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1)) |
|
|
|
|
|
self.model = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, img): |
|
|
return self.model(img) |
|
|
|
|
|
|