Ref-Band / src /model /common.py
Traver's picture
Upload 180 files
2147e2e verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias)
class ResUnit(nn.Module):
def __init__(self, dim):
super(ResUnit, self).__init__()
self.act = nn.ReLU(True)
self.conv1 = default_conv(dim, dim, 3)
self.conv2 = default_conv(dim, dim*2, 1)
self.conv3 = default_conv(dim*2, dim, 1)
def forward(self, x):
shortcut = x
x = self.conv1(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
return x + shortcut
class FusionBlock(nn.Module):
def __init__(self, n_color, embed_dim):
super(FusionBlock, self).__init__()
self.act = nn.ReLU(True)
self.conv_1 = default_conv(n_color, embed_dim, 3)
self.conv_2 = default_conv(embed_dim, embed_dim, 3)
self.conv_1_2 = default_conv(embed_dim, embed_dim, 3)
self.conv_2_2 = default_conv(embed_dim, embed_dim, 3)
self.ru_1 = ResUnit(embed_dim)
self.ru_2 = ResUnit(embed_dim)
self.ru_1_1 = ResUnit(embed_dim)
self.ru_2_1 = ResUnit(embed_dim)
self.ru = ResUnit(embed_dim)
self.ru_ = ResUnit(embed_dim)
self.conv_tail_1 = default_conv(embed_dim*2, embed_dim, 3)
self.conv_tail_2 = default_conv(embed_dim, embed_dim, 3)
def forward(self, img_snow, mask):
img_snow = self.ru_1(self.conv_1(img_snow))
mask = self.ru_2(self.conv_2(mask))
img_1 = self.ru(self.conv_1_2((img_snow-mask)))
img_snow = self.ru_1_1(img_snow)
mask = self.ru_2_1(mask)
img_2 = self.ru_(self.conv_2_2((img_snow-mask)))
return self.conv_tail_2(self.act(self.conv_tail_1(torch.cat((img_1, img_2), dim=1))))
class MARB(nn.Module):
def __init__(self, dim):
super(MARB, self).__init__()
self.act = nn.ReLU(True)
self.conv_dl2 = default_conv(dim, dim, 1)
self.conv_dl3 = default_conv(dim, dim, 3)
self.conv_dl5 = default_conv(dim, dim, 5)
self.conv1_1 = default_conv(dim, dim, 1)
self.conv1_2 = default_conv(dim, dim, 1)
self.conv1_3 = default_conv(dim, dim, 1)
self.conv2_1 = default_conv(dim*2, dim, 1)
self.conv2_2 = default_conv(dim*2, dim, 1)
self.conv_tail = default_conv(dim*2, dim, 1)
def forward(self, x):
x1 = self.conv1_1(self.conv_dl2(x))
x2 = self.conv1_2(self.conv_dl3(x))
x3 = self.conv1_3(self.conv_dl5(x))
x_cat_1 = self.conv2_1(torch.cat((x1, x2), dim=1))
x_cat_2 = self.conv2_2(torch.cat((x2, x3), dim=1))
return self.conv_tail(self.act(torch.cat((x_cat_1, x_cat_2), dim=1))) + x
# class MARB(nn.Module):
# def __init__(self, dim):
# super(MARB, self).__init__()
#
# self.act = nn.ReLU(True)
#
# self.conv_dl2 = default_conv(dim, dim, 1)
# self.conv_dl3 = default_conv(dim, dim, 3)
# self.conv_dl5 = default_conv(dim, dim, 5)
#
# self.conv1_1 = default_conv(dim, dim, 3)
# self.conv1_2 = default_conv(dim, dim, 3)
# self.conv1_3 = default_conv(dim, dim, 3)
#
# # self.conv2_1 = default_conv(dim*2, dim, 3)
# # self.conv2_2 = default_conv(dim*2, dim, 3)
#
# self.conv_tail = default_conv(dim*3, dim, 3)
#
# def forward(self, x):
# x1 = self.conv1_1(self.conv_dl2(x))
# x2 = self.conv1_2(self.conv_dl3(x))
# x3 = self.conv1_3(self.conv_dl5(x))
#
# # x_cat_1 = self.conv2_1(torch.cat((x1, x2), dim=1))
# # x_cat_2 = self.conv2_2(torch.cat((x2, x3), dim=1))
#
# return self.conv_tail(self.act(torch.cat((x1, x2, x3), dim=1))) + x
class MaskBlock(nn.Module):
def __init__(self, embed_dim):
super(MaskBlock, self).__init__()
self.act = nn.ReLU(True)
self.conv_head = default_conv(embed_dim, embed_dim, 3)
self.conv_self = default_conv(embed_dim, embed_dim, 1)
self.conv1 = default_conv(embed_dim, embed_dim, 3)
self.conv1_1 = default_conv(embed_dim, embed_dim, 1)
self.conv1_2 = default_conv(embed_dim, embed_dim, 1)
self.conv_tail = default_conv(embed_dim, embed_dim, 3)
def forward(self, x):
x = self.conv_head(x)
x = self.conv_self(x)
x = x.mul(x)
x = self.act(self.conv1(x))
x = self.conv1_1(x).mul(self.conv1_2(x))
return self.conv_tail(x)
def dwt_init(x):
x01 = x[:, :, 0::2, :] / 2
x02 = x[:, :, 1::2, :] / 2
x1 = x01[:, :, :, 0::2]
x2 = x02[:, :, :, 0::2]
x3 = x01[:, :, :, 1::2]
x4 = x02[:, :, :, 1::2]
x_LL = x1 + x2 + x3 + x4
x_HL = -x1 - x2 + x3 + x4
x_LH = -x1 + x2 - x3 + x4
x_HH = x1 - x2 - x3 + x4
return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
def iwt_init(x):
r = 2
in_batch, in_channel, in_height, in_width = x.size()
# print([in_batch, in_channel, in_height, in_width])
out_batch, out_channel, out_height, out_width = in_batch, int(
in_channel / (r ** 2)), r * in_height, r * in_width
x1 = x[:, 0:out_channel, :, :] / 2
x2 = x[:, out_channel:out_channel * 2, :, :] / 2
x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
return h
class DWT(nn.Module):
def __init__(self):
super(DWT, self).__init__()
self.requires_grad = False
def forward(self, x):
return dwt_init(x)
class IWT(nn.Module):
def __init__(self):
super(IWT, self).__init__()
self.requires_grad = False
def forward(self, x):
return iwt_init(x)