GenSeg-Baselines / code /framework /data /unified_dataset.py
MaybeRichard's picture
code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified
Raw
History Blame Contribute Delete
6.95 kB
"""Dataset reader for the standardized `processed_unified` layout.
Expected layout (see dataset/SEGMENTATION_WORKSPACE_README.md):
<data_root>/<dataset>/<protocol>/<split>/images/*.png
<data_root>/<dataset>/<protocol>/<split>/masks/*.png
<data_root>/<dataset>/metadata.json (optional, preferred)
<data_root>/<dataset>/manifest.jsonl (optional)
Returns per item: {"image": FloatTensor[C,H,W], "mask": LongTensor[H,W], "name": str}.
Binary and multi-class masks are both supported: masks keep their integer class
ids (0..C-1). Auto-detection of in_channels / num_classes falls back to scanning
files when metadata is absent, so the loader is robust to missing metadata.
"""
from __future__ import annotations
import json
import os
from glob import glob
from typing import Optional, Callable, List, Tuple
import numpy as np
import cv2
from torch.utils.data import Dataset
_MODALITY_CHANNELS = { # hint table; only used when metadata lacks in_channels
"rgb": 3, "fundus": 3, "colonoscopy": 3, "endoscopy": 3, "histopathology": 3,
"ultrasound": 1, "mri": 1, "ct": 1, "grayscale": 1,
}
# Documented class counts (incl. background). metadata.json on the server has no
# num_classes field, so this table is the fast, reliable primary source; unknown
# datasets fall back to a FULL scan of the mask set (accurate but slower).
_KNOWN_NUM_CLASSES = {
"cvc_clinicdb": 2, "kvasir_seg": 2, "fives": 2, "busi": 2,
"refuge2": 3, "acdc_png": 4,
"idridd_segmentation": 6, "pannuke_semantic": 6,
}
def _read_metadata(data_root: str, dataset: str) -> dict:
path = os.path.join(data_root, dataset, "metadata.json")
if os.path.isfile(path):
try:
with open(path) as f:
return json.load(f)
except Exception:
return {}
return {}
def _pair_from_manifest(split_dir: str, manifest: str) -> Optional[List[Tuple[str, str]]]:
if not os.path.isfile(manifest):
return None
pairs = []
base = os.path.dirname(manifest)
with open(manifest) as f:
for line in f:
line = line.strip()
if not line:
continue
rec = json.loads(line)
img = rec.get("image") or rec.get("image_path") or rec.get("img")
msk = rec.get("mask") or rec.get("mask_path") or rec.get("label")
if img is None or msk is None:
return None
# manifest paths may be relative to dataset root or absolute
ip = img if os.path.isabs(img) else os.path.join(base, img)
mp = msk if os.path.isabs(msk) else os.path.join(base, msk)
# only keep entries that fall under this split dir
if os.path.normpath(split_dir) in os.path.normpath(ip):
pairs.append((ip, mp))
return pairs or None
def _pair_by_glob(split_dir: str) -> List[Tuple[str, str]]:
img_dir = os.path.join(split_dir, "images")
msk_dir = os.path.join(split_dir, "masks")
imgs = sorted(glob(os.path.join(img_dir, "*")))
pairs = []
for ip in imgs:
stem = os.path.splitext(os.path.basename(ip))[0]
# mask may share extension or be .png
cands = glob(os.path.join(msk_dir, stem + ".*"))
if not cands:
continue
pairs.append((ip, cands[0]))
return pairs
def detect_in_channels(meta: dict, sample_img: Optional[str]) -> int:
if meta.get("in_channels"):
return int(meta["in_channels"])
mod = str(meta.get("modality", "")).lower()
for k, v in _MODALITY_CHANNELS.items():
if k in mod:
return v
if sample_img and os.path.isfile(sample_img):
im = cv2.imread(sample_img, cv2.IMREAD_UNCHANGED)
if im is not None and im.ndim == 3 and im.shape[2] >= 3:
return 3
return 1
def detect_num_classes(meta: dict, mask_paths: List[str], dataset: str = "") -> int:
if dataset in _KNOWN_NUM_CLASSES:
return _KNOWN_NUM_CLASSES[dataset]
if meta.get("num_classes"):
return int(meta["num_classes"])
# unknown dataset: scan ALL masks so a rare class is never missed
vals = set()
for mp in mask_paths:
m = cv2.imread(mp, cv2.IMREAD_GRAYSCALE)
if m is not None:
vals.update(np.unique(m).tolist())
if not vals:
return 2
maxv = max(vals)
return int(maxv) + 1 if maxv >= 1 else 2
class UnifiedSegDataset(Dataset):
def __init__(self, data_root: str, dataset: str, protocol: str, split: str,
transform: Optional[Callable] = None,
in_channels: int = 0, num_classes: int = 0,
synth_dir: str = ""):
self.data_root = data_root
self.dataset = dataset
self.split = split
self.transform = transform
split_dir = os.path.join(data_root, dataset, protocol, split)
if not os.path.isdir(split_dir):
raise FileNotFoundError(
f"split dir not found: {split_dir}\n"
f"(data is prepared separately; see dataset/ scripts)")
manifest = os.path.join(data_root, dataset, "manifest.jsonl")
pairs = _pair_from_manifest(split_dir, manifest) or _pair_by_glob(split_dir)
if not pairs:
raise RuntimeError(f"no (image,mask) pairs found in {split_dir}")
# optionally merge synthetic (image,mask) pairs into the (train) split
if synth_dir and os.path.isdir(synth_dir):
sp = _pair_by_glob(synth_dir if os.path.isdir(os.path.join(synth_dir, "images"))
else os.path.dirname(synth_dir))
pairs = pairs + sp
self.pairs = pairs
meta = _read_metadata(data_root, dataset)
self.in_channels = in_channels or detect_in_channels(meta, pairs[0][0])
self.num_classes = num_classes or detect_num_classes(meta, [p[1] for p in pairs], dataset)
def __len__(self) -> int:
return len(self.pairs)
def _load_image(self, path: str) -> np.ndarray:
if self.in_channels == 1:
im = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
if im is None:
raise IOError(f"cannot read image {path}")
return im[:, :, None] # H,W,1
im = cv2.imread(path, cv2.IMREAD_COLOR) # BGR
if im is None:
raise IOError(f"cannot read image {path}")
return cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # H,W,3
def __getitem__(self, idx: int):
ip, mp = self.pairs[idx]
image = self._load_image(ip)
mask = cv2.imread(mp, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise IOError(f"cannot read mask {mp}")
mask = mask.astype(np.int64)
if self.transform is not None:
image, mask = self.transform(image, mask)
return {"image": image, "mask": mask,
"name": os.path.splitext(os.path.basename(ip))[0]}