File size: 1,841 Bytes
3050f1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Shared metric computation for NisabaRelief evaluation scripts."""

from concurrent.futures import ThreadPoolExecutor

import numpy as np
from image_similarity_measures.quality_metrics import (
    rmse,
    psnr,
    sre,
)
import torch
from pytorch_msssim import ms_ssim as _pt_msssim
from util.psnr_hvsm import psnr_hvsm

DICE_THRESHOLD = 130

METRIC_NAMES = [
    "dice",
    "rmse",
    "msssim",
    "psnr",
    "psnr_hvsm",
    "sre",
]

LABELS = {
    "dice": "**Dice**",
    "rmse": "RMSE",
    "msssim": "MS-SSIM",
    "psnr": "PSNR",
    "psnr_hvsm": "PSNR-HVS-M",
    "sre": "SRE",
}


def _to_tensor(arr: np.ndarray) -> torch.Tensor:
    return torch.from_numpy(arr).float().unsqueeze(0).unsqueeze(0)


def _msssim(gt: np.ndarray, pred: np.ndarray) -> float:
    return _pt_msssim(
        _to_tensor(gt), _to_tensor(pred), data_range=255, size_average=True
    ).item()


def compute_metrics(pred: np.ndarray, gt: np.ndarray) -> dict[str, float]:
    """Compute all metrics for a pair of equal-shape grayscale uint8 images."""
    pred_3d = pred[:, :, np.newaxis]
    gt_3d = gt[:, :, np.newaxis]

    pred_bin = pred > DICE_THRESHOLD
    gt_bin = gt > DICE_THRESHOLD
    denom = pred_bin.sum() + gt_bin.sum()
    dice = float(2 * np.logical_and(pred_bin, gt_bin).sum() / denom) if denom > 0 else 1.0

    tasks = {
        "rmse": lambda: rmse(gt_3d, pred_3d, max_p=255),
        "psnr": lambda: psnr(gt_3d, pred_3d, max_p=255),
        "msssim": lambda: _msssim(gt, pred),
        "sre": lambda: sre(gt_3d, pred_3d),
        "psnr_hvsm": lambda: psnr_hvsm(gt, pred)[0],
        "dice": lambda: dice,
    }

    with ThreadPoolExecutor(max_workers=len(tasks)) as executor:
        futures = {name: executor.submit(fn) for name, fn in tasks.items()}
        return {name: future.result() for name, future in futures.items()}