GenSeg-Baselines / code /framework /nnunet_eval.py
MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
4.32 kB
"""Compute framework-format metrics.json from nnU-Net / U-Mamba test predictions.
nnU-Net only reports validation Dice during training. To compare on the SAME
held-out test set with the SAME 7 metrics as the framework, we: predict on
imagesTs (done separately via nnUNetv2_predict), then run THIS script to score the
predicted masks against labelsTs using framework/metrics.py, writing a metrics.json
in the exact framework format so report/aggregate.py includes nnU-Net/U-Mamba rows.
python framework/nnunet_eval.py --data_root <processed_unified> --dataset <ds> \
--protocol <proto> --raw <nnUNet_raw> --dataset_id <ID> --fold <f> \
--pred_dir <predictions> --arch nnunet --exp_name baselines
"""
from __future__ import annotations
import os
import sys
import json
import glob
import argparse
import numpy as np
import cv2
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from framework.metrics.metrics import per_image_metrics, aggregate
from framework.data.unified_dataset import _read_metadata, detect_num_classes
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--data_root", required=True)
ap.add_argument("--dataset", required=True)
ap.add_argument("--protocol", required=True)
ap.add_argument("--raw", required=True, help="nnUNet_raw root")
ap.add_argument("--dataset_id", type=int, required=True)
ap.add_argument("--name", default="", help="Dataset<DDD>_<name> suffix; default <dataset>_<protocol>")
ap.add_argument("--fold", type=int, required=True)
ap.add_argument("--pred_dir", required=True)
ap.add_argument("--arch", default="nnunet")
ap.add_argument("--exp_name", default="baselines")
ap.add_argument("--out_root", default="results")
ap.add_argument("--include_background", action="store_true")
ap.add_argument("--eval_size", type=int, default=0,
help="resize pred+gt to R×R (nearest) before scoring; 0 = native GT resolution")
args = ap.parse_args()
name = args.name or f"{args.dataset}_{args.protocol}"
dsname = f"Dataset{args.dataset_id:03d}_{name}"
lab_dir = os.path.join(args.raw, dsname, "labelsTs")
meta = _read_metadata(args.data_root, args.dataset)
gt_masks = sorted(glob.glob(os.path.join(lab_dir, "*.png")))
num_classes = detect_num_classes(meta, gt_masks, args.dataset)
records = []
n_missing = 0
for pp in sorted(glob.glob(os.path.join(args.pred_dir, "*.png"))):
base = os.path.basename(pp)
gp = os.path.join(lab_dir, base)
if not os.path.isfile(gp):
n_missing += 1
continue
pred = cv2.imread(pp, cv2.IMREAD_GRAYSCALE)
gt = cv2.imread(gp, cv2.IMREAD_GRAYSCALE)
if pred is None or gt is None:
n_missing += 1
continue
if pred.shape != gt.shape:
pred = cv2.resize(pred, (gt.shape[1], gt.shape[0]), interpolation=cv2.INTER_NEAREST)
if args.eval_size > 0: # resolution-fair common-R scoring
R = args.eval_size
pred = cv2.resize(pred, (R, R), interpolation=cv2.INTER_NEAREST)
gt = cv2.resize(gt, (R, R), interpolation=cv2.INTER_NEAREST)
records.append(per_image_metrics(pred.astype(np.int64), gt.astype(np.int64),
num_classes, include_background=args.include_background,
compute_hd95=True))
if not records:
raise SystemExit(f"no matched (pred,gt) pairs in {args.pred_dir} vs {lab_dir}")
agg = aggregate(records)
out = {"dataset": args.dataset, "protocol": args.protocol, "arch": args.arch,
"seed": args.fold, "num_classes": num_classes, "metrics": agg, "per_image": records}
out_dir = os.path.join(args.out_root, args.exp_name, f"{args.dataset}_{args.protocol}",
args.arch, f"seed{args.fold}")
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "metrics.json"), "w") as f:
json.dump(out, f, indent=2)
print(f"[nnunet_eval] {dsname} fold{args.fold}: n={len(records)} (missing {n_missing}) "
f"dice={agg['dice_mean']:.4f} iou={agg['iou_mean']:.4f} hd95={agg['hd95_mean']:.2f} "
f"-> {out_dir}/metrics.json")
if __name__ == "__main__":
main()