| """Convert processed_unified/<dataset>/<protocol> -> nnU-Net v2 raw format. |
| |
| Serves BOTH nnU-Net and U-Mamba (U-Mamba is nnU-Net v2 under the hood; same data |
| format, only a different trainer/env). Honors our FIXED train/val/test split: |
| * train + val images go into imagesTr/labelsTr (nnU-Net needs them together) |
| * the exact train/val partition is emitted as splits_final.json (fold 0) |
| * test goes into imagesTs/labelsTs (excluded from training; for our evaluation) |
| |
| Key format rules (verified against sota/nnUNet): |
| * image file: <caseid>_0000.png (ONE file; RGB read as 3 channels, gray as 1) |
| * mask file : <caseid>.png (uint8, values 0..C-1, NEVER 0/255) |
| * channel_names: 3 entries (R,G,B) for RGB, 1 ("grayscale") for grayscale |
| * case ids are prefixed by split so train/val never collide inside imagesTr |
| |
| Usage: |
| python framework/nnunet_convert.py --data_root <processed_unified> \ |
| --dataset cvc_clinicdb --protocol official --nnunet_raw <nnUNet_raw> --dataset_id 1 |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import json |
| import argparse |
|
|
| import numpy as np |
| import cv2 |
|
|
| import sys |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
| from framework.data.unified_dataset import ( |
| _read_metadata, _pair_from_manifest, _pair_by_glob, |
| detect_in_channels, detect_num_classes, |
| ) |
|
|
|
|
| def get_pairs(data_root, dataset, protocol, split): |
| split_dir = os.path.join(data_root, dataset, protocol, split) |
| if not os.path.isdir(split_dir): |
| return [] |
| manifest = os.path.join(data_root, dataset, "manifest.jsonl") |
| return _pair_from_manifest(split_dir, manifest) or _pair_by_glob(split_dir) |
|
|
|
|
| def _link(src, dst): |
| if os.path.lexists(dst): |
| os.remove(dst) |
| os.symlink(os.path.abspath(src), dst) |
|
|
|
|
| def _emit_image(src, dst_png, in_ch): |
| """Place image as <...>_0000.png, ensuring a CONSISTENT channel count matching |
| in_ch (nnU-Net requires it). Grayscale datasets are re-encoded to true 1-channel |
| (some sources, e.g. kits19, store a mix of 1- and 3-channel PNGs); RGB .png are |
| symlinked (fast).""" |
| if os.path.lexists(dst_png): |
| os.remove(dst_png) |
| if in_ch == 1: |
| im = cv2.imread(src, cv2.IMREAD_GRAYSCALE) |
| if im is None: |
| raise IOError(f"cannot read image {src}") |
| cv2.imwrite(dst_png, im) |
| elif os.path.splitext(src)[1].lower() == ".png": |
| _link(src, dst_png) |
| else: |
| im = cv2.imread(src, cv2.IMREAD_COLOR) |
| cv2.imwrite(dst_png, cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) |
|
|
|
|
| def _emit_mask(src, dst_png, num_classes): |
| """Place mask as <...>.png with values 0..C-1 (remap 0/255 binary etc.).""" |
| m = cv2.imread(src, cv2.IMREAD_GRAYSCALE) |
| if m is None: |
| raise IOError(f"cannot read mask {src}") |
| uniq = set(int(v) for v in np.unique(m)) |
| allowed = set(range(num_classes)) |
| if uniq <= allowed and os.path.splitext(src)[1].lower() == ".png": |
| _link(src, dst_png) |
| return |
| if uniq <= {0, 255} and num_classes == 2: |
| m = (m > 0).astype(np.uint8) |
| elif len(uniq) <= num_classes: |
| remap = {v: i for i, v in enumerate(sorted(uniq))} |
| m = np.vectorize(remap.get)(m).astype(np.uint8) |
| else: |
| raise ValueError(f"mask {src} has values {sorted(uniq)} outside 0..{num_classes-1}") |
| cv2.imwrite(dst_png, m) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--data_root", required=True) |
| ap.add_argument("--dataset", required=True) |
| ap.add_argument("--protocol", required=True) |
| ap.add_argument("--nnunet_raw", required=True, help="output nnUNet_raw root") |
| ap.add_argument("--dataset_id", type=int, required=True) |
| ap.add_argument("--name", default="", help="override dataset name suffix") |
| args = ap.parse_args() |
|
|
| meta = _read_metadata(args.data_root, args.dataset) |
| tr = get_pairs(args.data_root, args.dataset, args.protocol, "train") |
| va = get_pairs(args.data_root, args.dataset, args.protocol, "val") |
| ts = get_pairs(args.data_root, args.dataset, args.protocol, "test") |
| if not tr: |
| raise SystemExit(f"no train pairs for {args.dataset}/{args.protocol}") |
|
|
| in_ch = detect_in_channels(meta, tr[0][0]) |
| num_classes = detect_num_classes(meta, [p[1] for p in tr + va + ts], args.dataset) |
|
|
| name = args.name or f"{args.dataset}_{args.protocol}" |
| dsname = f"Dataset{args.dataset_id:03d}_{name}" |
| root = os.path.join(args.nnunet_raw, dsname) |
| for d in ("imagesTr", "labelsTr", "imagesTs", "labelsTs"): |
| os.makedirs(os.path.join(root, d), exist_ok=True) |
|
|
| def emit(pairs, split, img_dir, lab_dir): |
| ids = [] |
| for ip, mp in pairs: |
| stem = os.path.splitext(os.path.basename(ip))[0] |
| cid = f"{split}_{stem}" |
| _emit_image(ip, os.path.join(root, img_dir, f"{cid}_0000.png"), in_ch) |
| _emit_mask(mp, os.path.join(root, lab_dir, f"{cid}.png"), num_classes) |
| ids.append(cid) |
| return ids |
|
|
| train_ids = emit(tr, "train", "imagesTr", "labelsTr") |
| val_ids = emit(va, "val", "imagesTr", "labelsTr") |
| emit(ts, "test", "imagesTs", "labelsTs") |
|
|
| |
| if in_ch == 3: |
| channel_names = {"0": "R", "1": "G", "2": "B"} |
| else: |
| channel_names = {"0": "grayscale"} |
| labels = {"background": 0} |
| for c in range(1, num_classes): |
| labels[f"label{c}"] = c |
| dataset_json = { |
| "channel_names": channel_names, |
| "labels": labels, |
| "numTraining": len(train_ids) + len(val_ids), |
| "file_ending": ".png", |
| "name": dsname, |
| "description": f"converted from processed_unified/{args.dataset}/{args.protocol}", |
| } |
| with open(os.path.join(root, "dataset.json"), "w") as f: |
| json.dump(dataset_json, f, indent=2) |
|
|
| |
| |
| |
| splits = [{"train": train_ids, "val": val_ids} for _ in range(3)] |
| with open(os.path.join(root, "splits_final.json"), "w") as f: |
| json.dump(splits, f, indent=2) |
|
|
| print(f"[ok] {dsname}: in_ch={in_ch} num_classes={num_classes} " |
| f"train={len(train_ids)} val={len(val_ids)} test={len(ts)} -> {root}") |
| print("Next:") |
| print(f" export nnUNet_raw=$(dirname {root})") |
| print(" export nnUNet_preprocessed=<fast_dir> nnUNet_results=<dir>") |
| print(f" nnUNetv2_plan_and_preprocess -d {args.dataset_id} -c 2d --verify_dataset_integrity") |
| print(f" cp {root}/splits_final.json $nnUNet_preprocessed/{dsname}/splits_final.json") |
| print(f" nnUNetv2_train {args.dataset_id} 2d 0 # nnU-Net") |
| print(f" nnUNetv2_train {args.dataset_id} 2d 0 -tr nnUNetTrainerUMambaBot # U-Mamba") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|