"""Convert processed_unified// -> nnU-Net v2 raw format. Serves BOTH nnU-Net and U-Mamba (U-Mamba is nnU-Net v2 under the hood; same data format, only a different trainer/env). Honors our FIXED train/val/test split: * train + val images go into imagesTr/labelsTr (nnU-Net needs them together) * the exact train/val partition is emitted as splits_final.json (fold 0) * test goes into imagesTs/labelsTs (excluded from training; for our evaluation) Key format rules (verified against sota/nnUNet): * image file: _0000.png (ONE file; RGB read as 3 channels, gray as 1) * mask file : .png (uint8, values 0..C-1, NEVER 0/255) * channel_names: 3 entries (R,G,B) for RGB, 1 ("grayscale") for grayscale * case ids are prefixed by split so train/val never collide inside imagesTr Usage: python framework/nnunet_convert.py --data_root \ --dataset cvc_clinicdb --protocol official --nnunet_raw --dataset_id 1 """ from __future__ import annotations import os import json import argparse import numpy as np import cv2 import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from framework.data.unified_dataset import ( _read_metadata, _pair_from_manifest, _pair_by_glob, detect_in_channels, detect_num_classes, ) def get_pairs(data_root, dataset, protocol, split): split_dir = os.path.join(data_root, dataset, protocol, split) if not os.path.isdir(split_dir): return [] manifest = os.path.join(data_root, dataset, "manifest.jsonl") return _pair_from_manifest(split_dir, manifest) or _pair_by_glob(split_dir) def _link(src, dst): if os.path.lexists(dst): os.remove(dst) os.symlink(os.path.abspath(src), dst) def _emit_image(src, dst_png, in_ch): """Place image as <...>_0000.png, ensuring a CONSISTENT channel count matching in_ch (nnU-Net requires it). Grayscale datasets are re-encoded to true 1-channel (some sources, e.g. kits19, store a mix of 1- and 3-channel PNGs); RGB .png are symlinked (fast).""" if os.path.lexists(dst_png): os.remove(dst_png) # never write THROUGH an existing symlink (would corrupt source) if in_ch == 1: im = cv2.imread(src, cv2.IMREAD_GRAYSCALE) if im is None: raise IOError(f"cannot read image {src}") cv2.imwrite(dst_png, im) elif os.path.splitext(src)[1].lower() == ".png": _link(src, dst_png) else: im = cv2.imread(src, cv2.IMREAD_COLOR) cv2.imwrite(dst_png, cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) def _emit_mask(src, dst_png, num_classes): """Place mask as <...>.png with values 0..C-1 (remap 0/255 binary etc.).""" m = cv2.imread(src, cv2.IMREAD_GRAYSCALE) if m is None: raise IOError(f"cannot read mask {src}") uniq = set(int(v) for v in np.unique(m)) allowed = set(range(num_classes)) if uniq <= allowed and os.path.splitext(src)[1].lower() == ".png": _link(src, dst_png) return if uniq <= {0, 255} and num_classes == 2: m = (m > 0).astype(np.uint8) elif len(uniq) <= num_classes: remap = {v: i for i, v in enumerate(sorted(uniq))} m = np.vectorize(remap.get)(m).astype(np.uint8) else: raise ValueError(f"mask {src} has values {sorted(uniq)} outside 0..{num_classes-1}") cv2.imwrite(dst_png, m) def main(): ap = argparse.ArgumentParser() ap.add_argument("--data_root", required=True) ap.add_argument("--dataset", required=True) ap.add_argument("--protocol", required=True) ap.add_argument("--nnunet_raw", required=True, help="output nnUNet_raw root") ap.add_argument("--dataset_id", type=int, required=True) ap.add_argument("--name", default="", help="override dataset name suffix") args = ap.parse_args() meta = _read_metadata(args.data_root, args.dataset) tr = get_pairs(args.data_root, args.dataset, args.protocol, "train") va = get_pairs(args.data_root, args.dataset, args.protocol, "val") ts = get_pairs(args.data_root, args.dataset, args.protocol, "test") if not tr: raise SystemExit(f"no train pairs for {args.dataset}/{args.protocol}") in_ch = detect_in_channels(meta, tr[0][0]) num_classes = detect_num_classes(meta, [p[1] for p in tr + va + ts], args.dataset) name = args.name or f"{args.dataset}_{args.protocol}" dsname = f"Dataset{args.dataset_id:03d}_{name}" root = os.path.join(args.nnunet_raw, dsname) for d in ("imagesTr", "labelsTr", "imagesTs", "labelsTs"): os.makedirs(os.path.join(root, d), exist_ok=True) def emit(pairs, split, img_dir, lab_dir): ids = [] for ip, mp in pairs: stem = os.path.splitext(os.path.basename(ip))[0] cid = f"{split}_{stem}" _emit_image(ip, os.path.join(root, img_dir, f"{cid}_0000.png"), in_ch) _emit_mask(mp, os.path.join(root, lab_dir, f"{cid}.png"), num_classes) ids.append(cid) return ids train_ids = emit(tr, "train", "imagesTr", "labelsTr") val_ids = emit(va, "val", "imagesTr", "labelsTr") # val also in imagesTr emit(ts, "test", "imagesTs", "labelsTs") # dataset.json if in_ch == 3: channel_names = {"0": "R", "1": "G", "2": "B"} else: channel_names = {"0": "grayscale"} labels = {"background": 0} for c in range(1, num_classes): labels[f"label{c}"] = c dataset_json = { "channel_names": channel_names, "labels": labels, "numTraining": len(train_ids) + len(val_ids), "file_ending": ".png", "name": dsname, "description": f"converted from processed_unified/{args.dataset}/{args.protocol}", } with open(os.path.join(root, "dataset.json"), "w") as f: json.dump(dataset_json, f, indent=2) # staged splits_final.json (copy into nnUNet_preprocessed// AFTER preprocessing). # 3 IDENTICAL folds => train folds 0/1/2 = 3 runs of the SAME fixed split # (run-to-run variance for mean±SD, matching the framework's 3-seed protocol). splits = [{"train": train_ids, "val": val_ids} for _ in range(3)] with open(os.path.join(root, "splits_final.json"), "w") as f: json.dump(splits, f, indent=2) print(f"[ok] {dsname}: in_ch={in_ch} num_classes={num_classes} " f"train={len(train_ids)} val={len(val_ids)} test={len(ts)} -> {root}") print("Next:") print(f" export nnUNet_raw=$(dirname {root})") print(" export nnUNet_preprocessed= nnUNet_results=") print(f" nnUNetv2_plan_and_preprocess -d {args.dataset_id} -c 2d --verify_dataset_integrity") print(f" cp {root}/splits_final.json $nnUNet_preprocessed/{dsname}/splits_final.json") print(f" nnUNetv2_train {args.dataset_id} 2d 0 # nnU-Net") print(f" nnUNetv2_train {args.dataset_id} 2d 0 -tr nnUNetTrainerUMambaBot # U-Mamba") if __name__ == "__main__": main()