LMAR / model /module.py
SuperPatrick's picture
Upload 20 files
98feea6 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision.transforms.functional import rgb_to_grayscale
import numpy as np
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True)
)
self.apply(self._init_weights)
def forward(self, x):
return self.double_conv(x)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
# spatial attention
class SpatialGate(nn.Module):
def __init__(self, in_channels):
super(SpatialGate, self).__init__()
self.spatial = nn.Conv2d(in_channels, 1, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x_out = self.spatial(x)
scale = self.sigmoid(x_out)
return scale * x
# sobel
class SobelOperator(nn.Module):
def __init__(self):
super(SobelOperator, self).__init__()
self.conv_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
self.conv_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
self.conv_x.weight[0].data[:, :, :] = torch.FloatTensor([[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]])
self.conv_y.weight[0].data[:, :, :] = torch.FloatTensor([[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]])
def forward(self, x):
G_x = self.conv_x(x)
G_y = self.conv_y(x)
grad_mag = torch.sqrt(torch.pow(G_x, 2) + torch.pow(G_y, 2))
return grad_mag
class offset_estimator(nn.Sequential):
def __init__(self, kernel_size, fwhm, in_channel, mid_channel, out_channel) -> None:
super().__init__()
model = []
assert len(kernel_size) == len(fwhm), "length error"
for i in range(len(kernel_size)):
if i == 0:
gaussian_weight = torch.FloatTensor(gaussian_2d(kernel_size[i], fwhm=fwhm[i]))
gauss_filter = nn.Conv2d(in_channel, mid_channel, kernel_size[i], padding=(kernel_size[i] - 1) // 2,
bias=False)
gauss_filter.weight[0].data[:, :, :] = gaussian_weight
model += [gauss_filter, nn.ReLU(inplace=True)]
elif i == len(kernel_size) - 1:
gaussian_weight = torch.FloatTensor(gaussian_2d(kernel_size[i], fwhm=fwhm[i]))
gauss_filter = nn.Conv2d(mid_channel, out_channel, kernel_size[i], padding=(kernel_size[i] - 1) // 2,
bias=False)
gauss_filter.weight[0].data[:, :, :] = gaussian_weight
model += [gauss_filter, nn.ReLU(inplace=True)]
else:
gaussian_weight = torch.FloatTensor(gaussian_2d(kernel_size[i], fwhm=fwhm[i]))
gauss_filter = nn.Conv2d(mid_channel, mid_channel, kernel_size[i], padding=(kernel_size[i] - 1) // 2,
bias=False)
gauss_filter.weight[0].data[:, :, :] = gaussian_weight
model += [gauss_filter, nn.ReLU(inplace=True)]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
# Channel attention
def logsumexp_2d(tensor):
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
return outputs
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels)
)
self.pool_types = pool_types
def forward(self, x):
channel_att_sum = None
for pool_type in self.pool_types:
if pool_type == 'avg':
avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(avg_pool)
elif pool_type == 'max':
max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(max_pool)
elif pool_type == 'lp':
lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
channel_att_raw = self.mlp(lp_pool)
elif pool_type == 'lse':
# LSE pool only
lse_pool = logsumexp_2d(x)
channel_att_raw = self.mlp(lse_pool)
if channel_att_sum is None:
channel_att_sum = channel_att_raw
else:
channel_att_sum = channel_att_sum + channel_att_raw
scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
return x * scale
# LBP
def LBP(image): # b, 3, h, w tensor
radius = 2
n_points = 8 * radius
method = 'uniform'
gray_img = rgb_to_grayscale(image) # b, 1, h, w
gray_img = gray_img.squeeze(1)
lbf_feature = np.zeros((gray_img.shape[0], gray_img.shape[1], gray_img.shape[2]))
for i in range(gray_img.shape[0]):
lbf_feature[i] = feature.local_binary_pattern(gray_img[i], n_points, radius, method)
return torch.FloatTensor(lbf_feature).unsqueeze(1)
class Discriminator(nn.Module):
def __init__(self, in_channel):
super().__init__()
self.in_channel = in_channel
def discriminator_block(in_filters, out_filters):
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=False)]
return layers
self.model = nn.Sequential(
*discriminator_block(self.in_channel, 4),
*discriminator_block(4, 4),
*discriminator_block(4, 4),
*discriminator_block(4, 4),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(4, 1, 4, padding=1, bias=False)
)
def forward(self, x):
return self.model(x)
class Discriminator_new(nn.Module):
def __init__(self):
super().__init__()
def discriminator_block(in_filters, out_filters, first_block=False):
layers = []
layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
layers.append(nn.LeakyReLU(0.2, inplace=True))
layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
layers = []
in_filters = 3
for i, out_filters in enumerate([4, 6, 8, 10]):
layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
in_filters = out_filters
layers.append(nn.ZeroPad2d((1, 0, 1, 0)))
layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
self.model = nn.Sequential(*layers)
def forward(self, img):
return self.model(img)