|
|
"""
|
|
|
Adapted from: https://github.com/Maggiking/AdaIN-Style-Transfer-PyTorch
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import logging
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.jit
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
vggnet = nn.Sequential(
|
|
|
|
|
|
nn.Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1)),
|
|
|
nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
|
|
|
|
|
|
nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
|
|
|
nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
|
|
|
|
|
|
nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
|
|
|
|
|
|
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True)
|
|
|
)
|
|
|
|
|
|
|
|
|
encoder = nn.Sequential(
|
|
|
nn.Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1)),
|
|
|
nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
|
|
|
|
|
|
nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
|
|
|
|
|
|
nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
|
|
|
|
|
|
nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True)
|
|
|
)
|
|
|
|
|
|
|
|
|
decoder = nn.Sequential(
|
|
|
nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Upsample(scale_factor=2, mode='nearest'),
|
|
|
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Upsample(scale_factor=2, mode='nearest'),
|
|
|
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Upsample(scale_factor=2, mode='nearest'),
|
|
|
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
|
|
|
|
|
|
class AdaIN(nn.Module):
|
|
|
def __init__(self, eps=1e-5):
|
|
|
super().__init__()
|
|
|
self.eps = eps
|
|
|
|
|
|
def forward(self, x_content, moments_list, pretrain=False):
|
|
|
means_style, stds_style = moments_list[0], moments_list[1]
|
|
|
|
|
|
|
|
|
means_content = torch.mean(x_content, dim=[2, 3], keepdim=True)
|
|
|
stds_content = torch.std(x_content, dim=[2, 3], keepdim=True) + self.eps
|
|
|
|
|
|
|
|
|
rand_sample_nr = torch.randint(means_style.shape[0], size=x_content.shape[:1]) if pretrain else -1 * torch.arange(1, x_content.shape[0] + 1)
|
|
|
means_style = means_style[rand_sample_nr].unsqueeze(-1).unsqueeze(-1)
|
|
|
stds_style = stds_style[rand_sample_nr].unsqueeze(-1).unsqueeze(-1) + self.eps
|
|
|
|
|
|
adain_out = (x_content - means_content) / stds_content * stds_style + means_style
|
|
|
|
|
|
return adain_out, means_style.squeeze(), stds_style.squeeze()
|
|
|
|
|
|
|
|
|
class TransferNet(nn.Module):
|
|
|
def __init__(self, ckpt_path_vgg, ckpt_path_dec, data_loader, num_iters_pretrain=20000):
|
|
|
"""
|
|
|
Style transfer network
|
|
|
:param vgg_model: Path to ImageNet pre-trained vgg19 model
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.mse_criterion = nn.MSELoss()
|
|
|
self.data_loader = data_loader
|
|
|
self.data_loader_iter = iter(data_loader)
|
|
|
|
|
|
|
|
|
vgg_model = torch.load(ckpt_path_vgg)
|
|
|
vggnet.load_state_dict(vgg_model)
|
|
|
|
|
|
|
|
|
self.encoder = encoder.cuda()
|
|
|
self.encoder.load_state_dict(vggnet[:21].state_dict())
|
|
|
for parameter in self.encoder.parameters():
|
|
|
parameter.requires_grad = False
|
|
|
|
|
|
|
|
|
self.decoder = decoder.cuda()
|
|
|
self.adain = AdaIN()
|
|
|
|
|
|
|
|
|
if not os.path.isfile(ckpt_path_dec):
|
|
|
logger.info(f"Start pre-training the style transfer model...")
|
|
|
self.opt_adain_dec = torch.optim.Adam(self.decoder.parameters(), lr=1e-4)
|
|
|
self.pretrain_adain(final_ckpt_path=ckpt_path_dec, num_iters=num_iters_pretrain)
|
|
|
|
|
|
|
|
|
checkpoint = torch.load(ckpt_path_dec, map_location="cuda")
|
|
|
self.decoder.load_state_dict(checkpoint['decoder'])
|
|
|
logger.info(f"Successfully loaded AdaIN checkpoint: {ckpt_path_dec}")
|
|
|
self.opt_adain_dec = torch.optim.Adam(self.decoder.parameters(), lr=1e-4)
|
|
|
|
|
|
def pretrain_adain(self, final_ckpt_path, num_iters=20000):
|
|
|
|
|
|
moments_list = [[torch.tensor([], device="cuda"), torch.tensor([], device="cuda")] for _ in range(2)]
|
|
|
n_samples = 0
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for images, labels in self.data_loader:
|
|
|
n_samples += images.shape[0]
|
|
|
|
|
|
out_adain = self.forward(images=images.cuda())
|
|
|
|
|
|
|
|
|
for i_adain_layer, (means, stds) in enumerate(out_adain):
|
|
|
moments_list[i_adain_layer][0] = torch.cat([moments_list[i_adain_layer][0], means], dim=0)
|
|
|
moments_list[i_adain_layer][1] = torch.cat([moments_list[i_adain_layer][1], stds], dim=0)
|
|
|
|
|
|
if n_samples >= 100000:
|
|
|
break
|
|
|
|
|
|
|
|
|
self.train()
|
|
|
avg_loss = 0
|
|
|
avg_loss_content = 0
|
|
|
avg_loss_style = 0
|
|
|
|
|
|
for i in range(1, num_iters + 1):
|
|
|
try:
|
|
|
images, labels_src = next(self.data_loader_iter)
|
|
|
except StopIteration:
|
|
|
self.data_loader_iter = iter(self.data_loader)
|
|
|
images, labels = next(self.data_loader_iter)
|
|
|
|
|
|
|
|
|
self.opt_adain_dec.zero_grad()
|
|
|
gen_img, loss_content, loss_style = self.forward(images=images.cuda(),
|
|
|
moments_list=moments_list,
|
|
|
pretrain=True)
|
|
|
|
|
|
loss = loss_content + 0.1 * loss_style
|
|
|
loss.backward()
|
|
|
self.opt_adain_dec.step()
|
|
|
|
|
|
avg_loss += loss.item()
|
|
|
avg_loss_content += loss_content.item()
|
|
|
avg_loss_style += loss_style.item()
|
|
|
|
|
|
if i % 500 == 0:
|
|
|
logger.info(f"[{i}/{num_iters}] loss: {avg_loss / 500:.4f}, "
|
|
|
f"content: {avg_loss_content / 500:.4f}, style: {avg_loss_style / 500:.4f}")
|
|
|
avg_loss = 0
|
|
|
avg_loss_content = 0
|
|
|
avg_loss_style = 0
|
|
|
|
|
|
ckpt_dict = {'decoder': self.decoder.state_dict()}
|
|
|
torch.save(ckpt_dict, final_ckpt_path)
|
|
|
logger.info(f"Saved pre-trained AdaIN model to: {final_ckpt_path}")
|
|
|
|
|
|
def _calculate_moments(self, x):
|
|
|
means = torch.mean(x, dim=[2, 3])
|
|
|
stds = torch.std(x, dim=[2, 3])
|
|
|
return means, stds
|
|
|
|
|
|
def forward(self, images, moments_list=None, pretrain=False):
|
|
|
|
|
|
fm11_enc = self.encoder[:5](images)
|
|
|
out_encoder = self.encoder[5:](fm11_enc)
|
|
|
|
|
|
if moments_list is None:
|
|
|
|
|
|
means_fm11, stds_fm11 = self._calculate_moments(fm11_enc)
|
|
|
means_enc, stds_enc = self._calculate_moments(out_encoder)
|
|
|
return [means_fm11, stds_fm11], [means_enc, stds_enc]
|
|
|
|
|
|
else:
|
|
|
|
|
|
out_encoder, means_style_enc, stds_style_enc = self.adain(out_encoder, moments_list=moments_list[1], pretrain=pretrain)
|
|
|
|
|
|
|
|
|
fm11_dec = self.decoder[:17](out_encoder)
|
|
|
|
|
|
|
|
|
fm11_enc, means_style_11, stds_style_11 = self.adain(fm11_enc, moments_list=moments_list[0], pretrain=pretrain)
|
|
|
|
|
|
fm11_dec = torch.add(fm11_dec, fm11_enc)
|
|
|
gen_img = self.decoder[17:](fm11_dec)
|
|
|
|
|
|
if self.training:
|
|
|
|
|
|
fm11_gen = self.encoder[:5](gen_img)
|
|
|
encode_gen = self.encoder[5:](fm11_gen)
|
|
|
|
|
|
means_gen_11, stds_gen_11 = self._calculate_moments(fm11_gen)
|
|
|
means_gen_enc, stds_gen_enc = self._calculate_moments(encode_gen)
|
|
|
|
|
|
|
|
|
loss_content = self.mse_criterion(encode_gen, out_encoder)
|
|
|
loss_style = self.mse_criterion(means_gen_11, means_style_11) + self.mse_criterion(means_gen_enc, means_style_enc) + \
|
|
|
self.mse_criterion(stds_gen_11, stds_style_11) + self.mse_criterion(stds_gen_enc, stds_style_enc)
|
|
|
|
|
|
return gen_img, loss_content, loss_style
|
|
|
|
|
|
else:
|
|
|
return gen_img
|
|
|
|