|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchvision |
|
|
|
|
|
|
|
|
class VGG19(torch.nn.Module): |
|
|
def __init__(self, requires_grad=False): |
|
|
super().__init__() |
|
|
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features |
|
|
self.slice1 = torch.nn.Sequential() |
|
|
self.slice2 = torch.nn.Sequential() |
|
|
self.slice3 = torch.nn.Sequential() |
|
|
self.slice4 = torch.nn.Sequential() |
|
|
self.slice5 = torch.nn.Sequential() |
|
|
for x in range(2): |
|
|
self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
|
|
for x in range(2, 7): |
|
|
self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
|
|
for x in range(7, 12): |
|
|
self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
|
|
for x in range(12, 21): |
|
|
self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
|
|
for x in range(21, 30): |
|
|
self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
|
|
if not requires_grad: |
|
|
for param in self.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, X): |
|
|
h_relu1 = self.slice1(X) |
|
|
h_relu2 = self.slice2(h_relu1) |
|
|
h_relu3 = self.slice3(h_relu2) |
|
|
h_relu4 = self.slice4(h_relu3) |
|
|
h_relu5 = self.slice5(h_relu4) |
|
|
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] |
|
|
return out |
|
|
|
|
|
|
|
|
class VGGLoss(nn.Module): |
|
|
def __init__(self): |
|
|
super(VGGLoss, self).__init__() |
|
|
self.vgg = VGG19().cuda() |
|
|
self.criterion = nn.L1Loss() |
|
|
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] |
|
|
|
|
|
def forward(self, x, y): |
|
|
x_vgg, y_vgg = self.vgg(x), self.vgg(y) |
|
|
loss = 0 |
|
|
for i in range(len(x_vgg)): |
|
|
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) |
|
|
return loss |
|
|
|
|
|
|
|
|
class VGGPerceptualLoss(torch.nn.Module): |
|
|
def __init__(self, lam=1, lam_p=1): |
|
|
super(VGGPerceptualLoss, self).__init__() |
|
|
self.loss_fn = VGGPerceptualLoss() |
|
|
|
|
|
def forward(self, out, gt): |
|
|
loss = self.loss_fn(out, gt, feature_layers=[2]) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
class VGGPerceptualLoss(torch.nn.Module): |
|
|
def __init__(self, resize=True): |
|
|
super(VGGPerceptualLoss, self).__init__() |
|
|
blocks = [] |
|
|
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) |
|
|
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) |
|
|
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) |
|
|
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) |
|
|
for bl in blocks: |
|
|
for p in bl: |
|
|
p.requires_grad = False |
|
|
self.blocks = torch.nn.ModuleList(blocks).cuda() |
|
|
self.transform = torch.nn.functional.interpolate |
|
|
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)).cuda() |
|
|
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)).cuda() |
|
|
self.resize = resize |
|
|
|
|
|
def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]): |
|
|
if input.shape[1] != 3: |
|
|
input = input.repeat(1, 3, 1, 1) |
|
|
target = target.repeat(1, 3, 1, 1) |
|
|
input = (input - self.mean) / self.std |
|
|
target = (target - self.mean) / self.std |
|
|
if self.resize: |
|
|
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) |
|
|
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) |
|
|
loss = 0.0 |
|
|
x = input |
|
|
y = target |
|
|
for i, block in enumerate(self.blocks): |
|
|
x = block(x) |
|
|
y = block(y) |
|
|
if i in feature_layers: |
|
|
loss += torch.nn.functional.l1_loss(x, y) |
|
|
if i in style_layers: |
|
|
act_x = x.reshape(x.shape[0], x.shape[1], -1) |
|
|
act_y = y.reshape(y.shape[0], y.shape[1], -1) |
|
|
gram_x = act_x @ act_x.permute(0, 2, 1) |
|
|
gram_y = act_y @ act_y.permute(0, 2, 1) |
|
|
loss += torch.nn.functional.l1_loss(gram_x, gram_y) |
|
|
return loss |
|
|
|
|
|
|
|
|
def scharr(x): |
|
|
b, c, h, w = x.shape |
|
|
pad = nn.ReplicationPad2d(padding=(1, 1, 1, 1)) |
|
|
x = pad(x) |
|
|
kx = F.unfold(x, kernel_size=3, stride=1, padding=0) |
|
|
kx = kx.permute([0, 2, 1]) |
|
|
|
|
|
|
|
|
w1 = torch.tensor([-3, 0, 3, -10, 0, 10, -3, 0, 3]).float().cuda() |
|
|
w2 = torch.tensor([-3, -10, -3, 0, 0, 0, 3, 10, 3]).float().cuda() |
|
|
|
|
|
y1 = torch.matmul((kx * 255.0), w1) |
|
|
y2 = torch.matmul((kx * 255.0), w2) |
|
|
|
|
|
y1 = y1.unsqueeze(-1).permute([0, 2, 1]) |
|
|
|
|
|
y2 = y2.unsqueeze(-1).permute([0, 2, 1]) |
|
|
|
|
|
y1 = F.fold(y1, output_size=(h, w), kernel_size=1) |
|
|
y2 = F.fold(y2, output_size=(h, w), kernel_size=1) |
|
|
y1 = y1.clamp(-255, 255) |
|
|
y2 = y2.clamp(-255, 255) |
|
|
return (0.5 * torch.abs(y1) + 0.5 * torch.abs(y2)) / 255.0 |
|
|
|
|
|
|
|
|
def gram_matrix(input): |
|
|
a, b, c, d = input.size() |
|
|
|
|
|
|
|
|
|
|
|
features = input.reshape(a * b, c * d) |
|
|
|
|
|
G = torch.mm(features, features.t()) |
|
|
|
|
|
|
|
|
|
|
|
return G.div(a * b * c * d) |
|
|
|
|
|
|
|
|
class StyleLoss(nn.Module): |
|
|
def __init__(self): |
|
|
super(StyleLoss, self).__init__() |
|
|
|
|
|
def forward(self, input_fea, target_fea): |
|
|
target = gram_matrix(target_fea).detach() |
|
|
G = gram_matrix(input_fea) |
|
|
loss = F.mse_loss(G, target) |
|
|
return loss |
|
|
|
|
|
|
|
|
def cos_loss(feat1, feat2): |
|
|
|
|
|
return -F.cosine_similarity(feat1, feat2).mean() |
|
|
|
|
|
|
|
|
def feat_scharr(x): |
|
|
x = torch.mean(x, dim=1, keepdim=True) |
|
|
x = (x - x.min()) / (x.max() - x.min()) |
|
|
x = x * 255 |
|
|
return scharr(x) |
|
|
|
|
|
|
|
|
def feat_ssim(feat1, feat2, gt): |
|
|
mask = scharr(torch.mean(gt, dim=1, keepdim=True)) |
|
|
|
|
|
mask = F.interpolate(mask, size=(feat1.shape[2], feat1.shape[3]), mode="bicubic") |
|
|
loss = torch.abs(feat1 - feat2) * mask |
|
|
return torch.mean(loss), mask |
|
|
|
|
|
|
|
|
def similarity_loss(f_s, f_t): |
|
|
def at(f): |
|
|
return F.normalize(f.pow(2).mean(1).view(f.size(0), -1)) |
|
|
|
|
|
return (at(f_s) - at(f_t)).pow(2).mean() |
|
|
|
|
|
|
|
|
class RBF(nn.Module): |
|
|
|
|
|
def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None): |
|
|
super().__init__() |
|
|
self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2) |
|
|
self.bandwidth = bandwidth |
|
|
|
|
|
def get_bandwidth(self, L2_distances): |
|
|
if self.bandwidth is None: |
|
|
n_samples = L2_distances.shape[0] |
|
|
return L2_distances.data.sum() / (n_samples ** 2 - n_samples) |
|
|
|
|
|
return self.bandwidth |
|
|
|
|
|
def forward(self, X): |
|
|
L2_distances = torch.cdist(X, X) ** 2 |
|
|
|
|
|
return torch.exp( |
|
|
-L2_distances[None, ...].cuda() / (self.get_bandwidth(L2_distances).cuda() * self.bandwidth_multipliers.cuda())[:, None, |
|
|
None]).sum(dim=0) |
|
|
|
|
|
|
|
|
class MMDLoss(nn.Module): |
|
|
|
|
|
def __init__(self, kernel=RBF()): |
|
|
super().__init__() |
|
|
self.kernel = kernel.cuda() |
|
|
|
|
|
def forward(self, X, Y): |
|
|
K = self.kernel(torch.vstack([X, Y])) |
|
|
|
|
|
X_size = X.shape[0] |
|
|
XX = K[:X_size, :X_size].mean() |
|
|
XY = K[:X_size, X_size:].mean() |
|
|
YY = K[X_size:, X_size:].mean() |
|
|
return XX - 2 * XY + YY |
|
|
|