NisabaRelief / dev_scripts /util /metrics.py
boatbomber's picture
Initial release
3050f1b
"""Shared metric computation for NisabaRelief evaluation scripts."""
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from image_similarity_measures.quality_metrics import (
rmse,
psnr,
sre,
)
import torch
from pytorch_msssim import ms_ssim as _pt_msssim
from util.psnr_hvsm import psnr_hvsm
DICE_THRESHOLD = 130
METRIC_NAMES = [
"dice",
"rmse",
"msssim",
"psnr",
"psnr_hvsm",
"sre",
]
LABELS = {
"dice": "**Dice**",
"rmse": "RMSE",
"msssim": "MS-SSIM",
"psnr": "PSNR",
"psnr_hvsm": "PSNR-HVS-M",
"sre": "SRE",
}
def _to_tensor(arr: np.ndarray) -> torch.Tensor:
return torch.from_numpy(arr).float().unsqueeze(0).unsqueeze(0)
def _msssim(gt: np.ndarray, pred: np.ndarray) -> float:
return _pt_msssim(
_to_tensor(gt), _to_tensor(pred), data_range=255, size_average=True
).item()
def compute_metrics(pred: np.ndarray, gt: np.ndarray) -> dict[str, float]:
"""Compute all metrics for a pair of equal-shape grayscale uint8 images."""
pred_3d = pred[:, :, np.newaxis]
gt_3d = gt[:, :, np.newaxis]
pred_bin = pred > DICE_THRESHOLD
gt_bin = gt > DICE_THRESHOLD
denom = pred_bin.sum() + gt_bin.sum()
dice = float(2 * np.logical_and(pred_bin, gt_bin).sum() / denom) if denom > 0 else 1.0
tasks = {
"rmse": lambda: rmse(gt_3d, pred_3d, max_p=255),
"psnr": lambda: psnr(gt_3d, pred_3d, max_p=255),
"msssim": lambda: _msssim(gt, pred),
"sre": lambda: sre(gt_3d, pred_3d),
"psnr_hvsm": lambda: psnr_hvsm(gt, pred)[0],
"dice": lambda: dice,
}
with ThreadPoolExecutor(max_workers=len(tasks)) as executor:
futures = {name: executor.submit(fn) for name, fn in tasks.items()}
return {name: future.result() for name, future in futures.items()}