"""Unified segmentation metrics. Per-image (per-case), foreground-class macro-averaged unless noted: * Dice (DSC) overlap (headline) * IoU (Jaccard) overlap * HD95 95th-percentile Hausdorff distance (boundary) * ASSD average symmetric surface distance (boundary) * Sensitivity / Recall TP/(TP+FN) * Specificity TN/(TN+FP) (one-vs-rest, pixel-level) * Precision TP/(TP+FP) Convention: masks are integer maps 0..C-1 (0 = background); binary == 2 classes. Per-class values are also recorded (per_class[c]) for per-class paper tables. Surface metrics use MONAI if available, else medpy, else NaN. Aggregation: per-image -> mean±SD over the test set; report/aggregate.py then does mean±SD over seeds. """ from __future__ import annotations from typing import Dict, List import warnings import numpy as np _SURF_BACKEND = None _OVERLAP_KEYS = ("dice", "iou", "sensitivity", "specificity", "precision") _SURFACE_KEYS = ("hd95", "assd") SCALAR_KEYS = _OVERLAP_KEYS + _SURFACE_KEYS def _select_surface_backend(): global _SURF_BACKEND if _SURF_BACKEND is not None: return _SURF_BACKEND try: from monai.metrics import compute_hausdorff_distance, compute_average_surface_distance # noqa _SURF_BACKEND = "monai" except Exception: try: from medpy.metric.binary import hd95, assd # noqa _SURF_BACKEND = "medpy" except Exception: _SURF_BACKEND = "none" warnings.warn("Neither MONAI nor medpy available -> HD95/ASSD will be NaN.") return _SURF_BACKEND def _surface_binary(pred: np.ndarray, gt: np.ndarray) -> Dict[str, float]: """HD95 and ASSD for one binary 2D mask pair. Handles empty masks.""" backend = _select_surface_backend() p_any, g_any = pred.any(), gt.any() if not p_any and not g_any: return {"hd95": 0.0, "assd": 0.0} if not p_any or not g_any: return {"hd95": float("nan"), "assd": float("nan")} if backend == "monai": import torch from monai.metrics import compute_hausdorff_distance, compute_average_surface_distance p = torch.from_numpy(pred[None, None].astype(np.uint8)) g = torch.from_numpy(gt[None, None].astype(np.uint8)) hd = compute_hausdorff_distance(p, g, percentile=95).item() asd = compute_average_surface_distance(p, g, symmetric=True).item() return {"hd95": float(hd), "assd": float(asd)} if backend == "medpy": from medpy.metric.binary import hd95 as _hd95, assd as _assd pb, gb = pred.astype(bool), gt.astype(bool) return {"hd95": float(_hd95(pb, gb)), "assd": float(_assd(pb, gb))} return {"hd95": float("nan"), "assd": float("nan")} def _nanmean(xs): xs = [x for x in xs if not (isinstance(x, float) and np.isnan(x))] return float(np.mean(xs)) if xs else float("nan") def per_image_metrics(pred: np.ndarray, target: np.ndarray, num_classes: int, include_background: bool = False, compute_hd95: bool = True) -> Dict[str, object]: """Metrics for a single image (2D int class maps). Returns foreground-macro scalars plus a per_class breakdown. Classes absent in BOTH pred and gt are skipped so they don't dilute the average.""" start = 0 if include_background else 1 acc = {k: [] for k in SCALAR_KEYS} per_class: Dict[str, Dict[str, float]] = {} for c in range(start, num_classes): p = pred == c g = target == c if not p.any() and not g.any(): continue tp = float(np.logical_and(p, g).sum()) fp = float(np.logical_and(p, ~g).sum()) fn = float(np.logical_and(~p, g).sum()) tn = float(np.logical_and(~p, ~g).sum()) cls = { "dice": (2 * tp) / (2 * tp + fp + fn + 1e-8), "iou": tp / (tp + fp + fn + 1e-8), "sensitivity": tp / (tp + fn + 1e-8), "specificity": tn / (tn + fp + 1e-8), "precision": tp / (tp + fp + 1e-8), } if compute_hd95: cls.update(_surface_binary(p, g)) else: cls.update({"hd95": float("nan"), "assd": float("nan")}) per_class[str(c)] = cls for k in SCALAR_KEYS: acc[k].append(cls[k]) out: Dict[str, object] = {} for k in _OVERLAP_KEYS: out[k] = float(np.mean(acc[k])) if acc[k] else float("nan") for k in _SURFACE_KEYS: out[k] = _nanmean(acc[k]) if acc[k] else float("nan") out["per_class"] = per_class return out def aggregate(records: List[Dict[str, object]]) -> Dict[str, float]: """Aggregate per-image metric dicts into mean/std over the set (per metric).""" out: Dict[str, float] = {} for key in SCALAR_KEYS: vals = np.array([r[key] for r in records], dtype=np.float64) vals = vals[~np.isnan(vals)] out[f"{key}_mean"] = float(vals.mean()) if vals.size else float("nan") out[f"{key}_std"] = float(vals.std()) if vals.size else float("nan") out["n_images"] = float(len(records)) return out