"""Dataset reader for the standardized `processed_unified` layout. Expected layout (see dataset/SEGMENTATION_WORKSPACE_README.md): ////images/*.png ////masks/*.png //metadata.json (optional, preferred) //manifest.jsonl (optional) Returns per item: {"image": FloatTensor[C,H,W], "mask": LongTensor[H,W], "name": str}. Binary and multi-class masks are both supported: masks keep their integer class ids (0..C-1). Auto-detection of in_channels / num_classes falls back to scanning files when metadata is absent, so the loader is robust to missing metadata. """ from __future__ import annotations import json import os from glob import glob from typing import Optional, Callable, List, Tuple import numpy as np import cv2 from torch.utils.data import Dataset _MODALITY_CHANNELS = { # hint table; only used when metadata lacks in_channels "rgb": 3, "fundus": 3, "colonoscopy": 3, "endoscopy": 3, "histopathology": 3, "ultrasound": 1, "mri": 1, "ct": 1, "grayscale": 1, } # Documented class counts (incl. background). metadata.json on the server has no # num_classes field, so this table is the fast, reliable primary source; unknown # datasets fall back to a FULL scan of the mask set (accurate but slower). _KNOWN_NUM_CLASSES = { "cvc_clinicdb": 2, "kvasir_seg": 2, "fives": 2, "busi": 2, "refuge2": 3, "acdc_png": 4, "idridd_segmentation": 6, "pannuke_semantic": 6, } def _read_metadata(data_root: str, dataset: str) -> dict: path = os.path.join(data_root, dataset, "metadata.json") if os.path.isfile(path): try: with open(path) as f: return json.load(f) except Exception: return {} return {} def _pair_from_manifest(split_dir: str, manifest: str) -> Optional[List[Tuple[str, str]]]: if not os.path.isfile(manifest): return None pairs = [] base = os.path.dirname(manifest) with open(manifest) as f: for line in f: line = line.strip() if not line: continue rec = json.loads(line) img = rec.get("image") or rec.get("image_path") or rec.get("img") msk = rec.get("mask") or rec.get("mask_path") or rec.get("label") if img is None or msk is None: return None # manifest paths may be relative to dataset root or absolute ip = img if os.path.isabs(img) else os.path.join(base, img) mp = msk if os.path.isabs(msk) else os.path.join(base, msk) # only keep entries that fall under this split dir if os.path.normpath(split_dir) in os.path.normpath(ip): pairs.append((ip, mp)) return pairs or None def _pair_by_glob(split_dir: str) -> List[Tuple[str, str]]: img_dir = os.path.join(split_dir, "images") msk_dir = os.path.join(split_dir, "masks") imgs = sorted(glob(os.path.join(img_dir, "*"))) pairs = [] for ip in imgs: stem = os.path.splitext(os.path.basename(ip))[0] # mask may share extension or be .png cands = glob(os.path.join(msk_dir, stem + ".*")) if not cands: continue pairs.append((ip, cands[0])) return pairs def detect_in_channels(meta: dict, sample_img: Optional[str]) -> int: if meta.get("in_channels"): return int(meta["in_channels"]) mod = str(meta.get("modality", "")).lower() for k, v in _MODALITY_CHANNELS.items(): if k in mod: return v if sample_img and os.path.isfile(sample_img): im = cv2.imread(sample_img, cv2.IMREAD_UNCHANGED) if im is not None and im.ndim == 3 and im.shape[2] >= 3: return 3 return 1 def detect_num_classes(meta: dict, mask_paths: List[str], dataset: str = "") -> int: if dataset in _KNOWN_NUM_CLASSES: return _KNOWN_NUM_CLASSES[dataset] if meta.get("num_classes"): return int(meta["num_classes"]) # unknown dataset: scan ALL masks so a rare class is never missed vals = set() for mp in mask_paths: m = cv2.imread(mp, cv2.IMREAD_GRAYSCALE) if m is not None: vals.update(np.unique(m).tolist()) if not vals: return 2 maxv = max(vals) return int(maxv) + 1 if maxv >= 1 else 2 class UnifiedSegDataset(Dataset): def __init__(self, data_root: str, dataset: str, protocol: str, split: str, transform: Optional[Callable] = None, in_channels: int = 0, num_classes: int = 0, synth_dir: str = ""): self.data_root = data_root self.dataset = dataset self.split = split self.transform = transform split_dir = os.path.join(data_root, dataset, protocol, split) if not os.path.isdir(split_dir): raise FileNotFoundError( f"split dir not found: {split_dir}\n" f"(data is prepared separately; see dataset/ scripts)") manifest = os.path.join(data_root, dataset, "manifest.jsonl") pairs = _pair_from_manifest(split_dir, manifest) or _pair_by_glob(split_dir) if not pairs: raise RuntimeError(f"no (image,mask) pairs found in {split_dir}") # optionally merge synthetic (image,mask) pairs into the (train) split if synth_dir and os.path.isdir(synth_dir): sp = _pair_by_glob(synth_dir if os.path.isdir(os.path.join(synth_dir, "images")) else os.path.dirname(synth_dir)) pairs = pairs + sp self.pairs = pairs meta = _read_metadata(data_root, dataset) self.in_channels = in_channels or detect_in_channels(meta, pairs[0][0]) self.num_classes = num_classes or detect_num_classes(meta, [p[1] for p in pairs], dataset) def __len__(self) -> int: return len(self.pairs) def _load_image(self, path: str) -> np.ndarray: if self.in_channels == 1: im = cv2.imread(path, cv2.IMREAD_GRAYSCALE) if im is None: raise IOError(f"cannot read image {path}") return im[:, :, None] # H,W,1 im = cv2.imread(path, cv2.IMREAD_COLOR) # BGR if im is None: raise IOError(f"cannot read image {path}") return cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # H,W,3 def __getitem__(self, idx: int): ip, mp = self.pairs[idx] image = self._load_image(ip) mask = cv2.imread(mp, cv2.IMREAD_GRAYSCALE) if mask is None: raise IOError(f"cannot read mask {mp}") mask = mask.astype(np.int64) if self.transform is not None: image, mask = self.transform(image, mask) return {"image": image, "mask": mask, "name": os.path.splitext(os.path.basename(ip))[0]}