"""Shared metric computation for NisabaRelief evaluation scripts.""" from concurrent.futures import ThreadPoolExecutor import numpy as np from image_similarity_measures.quality_metrics import ( rmse, psnr, sre, ) import torch from pytorch_msssim import ms_ssim as _pt_msssim from util.psnr_hvsm import psnr_hvsm DICE_THRESHOLD = 130 METRIC_NAMES = [ "dice", "rmse", "msssim", "psnr", "psnr_hvsm", "sre", ] LABELS = { "dice": "**Dice**", "rmse": "RMSE", "msssim": "MS-SSIM", "psnr": "PSNR", "psnr_hvsm": "PSNR-HVS-M", "sre": "SRE", } def _to_tensor(arr: np.ndarray) -> torch.Tensor: return torch.from_numpy(arr).float().unsqueeze(0).unsqueeze(0) def _msssim(gt: np.ndarray, pred: np.ndarray) -> float: return _pt_msssim( _to_tensor(gt), _to_tensor(pred), data_range=255, size_average=True ).item() def compute_metrics(pred: np.ndarray, gt: np.ndarray) -> dict[str, float]: """Compute all metrics for a pair of equal-shape grayscale uint8 images.""" pred_3d = pred[:, :, np.newaxis] gt_3d = gt[:, :, np.newaxis] pred_bin = pred > DICE_THRESHOLD gt_bin = gt > DICE_THRESHOLD denom = pred_bin.sum() + gt_bin.sum() dice = float(2 * np.logical_and(pred_bin, gt_bin).sum() / denom) if denom > 0 else 1.0 tasks = { "rmse": lambda: rmse(gt_3d, pred_3d, max_p=255), "psnr": lambda: psnr(gt_3d, pred_3d, max_p=255), "msssim": lambda: _msssim(gt, pred), "sre": lambda: sre(gt_3d, pred_3d), "psnr_hvsm": lambda: psnr_hvsm(gt, pred)[0], "dice": lambda: dice, } with ThreadPoolExecutor(max_workers=len(tasks)) as executor: futures = {name: executor.submit(fn) for name, fn in tasks.items()} return {name: future.result() for name, future in futures.items()}