MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
5.15 kB
"""Unified segmentation metrics.
Per-image (per-case), foreground-class macro-averaged unless noted:
* Dice (DSC) overlap (headline)
* IoU (Jaccard) overlap
* HD95 95th-percentile Hausdorff distance (boundary)
* ASSD average symmetric surface distance (boundary)
* Sensitivity / Recall TP/(TP+FN)
* Specificity TN/(TN+FP) (one-vs-rest, pixel-level)
* Precision TP/(TP+FP)
Convention: masks are integer maps 0..C-1 (0 = background); binary == 2 classes.
Per-class values are also recorded (per_class[c]) for per-class paper tables.
Surface metrics use MONAI if available, else medpy, else NaN.
Aggregation: per-image -> mean±SD over the test set; report/aggregate.py then
does mean±SD over seeds.
"""
from __future__ import annotations
from typing import Dict, List
import warnings
import numpy as np
_SURF_BACKEND = None
_OVERLAP_KEYS = ("dice", "iou", "sensitivity", "specificity", "precision")
_SURFACE_KEYS = ("hd95", "assd")
SCALAR_KEYS = _OVERLAP_KEYS + _SURFACE_KEYS
def _select_surface_backend():
global _SURF_BACKEND
if _SURF_BACKEND is not None:
return _SURF_BACKEND
try:
from monai.metrics import compute_hausdorff_distance, compute_average_surface_distance # noqa
_SURF_BACKEND = "monai"
except Exception:
try:
from medpy.metric.binary import hd95, assd # noqa
_SURF_BACKEND = "medpy"
except Exception:
_SURF_BACKEND = "none"
warnings.warn("Neither MONAI nor medpy available -> HD95/ASSD will be NaN.")
return _SURF_BACKEND
def _surface_binary(pred: np.ndarray, gt: np.ndarray) -> Dict[str, float]:
"""HD95 and ASSD for one binary 2D mask pair. Handles empty masks."""
backend = _select_surface_backend()
p_any, g_any = pred.any(), gt.any()
if not p_any and not g_any:
return {"hd95": 0.0, "assd": 0.0}
if not p_any or not g_any:
return {"hd95": float("nan"), "assd": float("nan")}
if backend == "monai":
import torch
from monai.metrics import compute_hausdorff_distance, compute_average_surface_distance
p = torch.from_numpy(pred[None, None].astype(np.uint8))
g = torch.from_numpy(gt[None, None].astype(np.uint8))
hd = compute_hausdorff_distance(p, g, percentile=95).item()
asd = compute_average_surface_distance(p, g, symmetric=True).item()
return {"hd95": float(hd), "assd": float(asd)}
if backend == "medpy":
from medpy.metric.binary import hd95 as _hd95, assd as _assd
pb, gb = pred.astype(bool), gt.astype(bool)
return {"hd95": float(_hd95(pb, gb)), "assd": float(_assd(pb, gb))}
return {"hd95": float("nan"), "assd": float("nan")}
def _nanmean(xs):
xs = [x for x in xs if not (isinstance(x, float) and np.isnan(x))]
return float(np.mean(xs)) if xs else float("nan")
def per_image_metrics(pred: np.ndarray, target: np.ndarray, num_classes: int,
include_background: bool = False,
compute_hd95: bool = True) -> Dict[str, object]:
"""Metrics for a single image (2D int class maps). Returns foreground-macro
scalars plus a per_class breakdown. Classes absent in BOTH pred and gt are
skipped so they don't dilute the average."""
start = 0 if include_background else 1
acc = {k: [] for k in SCALAR_KEYS}
per_class: Dict[str, Dict[str, float]] = {}
for c in range(start, num_classes):
p = pred == c
g = target == c
if not p.any() and not g.any():
continue
tp = float(np.logical_and(p, g).sum())
fp = float(np.logical_and(p, ~g).sum())
fn = float(np.logical_and(~p, g).sum())
tn = float(np.logical_and(~p, ~g).sum())
cls = {
"dice": (2 * tp) / (2 * tp + fp + fn + 1e-8),
"iou": tp / (tp + fp + fn + 1e-8),
"sensitivity": tp / (tp + fn + 1e-8),
"specificity": tn / (tn + fp + 1e-8),
"precision": tp / (tp + fp + 1e-8),
}
if compute_hd95:
cls.update(_surface_binary(p, g))
else:
cls.update({"hd95": float("nan"), "assd": float("nan")})
per_class[str(c)] = cls
for k in SCALAR_KEYS:
acc[k].append(cls[k])
out: Dict[str, object] = {}
for k in _OVERLAP_KEYS:
out[k] = float(np.mean(acc[k])) if acc[k] else float("nan")
for k in _SURFACE_KEYS:
out[k] = _nanmean(acc[k]) if acc[k] else float("nan")
out["per_class"] = per_class
return out
def aggregate(records: List[Dict[str, object]]) -> Dict[str, float]:
"""Aggregate per-image metric dicts into mean/std over the set (per metric)."""
out: Dict[str, float] = {}
for key in SCALAR_KEYS:
vals = np.array([r[key] for r in records], dtype=np.float64)
vals = vals[~np.isnan(vals)]
out[f"{key}_mean"] = float(vals.mean()) if vals.size else float("nan")
out[f"{key}_std"] = float(vals.std()) if vals.size else float("nan")
out["n_images"] = float(len(records))
return out