| """Shared metric computation for NisabaRelief evaluation scripts.""" |
|
|
| from concurrent.futures import ThreadPoolExecutor |
|
|
| import numpy as np |
| from image_similarity_measures.quality_metrics import ( |
| rmse, |
| psnr, |
| sre, |
| ) |
| import torch |
| from pytorch_msssim import ms_ssim as _pt_msssim |
| from util.psnr_hvsm import psnr_hvsm |
|
|
| DICE_THRESHOLD = 130 |
|
|
| METRIC_NAMES = [ |
| "dice", |
| "rmse", |
| "msssim", |
| "psnr", |
| "psnr_hvsm", |
| "sre", |
| ] |
|
|
| LABELS = { |
| "dice": "**Dice**", |
| "rmse": "RMSE", |
| "msssim": "MS-SSIM", |
| "psnr": "PSNR", |
| "psnr_hvsm": "PSNR-HVS-M", |
| "sre": "SRE", |
| } |
|
|
|
|
| def _to_tensor(arr: np.ndarray) -> torch.Tensor: |
| return torch.from_numpy(arr).float().unsqueeze(0).unsqueeze(0) |
|
|
|
|
| def _msssim(gt: np.ndarray, pred: np.ndarray) -> float: |
| return _pt_msssim( |
| _to_tensor(gt), _to_tensor(pred), data_range=255, size_average=True |
| ).item() |
|
|
|
|
| def compute_metrics(pred: np.ndarray, gt: np.ndarray) -> dict[str, float]: |
| """Compute all metrics for a pair of equal-shape grayscale uint8 images.""" |
| pred_3d = pred[:, :, np.newaxis] |
| gt_3d = gt[:, :, np.newaxis] |
|
|
| pred_bin = pred > DICE_THRESHOLD |
| gt_bin = gt > DICE_THRESHOLD |
| denom = pred_bin.sum() + gt_bin.sum() |
| dice = float(2 * np.logical_and(pred_bin, gt_bin).sum() / denom) if denom > 0 else 1.0 |
|
|
| tasks = { |
| "rmse": lambda: rmse(gt_3d, pred_3d, max_p=255), |
| "psnr": lambda: psnr(gt_3d, pred_3d, max_p=255), |
| "msssim": lambda: _msssim(gt, pred), |
| "sre": lambda: sre(gt_3d, pred_3d), |
| "psnr_hvsm": lambda: psnr_hvsm(gt, pred)[0], |
| "dice": lambda: dice, |
| } |
|
|
| with ThreadPoolExecutor(max_workers=len(tasks)) as executor: |
| futures = {name: executor.submit(fn) for name, fn in tasks.items()} |
| return {name: future.result() for name, future in futures.items()} |
|
|