| """Unified segmentation metrics. |
| |
| Per-image (per-case), foreground-class macro-averaged unless noted: |
| * Dice (DSC) overlap (headline) |
| * IoU (Jaccard) overlap |
| * HD95 95th-percentile Hausdorff distance (boundary) |
| * ASSD average symmetric surface distance (boundary) |
| * Sensitivity / Recall TP/(TP+FN) |
| * Specificity TN/(TN+FP) (one-vs-rest, pixel-level) |
| * Precision TP/(TP+FP) |
| |
| Convention: masks are integer maps 0..C-1 (0 = background); binary == 2 classes. |
| Per-class values are also recorded (per_class[c]) for per-class paper tables. |
| Surface metrics use MONAI if available, else medpy, else NaN. |
| Aggregation: per-image -> mean±SD over the test set; report/aggregate.py then |
| does mean±SD over seeds. |
| """ |
| from __future__ import annotations |
|
|
| from typing import Dict, List |
| import warnings |
|
|
| import numpy as np |
|
|
| _SURF_BACKEND = None |
| _OVERLAP_KEYS = ("dice", "iou", "sensitivity", "specificity", "precision") |
| _SURFACE_KEYS = ("hd95", "assd") |
| SCALAR_KEYS = _OVERLAP_KEYS + _SURFACE_KEYS |
|
|
|
|
| def _select_surface_backend(): |
| global _SURF_BACKEND |
| if _SURF_BACKEND is not None: |
| return _SURF_BACKEND |
| try: |
| from monai.metrics import compute_hausdorff_distance, compute_average_surface_distance |
| _SURF_BACKEND = "monai" |
| except Exception: |
| try: |
| from medpy.metric.binary import hd95, assd |
| _SURF_BACKEND = "medpy" |
| except Exception: |
| _SURF_BACKEND = "none" |
| warnings.warn("Neither MONAI nor medpy available -> HD95/ASSD will be NaN.") |
| return _SURF_BACKEND |
|
|
|
|
| def _surface_binary(pred: np.ndarray, gt: np.ndarray) -> Dict[str, float]: |
| """HD95 and ASSD for one binary 2D mask pair. Handles empty masks.""" |
| backend = _select_surface_backend() |
| p_any, g_any = pred.any(), gt.any() |
| if not p_any and not g_any: |
| return {"hd95": 0.0, "assd": 0.0} |
| if not p_any or not g_any: |
| return {"hd95": float("nan"), "assd": float("nan")} |
| if backend == "monai": |
| import torch |
| from monai.metrics import compute_hausdorff_distance, compute_average_surface_distance |
| p = torch.from_numpy(pred[None, None].astype(np.uint8)) |
| g = torch.from_numpy(gt[None, None].astype(np.uint8)) |
| hd = compute_hausdorff_distance(p, g, percentile=95).item() |
| asd = compute_average_surface_distance(p, g, symmetric=True).item() |
| return {"hd95": float(hd), "assd": float(asd)} |
| if backend == "medpy": |
| from medpy.metric.binary import hd95 as _hd95, assd as _assd |
| pb, gb = pred.astype(bool), gt.astype(bool) |
| return {"hd95": float(_hd95(pb, gb)), "assd": float(_assd(pb, gb))} |
| return {"hd95": float("nan"), "assd": float("nan")} |
|
|
|
|
| def _nanmean(xs): |
| xs = [x for x in xs if not (isinstance(x, float) and np.isnan(x))] |
| return float(np.mean(xs)) if xs else float("nan") |
|
|
|
|
| def per_image_metrics(pred: np.ndarray, target: np.ndarray, num_classes: int, |
| include_background: bool = False, |
| compute_hd95: bool = True) -> Dict[str, object]: |
| """Metrics for a single image (2D int class maps). Returns foreground-macro |
| scalars plus a per_class breakdown. Classes absent in BOTH pred and gt are |
| skipped so they don't dilute the average.""" |
| start = 0 if include_background else 1 |
| acc = {k: [] for k in SCALAR_KEYS} |
| per_class: Dict[str, Dict[str, float]] = {} |
|
|
| for c in range(start, num_classes): |
| p = pred == c |
| g = target == c |
| if not p.any() and not g.any(): |
| continue |
| tp = float(np.logical_and(p, g).sum()) |
| fp = float(np.logical_and(p, ~g).sum()) |
| fn = float(np.logical_and(~p, g).sum()) |
| tn = float(np.logical_and(~p, ~g).sum()) |
| cls = { |
| "dice": (2 * tp) / (2 * tp + fp + fn + 1e-8), |
| "iou": tp / (tp + fp + fn + 1e-8), |
| "sensitivity": tp / (tp + fn + 1e-8), |
| "specificity": tn / (tn + fp + 1e-8), |
| "precision": tp / (tp + fp + 1e-8), |
| } |
| if compute_hd95: |
| cls.update(_surface_binary(p, g)) |
| else: |
| cls.update({"hd95": float("nan"), "assd": float("nan")}) |
| per_class[str(c)] = cls |
| for k in SCALAR_KEYS: |
| acc[k].append(cls[k]) |
|
|
| out: Dict[str, object] = {} |
| for k in _OVERLAP_KEYS: |
| out[k] = float(np.mean(acc[k])) if acc[k] else float("nan") |
| for k in _SURFACE_KEYS: |
| out[k] = _nanmean(acc[k]) if acc[k] else float("nan") |
| out["per_class"] = per_class |
| return out |
|
|
|
|
| def aggregate(records: List[Dict[str, object]]) -> Dict[str, float]: |
| """Aggregate per-image metric dicts into mean/std over the set (per metric).""" |
| out: Dict[str, float] = {} |
| for key in SCALAR_KEYS: |
| vals = np.array([r[key] for r in records], dtype=np.float64) |
| vals = vals[~np.isnan(vals)] |
| out[f"{key}_mean"] = float(vals.mean()) if vals.size else float("nan") |
| out[f"{key}_std"] = float(vals.std()) if vals.size else float("nan") |
| out["n_images"] = float(len(records)) |
| return out |
|
|