LMAR / loss.py
SuperPatrick's picture
Upload 20 files
98feea6 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
class VGG19(torch.nn.Module):
def __init__(self, requires_grad=False):
super().__init__()
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class VGGLoss(nn.Module):
def __init__(self):
super(VGGLoss, self).__init__()
self.vgg = VGG19().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, lam=1, lam_p=1):
super(VGGPerceptualLoss, self).__init__()
self.loss_fn = VGGPerceptualLoss()
def forward(self, out, gt):
loss = self.loss_fn(out, gt, feature_layers=[2])
return loss
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, resize=True):
super(VGGPerceptualLoss, self).__init__()
blocks = []
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
for bl in blocks:
for p in bl:
p.requires_grad = False
self.blocks = torch.nn.ModuleList(blocks).cuda()
self.transform = torch.nn.functional.interpolate
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)).cuda()
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)).cuda()
self.resize = resize
def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
input = (input - self.mean) / self.std
target = (target - self.mean) / self.std
if self.resize:
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
loss = 0.0
x = input
y = target
for i, block in enumerate(self.blocks):
x = block(x)
y = block(y)
if i in feature_layers:
loss += torch.nn.functional.l1_loss(x, y)
if i in style_layers:
act_x = x.reshape(x.shape[0], x.shape[1], -1)
act_y = y.reshape(y.shape[0], y.shape[1], -1)
gram_x = act_x @ act_x.permute(0, 2, 1)
gram_y = act_y @ act_y.permute(0, 2, 1)
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
return loss
def scharr(x): # 输入前对RGB通道求均值在灰度图上算
b, c, h, w = x.shape
pad = nn.ReplicationPad2d(padding=(1, 1, 1, 1))
x = pad(x)
kx = F.unfold(x, kernel_size=3, stride=1, padding=0) # b,n*k*k,n_H*n_W
kx = kx.permute([0, 2, 1]) # b,n_H*n_W,n*k*k
# kx=kx.view(1, b*h*w, 9) #1,b*n_H*n_W,n*k*k
w1 = torch.tensor([-3, 0, 3, -10, 0, 10, -3, 0, 3]).float().cuda()
w2 = torch.tensor([-3, -10, -3, 0, 0, 0, 3, 10, 3]).float().cuda()
y1 = torch.matmul((kx * 255.0), w1) # 1,b*n_H*n_W,1
y2 = torch.matmul((kx * 255.0), w2) # 1,b*n_H*n_W,1
# y1=y1.view(b,h*w,1) #b,n_H*n_W,1
y1 = y1.unsqueeze(-1).permute([0, 2, 1]) # b,1,n_H*n_W
# y2=y2.view(b,h*w,1) #b,n_H*n_W,1
y2 = y2.unsqueeze(-1).permute([0, 2, 1]) # b,1,n_H*n_W
y1 = F.fold(y1, output_size=(h, w), kernel_size=1) # b,m,n_H,n_W
y2 = F.fold(y2, output_size=(h, w), kernel_size=1) # b,m,n_H,n_W
y1 = y1.clamp(-255, 255)
y2 = y2.clamp(-255, 255)
return (0.5 * torch.abs(y1) + 0.5 * torch.abs(y2)) / 255.0
def gram_matrix(input):
a, b, c, d = input.size() # a=batch size(=1)
# b=number of feature maps
# (c,d)=dimensions of a f. map (N=c*d)
features = input.reshape(a * b, c * d) # resize F_XL into \hat F_XL
G = torch.mm(features, features.t()) # compute the gram product
# we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(a * b * c * d)
class StyleLoss(nn.Module):
def __init__(self):
super(StyleLoss, self).__init__()
def forward(self, input_fea, target_fea):
target = gram_matrix(target_fea).detach()
G = gram_matrix(input_fea)
loss = F.mse_loss(G, target)
return loss
def cos_loss(feat1, feat2):
# maximize average cosine similarity
return -F.cosine_similarity(feat1, feat2).mean()
def feat_scharr(x):
x = torch.mean(x, dim=1, keepdim=True)
x = (x - x.min()) / (x.max() - x.min())
x = x * 255
return scharr(x)
def feat_ssim(feat1, feat2, gt):
mask = scharr(torch.mean(gt, dim=1, keepdim=True))
# mask = torch.nn.MaxPool2d(5, 1, 2)(mask)
mask = F.interpolate(mask, size=(feat1.shape[2], feat1.shape[3]), mode="bicubic")
loss = torch.abs(feat1 - feat2) * mask
return torch.mean(loss), mask
def similarity_loss(f_s, f_t):
def at(f):
return F.normalize(f.pow(2).mean(1).view(f.size(0), -1))
return (at(f_s) - at(f_t)).pow(2).mean()
class RBF(nn.Module):
def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None):
super().__init__()
self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2)
self.bandwidth = bandwidth
def get_bandwidth(self, L2_distances):
if self.bandwidth is None:
n_samples = L2_distances.shape[0]
return L2_distances.data.sum() / (n_samples ** 2 - n_samples)
return self.bandwidth
def forward(self, X):
L2_distances = torch.cdist(X, X) ** 2
return torch.exp(
-L2_distances[None, ...].cuda() / (self.get_bandwidth(L2_distances).cuda() * self.bandwidth_multipliers.cuda())[:, None,
None]).sum(dim=0)
class MMDLoss(nn.Module):
def __init__(self, kernel=RBF()):
super().__init__()
self.kernel = kernel.cuda()
def forward(self, X, Y):
K = self.kernel(torch.vstack([X, Y]))
X_size = X.shape[0]
XX = K[:X_size, :X_size].mean()
XY = K[:X_size, X_size:].mean()
YY = K[X_size:, X_size:].mean()
return XX - 2 * XY + YY