"""Compute framework-format metrics.json from nnU-Net / U-Mamba test predictions. nnU-Net only reports validation Dice during training. To compare on the SAME held-out test set with the SAME 7 metrics as the framework, we: predict on imagesTs (done separately via nnUNetv2_predict), then run THIS script to score the predicted masks against labelsTs using framework/metrics.py, writing a metrics.json in the exact framework format so report/aggregate.py includes nnU-Net/U-Mamba rows. python framework/nnunet_eval.py --data_root --dataset \ --protocol --raw --dataset_id --fold \ --pred_dir --arch nnunet --exp_name baselines """ from __future__ import annotations import os import sys import json import glob import argparse import numpy as np import cv2 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from framework.metrics.metrics import per_image_metrics, aggregate from framework.data.unified_dataset import _read_metadata, detect_num_classes def main(): ap = argparse.ArgumentParser() ap.add_argument("--data_root", required=True) ap.add_argument("--dataset", required=True) ap.add_argument("--protocol", required=True) ap.add_argument("--raw", required=True, help="nnUNet_raw root") ap.add_argument("--dataset_id", type=int, required=True) ap.add_argument("--name", default="", help="Dataset_ suffix; default _") ap.add_argument("--fold", type=int, required=True) ap.add_argument("--pred_dir", required=True) ap.add_argument("--arch", default="nnunet") ap.add_argument("--exp_name", default="baselines") ap.add_argument("--out_root", default="results") ap.add_argument("--include_background", action="store_true") ap.add_argument("--eval_size", type=int, default=0, help="resize pred+gt to R×R (nearest) before scoring; 0 = native GT resolution") args = ap.parse_args() name = args.name or f"{args.dataset}_{args.protocol}" dsname = f"Dataset{args.dataset_id:03d}_{name}" lab_dir = os.path.join(args.raw, dsname, "labelsTs") meta = _read_metadata(args.data_root, args.dataset) gt_masks = sorted(glob.glob(os.path.join(lab_dir, "*.png"))) num_classes = detect_num_classes(meta, gt_masks, args.dataset) records = [] n_missing = 0 for pp in sorted(glob.glob(os.path.join(args.pred_dir, "*.png"))): base = os.path.basename(pp) gp = os.path.join(lab_dir, base) if not os.path.isfile(gp): n_missing += 1 continue pred = cv2.imread(pp, cv2.IMREAD_GRAYSCALE) gt = cv2.imread(gp, cv2.IMREAD_GRAYSCALE) if pred is None or gt is None: n_missing += 1 continue if pred.shape != gt.shape: pred = cv2.resize(pred, (gt.shape[1], gt.shape[0]), interpolation=cv2.INTER_NEAREST) if args.eval_size > 0: # resolution-fair common-R scoring R = args.eval_size pred = cv2.resize(pred, (R, R), interpolation=cv2.INTER_NEAREST) gt = cv2.resize(gt, (R, R), interpolation=cv2.INTER_NEAREST) records.append(per_image_metrics(pred.astype(np.int64), gt.astype(np.int64), num_classes, include_background=args.include_background, compute_hd95=True)) if not records: raise SystemExit(f"no matched (pred,gt) pairs in {args.pred_dir} vs {lab_dir}") agg = aggregate(records) out = {"dataset": args.dataset, "protocol": args.protocol, "arch": args.arch, "seed": args.fold, "num_classes": num_classes, "metrics": agg, "per_image": records} out_dir = os.path.join(args.out_root, args.exp_name, f"{args.dataset}_{args.protocol}", args.arch, f"seed{args.fold}") os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, "metrics.json"), "w") as f: json.dump(out, f, indent=2) print(f"[nnunet_eval] {dsname} fold{args.fold}: n={len(records)} (missing {n_missing}) " f"dice={agg['dice_mean']:.4f} iou={agg['iou_mean']:.4f} hd95={agg['hd95_mean']:.2f} " f"-> {out_dir}/metrics.json") if __name__ == "__main__": main()