| """ |
| Copyright (C) 2019 NVIDIA Corporation. All rights reserved. |
| Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). |
| """ |
| import os |
| from importlib import import_module |
| import torch.nn.utils.spectral_norm as spectral_norm |
| import torch.nn as nn |
| import numpy as np |
| from torch.nn import init as init |
| import math |
| from win_util import window_partitions,window_partitionx,window_reversex |
| from models.batchnorm import SynchronizedBatchNorm2d |
| from models.base_network import BaseNetwork |
| from models.normalization import get_nonspade_norm_layer |
| import torch.nn.functional as F |
| import torch |
| import re |
| from PIL import Image |
|
|
| class conv_bench(nn.Module): |
| def __init__(self, n_feat = 3, kernel_size=3, act_method=nn.ReLU, bias=False): |
| super(conv_bench, self).__init__() |
| self.conv1 = nn.Conv2d(6,8,1) |
| self.conv2 = nn.Conv2d(8,8,1) |
| self.act = act_method() |
| self.relu = nn.ReLU(inplace=True) |
|
|
| def forward(self, x): |
| x1 = self.conv1(x) |
| |
| y = self.conv2(x1) |
| return y |
|
|
| class fft_bench_complex_mlp_flops(nn.Module): |
| def __init__(self, dim=6, dw=1, norm='backward', act_method=nn.ReLU, window_size=0, bias=False): |
| super(fft_bench_complex_mlp_flops, self).__init__() |
| self.act_fft = act_method() |
| self.window_size = window_size |
| hid_dim = dim * dw |
| self.complex_weight1 = nn.Conv2d(dim*2, hid_dim*2, 1) |
| self.complex_weight2 = nn.Conv2d(hid_dim*2, dim*2, 1) |
|
|
| |
| |
| self.conv = nn.Conv2d(6,8,1) |
| self.bias = bias |
| self.norm = norm |
| |
| |
| def forward(self, x): |
| _, _, H, W = x.shape |
| if self.window_size > 0 and (H != self.window_size or W != self.window_size): |
| x, batch_list = window_partitionx(x, self.window_size) |
| y = torch.fft.rfft2(x, norm=self.norm) |
| dim = 1 |
| y = torch.cat([y.real, y.imag], dim=dim) |
| |
| y = self.complex_weight1(y) |
| y = self.act_fft(y) |
| y = self.complex_weight2(y) |
| |
|
|
| y_real, y_imag = torch.chunk(y, 2, dim=dim) |
| y = torch.complex(y_real, y_imag) |
| |
| if self.window_size > 0 and (H != self.window_size or W != self.window_size): |
| y = torch.fft.irfft2(y, s=(self.window_size, self.window_size), norm=self.norm) |
| y = window_reversex(y, self.window_size, H, W, batch_list) |
| else: |
| y = torch.fft.irfft2(y, s=(H, W), norm=self.norm) |
| y = self.conv(y) |
| return y |
|
|
|
|
| class symy(nn.Module): |
|
|
| def __init__(self, in_chn=3, wf=64, depth=5, relu_slope=0.2, hin_position_left=0, hin_position_right=4): |
| super(symy, self).__init__() |
| self.generator = sdmy() |
| def forward(self, x): |
| out = self.generator(x) |
| return out |
|
|
| def get_input_chn(self, in_chn): |
| return in_chn |
|
|
| def _initialize(self): |
| gain = nn.init.calculate_gain('leaky_relu', 0.20) |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.orthogonal_(m.weight, gain=gain) |
| if not m.bias is None: |
| nn.init.constant_(m.bias, 0) |
| class ResnetBlock(nn.Module): |
| def __init__(self, dim, act, kernel_size=3): |
| super().__init__() |
| self.act = act |
| pw = (kernel_size - 1) // 2 |
| self.conv_block = nn.Sequential( |
| nn.ReflectionPad2d(pw), |
| nn.Conv2d(dim, dim, kernel_size=kernel_size), |
| act, |
| nn.ReflectionPad2d(pw), |
| nn.Conv2d(dim, dim, kernel_size=kernel_size) |
| ) |
|
|
| def forward(self, x): |
| y = self.conv_block(x) |
| out = x + y |
| return self.act(out) |
|
|
|
|
|
|
| class Convres(BaseNetwork): |
| def __init__(self): |
| super().__init__() |
| wf = 48 |
| kw = 3 |
| pw = int(np.ceil((kw - 1.0) / 2)) |
| ndf = 48 |
| self.ndf = ndf |
| norm_E = "spectralinstance" |
| norm_layer = get_nonspade_norm_layer(None, norm_E) |
| self.layer11 = norm_layer(nn.Conv2d(4, ndf, kw, stride=2, padding=pw)) |
| self.layer12 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)) |
| |
| self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)) |
| self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)) |
| self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)) |
| |
| |
| self.res_0 = ResnetBlock(ndf * 8, nn.LeakyReLU(0.2, False)) |
| self.res_1 = ResnetBlock(ndf * 8, nn.LeakyReLU(0.2, False)) |
| self.res_2 = ResnetBlock(ndf * 8, nn.LeakyReLU(0.2, False)) |
| self.so = 4 |
|
|
|
|
| self.mu_make = nn.Sequential( |
| |
| nn.Conv2d(ndf * 16,ndf * 8,1) |
| ) |
|
|
| self.mu_make_0 = nn.Sequential( |
| |
| nn.Upsample(scale_factor=2), |
| nn.Conv2d(ndf * 16,ndf * 8,1) |
| ) |
|
|
| |
| self.actvn = nn.LeakyReLU(0.2, False) |
| self.pad_3 = nn.ReflectionPad2d(3) |
| |
| self.conv_7x7 = nn.Conv2d(ndf, ndf, kernel_size=7, padding=0, bias=True) |
|
|
| self.upp = nn.Conv2d(8*ndf, 16*ndf, kernel_size=1, padding=0, bias=True) |
|
|
| self.conv_latent_up2 = Up_ConvBlock(8 * wf, 4 * wf) |
| self.conv_latent_up3 = Up_ConvBlock(4 * wf, 2 * wf) |
| self.conv_latent_up4 = Up_ConvBlock(2 * wf, 1 * wf) |
|
|
|
|
| def forward(self, x, gray, white, flag): |
| if x.size(2) != 256 or x.size(3) != 256: |
| x = F.interpolate(x, size=(256, 256), mode='bilinear') |
|
|
| if flag=='low': |
| gray = self.layer11(gray) |
| gray = self.conv_7x7(self.pad_3(self.actvn(gray))) |
| gray = self.layer2(self.actvn(gray)) |
| gray = self.layer3(self.actvn(gray)) |
| gray = self.layer4(self.actvn(gray)) |
| gray = self.res_0(gray) |
| gray = self.res_1(gray) |
| gray = self.res_2(gray) |
|
|
| white = self.layer11(white) |
| white = self.conv_7x7(self.pad_3(self.actvn(white))) |
| white = self.layer2(self.actvn(white)) |
| white = self.layer3(self.actvn(white)) |
| white = self.layer4(self.actvn(white)) |
| white = self.res_0(white) |
| white = self.res_1(white) |
| white = self.res_2(white) |
|
|
| mu = torch.cat([gray, white], dim=1) |
| mu = self.mu_make_0(mu) |
| else: |
| x = self.layer12(x) |
| x = self.conv_7x7(self.pad_3(self.actvn(x))) |
| x = self.layer2(self.actvn(x)) |
| x = self.layer3(self.actvn(x)) |
| x = self.layer4(self.actvn(x)) |
| x = self.res_0(x) |
| x = self.res_1(x) |
| x = self.res_2(x) |
| up = self.upp(x) |
| mu = self.mu_make(up) |
|
|
| latent_2 = self.conv_latent_up2(mu) |
| latent_3 = self.conv_latent_up3(latent_2) |
| latent_4 = self.conv_latent_up4(latent_3) |
| latent_list = [latent_4, latent_3, latent_2, mu] |
|
|
| return mu, latent_list |
|
|
| class ConvEncoderLoss(BaseNetwork): |
| """ Same architecture as the image discriminator """ |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| kw = 3 |
| pw = int(np.ceil((kw - 1.0) / 2)) |
| ndf = 64 |
| self.ndf = ndf |
| norm_E = "spectralinstance" |
| norm_layer = get_nonspade_norm_layer(None, norm_E) |
| self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)) |
| self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)) |
| |
| self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)) |
| |
| self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)) |
| |
| self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) |
| self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=1, padding=pw)) |
| |
| |
| self.so = s0 = 4 |
| self.out = norm_layer(nn.Conv2d(ndf * 8, ndf * 4, kw, stride=1, padding=0)) |
| self.down = nn.AvgPool2d(2,2) |
| |
|
|
|
|
|
|
| self.actvn = nn.LeakyReLU(0.2, False) |
| self.pad_3 = nn.ReflectionPad2d(3) |
| self.pad_1 = nn.ReflectionPad2d(1) |
| self.conv_7x7 = nn.Conv2d(ndf, ndf, kernel_size=7, padding=0, bias=True) |
| |
|
|
| def forward(self, x): |
|
|
| x1 = self.layer1(x) |
| x2 = self.conv_7x7(self.pad_3(self.actvn(x1))) |
| x3 = self.layer2(self.actvn(x2)) |
| |
| x4 = self.layer3(self.actvn(x3)) |
| |
| x5 = self.layer4(self.actvn(x4)) |
| |
| return [x1, x2, x3, x4, x5] |
| class EncodeMap(BaseNetwork): |
| """ Same architecture as the image discriminator """ |
|
|
| def __init__(self, opt): |
| super().__init__() |
|
|
| kw = 3 |
| pw = int(np.ceil((kw - 1.0) / 2)) |
| ndf = opt.ngf |
| norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) |
| self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)) |
| self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)) |
| self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)) |
| self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)) |
| self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) |
| if opt.crop_size >= 256: |
| self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=1, padding=pw)) |
|
|
| self.so = s0 = 4 |
| self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256) |
| self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256) |
| self.layer_final = nn.Conv2d(ndf * 8, ndf * 16, kw, stride=1, padding=pw) |
|
|
| self.actvn = nn.LeakyReLU(0.2, False) |
| self.opt = opt |
|
|
| def forward(self, x): |
| if x.size(2) != 256 or x.size(3) != 256: |
| x = F.interpolate(x, size=(256, 256), mode='bilinear') |
|
|
| x = self.layer1(x) |
| x = self.layer2(self.actvn(x)) |
| x = self.layer3(self.actvn(x)) |
| x = self.layer4(self.actvn(x)) |
| x = self.layer5(self.actvn(x)) |
| |
| |
| x = self.actvn(x) |
| return self.layer_final(x) |
|
|
| x = x.view(x.size(0), -1) |
| mu = self.fc_mu(x) |
| logvar = self.fc_var(x) |
|
|
| return mu, logvar |
|
|
|
|
| class Up_ConvBlock(nn.Module): |
| def __init__(self, dim_in, dim_out, activation=nn.LeakyReLU(0.2, False), kernel_size=3): |
| super().__init__() |
|
|
| pw = (kernel_size - 1) // 2 |
| '''self.conv1 = nn.Sequential( |
| nn.ReflectionPad2d(pw), |
| nn.Conv2d(dim, dim, kernel_size=kernel_size), |
| activation)''' |
| |
| self.conv_block = nn.Sequential( |
| nn.ReflectionPad2d(pw), |
| spectral_norm(nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size)), |
| activation, |
| nn.Upsample(scale_factor=2), |
| nn.ReflectionPad2d(pw), |
| spectral_norm(nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size)), |
| activation |
| ) |
|
|
| def forward(self, x): |
| |
| y = self.conv_block(x) |
| return y |
|
|
|
|
| class UNetConvBlock(nn.Module): |
| def __init__(self, in_size, out_size, downsample, relu_slope, use_csff=False, use_HIN=False): |
| super(UNetConvBlock, self).__init__() |
| self.downsample = downsample |
| self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0) |
| self.use_csff = use_csff |
|
|
| self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True) |
| self.relu_1 = nn.LeakyReLU(relu_slope, inplace=False) |
| self.conv_2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True) |
| self.relu_2 = nn.LeakyReLU(relu_slope, inplace=False) |
|
|
| if downsample and use_csff: |
| self.csff_enc = nn.Conv2d(out_size, out_size, 3, 1, 1) |
| self.csff_dec = nn.Conv2d(out_size, out_size, 3, 1, 1) |
|
|
| if use_HIN: |
| self.norm = nn.InstanceNorm2d(out_size // 2, affine=True) |
| self.use_HIN = use_HIN |
|
|
| if downsample: |
| self.downsample = conv_down(out_size, out_size, bias=False) |
|
|
| def forward(self, x, enc=None, dec=None): |
| out = self.conv_1(x) |
|
|
| if self.use_HIN: |
| out_1, out_2 = torch.chunk(out, 2, dim=1) |
| out = torch.cat([self.norm(out_1), out_2], dim=1) |
| out = self.relu_1(out) |
| out = self.relu_2(self.conv_2(out)) |
|
|
| out += self.identity(x) |
| if enc is not None and dec is not None: |
| assert self.use_csff |
| out = out + self.csff_enc(enc) + self.csff_dec(dec) |
| if self.downsample: |
| out_down = self.downsample(out) |
| return out_down, out |
| else: |
| return out |
|
|
|
|
| class UNetUpBlock(nn.Module): |
| def __init__(self, in_size, out_size, relu_slope): |
| super(UNetUpBlock, self).__init__() |
| self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True) |
| self.conv_block = UNetConvBlock(in_size, out_size, False, relu_slope) |
|
|
| def forward(self, x, bridge): |
| up = self.up(x) |
| out = torch.cat([up, bridge], 1) |
| out = self.conv_block(out) |
| return out |
|
|
|
|
| class Subspace(nn.Module): |
|
|
| def __init__(self, in_size, out_size): |
| super(Subspace, self).__init__() |
| self.blocks = nn.ModuleList() |
| self.blocks.append(UNetConvBlock(in_size, out_size, False, 0.2)) |
| self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True) |
|
|
| def forward(self, x): |
| sc = self.shortcut(x) |
| for i in range(len(self.blocks)): |
| x = self.blocks[i](x) |
| return x + sc |
|
|
|
|
| class skip_blocks(nn.Module): |
|
|
| def __init__(self, in_size, out_size, repeat_num=1): |
| super(skip_blocks, self).__init__() |
| self.blocks = nn.ModuleList() |
| self.re_num = repeat_num |
| mid_c = 128 |
| self.blocks.append(UNetConvBlock(in_size, mid_c, False, 0.2)) |
| for i in range(self.re_num - 2): |
| self.blocks.append(UNetConvBlock(mid_c, mid_c, False, 0.2)) |
| self.blocks.append(UNetConvBlock(mid_c, out_size, False, 0.2)) |
| self.shortcut = nn.Conv2d(in_size, out_size, kernel_size=1, bias=True) |
|
|
| def forward(self, x): |
| sc = self.shortcut(x) |
| for m in self.blocks: |
| x = m(x) |
| return x + sc |
|
|
| def conv3x3(in_chn, out_chn, bias=True): |
| layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias) |
| return layer |
|
|
|
|
| def conv_down(in_chn, out_chn, bias=False): |
| layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias) |
| return layer |
|
|
|
|
| class SAM(nn.Module): |
| def __init__(self, n_feat, kernel_size=3, bias=True): |
| super(SAM, self).__init__() |
| self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) |
| self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) |
| self.conv3 = conv(3, n_feat, kernel_size, bias=bias) |
|
|
| def forward(self, x, x_img): |
| x1 = self.conv1(x) |
| img = self.conv2(x) + x_img |
| x2 = torch.sigmoid(self.conv3(img)) |
| x1 = x1*x2 |
| x1 = x1+x |
| return x1, img |
| def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): |
| return nn.Conv2d( |
| in_channels, out_channels, kernel_size, |
| padding=(kernel_size//2), bias=bias, stride = stride) |
| def conv3x3(in_chn, out_chn, bias=True): |
| layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias) |
| return layer |
|
|
| def conv_down(in_chn, out_chn, bias=False): |
| layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias) |
| return layer |