|
|
|
|
|
|
|
|
import logging |
|
|
import torch.nn as nn |
|
|
from . import networks |
|
|
|
|
|
|
|
|
class Mapping_Model_with_mask(nn.Module): |
|
|
def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): |
|
|
super(Mapping_Model_with_mask, self).__init__() |
|
|
|
|
|
norm_layer = networks.get_norm_layer(norm_type=norm) |
|
|
activation = nn.ReLU(True) |
|
|
model = [] |
|
|
|
|
|
tmp_nc = 64 |
|
|
n_up = 4 |
|
|
|
|
|
for i in range(n_up): |
|
|
ic = min(tmp_nc * (2 ** i), mc) |
|
|
oc = min(tmp_nc * (2 ** (i + 1)), mc) |
|
|
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] |
|
|
|
|
|
self.before_NL = nn.Sequential(*model) |
|
|
|
|
|
if opt.NL_res: |
|
|
self.NL = networks.NonLocalBlock2D_with_mask_Res( |
|
|
mc, |
|
|
mc, |
|
|
opt.NL_fusion_method, |
|
|
opt.correlation_renormalize, |
|
|
opt.softmax_temperature, |
|
|
opt.use_self, |
|
|
opt.cosin_similarity, |
|
|
) |
|
|
print("You are using NL + Res") |
|
|
|
|
|
model = [] |
|
|
for i in range(n_blocks): |
|
|
model += [ |
|
|
networks.ResnetBlock( |
|
|
mc, |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
dilation=opt.mapping_net_dilation, |
|
|
) |
|
|
] |
|
|
|
|
|
for i in range(n_up - 1): |
|
|
ic = min(64 * (2 ** (4 - i)), mc) |
|
|
oc = min(64 * (2 ** (3 - i)), mc) |
|
|
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] |
|
|
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] |
|
|
if opt.feat_dim > 0 and opt.feat_dim < 64: |
|
|
model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] |
|
|
|
|
|
self.after_NL = nn.Sequential(*model) |
|
|
|
|
|
def forward(self, input, mask): |
|
|
x1 = self.before_NL(input) |
|
|
del input |
|
|
x2 = self.NL(x1, mask) |
|
|
del x1, mask |
|
|
x3 = self.after_NL(x2) |
|
|
del x2 |
|
|
|
|
|
return x3 |
|
|
|
|
|
|
|
|
class Mapping_Model_with_mask_2(nn.Module): |
|
|
def __init__(self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None): |
|
|
super(Mapping_Model_with_mask_2, self).__init__() |
|
|
|
|
|
norm_layer = networks.get_norm_layer(norm_type=norm) |
|
|
activation = nn.ReLU(True) |
|
|
model = [] |
|
|
|
|
|
tmp_nc = 64 |
|
|
n_up = 4 |
|
|
|
|
|
for i in range(n_up): |
|
|
ic = min(tmp_nc * (2 ** i), mc) |
|
|
oc = min(tmp_nc * (2 ** (i + 1)), mc) |
|
|
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] |
|
|
|
|
|
for i in range(2): |
|
|
model += [ |
|
|
networks.ResnetBlock( |
|
|
mc, |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
dilation=opt.mapping_net_dilation, |
|
|
) |
|
|
] |
|
|
|
|
|
logging.info("Mapping: You are using multi-scale patch attention, conv combine + mask input") |
|
|
|
|
|
self.before_NL = nn.Sequential(*model) |
|
|
|
|
|
if opt.mapping_exp == 1: |
|
|
self.NL_scale_1 = networks.Patch_Attention_4(mc, mc, 8) |
|
|
|
|
|
model = [] |
|
|
for i in range(2): |
|
|
model += [ |
|
|
networks.ResnetBlock( |
|
|
mc, |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
dilation=opt.mapping_net_dilation, |
|
|
) |
|
|
] |
|
|
|
|
|
self.res_block_1 = nn.Sequential(*model) |
|
|
|
|
|
if opt.mapping_exp == 1: |
|
|
self.NL_scale_2 = networks.Patch_Attention_4(mc, mc, 4) |
|
|
|
|
|
model = [] |
|
|
for i in range(2): |
|
|
model += [ |
|
|
networks.ResnetBlock( |
|
|
mc, |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
dilation=opt.mapping_net_dilation, |
|
|
) |
|
|
] |
|
|
|
|
|
self.res_block_2 = nn.Sequential(*model) |
|
|
|
|
|
if opt.mapping_exp == 1: |
|
|
self.NL_scale_3 = networks.Patch_Attention_4(mc, mc, 2) |
|
|
|
|
|
|
|
|
model = [] |
|
|
for i in range(2): |
|
|
model += [ |
|
|
networks.ResnetBlock( |
|
|
mc, |
|
|
padding_type=padding_type, |
|
|
activation=activation, |
|
|
norm_layer=norm_layer, |
|
|
opt=opt, |
|
|
dilation=opt.mapping_net_dilation, |
|
|
) |
|
|
] |
|
|
|
|
|
for i in range(n_up - 1): |
|
|
ic = min(64 * (2 ** (4 - i)), mc) |
|
|
oc = min(64 * (2 ** (3 - i)), mc) |
|
|
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation] |
|
|
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)] |
|
|
if opt.feat_dim > 0 and opt.feat_dim < 64: |
|
|
model += [norm_layer(tmp_nc), activation, nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1)] |
|
|
|
|
|
self.after_NL = nn.Sequential(*model) |
|
|
|
|
|
def forward(self, input, mask): |
|
|
x1 = self.before_NL(input) |
|
|
x2 = self.NL_scale_1(x1, mask) |
|
|
x3 = self.res_block_1(x2) |
|
|
x4 = self.NL_scale_2(x3, mask) |
|
|
x5 = self.res_block_2(x4) |
|
|
x6 = self.NL_scale_3(x5, mask) |
|
|
x7 = self.after_NL(x6) |
|
|
return x7 |
|
|
|
|
|
def inference_forward(self, input, mask): |
|
|
x1 = self.before_NL(input) |
|
|
del input |
|
|
x2 = self.NL_scale_1.inference_forward(x1, mask) |
|
|
del x1 |
|
|
x3 = self.res_block_1(x2) |
|
|
del x2 |
|
|
x4 = self.NL_scale_2.inference_forward(x3, mask) |
|
|
del x3 |
|
|
x5 = self.res_block_2(x4) |
|
|
del x4 |
|
|
x6 = self.NL_scale_3.inference_forward(x5, mask) |
|
|
del x5 |
|
|
x7 = self.after_NL(x6) |
|
|
del x6 |
|
|
return x7 |
|
|
|