File size: 1,841 Bytes
3050f1b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | """Shared metric computation for NisabaRelief evaluation scripts."""
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from image_similarity_measures.quality_metrics import (
rmse,
psnr,
sre,
)
import torch
from pytorch_msssim import ms_ssim as _pt_msssim
from util.psnr_hvsm import psnr_hvsm
DICE_THRESHOLD = 130
METRIC_NAMES = [
"dice",
"rmse",
"msssim",
"psnr",
"psnr_hvsm",
"sre",
]
LABELS = {
"dice": "**Dice**",
"rmse": "RMSE",
"msssim": "MS-SSIM",
"psnr": "PSNR",
"psnr_hvsm": "PSNR-HVS-M",
"sre": "SRE",
}
def _to_tensor(arr: np.ndarray) -> torch.Tensor:
return torch.from_numpy(arr).float().unsqueeze(0).unsqueeze(0)
def _msssim(gt: np.ndarray, pred: np.ndarray) -> float:
return _pt_msssim(
_to_tensor(gt), _to_tensor(pred), data_range=255, size_average=True
).item()
def compute_metrics(pred: np.ndarray, gt: np.ndarray) -> dict[str, float]:
"""Compute all metrics for a pair of equal-shape grayscale uint8 images."""
pred_3d = pred[:, :, np.newaxis]
gt_3d = gt[:, :, np.newaxis]
pred_bin = pred > DICE_THRESHOLD
gt_bin = gt > DICE_THRESHOLD
denom = pred_bin.sum() + gt_bin.sum()
dice = float(2 * np.logical_and(pred_bin, gt_bin).sum() / denom) if denom > 0 else 1.0
tasks = {
"rmse": lambda: rmse(gt_3d, pred_3d, max_p=255),
"psnr": lambda: psnr(gt_3d, pred_3d, max_p=255),
"msssim": lambda: _msssim(gt, pred),
"sre": lambda: sre(gt_3d, pred_3d),
"psnr_hvsm": lambda: psnr_hvsm(gt, pred)[0],
"dice": lambda: dice,
}
with ThreadPoolExecutor(max_workers=len(tasks)) as executor:
futures = {name: executor.submit(fn) for name, fn in tasks.items()}
return {name: future.result() for name, future in futures.items()}
|