RDNet / util /index.py
lime-j's picture
Upload 89 files
347b44e
# Metrics/Indexes
try:
from skimage.measure import compare_ssim, compare_psnr
except:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from functools import partial
import numpy as np
class Bandwise(object):
def __init__(self, index_fn):
self.index_fn = index_fn
def __call__(self, X, Y):
C = X.shape[-1]
bwindex = []
for ch in range(C):
x = X[..., ch]
y = Y[..., ch]
index = self.index_fn(x, y)
bwindex.append(index)
return bwindex
cal_bwpsnr = Bandwise(partial(compare_psnr, data_range=255))
cal_bwssim = Bandwise(partial(compare_ssim, data_range=255))
def compare_ncc(x, y):
return np.mean((x - np.mean(x)) * (y - np.mean(y))) / (np.std(x) * np.std(y))
def ssq_error(correct, estimate):
"""Compute the sum-squared-error for an image, where the estimate is
multiplied by a scalar which minimizes the error. Sums over all pixels
where mask is True. If the inputs are color, each color channel can be
rescaled independently."""
assert correct.ndim == 2
if np.sum(estimate ** 2) > 1e-5:
alpha = np.sum(correct * estimate) / np.sum(estimate ** 2)
else:
alpha = 0.
return np.sum((correct - alpha * estimate) ** 2)
def local_error(correct, estimate, window_size, window_shift):
"""Returns the sum of the local sum-squared-errors, where the estimate may
be rescaled within each local region to minimize the error. The windows are
window_size x window_size, and they are spaced by window_shift."""
M, N, C = correct.shape
ssq = total = 0.
for c in range(C):
for i in range(0, M - window_size + 1, window_shift):
for j in range(0, N - window_size + 1, window_shift):
correct_curr = correct[i:i + window_size, j:j + window_size, c]
estimate_curr = estimate[i:i + window_size, j:j + window_size, c]
ssq += ssq_error(correct_curr, estimate_curr)
total += np.sum(correct_curr ** 2)
# assert np.isnan(ssq/total)
return ssq / total
def quality_assess(X, Y):
# Y: correct; X: estimate
psnr = np.mean(cal_bwpsnr(Y, X))
ssim = np.mean(cal_bwssim(Y, X))
lmse = local_error(Y, X, 20, 10)
ncc = compare_ncc(Y, X)
return {'PSNR': psnr, 'SSIM': ssim, 'LMSE': lmse, 'NCC': ncc}