GenSeg-Baselines / code /framework /nnunet_convert.py
MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
7.01 kB
"""Convert processed_unified/<dataset>/<protocol> -> nnU-Net v2 raw format.
Serves BOTH nnU-Net and U-Mamba (U-Mamba is nnU-Net v2 under the hood; same data
format, only a different trainer/env). Honors our FIXED train/val/test split:
* train + val images go into imagesTr/labelsTr (nnU-Net needs them together)
* the exact train/val partition is emitted as splits_final.json (fold 0)
* test goes into imagesTs/labelsTs (excluded from training; for our evaluation)
Key format rules (verified against sota/nnUNet):
* image file: <caseid>_0000.png (ONE file; RGB read as 3 channels, gray as 1)
* mask file : <caseid>.png (uint8, values 0..C-1, NEVER 0/255)
* channel_names: 3 entries (R,G,B) for RGB, 1 ("grayscale") for grayscale
* case ids are prefixed by split so train/val never collide inside imagesTr
Usage:
python framework/nnunet_convert.py --data_root <processed_unified> \
--dataset cvc_clinicdb --protocol official --nnunet_raw <nnUNet_raw> --dataset_id 1
"""
from __future__ import annotations
import os
import json
import argparse
import numpy as np
import cv2
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from framework.data.unified_dataset import (
_read_metadata, _pair_from_manifest, _pair_by_glob,
detect_in_channels, detect_num_classes,
)
def get_pairs(data_root, dataset, protocol, split):
split_dir = os.path.join(data_root, dataset, protocol, split)
if not os.path.isdir(split_dir):
return []
manifest = os.path.join(data_root, dataset, "manifest.jsonl")
return _pair_from_manifest(split_dir, manifest) or _pair_by_glob(split_dir)
def _link(src, dst):
if os.path.lexists(dst):
os.remove(dst)
os.symlink(os.path.abspath(src), dst)
def _emit_image(src, dst_png, in_ch):
"""Place image as <...>_0000.png, ensuring a CONSISTENT channel count matching
in_ch (nnU-Net requires it). Grayscale datasets are re-encoded to true 1-channel
(some sources, e.g. kits19, store a mix of 1- and 3-channel PNGs); RGB .png are
symlinked (fast)."""
if os.path.lexists(dst_png):
os.remove(dst_png) # never write THROUGH an existing symlink (would corrupt source)
if in_ch == 1:
im = cv2.imread(src, cv2.IMREAD_GRAYSCALE)
if im is None:
raise IOError(f"cannot read image {src}")
cv2.imwrite(dst_png, im)
elif os.path.splitext(src)[1].lower() == ".png":
_link(src, dst_png)
else:
im = cv2.imread(src, cv2.IMREAD_COLOR)
cv2.imwrite(dst_png, cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
def _emit_mask(src, dst_png, num_classes):
"""Place mask as <...>.png with values 0..C-1 (remap 0/255 binary etc.)."""
m = cv2.imread(src, cv2.IMREAD_GRAYSCALE)
if m is None:
raise IOError(f"cannot read mask {src}")
uniq = set(int(v) for v in np.unique(m))
allowed = set(range(num_classes))
if uniq <= allowed and os.path.splitext(src)[1].lower() == ".png":
_link(src, dst_png)
return
if uniq <= {0, 255} and num_classes == 2:
m = (m > 0).astype(np.uint8)
elif len(uniq) <= num_classes:
remap = {v: i for i, v in enumerate(sorted(uniq))}
m = np.vectorize(remap.get)(m).astype(np.uint8)
else:
raise ValueError(f"mask {src} has values {sorted(uniq)} outside 0..{num_classes-1}")
cv2.imwrite(dst_png, m)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--data_root", required=True)
ap.add_argument("--dataset", required=True)
ap.add_argument("--protocol", required=True)
ap.add_argument("--nnunet_raw", required=True, help="output nnUNet_raw root")
ap.add_argument("--dataset_id", type=int, required=True)
ap.add_argument("--name", default="", help="override dataset name suffix")
args = ap.parse_args()
meta = _read_metadata(args.data_root, args.dataset)
tr = get_pairs(args.data_root, args.dataset, args.protocol, "train")
va = get_pairs(args.data_root, args.dataset, args.protocol, "val")
ts = get_pairs(args.data_root, args.dataset, args.protocol, "test")
if not tr:
raise SystemExit(f"no train pairs for {args.dataset}/{args.protocol}")
in_ch = detect_in_channels(meta, tr[0][0])
num_classes = detect_num_classes(meta, [p[1] for p in tr + va + ts], args.dataset)
name = args.name or f"{args.dataset}_{args.protocol}"
dsname = f"Dataset{args.dataset_id:03d}_{name}"
root = os.path.join(args.nnunet_raw, dsname)
for d in ("imagesTr", "labelsTr", "imagesTs", "labelsTs"):
os.makedirs(os.path.join(root, d), exist_ok=True)
def emit(pairs, split, img_dir, lab_dir):
ids = []
for ip, mp in pairs:
stem = os.path.splitext(os.path.basename(ip))[0]
cid = f"{split}_{stem}"
_emit_image(ip, os.path.join(root, img_dir, f"{cid}_0000.png"), in_ch)
_emit_mask(mp, os.path.join(root, lab_dir, f"{cid}.png"), num_classes)
ids.append(cid)
return ids
train_ids = emit(tr, "train", "imagesTr", "labelsTr")
val_ids = emit(va, "val", "imagesTr", "labelsTr") # val also in imagesTr
emit(ts, "test", "imagesTs", "labelsTs")
# dataset.json
if in_ch == 3:
channel_names = {"0": "R", "1": "G", "2": "B"}
else:
channel_names = {"0": "grayscale"}
labels = {"background": 0}
for c in range(1, num_classes):
labels[f"label{c}"] = c
dataset_json = {
"channel_names": channel_names,
"labels": labels,
"numTraining": len(train_ids) + len(val_ids),
"file_ending": ".png",
"name": dsname,
"description": f"converted from processed_unified/{args.dataset}/{args.protocol}",
}
with open(os.path.join(root, "dataset.json"), "w") as f:
json.dump(dataset_json, f, indent=2)
# staged splits_final.json (copy into nnUNet_preprocessed/<dsname>/ AFTER preprocessing).
# 3 IDENTICAL folds => train folds 0/1/2 = 3 runs of the SAME fixed split
# (run-to-run variance for mean±SD, matching the framework's 3-seed protocol).
splits = [{"train": train_ids, "val": val_ids} for _ in range(3)]
with open(os.path.join(root, "splits_final.json"), "w") as f:
json.dump(splits, f, indent=2)
print(f"[ok] {dsname}: in_ch={in_ch} num_classes={num_classes} "
f"train={len(train_ids)} val={len(val_ids)} test={len(ts)} -> {root}")
print("Next:")
print(f" export nnUNet_raw=$(dirname {root})")
print(" export nnUNet_preprocessed=<fast_dir> nnUNet_results=<dir>")
print(f" nnUNetv2_plan_and_preprocess -d {args.dataset_id} -c 2d --verify_dataset_integrity")
print(f" cp {root}/splits_final.json $nnUNet_preprocessed/{dsname}/splits_final.json")
print(f" nnUNetv2_train {args.dataset_id} 2d 0 # nnU-Net")
print(f" nnUNetv2_train {args.dataset_id} 2d 0 -tr nnUNetTrainerUMambaBot # U-Mamba")
if __name__ == "__main__":
main()