fr-iqa / metrics /DeepWSD.py
rizavelioglu
fix metric directions and table
81a5614
import numpy as np
import os
import sys
import torch
from torchvision import models,transforms
import torch.nn as nn
import torch.nn.functional as F
import inspect
from ot.lp import wasserstein_1d
# Process input of VGG16 to make it close to 256
def downsample(img1, img2, maxSize = 256):
_,channels,H,W = img1.shape
f = int(max(1,np.round(max(H,W)/maxSize)))
aveKernel = (torch.ones(channels,1,f,f)/f**2).to(img1.device)
img1 = F.conv2d(img1, aveKernel, stride=f, padding = 0, groups = channels)
img2 = F.conv2d(img2, aveKernel, stride=f, padding = 0, groups = channels)
# For an extremely Large image, the larger window will use to increase the receptive field.
if f >= 5:
win = 16
else:
win = 4
return img1, img2, win, f
# Use L2pooling for VGG16 networks.
# Original Maxpooling will generate distortions in color channels during optimization.
class L2pooling(nn.Module):
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
super(L2pooling, self).__init__()
self.padding = (filter_size - 2 )//2
self.stride = stride
self.channels = channels
a = np.hanning(filter_size)[1:-1]
g = torch.Tensor(a[:,None]*a[None,:])
g = g/torch.sum(g)
self.register_buffer('filter', g[None,None,:,:].repeat((self.channels,1,1,1)))
def forward(self, input):
input = input**2
out = F.conv2d(input, self.filter, stride=self.stride, padding=self.padding, groups=input.shape[1])
return (out+1e-12).sqrt()
def ws_distance(X,Y,P=2,win=4):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
chn_num = X.shape[1]
X_sum = X.sum().sum()
Y_sum = Y.sum().sum()
X_patch = torch.reshape(X,[win,win,chn_num,-1])
Y_patch = torch.reshape(Y,[win,win,chn_num,-1])
patch_num = (X.shape[2]//win) * (X.shape[3]//win)
X_1D = torch.reshape(X_patch,[-1,chn_num*patch_num])
Y_1D = torch.reshape(Y_patch,[-1,chn_num*patch_num])
X_1D_pdf = X_1D / (X_sum + 1e-6)
Y_1D_pdf = Y_1D / (Y_sum + 1e-6)
interval = np.arange(0, X_1D.shape[0], 1)
all_samples = torch.from_numpy(interval).to(device).repeat([patch_num*chn_num,1]).t()
X_pdf = X_1D * X_1D_pdf
Y_pdf = Y_1D * Y_1D_pdf
wsd = wasserstein_1d(all_samples, all_samples, X_pdf, Y_pdf, P)
L2 = ((X_1D - Y_1D) ** 2).sum(dim=0)
w = (1 / ( torch.sqrt(torch.exp( (- 1/(wsd+10) ))) * (wsd+10)**2))
final = wsd + L2 * w
# final = wsd
return final.sum()
class DeepWSD(torch.nn.Module):
def __init__(self, channels=3, load_weights=True):
assert channels == 3
super(DeepWSD, self).__init__()
self.window = 4
vgg_pretrained_features = models.vgg16(pretrained=True).features
self.stage1 = torch.nn.Sequential()
self.stage2 = torch.nn.Sequential()
self.stage3 = torch.nn.Sequential()
self.stage4 = torch.nn.Sequential()
self.stage5 = torch.nn.Sequential()
# Rewrite the output layer of every block in the VGG network: maxpool->l2pool
for x in range(0, 4):
self.stage1.add_module(str(x), vgg_pretrained_features[x])
self.stage2.add_module(str(4), L2pooling(channels=64))
for x in range(5, 9):
self.stage2.add_module(str(x), vgg_pretrained_features[x])
self.stage3.add_module(str(9), L2pooling(channels=128))
for x in range(10, 16):
self.stage3.add_module(str(x), vgg_pretrained_features[x])
self.stage4.add_module(str(16), L2pooling(channels=256))
for x in range(17, 23):
self.stage4.add_module(str(x), vgg_pretrained_features[x])
self.stage5.add_module(str(23), L2pooling(channels=512))
for x in range(24, 30):
self.stage5.add_module(str(x), vgg_pretrained_features[x])
for param in self.parameters():
param.requires_grad = False
self.chns = [3, 64, 128, 256, 512, 512]
def forward_once(self, x):
h = x
h = self.stage1(h)
h_relu1_2 = h
h = self.stage2(h)
h_relu2_2 = h
h = self.stage3(h)
h_relu3_3 = h
h = self.stage4(h)
h_relu4_3 = h
h = self.stage5(h)
h_relu5_3 = h
return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
def forward(self, x, y, as_loss=False, resize=True):
assert x.shape == y.shape
if resize:
x, y, window, f = downsample(x, y)
if as_loss:
feats0 = self.forward_once(x)
feats1 = self.forward_once(y)
else:
with torch.no_grad():
feats0 = self.forward_once(x)
feats1 = self.forward_once(y)
score = 0
layer_score=[]
# To see score of each layer, use debugging mode of Pycharm.
for k in range(len(self.chns)):
row_padding = round(feats0[k].size(2) / window) * window - feats0[k].size(2)
column_padding = round(feats0[k].size(3) / window) * window - feats0[k].size(3)
pad = nn.ZeroPad2d((column_padding, 0, 0, row_padding))
feats0_k = pad(feats0[k])
feats1_k = pad(feats1[k])
tmp = ws_distance(feats0_k, feats1_k, win=window)
layer_score.append(torch.log(tmp + 1))
score = score + tmp
score = score / (k+1)
# For optimization, the logrithm will not use.
if as_loss:
return score
# We find use log**2 output will lead to higher PLCC results, thus we provide two output strategies
# They will only affect PLCC of Quality Assessment Results.
elif f==1:
return torch.log(score + 1)
else:
return torch.log(score + 1)**2