|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from torch.autograd import Variable |
|
|
|
|
|
|
|
|
def default_conv(in_channels, out_channels, kernel_size, bias=True): |
|
|
return nn.Conv2d( |
|
|
in_channels, out_channels, kernel_size, |
|
|
padding=(kernel_size//2), bias=bias) |
|
|
|
|
|
|
|
|
class ResUnit(nn.Module): |
|
|
def __init__(self, dim): |
|
|
super(ResUnit, self).__init__() |
|
|
|
|
|
self.act = nn.ReLU(True) |
|
|
self.conv1 = default_conv(dim, dim, 3) |
|
|
self.conv2 = default_conv(dim, dim*2, 1) |
|
|
self.conv3 = default_conv(dim*2, dim, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
shortcut = x |
|
|
|
|
|
x = self.conv1(x) |
|
|
x = self.conv2(x) |
|
|
x = self.act(x) |
|
|
x = self.conv3(x) |
|
|
|
|
|
return x + shortcut |
|
|
|
|
|
|
|
|
class FusionBlock(nn.Module): |
|
|
def __init__(self, n_color, embed_dim): |
|
|
super(FusionBlock, self).__init__() |
|
|
|
|
|
self.act = nn.ReLU(True) |
|
|
|
|
|
self.conv_1 = default_conv(n_color, embed_dim, 3) |
|
|
self.conv_2 = default_conv(embed_dim, embed_dim, 3) |
|
|
|
|
|
self.conv_1_2 = default_conv(embed_dim, embed_dim, 3) |
|
|
self.conv_2_2 = default_conv(embed_dim, embed_dim, 3) |
|
|
|
|
|
self.ru_1 = ResUnit(embed_dim) |
|
|
self.ru_2 = ResUnit(embed_dim) |
|
|
|
|
|
self.ru_1_1 = ResUnit(embed_dim) |
|
|
self.ru_2_1 = ResUnit(embed_dim) |
|
|
|
|
|
self.ru = ResUnit(embed_dim) |
|
|
self.ru_ = ResUnit(embed_dim) |
|
|
|
|
|
self.conv_tail_1 = default_conv(embed_dim*2, embed_dim, 3) |
|
|
self.conv_tail_2 = default_conv(embed_dim, embed_dim, 3) |
|
|
|
|
|
def forward(self, img_snow, mask): |
|
|
|
|
|
img_snow = self.ru_1(self.conv_1(img_snow)) |
|
|
mask = self.ru_2(self.conv_2(mask)) |
|
|
|
|
|
img_1 = self.ru(self.conv_1_2((img_snow-mask))) |
|
|
|
|
|
img_snow = self.ru_1_1(img_snow) |
|
|
mask = self.ru_2_1(mask) |
|
|
|
|
|
img_2 = self.ru_(self.conv_2_2((img_snow-mask))) |
|
|
|
|
|
|
|
|
return self.conv_tail_2(self.act(self.conv_tail_1(torch.cat((img_1, img_2), dim=1)))) |
|
|
|
|
|
|
|
|
class MARB(nn.Module): |
|
|
def __init__(self, dim): |
|
|
super(MARB, self).__init__() |
|
|
|
|
|
self.act = nn.ReLU(True) |
|
|
|
|
|
self.conv_dl2 = default_conv(dim, dim, 1) |
|
|
self.conv_dl3 = default_conv(dim, dim, 3) |
|
|
self.conv_dl5 = default_conv(dim, dim, 5) |
|
|
|
|
|
self.conv1_1 = default_conv(dim, dim, 1) |
|
|
self.conv1_2 = default_conv(dim, dim, 1) |
|
|
self.conv1_3 = default_conv(dim, dim, 1) |
|
|
|
|
|
self.conv2_1 = default_conv(dim*2, dim, 1) |
|
|
self.conv2_2 = default_conv(dim*2, dim, 1) |
|
|
|
|
|
self.conv_tail = default_conv(dim*2, dim, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
x1 = self.conv1_1(self.conv_dl2(x)) |
|
|
x2 = self.conv1_2(self.conv_dl3(x)) |
|
|
x3 = self.conv1_3(self.conv_dl5(x)) |
|
|
|
|
|
x_cat_1 = self.conv2_1(torch.cat((x1, x2), dim=1)) |
|
|
x_cat_2 = self.conv2_2(torch.cat((x2, x3), dim=1)) |
|
|
|
|
|
return self.conv_tail(self.act(torch.cat((x_cat_1, x_cat_2), dim=1))) + x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MaskBlock(nn.Module): |
|
|
def __init__(self, embed_dim): |
|
|
super(MaskBlock, self).__init__() |
|
|
self.act = nn.ReLU(True) |
|
|
self.conv_head = default_conv(embed_dim, embed_dim, 3) |
|
|
|
|
|
self.conv_self = default_conv(embed_dim, embed_dim, 1) |
|
|
|
|
|
self.conv1 = default_conv(embed_dim, embed_dim, 3) |
|
|
self.conv1_1 = default_conv(embed_dim, embed_dim, 1) |
|
|
self.conv1_2 = default_conv(embed_dim, embed_dim, 1) |
|
|
self.conv_tail = default_conv(embed_dim, embed_dim, 3) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv_head(x) |
|
|
x = self.conv_self(x) |
|
|
x = x.mul(x) |
|
|
x = self.act(self.conv1(x)) |
|
|
x = self.conv1_1(x).mul(self.conv1_2(x)) |
|
|
|
|
|
return self.conv_tail(x) |
|
|
|
|
|
def dwt_init(x): |
|
|
x01 = x[:, :, 0::2, :] / 2 |
|
|
x02 = x[:, :, 1::2, :] / 2 |
|
|
x1 = x01[:, :, :, 0::2] |
|
|
x2 = x02[:, :, :, 0::2] |
|
|
x3 = x01[:, :, :, 1::2] |
|
|
x4 = x02[:, :, :, 1::2] |
|
|
x_LL = x1 + x2 + x3 + x4 |
|
|
x_HL = -x1 - x2 + x3 + x4 |
|
|
x_LH = -x1 + x2 - x3 + x4 |
|
|
x_HH = x1 - x2 - x3 + x4 |
|
|
|
|
|
return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) |
|
|
|
|
|
|
|
|
def iwt_init(x): |
|
|
r = 2 |
|
|
in_batch, in_channel, in_height, in_width = x.size() |
|
|
|
|
|
out_batch, out_channel, out_height, out_width = in_batch, int( |
|
|
in_channel / (r ** 2)), r * in_height, r * in_width |
|
|
x1 = x[:, 0:out_channel, :, :] / 2 |
|
|
x2 = x[:, out_channel:out_channel * 2, :, :] / 2 |
|
|
x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2 |
|
|
x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2 |
|
|
|
|
|
h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda() |
|
|
|
|
|
h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4 |
|
|
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4 |
|
|
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4 |
|
|
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4 |
|
|
|
|
|
return h |
|
|
|
|
|
class DWT(nn.Module): |
|
|
def __init__(self): |
|
|
super(DWT, self).__init__() |
|
|
self.requires_grad = False |
|
|
|
|
|
def forward(self, x): |
|
|
return dwt_init(x) |
|
|
|
|
|
|
|
|
class IWT(nn.Module): |
|
|
def __init__(self): |
|
|
super(IWT, self).__init__() |
|
|
self.requires_grad = False |
|
|
|
|
|
def forward(self, x): |
|
|
return iwt_init(x) |