LMAR / model /model.py
SuperPatrick's picture
Upload 20 files
98feea6 verified
try:
from .module import *
except:
from module import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class SuperUnet_MS(nn.Module):
def __init__(self, channels, block="INV"):
super(SuperUnet_MS, self).__init__()
# ---------ENCODE
self.layer_dowm1 = basic_block(channels, channels, block)
self.dowm1 = nn.Sequential(nn.Conv2d(channels, channels * 2, 4, 2, 1, bias=True),
nn.InstanceNorm2d(channels * 2, affine=True), nn.LeakyReLU(0.2, inplace=True))
self.layer_dowm2 = basic_block(channels * 2, channels * 2, block)
self.dowm2 = nn.Sequential(nn.Conv2d(channels * 2, channels * 4, 4, 2, 1, bias=True),
nn.InstanceNorm2d(channels * 4, affine=True), nn.LeakyReLU(0.2, inplace=True))
# ---------DECODE
self.layer_bottom = basic_block(channels * 4, channels * 4, block)
self.up2 = nn.Sequential(nn.ConvTranspose2d(channels * 4, channels * 2, 4, 2, 1, bias=True),
nn.InstanceNorm2d(channels * 2, affine=True), nn.LeakyReLU(0.2, inplace=True))
self.layer_up2 = basic_block(channels * 2, channels * 2, block)
self.up1 = nn.Sequential(nn.ConvTranspose2d(channels * 2, channels, 4, 2, 1, bias=True),
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
self.layer_up1 = basic_block(channels, channels, block)
# ---------SKIP
self.fus2 = skip(channels * 4, channels * 2, "HIN")
self.fus1 = skip(channels * 2, channels, "HIN")
# ---------SKIP
self.skip_down1 = nn.Sequential(nn.Conv2d(channels, channels, 4, 2, 1, bias=True),
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
self.skip1 = skip(channels * 3, channels * 2, "CONV")
self.skip_down2 = nn.Sequential(nn.Conv2d(channels * 2, channels, 4, 2, 1, bias=True),
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
self.skip2 = skip(channels * 5, channels * 4, "CONV")
# self.skip3 = skip(channels*2, channels, "CONV")
self.skip_up4 = nn.Sequential(nn.ConvTranspose2d(channels * 4, channels, 4, 2, 1, bias=True),
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
self.skip4 = skip(channels * 3, channels * 2, "CONV")
# self.skip5 = skip(channels*2, channels, "CONV")
self.skip_up6 = nn.Sequential(nn.ConvTranspose2d(channels * 2, channels, 4, 2, 1, bias=True),
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True))
self.skip6 = skip(channels * 2, channels, "CONV")
def forward(self, x):
# ---------ENCODE
x_11 = self.layer_dowm1(x)
x_down1 = self.dowm1(x_11)
# x =self.skip_down1(x)
# print(x.shape, x_down1.shape)
x_down1 = self.skip1(torch.cat([self.skip_down1(x), x_down1], 1), x_down1)
x_12 = self.layer_dowm2(x_down1)
x_down2 = self.dowm2(x_12)
x_down2 = self.skip2(torch.cat([self.skip_down2(x_down1), x_down2], 1), x_down2)
x_bottom = self.layer_bottom(x_down2)
# ---------DECODE
x_up2 = self.up2(x_bottom)
x_22 = self.layer_up2(x_up2)
x_22 = self.skip4(torch.cat([self.skip_up4(x_bottom), x_22], 1), x_22)
x_22 = self.fus2(torch.cat([x_12, x_22], 1), x_22)
x_up1 = self.up1(x_22)
x_21 = self.layer_up1(x_up1)
x_21 = self.skip6(torch.cat([self.skip_up6(x_22), x_21], 1), x_21)
x_21 = self.fus1(torch.cat([x_11, x_21], 1), x_21)
return x_21, x_down2
class skip(nn.Module):
def __init__(self, channels_in, channels_out, block):
super(skip, self).__init__()
if block == "CONV":
self.body = nn.Sequential(nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=True),
nn.InstanceNorm2d(channels_out, affine=True), nn.ReLU(inplace=True), )
if block == "ID":
self.body = nn.Identity()
if block == "INV":
self.body = nn.Sequential(InvBlock(channels_in, channels_in // 2),
nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=True), )
if block == "HIN":
self.body = nn.Sequential(HinBlock(channels_in, channels_out))
# --------------------------------------
self.alpha1 = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.alpha1.data.fill_(1.0)
self.alpha2 = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.alpha2.data.fill_(0.5)
def forward(self, x, y):
out = self.alpha1 * self.body(x) + self.alpha2 * y
return out
def subnet(net_structure, init='xavier'):
def constructor(channel_in, channel_out):
if net_structure == 'HIN':
return HinBlock(channel_in, channel_out)
return constructor
class InvBlock(nn.Module):
def __init__(self, channel_num, channel_split_num, subnet_constructor=subnet('HIN'),
clamp=0.8): ################ split_channel一般设为channel_num的一半
super(InvBlock, self).__init__()
# channel_num: 3
# channel_split_num: 1
self.split_len1 = channel_split_num # 1
self.split_len2 = channel_num - channel_split_num # 2
self.clamp = clamp
self.F = subnet_constructor(self.split_len2, self.split_len1)
self.G = subnet_constructor(self.split_len1, self.split_len2)
self.H = subnet_constructor(self.split_len1, self.split_len2)
def forward(self, x):
x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2))
y1 = x1 + self.F(x2) # 1 channel
self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1)
y2 = x2.mul(torch.exp(self.s)) + self.G(y1) # 2 channel
out = torch.cat((y1, y2), 1)
return out + x
class sample_block(nn.Module):
def __init__(self, channels_in, channels_out, size, dil):
super(sample_block, self).__init__()
# ------------------------------------------
if size == "DOWN":
self.conv = nn.Sequential(
nn.Conv2d(channels_in, channels_out, 3, 1, dil, dilation=dil),
nn.InstanceNorm2d(channels_out, affine=True),
nn.ReLU(inplace=True),
)
if size == "UP":
self.conv = nn.Sequential(
nn.ConvTranspose2d(channels_in, channels_out, 3, 1, dil, dilation=dil),
nn.InstanceNorm2d(channels_out, affine=True),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.conv(x)
class HinBlock(nn.Module):
def __init__(self, in_size, out_size):
super(HinBlock, self).__init__()
self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0)
self.norm = nn.InstanceNorm2d(out_size // 2, affine=True)
self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, stride=1, padding=1, bias=True)
self.relu_1 = nn.Sequential(nn.LeakyReLU(0.2, inplace=False), )
self.conv_2 = nn.Sequential(nn.Conv2d(out_size, out_size, kernel_size=3, stride=1, padding=1, bias=True),
nn.LeakyReLU(0.2, inplace=False), )
def forward(self, x):
out = self.conv_1(x)
out_1, out_2 = torch.chunk(out, 2, dim=1)
out = torch.cat([self.norm(out_1), out_2], dim=1)
out = self.relu_1(out)
out = self.conv_2(out)
out += self.identity(x)
return out
class net(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args.model
self.hr_inc = DoubleConv(self.args["in_channel"], self.args["model_channel"] * 2)
self.hr_backbone = SuperUnet_MS(self.args["model_channel"] * 2)
self.final_out = nn.Conv2d(self.args["model_channel"] * 2, 3, kernel_size=1, bias=False)
def forward(self, x):
x = self.hr_inc(x)
x, mid_feat = self.hr_backbone(x)
out = self.final_out(x)
return out