|
|
|
|
|
try: |
|
|
from skimage.measure import compare_ssim, compare_psnr |
|
|
except: |
|
|
from skimage.metrics import peak_signal_noise_ratio as compare_psnr |
|
|
from skimage.metrics import structural_similarity as compare_ssim |
|
|
|
|
|
from functools import partial |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class Bandwise(object): |
|
|
def __init__(self, index_fn): |
|
|
self.index_fn = index_fn |
|
|
|
|
|
def __call__(self, X, Y): |
|
|
C = X.shape[-1] |
|
|
bwindex = [] |
|
|
for ch in range(C): |
|
|
x = X[..., ch] |
|
|
y = Y[..., ch] |
|
|
index = self.index_fn(x, y) |
|
|
bwindex.append(index) |
|
|
return bwindex |
|
|
|
|
|
|
|
|
cal_bwpsnr = Bandwise(partial(compare_psnr, data_range=255)) |
|
|
cal_bwssim = Bandwise(partial(compare_ssim, data_range=255)) |
|
|
|
|
|
|
|
|
def compare_ncc(x, y): |
|
|
return np.mean((x - np.mean(x)) * (y - np.mean(y))) / (np.std(x) * np.std(y)) |
|
|
|
|
|
|
|
|
def ssq_error(correct, estimate): |
|
|
"""Compute the sum-squared-error for an image, where the estimate is |
|
|
multiplied by a scalar which minimizes the error. Sums over all pixels |
|
|
where mask is True. If the inputs are color, each color channel can be |
|
|
rescaled independently.""" |
|
|
assert correct.ndim == 2 |
|
|
if np.sum(estimate ** 2) > 1e-5: |
|
|
alpha = np.sum(correct * estimate) / np.sum(estimate ** 2) |
|
|
else: |
|
|
alpha = 0. |
|
|
return np.sum((correct - alpha * estimate) ** 2) |
|
|
|
|
|
|
|
|
def local_error(correct, estimate, window_size, window_shift): |
|
|
"""Returns the sum of the local sum-squared-errors, where the estimate may |
|
|
be rescaled within each local region to minimize the error. The windows are |
|
|
window_size x window_size, and they are spaced by window_shift.""" |
|
|
M, N, C = correct.shape |
|
|
ssq = total = 0. |
|
|
for c in range(C): |
|
|
for i in range(0, M - window_size + 1, window_shift): |
|
|
for j in range(0, N - window_size + 1, window_shift): |
|
|
correct_curr = correct[i:i + window_size, j:j + window_size, c] |
|
|
estimate_curr = estimate[i:i + window_size, j:j + window_size, c] |
|
|
ssq += ssq_error(correct_curr, estimate_curr) |
|
|
total += np.sum(correct_curr ** 2) |
|
|
|
|
|
return ssq / total |
|
|
|
|
|
|
|
|
def quality_assess(X, Y): |
|
|
|
|
|
psnr = np.mean(cal_bwpsnr(Y, X)) |
|
|
ssim = np.mean(cal_bwssim(Y, X)) |
|
|
lmse = local_error(Y, X, 20, 10) |
|
|
ncc = compare_ncc(Y, X) |
|
|
return {'PSNR': psnr, 'SSIM': ssim, 'LMSE': lmse, 'NCC': ncc} |
|
|
|