| """Compute framework-format metrics.json from nnU-Net / U-Mamba test predictions. |
| |
| nnU-Net only reports validation Dice during training. To compare on the SAME |
| held-out test set with the SAME 7 metrics as the framework, we: predict on |
| imagesTs (done separately via nnUNetv2_predict), then run THIS script to score the |
| predicted masks against labelsTs using framework/metrics.py, writing a metrics.json |
| in the exact framework format so report/aggregate.py includes nnU-Net/U-Mamba rows. |
| |
| python framework/nnunet_eval.py --data_root <processed_unified> --dataset <ds> \ |
| --protocol <proto> --raw <nnUNet_raw> --dataset_id <ID> --fold <f> \ |
| --pred_dir <predictions> --arch nnunet --exp_name baselines |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import sys |
| import json |
| import glob |
| import argparse |
|
|
| import numpy as np |
| import cv2 |
|
|
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
| from framework.metrics.metrics import per_image_metrics, aggregate |
| from framework.data.unified_dataset import _read_metadata, detect_num_classes |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--data_root", required=True) |
| ap.add_argument("--dataset", required=True) |
| ap.add_argument("--protocol", required=True) |
| ap.add_argument("--raw", required=True, help="nnUNet_raw root") |
| ap.add_argument("--dataset_id", type=int, required=True) |
| ap.add_argument("--name", default="", help="Dataset<DDD>_<name> suffix; default <dataset>_<protocol>") |
| ap.add_argument("--fold", type=int, required=True) |
| ap.add_argument("--pred_dir", required=True) |
| ap.add_argument("--arch", default="nnunet") |
| ap.add_argument("--exp_name", default="baselines") |
| ap.add_argument("--out_root", default="results") |
| ap.add_argument("--include_background", action="store_true") |
| ap.add_argument("--eval_size", type=int, default=0, |
| help="resize pred+gt to R×R (nearest) before scoring; 0 = native GT resolution") |
| args = ap.parse_args() |
|
|
| name = args.name or f"{args.dataset}_{args.protocol}" |
| dsname = f"Dataset{args.dataset_id:03d}_{name}" |
| lab_dir = os.path.join(args.raw, dsname, "labelsTs") |
|
|
| meta = _read_metadata(args.data_root, args.dataset) |
| gt_masks = sorted(glob.glob(os.path.join(lab_dir, "*.png"))) |
| num_classes = detect_num_classes(meta, gt_masks, args.dataset) |
|
|
| records = [] |
| n_missing = 0 |
| for pp in sorted(glob.glob(os.path.join(args.pred_dir, "*.png"))): |
| base = os.path.basename(pp) |
| gp = os.path.join(lab_dir, base) |
| if not os.path.isfile(gp): |
| n_missing += 1 |
| continue |
| pred = cv2.imread(pp, cv2.IMREAD_GRAYSCALE) |
| gt = cv2.imread(gp, cv2.IMREAD_GRAYSCALE) |
| if pred is None or gt is None: |
| n_missing += 1 |
| continue |
| if pred.shape != gt.shape: |
| pred = cv2.resize(pred, (gt.shape[1], gt.shape[0]), interpolation=cv2.INTER_NEAREST) |
| if args.eval_size > 0: |
| R = args.eval_size |
| pred = cv2.resize(pred, (R, R), interpolation=cv2.INTER_NEAREST) |
| gt = cv2.resize(gt, (R, R), interpolation=cv2.INTER_NEAREST) |
| records.append(per_image_metrics(pred.astype(np.int64), gt.astype(np.int64), |
| num_classes, include_background=args.include_background, |
| compute_hd95=True)) |
| if not records: |
| raise SystemExit(f"no matched (pred,gt) pairs in {args.pred_dir} vs {lab_dir}") |
|
|
| agg = aggregate(records) |
| out = {"dataset": args.dataset, "protocol": args.protocol, "arch": args.arch, |
| "seed": args.fold, "num_classes": num_classes, "metrics": agg, "per_image": records} |
| out_dir = os.path.join(args.out_root, args.exp_name, f"{args.dataset}_{args.protocol}", |
| args.arch, f"seed{args.fold}") |
| os.makedirs(out_dir, exist_ok=True) |
| with open(os.path.join(out_dir, "metrics.json"), "w") as f: |
| json.dump(out, f, indent=2) |
| print(f"[nnunet_eval] {dsname} fold{args.fold}: n={len(records)} (missing {n_missing}) " |
| f"dice={agg['dice_mean']:.4f} iou={agg['iou_mean']:.4f} hd95={agg['hd95_mean']:.2f} " |
| f"-> {out_dir}/metrics.json") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|