FelixzeroSun's picture
Upload folder using huggingface_hub
19c1f58 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import SimpleITK
import numpy as np
from typing import Optional
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage.util.arraycrop import crop
from scipy.signal import fftconvolve
from scipy.ndimage import uniform_filter
import torch
class ImageMetrics():
def __init__(self, debug=False):
# Use fixed wide dynamic range
self.dynamic_range = [-1024., 3000.]
self.debug = debug
def score_patient(self, gt_img, synthetic_ct, mask):
assert gt_img.shape == synthetic_ct.shape
if mask is not None:
assert mask.shape == synthetic_ct.shape
# perform masking on the images
ground_truth = gt_img if mask is None else np.where(mask == 0, -1024, gt_img)
prediction = synthetic_ct if mask is None else np.where(mask == 0, -1024, synthetic_ct)
# Compute image similarity within the mask
mae_value = self.mae(ground_truth,
prediction,
mask)
psnr_value = self.psnr(ground_truth,
prediction,
mask,
use_population_range=True)
ms_ssim_value, ms_ssim_mask_value = self.ms_ssim(ground_truth,
prediction,
mask)
return {
'mae': mae_value,
'psnr': psnr_value,
'ms_ssim': ms_ssim_mask_value,
}
def mae(self,
gt: np.ndarray,
pred: np.ndarray,
mask: Optional[np.ndarray] = None) -> float:
"""
Compute Mean Absolute Error (MAE)
Parameters
----------
gt : np.ndarray
Ground truth
pred : np.ndarray
Prediction
mask : np.ndarray, optional
Mask for voxels to include. The default is None (including all voxels).
Returns
-------
mae : float
mean absolute error.
"""
if mask is None:
mask = np.ones(gt.shape)
else:
#binarize mask
mask = np.where(mask>0, 1., 0.)
mae_value = np.sum(np.abs(gt*mask - pred*mask))/mask.sum()
return float(mae_value)
def psnr(self,
gt: np.ndarray,
pred: np.ndarray,
mask: Optional[np.ndarray] = None,
use_population_range: Optional[bool] = False) -> float:
"""
Compute Peak Signal to Noise Ratio metric (PSNR)
Parameters
----------
gt : np.ndarray
Ground truth
pred : np.ndarray
Prediction
mask : np.ndarray, optional
Mask for voxels to include. The default is None (including all voxels).
use_population_range : bool, optional
When a predefined population wide dynamic range should be used.
gt and pred will also be clipped to these values.
Returns
-------
psnr : float
Peak signal to noise ratio..
"""
if mask is None:
mask = np.ones(gt.shape)
else:
#binarize mask
mask = np.where(mask>0, 1., 0.)
if use_population_range:
# Clip gt and pred to the dynamic range
gt = np.clip(gt, a_min=self.dynamic_range[0], a_max=self.dynamic_range[1])
pred = np.clip(pred, a_min=self.dynamic_range[0], a_max=self.dynamic_range[1])
dynamic_range = self.dynamic_range[1] - self.dynamic_range[0]
else:
dynamic_range = gt.max()-gt.min()
pred = np.clip(pred, a_min=gt.min(), a_max=gt.max())
# apply mask
gt = gt[mask==1]
pred = pred[mask==1]
psnr_value = peak_signal_noise_ratio(gt, pred, data_range=dynamic_range)
return float(psnr_value)
# Compute the luminance, contrast and structure components of the SSIM between two images
def structural_similarity_at_scale(self, im1,
im2,
*,
luminance_weight=1,
contrast_weight=1,
structure_weight=1,
win_size=None,
gradient=False,
data_range=None,
channel_axis=None,
gaussian_weights=False,
full=False,
**kwargs,):
K1 = kwargs.pop('K1', 0.01)
K2 = kwargs.pop('K2', 0.03)
sigma = kwargs.pop('sigma', 1.5)
if K1 < 0:
raise ValueError("K1 must be positive")
if K2 < 0:
raise ValueError("K2 must be positive")
if sigma < 0:
raise ValueError("sigma must be positive")
use_sample_covariance = kwargs.pop('use_sample_covariance', True)
if gaussian_weights:
# Set to give an 11-tap filter with the default sigma of 1.5 to match
# Wang et. al. 2004.
truncate = 3.5
if win_size is None:
if gaussian_weights:
# set win_size used by crop to match the filter size
r = int(truncate * sigma + 0.5) # radius as in ndimage
win_size = 2 * r + 1
else:
win_size = 7 # backwards compatibility
if gaussian_weights:
filter_func = gaussian
filter_args = {'sigma': sigma, 'truncate': truncate, 'mode': 'reflect'}
else:
filter_func = uniform_filter
filter_args = {'size': win_size}
ndim = im1.ndim
NP = win_size**ndim
# filter has already normalized by NP
if use_sample_covariance:
cov_norm = NP / (NP - 1) # sample covariance
else:
cov_norm = 1.0 # population covariance to match Wang et. al. 2004
# compute (weighted) means
ux = filter_func(im1, **filter_args)
uy = filter_func(im2, **filter_args)
# compute (weighted) variances and covariances
uxx = filter_func(im1 * im1, **filter_args)
uyy = filter_func(im2 * im2, **filter_args)
uxy = filter_func(im1 * im2, **filter_args)
vx = cov_norm * (uxx - ux * ux)
vxsqrt = np.clip(vx, a_min=0, a_max=None) ** 0.5 # TODO: this is very ugly
vy = cov_norm * (uyy - uy * uy)
vysqrt = np.clip(vy, a_min=0, a_max=None) ** 0.5 # TODO: this is very ugly
vxy = cov_norm * (uxy - ux * uy)
R = data_range
C1 = (K1 * R) ** 2
C2 = (K2 * R) ** 2
C3 = C2 / 2
L = np.clip((2 * ux * uy + C1) / (ux * ux + uy * uy + C1), a_min=0, a_max=None) # TODO is this clipping necessary or do we increase K1 and K2?
C = np.clip((2 * vxsqrt * vysqrt + C2) / (vx + vy + C2), a_min=0, a_max=None)
S = np.clip((vxy + C3) / (vxsqrt * vysqrt + C3), a_min=0, a_max=None)
result = (L ** luminance_weight) * (C ** contrast_weight) * (S ** structure_weight)
# to avoid edge effects will ignore filter radius strip around edges
pad = (win_size - 1) // 2
# compute (weighted) mean of ssim. Use float64 for accuracy.
mssim = crop(result, pad).mean(dtype=np.float64)
if full:
return mssim, result
return mssim
# Compute the masked MS-SSIM by masking the SSIM at every resolution level
def ms_ssim(self, gt: np.ndarray, pred: np.ndarray, mask: Optional[np.ndarray] = None, scale_weights: Optional[np.ndarray] = None) -> float:
# Clip gt and pred to the dynamic range
gt = np.clip(gt, min(self.dynamic_range), max(self.dynamic_range))
pred = np.clip(pred, min(self.dynamic_range), max(self.dynamic_range))
if mask is not None:
#binarize mask
mask = np.where(mask>0, 1., 0.)
# Mask gt and pred
gt = np.where(mask==0, min(self.dynamic_range), gt)
pred = np.where(mask==0, min(self.dynamic_range), pred)
# Make values non-negative
if min(self.dynamic_range) < 0:
gt = gt - min(self.dynamic_range)
pred = pred - min(self.dynamic_range)
# Set dynamic range for ssim calculation and calculate ssim_map per pixel
dynamic_range = self.dynamic_range[1] - self.dynamic_range[0]
# see Eq. 7 in https://live.ece.utexas.edu/publications/2003/zw_asil2003_msssim.pdf
# Also, the final sentence of section 3.2 (Results)
scale_weights = np.array([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) if scale_weights is None else scale_weights
luminance_weights = np.array([0, 0, 0, 0, 0, 0.1333]) if scale_weights is None else scale_weights
levels = len(scale_weights)
downsample_filter = np.ones((2, 2, 2)) / 8
gtx, gty, gtz = gt.shape
# Due to the downsampling in the MS-SSIM, the minimum matrix size must be 97 in every dimension
target_size = 97
pad_values = [
(np.clip((target_size - dim)//2, a_min=0, a_max=None),
np.clip(target_size - dim - (target_size - dim)//2, a_min=0, a_max=None))
for dim in [gtx, gty, gtz]]
gt = np.pad(gt, pad_values, mode='edge')
pred = np.pad(pred, pad_values, mode='edge')
mask = np.pad(mask, pad_values, mode='edge')
min_size = (downsample_filter.shape[-1] - 1) * 2 ** (levels - 1) + 1
ms_ssim_vals, ms_ssim_maps = [], []
for level in range(levels):
ssim_value_full, ssim_map = self.structural_similarity_at_scale(gt, pred,
luminance_weight=luminance_weights[level],
contrast_weight=scale_weights[level],
structure_weight=scale_weights[level],
data_range=dynamic_range, full=True)
pad = 3
# at every level, we get the ssim_value_full, which is just mean SSIM at a level L, and the
# SSIM map. The masked SSIM is the mean SSIM within this mask
ssim_value_masked = (crop(ssim_map, pad)[crop(mask, pad).astype(bool)]).mean(dtype=np.float64)
ms_ssim_vals.append(ssim_value_full)
ms_ssim_maps.append(ssim_value_masked)
# The images are cleverly downsampled using an uniform filter
# the mask is just downsampled by selecting every other line in every dimension
filtered = [fftconvolve(im, downsample_filter, mode='same')
for im in [gt, pred]]
gt, pred, mask = [x[::2, ::2, ::2] for x in [*filtered, mask]]
ms_ssim_val = np.prod([np.clip(x, a_min=0, a_max=1) for x in ms_ssim_vals])
ms_ssim_mask_val = np.prod([np.clip(x, a_min=0, a_max=1) for x in ms_ssim_maps])
return float(ms_ssim_val), float(ms_ssim_mask_val)
# compute image metrics for the predition folders
class ImageMetricsCompute(ImageMetrics):
def __init__(self):
super().__init__()
self.names = ["mae", "psnr", "ms_ssim"]
def init_storage(self, names: list):
self.storage = dict()
self.storage_id = []
self.names = names
for name in names:
self.storage[name] = []
def add(self, res: dict, patient_id=None):
for key, value in res.items():
self.storage[key].append(value)
if patient_id:
self.storage_id.append(patient_id)
def aggregate(self):
res = dict()
for name in self.names:
res[name] = dict()
for key, value in self.storage.items():
res[key]['mean'] = np.mean(value)
res[key]['std'] = np.std(value)
res[key]['max'] = np.max(value)
res[key]['min'] = np.min(value)
res[key]['25pc'] = np.percentile(value, 25)
res[key]['50pc'] = np.percentile(value, 50)
res[key]['75pc'] = np.percentile(value, 75)
res[key]['count'] = len(value)
return res
def reset(self):
for key, value in self.storage.items():
self.storage[key] = []
def score_array(self, gt_array, pred_array, mask_array=None):
if torch.is_tensor(gt_array):
gt_array = gt_array.cpu().numpy().squeeze()
if torch.is_tensor(pred_array):
pred_array = pred_array.cpu().numpy().squeeze()
if torch.is_tensor(mask_array):
mask_array = mask_array.cpu().numpy().squeeze()
# Calculate image metrics
res = dict()
if self.names and 'mae' in self.names:
mae_value = self.mae(gt_array,
pred_array,
mask_array)
res['mae'] = mae_value
if self.names and 'psnr' in self.names:
psnr_value = self.psnr(gt_array,
pred_array,
mask_array, use_population_range=True)
res['psnr'] = psnr_value
if self.names and 'ms_ssim' in self.names:
ms_ssim_value, ms_ssim_mask_value = self.ms_ssim(gt_array,
pred_array,
mask_array)
res['ms_ssim'] = ms_ssim_mask_value
return res
if __name__=='__main__':
metrics = ImageMetrics()
ground_truth_path = "path/to/ground_truth.mha"
predicted_path = "path/to/prediction.mha"
mask_path = "path/to/mask.mha"
print(metrics.score_patient(ground_truth_path, predicted_path, mask_path))