|
|
try: |
|
|
from .module import * |
|
|
except: |
|
|
from module import * |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.nn.init as init |
|
|
|
|
|
class SuperUnet_MS(nn.Module): |
|
|
def __init__(self, channels, block="INV"): |
|
|
super(SuperUnet_MS, self).__init__() |
|
|
|
|
|
self.layer_dowm1 = basic_block(channels, channels, block) |
|
|
self.dowm1 = nn.Sequential(nn.Conv2d(channels, channels * 2, 4, 2, 1, bias=True), |
|
|
nn.InstanceNorm2d(channels * 2, affine=True), nn.LeakyReLU(0.2, inplace=True)) |
|
|
self.layer_dowm2 = basic_block(channels * 2, channels * 2, block) |
|
|
self.dowm2 = nn.Sequential(nn.Conv2d(channels * 2, channels * 4, 4, 2, 1, bias=True), |
|
|
nn.InstanceNorm2d(channels * 4, affine=True), nn.LeakyReLU(0.2, inplace=True)) |
|
|
|
|
|
self.layer_bottom = basic_block(channels * 4, channels * 4, block) |
|
|
self.up2 = nn.Sequential(nn.ConvTranspose2d(channels * 4, channels * 2, 4, 2, 1, bias=True), |
|
|
nn.InstanceNorm2d(channels * 2, affine=True), nn.LeakyReLU(0.2, inplace=True)) |
|
|
self.layer_up2 = basic_block(channels * 2, channels * 2, block) |
|
|
self.up1 = nn.Sequential(nn.ConvTranspose2d(channels * 2, channels, 4, 2, 1, bias=True), |
|
|
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True)) |
|
|
self.layer_up1 = basic_block(channels, channels, block) |
|
|
|
|
|
self.fus2 = skip(channels * 4, channels * 2, "HIN") |
|
|
self.fus1 = skip(channels * 2, channels, "HIN") |
|
|
|
|
|
self.skip_down1 = nn.Sequential(nn.Conv2d(channels, channels, 4, 2, 1, bias=True), |
|
|
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True)) |
|
|
self.skip1 = skip(channels * 3, channels * 2, "CONV") |
|
|
self.skip_down2 = nn.Sequential(nn.Conv2d(channels * 2, channels, 4, 2, 1, bias=True), |
|
|
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True)) |
|
|
self.skip2 = skip(channels * 5, channels * 4, "CONV") |
|
|
|
|
|
self.skip_up4 = nn.Sequential(nn.ConvTranspose2d(channels * 4, channels, 4, 2, 1, bias=True), |
|
|
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True)) |
|
|
self.skip4 = skip(channels * 3, channels * 2, "CONV") |
|
|
|
|
|
self.skip_up6 = nn.Sequential(nn.ConvTranspose2d(channels * 2, channels, 4, 2, 1, bias=True), |
|
|
nn.InstanceNorm2d(channels, affine=True), nn.LeakyReLU(0.2, inplace=True)) |
|
|
self.skip6 = skip(channels * 2, channels, "CONV") |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x_11 = self.layer_dowm1(x) |
|
|
x_down1 = self.dowm1(x_11) |
|
|
|
|
|
|
|
|
|
|
|
x_down1 = self.skip1(torch.cat([self.skip_down1(x), x_down1], 1), x_down1) |
|
|
|
|
|
x_12 = self.layer_dowm2(x_down1) |
|
|
x_down2 = self.dowm2(x_12) |
|
|
x_down2 = self.skip2(torch.cat([self.skip_down2(x_down1), x_down2], 1), x_down2) |
|
|
|
|
|
x_bottom = self.layer_bottom(x_down2) |
|
|
|
|
|
|
|
|
x_up2 = self.up2(x_bottom) |
|
|
x_22 = self.layer_up2(x_up2) |
|
|
x_22 = self.skip4(torch.cat([self.skip_up4(x_bottom), x_22], 1), x_22) |
|
|
x_22 = self.fus2(torch.cat([x_12, x_22], 1), x_22) |
|
|
|
|
|
x_up1 = self.up1(x_22) |
|
|
x_21 = self.layer_up1(x_up1) |
|
|
x_21 = self.skip6(torch.cat([self.skip_up6(x_22), x_21], 1), x_21) |
|
|
x_21 = self.fus1(torch.cat([x_11, x_21], 1), x_21) |
|
|
return x_21, x_down2 |
|
|
|
|
|
|
|
|
class skip(nn.Module): |
|
|
def __init__(self, channels_in, channels_out, block): |
|
|
super(skip, self).__init__() |
|
|
if block == "CONV": |
|
|
self.body = nn.Sequential(nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=True), |
|
|
nn.InstanceNorm2d(channels_out, affine=True), nn.ReLU(inplace=True), ) |
|
|
if block == "ID": |
|
|
self.body = nn.Identity() |
|
|
if block == "INV": |
|
|
self.body = nn.Sequential(InvBlock(channels_in, channels_in // 2), |
|
|
nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=True), ) |
|
|
if block == "HIN": |
|
|
self.body = nn.Sequential(HinBlock(channels_in, channels_out)) |
|
|
|
|
|
self.alpha1 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) |
|
|
self.alpha1.data.fill_(1.0) |
|
|
self.alpha2 = nn.Parameter(torch.FloatTensor(1), requires_grad=True) |
|
|
self.alpha2.data.fill_(0.5) |
|
|
|
|
|
def forward(self, x, y): |
|
|
out = self.alpha1 * self.body(x) + self.alpha2 * y |
|
|
return out |
|
|
|
|
|
|
|
|
def subnet(net_structure, init='xavier'): |
|
|
def constructor(channel_in, channel_out): |
|
|
if net_structure == 'HIN': |
|
|
return HinBlock(channel_in, channel_out) |
|
|
|
|
|
return constructor |
|
|
|
|
|
|
|
|
class InvBlock(nn.Module): |
|
|
def __init__(self, channel_num, channel_split_num, subnet_constructor=subnet('HIN'), |
|
|
clamp=0.8): |
|
|
super(InvBlock, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.split_len1 = channel_split_num |
|
|
self.split_len2 = channel_num - channel_split_num |
|
|
|
|
|
self.clamp = clamp |
|
|
|
|
|
self.F = subnet_constructor(self.split_len2, self.split_len1) |
|
|
self.G = subnet_constructor(self.split_len1, self.split_len2) |
|
|
self.H = subnet_constructor(self.split_len1, self.split_len2) |
|
|
|
|
|
def forward(self, x): |
|
|
x1, x2 = (x.narrow(1, 0, self.split_len1), x.narrow(1, self.split_len1, self.split_len2)) |
|
|
|
|
|
y1 = x1 + self.F(x2) |
|
|
self.s = self.clamp * (torch.sigmoid(self.H(y1)) * 2 - 1) |
|
|
y2 = x2.mul(torch.exp(self.s)) + self.G(y1) |
|
|
out = torch.cat((y1, y2), 1) |
|
|
|
|
|
return out + x |
|
|
|
|
|
|
|
|
class sample_block(nn.Module): |
|
|
def __init__(self, channels_in, channels_out, size, dil): |
|
|
super(sample_block, self).__init__() |
|
|
|
|
|
if size == "DOWN": |
|
|
self.conv = nn.Sequential( |
|
|
nn.Conv2d(channels_in, channels_out, 3, 1, dil, dilation=dil), |
|
|
nn.InstanceNorm2d(channels_out, affine=True), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
if size == "UP": |
|
|
self.conv = nn.Sequential( |
|
|
nn.ConvTranspose2d(channels_in, channels_out, 3, 1, dil, dilation=dil), |
|
|
nn.InstanceNorm2d(channels_out, affine=True), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.conv(x) |
|
|
|
|
|
|
|
|
class HinBlock(nn.Module): |
|
|
def __init__(self, in_size, out_size): |
|
|
super(HinBlock, self).__init__() |
|
|
self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0) |
|
|
self.norm = nn.InstanceNorm2d(out_size // 2, affine=True) |
|
|
|
|
|
self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, stride=1, padding=1, bias=True) |
|
|
self.relu_1 = nn.Sequential(nn.LeakyReLU(0.2, inplace=False), ) |
|
|
self.conv_2 = nn.Sequential(nn.Conv2d(out_size, out_size, kernel_size=3, stride=1, padding=1, bias=True), |
|
|
nn.LeakyReLU(0.2, inplace=False), ) |
|
|
|
|
|
def forward(self, x): |
|
|
out = self.conv_1(x) |
|
|
out_1, out_2 = torch.chunk(out, 2, dim=1) |
|
|
out = torch.cat([self.norm(out_1), out_2], dim=1) |
|
|
out = self.relu_1(out) |
|
|
out = self.conv_2(out) |
|
|
out += self.identity(x) |
|
|
return out |
|
|
|
|
|
|
|
|
class net(nn.Module): |
|
|
def __init__(self, args): |
|
|
super().__init__() |
|
|
self.args = args.model |
|
|
self.hr_inc = DoubleConv(self.args["in_channel"], self.args["model_channel"] * 2) |
|
|
self.hr_backbone = SuperUnet_MS(self.args["model_channel"] * 2) |
|
|
self.final_out = nn.Conv2d(self.args["model_channel"] * 2, 3, kernel_size=1, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.hr_inc(x) |
|
|
x, mid_feat = self.hr_backbone(x) |
|
|
out = self.final_out(x) |
|
|
return out |
|
|
|
|
|
|