|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
import numpy as np |
|
|
from util import box_ops |
|
|
from util.misc import (NestedTensor, nested_tensor_from_tensor_list, |
|
|
accuracy, get_world_size, interpolate, |
|
|
is_dist_avail_and_initialized) |
|
|
from function import normal,normal_style |
|
|
from function import calc_mean_std |
|
|
import scipy.stats as stats |
|
|
from models.ViT_helper import DropPath, to_2tuple, trunc_normal_ |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
|
""" Image to Patch Embedding |
|
|
""" |
|
|
def __init__(self, img_size=256, patch_size=8, in_chans=3, embed_dim=512): |
|
|
super().__init__() |
|
|
img_size = to_2tuple(img_size) |
|
|
patch_size = to_2tuple(patch_size) |
|
|
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) |
|
|
self.img_size = img_size |
|
|
self.patch_size = patch_size |
|
|
self.num_patches = num_patches |
|
|
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
|
self.up1 = nn.Upsample(scale_factor=2, mode='nearest') |
|
|
|
|
|
def forward(self, x): |
|
|
B, C, H, W = x.shape |
|
|
print(f"PatchEmbed Input: {x.shape}") |
|
|
x = self.proj(x) |
|
|
print(f"PatchEmbed Output: {x.shape}") |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
decoder = nn.Sequential( |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(512, 256, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.Upsample(scale_factor=2, mode='nearest'), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(256, 256, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(256, 256, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(256, 256, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(256, 128, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.Upsample(scale_factor=2, mode='nearest'), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(128, 128, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(128, 64, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.Upsample(scale_factor=2, mode='nearest'), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(64, 64, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(64, 3, (3, 3)), |
|
|
) |
|
|
|
|
|
for name, module in decoder.named_children(): |
|
|
def hook(module, input, output): |
|
|
print(f"{module.__class__.__name__} Input: {input[0].shape}") |
|
|
print(f"{module.__class__.__name__} Output: {output.shape}") |
|
|
module.register_forward_hook(hook) |
|
|
|
|
|
|
|
|
vgg = nn.Sequential( |
|
|
nn.Conv2d(3, 3, (1, 1)), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(3, 64, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(64, 64, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(64, 128, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(128, 128, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(128, 256, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(256, 256, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(256, 256, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(256, 256, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(256, 512, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(512, 512, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(512, 512, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(512, 512, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(512, 512, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(512, 512, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(512, 512, (3, 3)), |
|
|
nn.ReLU(), |
|
|
nn.ReflectionPad2d((1, 1, 1, 1)), |
|
|
nn.Conv2d(512, 512, (3, 3)), |
|
|
nn.ReLU() |
|
|
) |
|
|
|
|
|
|
|
|
for name, module in vgg.named_children(): |
|
|
def hook(module, input, output): |
|
|
print(f"{module.__class__.__name__} Input: {input[0].shape}") |
|
|
print(f"{module.__class__.__name__} Output: {output.shape}") |
|
|
module.register_forward_hook(hook) |
|
|
|
|
|
class MLP(nn.Module): |
|
|
""" Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
|
super().__init__() |
|
|
self.num_layers = num_layers |
|
|
h = [hidden_dim] * (num_layers - 1) |
|
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
|
|
|
def forward(self, x): |
|
|
for i, layer in enumerate(self.layers): |
|
|
print(f"MLP Layer {i} Input: {x.shape}") |
|
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
|
print(f"MLP Layer {i} Output: {x.shape}") |
|
|
return x |
|
|
class StyTrans(nn.Module): |
|
|
""" This is the style transform transformer module """ |
|
|
|
|
|
def __init__(self,encoder,decoder,PatchEmbed, transformer,args): |
|
|
|
|
|
super().__init__() |
|
|
enc_layers = list(encoder.children()) |
|
|
self.enc_1 = nn.Sequential(*enc_layers[:4]) |
|
|
self.enc_2 = nn.Sequential(*enc_layers[4:11]) |
|
|
self.enc_3 = nn.Sequential(*enc_layers[11:18]) |
|
|
self.enc_4 = nn.Sequential(*enc_layers[18:31]) |
|
|
self.enc_5 = nn.Sequential(*enc_layers[31:44]) |
|
|
|
|
|
for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']: |
|
|
for param in getattr(self, name).parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
self.mse_loss = nn.MSELoss() |
|
|
self.transformer = transformer |
|
|
hidden_dim = transformer.d_model |
|
|
self.decode = decoder |
|
|
self.embedding = PatchEmbed |
|
|
|
|
|
def encode_with_intermediate(self, input): |
|
|
results = [input] |
|
|
for i in range(5): |
|
|
func = getattr(self, 'enc_{:d}'.format(i + 1)) |
|
|
results.append(func(results[-1])) |
|
|
return results[1:] |
|
|
|
|
|
def calc_content_loss(self, input, target): |
|
|
assert (input.size() == target.size()) |
|
|
assert (target.requires_grad is False) |
|
|
return self.mse_loss(input, target) |
|
|
|
|
|
def calc_style_loss(self, input, target): |
|
|
assert (input.size() == target.size()) |
|
|
assert (target.requires_grad is False) |
|
|
input_mean, input_std = calc_mean_std(input) |
|
|
target_mean, target_std = calc_mean_std(target) |
|
|
return self.mse_loss(input_mean, target_mean) + \ |
|
|
self.mse_loss(input_std, target_std) |
|
|
def forward(self, samples_c: NestedTensor,samples_s: NestedTensor): |
|
|
""" The forward expects a NestedTensor, which consists of: |
|
|
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] |
|
|
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels |
|
|
|
|
|
""" |
|
|
content_input = samples_c |
|
|
style_input = samples_s |
|
|
if isinstance(samples_c, (list, torch.Tensor)): |
|
|
samples_c = nested_tensor_from_tensor_list(samples_c) |
|
|
if isinstance(samples_s, (list, torch.Tensor)): |
|
|
samples_s = nested_tensor_from_tensor_list(samples_s) |
|
|
|
|
|
|
|
|
content_feats = self.encode_with_intermediate(samples_c.tensors) |
|
|
style_feats = self.encode_with_intermediate(samples_s.tensors) |
|
|
|
|
|
|
|
|
print(f"Embedding Content Input: {samples_c.tensors.shape}") |
|
|
style = self.embedding(samples_s.tensors) |
|
|
print(f"Style Output: {style.shape}") |
|
|
content = self.embedding(samples_c.tensors) |
|
|
print(f"Embedding Content Output: {content.shape}") |
|
|
|
|
|
|
|
|
pos_s = None |
|
|
pos_c = None |
|
|
|
|
|
mask = None |
|
|
hs = self.transformer(style, mask , content, pos_c, pos_s) |
|
|
Ics = self.decode(hs) |
|
|
|
|
|
Ics_feats = self.encode_with_intermediate(Ics) |
|
|
loss_c = self.calc_content_loss(normal(Ics_feats[-1]), normal(content_feats[-1]))+self.calc_content_loss(normal(Ics_feats[-2]), normal(content_feats[-2])) |
|
|
|
|
|
loss_s = self.calc_style_loss(Ics_feats[0], style_feats[0]) |
|
|
for i in range(1, 5): |
|
|
loss_s += self.calc_style_loss(Ics_feats[i], style_feats[i]) |
|
|
|
|
|
|
|
|
Icc = self.decode(self.transformer(content, mask , content, pos_c, pos_c)) |
|
|
Iss = self.decode(self.transformer(style, mask , style, pos_s, pos_s)) |
|
|
|
|
|
|
|
|
loss_lambda1 = self.calc_content_loss(Icc,content_input)+self.calc_content_loss(Iss,style_input) |
|
|
|
|
|
|
|
|
Icc_feats=self.encode_with_intermediate(Icc) |
|
|
Iss_feats=self.encode_with_intermediate(Iss) |
|
|
loss_lambda2 = self.calc_content_loss(Icc_feats[0], content_feats[0])+self.calc_content_loss(Iss_feats[0], style_feats[0]) |
|
|
for i in range(1, 5): |
|
|
loss_lambda2 += self.calc_content_loss(Icc_feats[i], content_feats[i])+self.calc_content_loss(Iss_feats[i], style_feats[i]) |
|
|
|
|
|
return Ics, loss_c, loss_s, loss_lambda1, loss_lambda2 |
|
|
|