|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import functools |
|
|
from torch.autograd import Variable |
|
|
import numpy as np |
|
|
from torch.nn.utils import spectral_norm |
|
|
|
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def weights_init(m): |
|
|
classname = m.__class__.__name__ |
|
|
if classname.find("Conv") != -1: |
|
|
m.weight.data.normal_(0.0, 0.02) |
|
|
elif classname.find("BatchNorm2d") != -1: |
|
|
m.weight.data.normal_(1.0, 0.02) |
|
|
m.bias.data.fill_(0) |
|
|
|
|
|
|
|
|
def get_norm_layer(norm_type="instance"): |
|
|
if norm_type == "batch": |
|
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True) |
|
|
elif norm_type == "instance": |
|
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) |
|
|
elif norm_type == "spectral": |
|
|
norm_layer = spectral_norm() |
|
|
elif norm_type == "SwitchNorm": |
|
|
norm_layer = SwitchNorm2d |
|
|
else: |
|
|
raise NotImplementedError("normalization layer [%s] is not found" % norm_type) |
|
|
return norm_layer |
|
|
|
|
|
|
|
|
def print_network(net): |
|
|
if isinstance(net, list): |
|
|
net = net[0] |
|
|
num_params = 0 |
|
|
for param in net.parameters(): |
|
|
num_params += param.numel() |
|
|
print(net) |
|
|
print("Total number of parameters: %d" % num_params) |
|
|
|
|
|
|
|
|
def define_G(input_nc, output_nc, ngf, netG, k_size=3, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, |
|
|
n_blocks_local=3, norm='instance', gpu_ids=[], opt=None): |
|
|
norm_layer = get_norm_layer(norm_type=norm) |
|
|
if netG == 'global': |
|
|
|
|
|
if opt.use_v2: |
|
|
netG = GlobalGenerator_DCDCv2(input_nc, output_nc, ngf, k_size, n_downsample_global, norm_layer, opt=opt) |
|
|
else: |
|
|
netG = GlobalGenerator_v2(input_nc, output_nc, ngf, k_size, n_downsample_global, n_blocks_global, |
|
|
norm_layer, opt=opt) |
|
|
else: |
|
|
raise ('generator not implemented!') |
|
|
print(netG) |
|
|
if len(gpu_ids) > 0: |
|
|
assert (torch.cuda.is_available()) |
|
|
netG.cuda(gpu_ids[0]) |
|
|
netG.apply(weights_init) |
|
|
return netG |
|
|
|
|
|
|
|
|
def define_D(input_nc, ndf, n_layers_D, opt, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, |
|
|
gpu_ids=[]): |
|
|
norm_layer = get_norm_layer(norm_type=norm) |
|
|
netD = MultiscaleDiscriminator(input_nc, opt, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) |
|
|
print(netD) |
|
|
if len(gpu_ids) > 0: |
|
|
assert (torch.cuda.is_available()) |
|
|
netD.cuda(gpu_ids[0]) |
|
|
netD.apply(weights_init) |
|
|
return netD |
|
|
|
|
|
|
|
|
class GlobalGenerator_DCDCv2(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_nc, |
|
|
output_nc, |
|
|
ngf=64, |
|
|
k_size=3, |
|
|
n_downsampling=8, |
|
|
norm_layer=nn.BatchNorm2d, |
|
|
padding_type="reflect", |
|
|
opt=None, |
|
|
): |
|
|
super(GlobalGenerator_DCDCv2, self).__init__() |
|
|
activation = nn.ReLU(True) |
|
|
|
|
|
model = [ |
|
|
nn.ReflectionPad2d(3), |
|
|
nn.Conv2d(input_nc, min(ngf, opt.mc), kernel_size=7, padding=0), |
|
|
norm_layer(ngf), |
|
|
activation, |
|
|
] |
|
|
|
|
|
for i in range(opt.start_r): |
|
|
mult = 2 ** i |
|
|
model += [ |
|
|
nn.Conv2d( |
|
|
min(ngf * mult, opt.mc), |
|
|
min(ngf * mult * 2, opt.mc), |
|
|
kernel_size=k_size, |
|
|
stride=2, |
|
|
padding=1, |
|
|
), |
|
|
norm_layer(min(ngf * mult * 2, opt.mc)), |
|
|
activation, |
|
|
] |
|
|
for i in range(opt.start_r, n_downsampling - 1): |
|
|
mult = 2 ** i |
|
|
model += [ |
|
|
nn.Conv2d( |
|
|
min(ngf * mult, opt.mc), |
|
|
min(ngf * mult * 2, opt.mc), |
|
|
kernel_size=k_size, |
|
|
stride=2, |
|
|
padding=1, |
|
|
), |
|
|
norm_layer(min(ngf * mult * 2, opt.mc)), |
|
|
activation, |
|
|
] |
|
|
model += [ |
|
|
ResnetBlock( |
|
|
min(ngf * mult * 2, opt.mc), |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
) |
|
|
] |
|
|
model += [ |
|
|
ResnetBlock( |
|
|
min(ngf * mult * 2, opt.mc), |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
) |
|
|
] |
|
|
mult = 2 ** (n_downsampling - 1) |
|
|
|
|
|
if opt.spatio_size == 32: |
|
|
model += [ |
|
|
nn.Conv2d( |
|
|
min(ngf * mult, opt.mc), |
|
|
min(ngf * mult * 2, opt.mc), |
|
|
kernel_size=k_size, |
|
|
stride=2, |
|
|
padding=1, |
|
|
), |
|
|
norm_layer(min(ngf * mult * 2, opt.mc)), |
|
|
activation, |
|
|
] |
|
|
if opt.spatio_size == 64: |
|
|
model += [ |
|
|
ResnetBlock( |
|
|
min(ngf * mult * 2, opt.mc), |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
) |
|
|
] |
|
|
model += [ |
|
|
ResnetBlock( |
|
|
min(ngf * mult * 2, opt.mc), |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
) |
|
|
] |
|
|
|
|
|
if opt.feat_dim > 0: |
|
|
model += [nn.Conv2d(min(ngf * mult * 2, opt.mc), opt.feat_dim, 1, 1)] |
|
|
self.encoder = nn.Sequential(*model) |
|
|
|
|
|
|
|
|
model = [] |
|
|
if opt.feat_dim > 0: |
|
|
model += [nn.Conv2d(opt.feat_dim, min(ngf * mult * 2, opt.mc), 1, 1)] |
|
|
|
|
|
o_pad = 0 if k_size == 4 else 1 |
|
|
mult = 2 ** n_downsampling |
|
|
model += [ |
|
|
ResnetBlock( |
|
|
min(ngf * mult, opt.mc), |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
) |
|
|
] |
|
|
|
|
|
if opt.spatio_size == 32: |
|
|
model += [ |
|
|
nn.ConvTranspose2d( |
|
|
min(ngf * mult, opt.mc), |
|
|
min(int(ngf * mult / 2), opt.mc), |
|
|
kernel_size=k_size, |
|
|
stride=2, |
|
|
padding=1, |
|
|
output_padding=o_pad, |
|
|
), |
|
|
norm_layer(min(int(ngf * mult / 2), opt.mc)), |
|
|
activation, |
|
|
] |
|
|
if opt.spatio_size == 64: |
|
|
model += [ |
|
|
ResnetBlock( |
|
|
min(ngf * mult, opt.mc), |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
) |
|
|
] |
|
|
|
|
|
for i in range(1, n_downsampling - opt.start_r): |
|
|
mult = 2 ** (n_downsampling - i) |
|
|
model += [ |
|
|
ResnetBlock( |
|
|
min(ngf * mult, opt.mc), |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
) |
|
|
] |
|
|
model += [ |
|
|
ResnetBlock( |
|
|
min(ngf * mult, opt.mc), |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
) |
|
|
] |
|
|
model += [ |
|
|
nn.ConvTranspose2d( |
|
|
min(ngf * mult, opt.mc), |
|
|
min(int(ngf * mult / 2), opt.mc), |
|
|
kernel_size=k_size, |
|
|
stride=2, |
|
|
padding=1, |
|
|
output_padding=o_pad, |
|
|
), |
|
|
norm_layer(min(int(ngf * mult / 2), opt.mc)), |
|
|
activation, |
|
|
] |
|
|
for i in range(n_downsampling - opt.start_r, n_downsampling): |
|
|
mult = 2 ** (n_downsampling - i) |
|
|
model += [ |
|
|
nn.ConvTranspose2d( |
|
|
min(ngf * mult, opt.mc), |
|
|
min(int(ngf * mult / 2), opt.mc), |
|
|
kernel_size=k_size, |
|
|
stride=2, |
|
|
padding=1, |
|
|
output_padding=o_pad, |
|
|
), |
|
|
norm_layer(min(int(ngf * mult / 2), opt.mc)), |
|
|
activation, |
|
|
] |
|
|
if opt.use_segmentation_model: |
|
|
model += [nn.ReflectionPad2d(3), nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0)] |
|
|
else: |
|
|
model += [ |
|
|
nn.ReflectionPad2d(3), |
|
|
nn.Conv2d(min(ngf, opt.mc), output_nc, kernel_size=7, padding=0), |
|
|
nn.Tanh(), |
|
|
] |
|
|
self.decoder = nn.Sequential(*model) |
|
|
|
|
|
def forward(self, input, flow="enc_dec"): |
|
|
if flow == "enc": |
|
|
return self.encoder(input) |
|
|
elif flow == "dec": |
|
|
return self.decoder(input) |
|
|
elif flow == "enc_dec": |
|
|
x = self.encoder(input) |
|
|
x = self.decoder(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
|
def __init__( |
|
|
self, dim, padding_type, norm_layer, opt, activation=nn.ReLU(True), use_dropout=False, dilation=1 |
|
|
): |
|
|
super(ResnetBlock, self).__init__() |
|
|
self.opt = opt |
|
|
self.dilation = dilation |
|
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) |
|
|
|
|
|
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): |
|
|
conv_block = [] |
|
|
p = 0 |
|
|
if padding_type == "reflect": |
|
|
conv_block += [nn.ReflectionPad2d(self.dilation)] |
|
|
elif padding_type == "replicate": |
|
|
conv_block += [nn.ReplicationPad2d(self.dilation)] |
|
|
elif padding_type == "zero": |
|
|
p = self.dilation |
|
|
else: |
|
|
raise NotImplementedError("padding [%s] is not implemented" % padding_type) |
|
|
|
|
|
conv_block += [ |
|
|
nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=self.dilation), |
|
|
norm_layer(dim), |
|
|
activation, |
|
|
] |
|
|
if use_dropout: |
|
|
conv_block += [nn.Dropout(0.5)] |
|
|
|
|
|
p = 0 |
|
|
if padding_type == "reflect": |
|
|
conv_block += [nn.ReflectionPad2d(1)] |
|
|
elif padding_type == "replicate": |
|
|
conv_block += [nn.ReplicationPad2d(1)] |
|
|
elif padding_type == "zero": |
|
|
p = 1 |
|
|
else: |
|
|
raise NotImplementedError("padding [%s] is not implemented" % padding_type) |
|
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, dilation=1), norm_layer(dim)] |
|
|
|
|
|
return nn.Sequential(*conv_block) |
|
|
|
|
|
def forward(self, x): |
|
|
out = x + self.conv_block(x) |
|
|
return out |
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): |
|
|
super(Encoder, self).__init__() |
|
|
self.output_nc = output_nc |
|
|
|
|
|
model = [ |
|
|
nn.ReflectionPad2d(3), |
|
|
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), |
|
|
norm_layer(ngf), |
|
|
nn.ReLU(True), |
|
|
] |
|
|
|
|
|
for i in range(n_downsampling): |
|
|
mult = 2 ** i |
|
|
model += [ |
|
|
nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), |
|
|
norm_layer(ngf * mult * 2), |
|
|
nn.ReLU(True), |
|
|
] |
|
|
|
|
|
|
|
|
for i in range(n_downsampling): |
|
|
mult = 2 ** (n_downsampling - i) |
|
|
model += [ |
|
|
nn.ConvTranspose2d( |
|
|
ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1 |
|
|
), |
|
|
norm_layer(int(ngf * mult / 2)), |
|
|
nn.ReLU(True), |
|
|
] |
|
|
|
|
|
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] |
|
|
self.model = nn.Sequential(*model) |
|
|
|
|
|
def forward(self, input, inst): |
|
|
outputs = self.model(input) |
|
|
|
|
|
|
|
|
outputs_mean = outputs.clone() |
|
|
inst_list = np.unique(inst.cpu().numpy().astype(int)) |
|
|
for i in inst_list: |
|
|
for b in range(input.size()[0]): |
|
|
indices = (inst[b: b + 1] == int(i)).nonzero() |
|
|
for j in range(self.output_nc): |
|
|
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3]] |
|
|
mean_feat = torch.mean(output_ins).expand_as(output_ins) |
|
|
outputs_mean[ |
|
|
indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], indices[:, 3] |
|
|
] = mean_feat |
|
|
return outputs_mean |
|
|
|
|
|
|
|
|
def SN(module, mode=True): |
|
|
if mode: |
|
|
return torch.nn.utils.spectral_norm(module) |
|
|
|
|
|
return module |
|
|
|
|
|
|
|
|
class NonLocalBlock2D_with_mask_Res(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
inter_channels, |
|
|
mode="add", |
|
|
re_norm=False, |
|
|
temperature=1.0, |
|
|
use_self=False, |
|
|
cosin=False, |
|
|
): |
|
|
super(NonLocalBlock2D_with_mask_Res, self).__init__() |
|
|
|
|
|
self.cosin = cosin |
|
|
self.renorm = re_norm |
|
|
self.in_channels = in_channels |
|
|
self.inter_channels = inter_channels |
|
|
|
|
|
self.g = nn.Conv2d( |
|
|
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 |
|
|
) |
|
|
|
|
|
self.W = nn.Conv2d( |
|
|
in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn.init.constant_(self.W.weight, 0) |
|
|
nn.init.constant_(self.W.bias, 0) |
|
|
self.theta = nn.Conv2d( |
|
|
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 |
|
|
) |
|
|
|
|
|
self.phi = nn.Conv2d( |
|
|
in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 |
|
|
) |
|
|
|
|
|
self.mode = mode |
|
|
self.temperature = temperature |
|
|
self.use_self = use_self |
|
|
|
|
|
norm_layer = get_norm_layer(norm_type="instance") |
|
|
activation = nn.ReLU(True) |
|
|
|
|
|
model = [] |
|
|
for i in range(3): |
|
|
model += [ |
|
|
ResnetBlock( |
|
|
inter_channels, |
|
|
padding_type="reflect", |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=None, |
|
|
) |
|
|
] |
|
|
self.res_block = nn.Sequential(*model) |
|
|
|
|
|
def forward(self, x, mask): |
|
|
batch_size = x.size(0) |
|
|
|
|
|
g_x = self.g(x).view(batch_size, self.inter_channels, -1) |
|
|
|
|
|
g_x = g_x.permute(0, 2, 1) |
|
|
|
|
|
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) |
|
|
|
|
|
theta_x = theta_x.permute(0, 2, 1) |
|
|
|
|
|
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) |
|
|
|
|
|
if self.cosin: |
|
|
theta_x = F.normalize(theta_x, dim=2) |
|
|
phi_x = F.normalize(phi_x, dim=1) |
|
|
|
|
|
f = torch.matmul(theta_x, phi_x) |
|
|
|
|
|
f /= self.temperature |
|
|
|
|
|
f_div_C = F.softmax(f, dim=2) |
|
|
|
|
|
tmp = 1 - mask |
|
|
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") |
|
|
mask[mask > 0] = 1.0 |
|
|
mask = 1 - mask |
|
|
|
|
|
tmp = F.interpolate(tmp, (x.size(2), x.size(3))) |
|
|
mask *= tmp |
|
|
|
|
|
mask_expand = mask.view(batch_size, 1, -1) |
|
|
mask_expand = mask_expand.repeat(1, x.size(2) * x.size(3), 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_self: |
|
|
mask_expand[:, range(x.size(2) * x.size(3)), range(x.size(2) * x.size(3))] = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f_div_C = mask_expand * f_div_C |
|
|
if self.renorm: |
|
|
f_div_C = F.normalize(f_div_C, p=1, dim=2) |
|
|
|
|
|
|
|
|
|
|
|
y = torch.matmul(f_div_C, g_x) |
|
|
|
|
|
y = y.permute(0, 2, 1).contiguous() |
|
|
|
|
|
y = y.view(batch_size, self.inter_channels, *x.size()[2:]) |
|
|
W_y = self.W(y) |
|
|
|
|
|
W_y = self.res_block(W_y) |
|
|
|
|
|
if self.mode == "combine": |
|
|
full_mask = mask.repeat(1, self.inter_channels, 1, 1) |
|
|
z = full_mask * x + (1 - full_mask) * W_y |
|
|
return z |
|
|
|
|
|
|
|
|
class MultiscaleDiscriminator(nn.Module): |
|
|
def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, |
|
|
use_sigmoid=False, num_D=3, getIntermFeat=False): |
|
|
super(MultiscaleDiscriminator, self).__init__() |
|
|
self.num_D = num_D |
|
|
self.n_layers = n_layers |
|
|
self.getIntermFeat = getIntermFeat |
|
|
|
|
|
for i in range(num_D): |
|
|
netD = NLayerDiscriminator(input_nc, opt, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) |
|
|
if getIntermFeat: |
|
|
for j in range(n_layers + 2): |
|
|
setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j))) |
|
|
else: |
|
|
setattr(self, 'layer' + str(i), netD.model) |
|
|
|
|
|
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) |
|
|
|
|
|
def singleD_forward(self, model, input): |
|
|
if self.getIntermFeat: |
|
|
result = [input] |
|
|
for i in range(len(model)): |
|
|
result.append(model[i](result[-1])) |
|
|
return result[1:] |
|
|
else: |
|
|
return [model(input)] |
|
|
|
|
|
def forward(self, input): |
|
|
num_D = self.num_D |
|
|
result = [] |
|
|
input_downsampled = input |
|
|
for i in range(num_D): |
|
|
if self.getIntermFeat: |
|
|
model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in |
|
|
range(self.n_layers + 2)] |
|
|
else: |
|
|
model = getattr(self, 'layer' + str(num_D - 1 - i)) |
|
|
result.append(self.singleD_forward(model, input_downsampled)) |
|
|
if i != (num_D - 1): |
|
|
input_downsampled = self.downsample(input_downsampled) |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class NLayerDiscriminator(nn.Module): |
|
|
def __init__(self, input_nc, opt, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, |
|
|
getIntermFeat=False): |
|
|
super(NLayerDiscriminator, self).__init__() |
|
|
self.getIntermFeat = getIntermFeat |
|
|
self.n_layers = n_layers |
|
|
|
|
|
kw = 4 |
|
|
padw = int(np.ceil((kw - 1.0) / 2)) |
|
|
sequence = [ |
|
|
[SN(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), opt.use_SN), nn.LeakyReLU(0.2, True)]] |
|
|
|
|
|
nf = ndf |
|
|
for n in range(1, n_layers): |
|
|
nf_prev = nf |
|
|
nf = min(nf * 2, 512) |
|
|
sequence += [[ |
|
|
SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), opt.use_SN), |
|
|
norm_layer(nf), nn.LeakyReLU(0.2, True) |
|
|
]] |
|
|
|
|
|
nf_prev = nf |
|
|
nf = min(nf * 2, 512) |
|
|
sequence += [[ |
|
|
SN(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), opt.use_SN), |
|
|
norm_layer(nf), |
|
|
nn.LeakyReLU(0.2, True) |
|
|
]] |
|
|
|
|
|
sequence += [[SN(nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw), opt.use_SN)]] |
|
|
|
|
|
if use_sigmoid: |
|
|
sequence += [[nn.Sigmoid()]] |
|
|
|
|
|
if getIntermFeat: |
|
|
for n in range(len(sequence)): |
|
|
setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) |
|
|
else: |
|
|
sequence_stream = [] |
|
|
for n in range(len(sequence)): |
|
|
sequence_stream += sequence[n] |
|
|
self.model = nn.Sequential(*sequence_stream) |
|
|
|
|
|
def forward(self, input): |
|
|
if self.getIntermFeat: |
|
|
res = [input] |
|
|
for n in range(self.n_layers + 2): |
|
|
model = getattr(self, 'model' + str(n)) |
|
|
res.append(model(res[-1])) |
|
|
return res[1:] |
|
|
else: |
|
|
return self.model(input) |
|
|
|
|
|
|
|
|
class Patch_Attention_4(nn.Module): |
|
|
def __init__(self, in_channels, inter_channels, patch_size): |
|
|
super(Patch_Attention_4, self).__init__() |
|
|
|
|
|
self.patch_size = patch_size |
|
|
|
|
|
self.F_Combine = nn.Conv2d(in_channels=1025, out_channels=512, kernel_size=3, stride=1, padding=1, bias=True) |
|
|
norm_layer = get_norm_layer(norm_type="instance") |
|
|
activation = nn.ReLU(True) |
|
|
|
|
|
model = [] |
|
|
for i in range(1): |
|
|
model += [ |
|
|
ResnetBlock( |
|
|
inter_channels, |
|
|
padding_type="reflect", |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=None, |
|
|
) |
|
|
] |
|
|
self.res_block = nn.Sequential(*model) |
|
|
|
|
|
def Hard_Compose(self, input, dim, index): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
views = [input.size(0)] + [1 if i != dim else -1 for i in range(1, len(input.size()))] |
|
|
expanse = list(input.size()) |
|
|
expanse[0] = -1 |
|
|
expanse[dim] = -1 |
|
|
index = index.view(views).expand(expanse) |
|
|
return torch.gather(input, dim, index) |
|
|
|
|
|
def forward(self, z, mask): |
|
|
|
|
|
x = self.res_block(z) |
|
|
|
|
|
b, c, h, w = x.shape |
|
|
|
|
|
|
|
|
|
|
|
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") |
|
|
mask[mask > 0] = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_unfold = F.unfold(mask, kernel_size=(self.patch_size, self.patch_size), padding=0, stride=self.patch_size) |
|
|
non_mask_region = (torch.mean(mask_unfold, dim=1, keepdim=True) > 0.6).float() |
|
|
all_patch_num = h * w / self.patch_size / self.patch_size |
|
|
non_mask_region = non_mask_region.repeat(1, int(all_patch_num), 1) |
|
|
|
|
|
x_unfold = F.unfold(x, kernel_size=(self.patch_size, self.patch_size), padding=0, stride=self.patch_size) |
|
|
y_unfold = x_unfold.permute(0, 2, 1) |
|
|
x_unfold_normalized = F.normalize(x_unfold, dim=1) |
|
|
y_unfold_normalized = F.normalize(y_unfold, dim=2) |
|
|
correlation_matrix = torch.bmm(y_unfold_normalized, x_unfold_normalized) |
|
|
correlation_matrix = correlation_matrix.masked_fill(non_mask_region == 1., -1e9) |
|
|
correlation_matrix = F.softmax(correlation_matrix, dim=2) |
|
|
|
|
|
|
|
|
|
|
|
R, max_arg = torch.max(correlation_matrix, dim=2) |
|
|
|
|
|
composed_unfold = self.Hard_Compose(x_unfold, 2, max_arg) |
|
|
composed_fold = F.fold(composed_unfold, output_size=(h, w), kernel_size=(self.patch_size, self.patch_size), |
|
|
padding=0, stride=self.patch_size) |
|
|
|
|
|
concat_1 = torch.cat((z, composed_fold, mask), dim=1) |
|
|
concat_1 = self.F_Combine(concat_1) |
|
|
|
|
|
return concat_1 |
|
|
|
|
|
def inference_forward(self, z, mask): |
|
|
|
|
|
x = self.res_block(z) |
|
|
|
|
|
b, c, h, w = x.shape |
|
|
|
|
|
|
|
|
|
|
|
mask = F.interpolate(mask, (x.size(2), x.size(3)), mode="bilinear") |
|
|
mask[mask > 0] = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_unfold = F.unfold(mask, kernel_size=(self.patch_size, self.patch_size), padding=0, stride=self.patch_size) |
|
|
non_mask_region = (torch.mean(mask_unfold, dim=1, keepdim=True) > 0.6).float()[0, 0, :] |
|
|
|
|
|
all_patch_num = h * w / self.patch_size / self.patch_size |
|
|
|
|
|
mask_index = torch.nonzero(non_mask_region, as_tuple=True)[0] |
|
|
|
|
|
if len(mask_index) == 0: |
|
|
|
|
|
composed_fold = x |
|
|
|
|
|
else: |
|
|
|
|
|
unmask_index = torch.nonzero(non_mask_region != 1, as_tuple=True)[0] |
|
|
|
|
|
x_unfold = F.unfold(x, kernel_size=(self.patch_size, self.patch_size), padding=0, stride=self.patch_size) |
|
|
|
|
|
Query_Patch = torch.index_select(x_unfold, 2, mask_index) |
|
|
Key_Patch = torch.index_select(x_unfold, 2, unmask_index) |
|
|
|
|
|
Query_Patch = Query_Patch.permute(0, 2, 1) |
|
|
Query_Patch_normalized = F.normalize(Query_Patch, dim=2) |
|
|
Key_Patch_normalized = F.normalize(Key_Patch, dim=1) |
|
|
|
|
|
correlation_matrix = torch.bmm(Query_Patch_normalized, Key_Patch_normalized) |
|
|
correlation_matrix = F.softmax(correlation_matrix, dim=2) |
|
|
|
|
|
R, max_arg = torch.max(correlation_matrix, dim=2) |
|
|
|
|
|
composed_unfold = self.Hard_Compose(Key_Patch, 2, max_arg) |
|
|
x_unfold[:, :, mask_index] = composed_unfold |
|
|
composed_fold = F.fold(x_unfold, output_size=(h, w), kernel_size=(self.patch_size, self.patch_size), |
|
|
padding=0, stride=self.patch_size) |
|
|
|
|
|
concat_1 = torch.cat((z, composed_fold, mask), dim=1) |
|
|
concat_1 = self.F_Combine(concat_1) |
|
|
|
|
|
return concat_1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GANLoss(nn.Module): |
|
|
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, |
|
|
tensor=torch.FloatTensor): |
|
|
super(GANLoss, self).__init__() |
|
|
self.real_label = target_real_label |
|
|
self.fake_label = target_fake_label |
|
|
self.real_label_var = None |
|
|
self.fake_label_var = None |
|
|
self.Tensor = tensor |
|
|
if use_lsgan: |
|
|
self.loss = nn.MSELoss() |
|
|
else: |
|
|
self.loss = nn.BCELoss() |
|
|
|
|
|
def get_target_tensor(self, input, target_is_real): |
|
|
target_tensor = None |
|
|
if target_is_real: |
|
|
create_label = ((self.real_label_var is None) or |
|
|
(self.real_label_var.numel() != input.numel())) |
|
|
if create_label: |
|
|
real_tensor = self.Tensor(input.size()).fill_(self.real_label) |
|
|
self.real_label_var = Variable(real_tensor, requires_grad=False) |
|
|
target_tensor = self.real_label_var |
|
|
else: |
|
|
create_label = ((self.fake_label_var is None) or |
|
|
(self.fake_label_var.numel() != input.numel())) |
|
|
if create_label: |
|
|
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) |
|
|
self.fake_label_var = Variable(fake_tensor, requires_grad=False) |
|
|
target_tensor = self.fake_label_var |
|
|
return target_tensor |
|
|
|
|
|
def __call__(self, input, target_is_real): |
|
|
if isinstance(input[0], list): |
|
|
loss = 0 |
|
|
for input_i in input: |
|
|
pred = input_i[-1] |
|
|
target_tensor = self.get_target_tensor(pred, target_is_real) |
|
|
loss += self.loss(pred, target_tensor) |
|
|
return loss |
|
|
else: |
|
|
target_tensor = self.get_target_tensor(input[-1], target_is_real) |
|
|
return self.loss(input[-1], target_tensor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torchvision import models |
|
|
|
|
|
|
|
|
class VGG19_torch(torch.nn.Module): |
|
|
def __init__(self, requires_grad=False): |
|
|
super(VGG19_torch, self).__init__() |
|
|
vgg_pretrained_features = models.vgg19(pretrained=True).features |
|
|
self.slice1 = torch.nn.Sequential() |
|
|
self.slice2 = torch.nn.Sequential() |
|
|
self.slice3 = torch.nn.Sequential() |
|
|
self.slice4 = torch.nn.Sequential() |
|
|
self.slice5 = torch.nn.Sequential() |
|
|
for x in range(2): |
|
|
self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
|
|
for x in range(2, 7): |
|
|
self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
|
|
for x in range(7, 12): |
|
|
self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
|
|
for x in range(12, 21): |
|
|
self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
|
|
for x in range(21, 30): |
|
|
self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
|
|
if not requires_grad: |
|
|
for param in self.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, X): |
|
|
h_relu1 = self.slice1(X) |
|
|
h_relu2 = self.slice2(h_relu1) |
|
|
h_relu3 = self.slice3(h_relu2) |
|
|
h_relu4 = self.slice4(h_relu3) |
|
|
h_relu5 = self.slice5(h_relu4) |
|
|
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] |
|
|
return out |
|
|
|
|
|
|
|
|
class VGGLoss_torch(nn.Module): |
|
|
def __init__(self, gpu_ids): |
|
|
super(VGGLoss_torch, self).__init__() |
|
|
self.vgg = VGG19_torch().cuda() |
|
|
self.criterion = nn.L1Loss() |
|
|
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] |
|
|
|
|
|
def forward(self, x, y): |
|
|
x_vgg, y_vgg = self.vgg(x), self.vgg(y) |
|
|
loss = 0 |
|
|
for i in range(len(x_vgg)): |
|
|
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) |
|
|
return loss |
|
|
|