code: complete eval pipeline (7 metrics + per-class + Wilcoxon) + Swin-UNet/TransUNet networks; remove backups/obsolete
1a18f22 verified | """Dataset reader for the standardized `processed_unified` layout. | |
| Expected layout (see dataset/SEGMENTATION_WORKSPACE_README.md): | |
| <data_root>/<dataset>/<protocol>/<split>/images/*.png | |
| <data_root>/<dataset>/<protocol>/<split>/masks/*.png | |
| <data_root>/<dataset>/metadata.json (optional, preferred) | |
| <data_root>/<dataset>/manifest.jsonl (optional) | |
| Returns per item: {"image": FloatTensor[C,H,W], "mask": LongTensor[H,W], "name": str}. | |
| Binary and multi-class masks are both supported: masks keep their integer class | |
| ids (0..C-1). Auto-detection of in_channels / num_classes falls back to scanning | |
| files when metadata is absent, so the loader is robust to missing metadata. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from glob import glob | |
| from typing import Optional, Callable, List, Tuple | |
| import numpy as np | |
| import cv2 | |
| from torch.utils.data import Dataset | |
| _MODALITY_CHANNELS = { # hint table; only used when metadata lacks in_channels | |
| "rgb": 3, "fundus": 3, "colonoscopy": 3, "endoscopy": 3, "histopathology": 3, | |
| "ultrasound": 1, "mri": 1, "ct": 1, "grayscale": 1, | |
| } | |
| # Documented class counts (incl. background). metadata.json on the server has no | |
| # num_classes field, so this table is the fast, reliable primary source; unknown | |
| # datasets fall back to a FULL scan of the mask set (accurate but slower). | |
| _KNOWN_NUM_CLASSES = { | |
| "cvc_clinicdb": 2, "kvasir_seg": 2, "fives": 2, "busi": 2, | |
| "refuge2": 3, "acdc_png": 4, | |
| "idridd_segmentation": 6, "pannuke_semantic": 6, | |
| } | |
| def _read_metadata(data_root: str, dataset: str) -> dict: | |
| path = os.path.join(data_root, dataset, "metadata.json") | |
| if os.path.isfile(path): | |
| try: | |
| with open(path) as f: | |
| return json.load(f) | |
| except Exception: | |
| return {} | |
| return {} | |
| def _pair_from_manifest(split_dir: str, manifest: str) -> Optional[List[Tuple[str, str]]]: | |
| if not os.path.isfile(manifest): | |
| return None | |
| pairs = [] | |
| base = os.path.dirname(manifest) | |
| with open(manifest) as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| rec = json.loads(line) | |
| img = rec.get("image") or rec.get("image_path") or rec.get("img") | |
| msk = rec.get("mask") or rec.get("mask_path") or rec.get("label") | |
| if img is None or msk is None: | |
| return None | |
| # manifest paths may be relative to dataset root or absolute | |
| ip = img if os.path.isabs(img) else os.path.join(base, img) | |
| mp = msk if os.path.isabs(msk) else os.path.join(base, msk) | |
| # only keep entries that fall under this split dir | |
| if os.path.normpath(split_dir) in os.path.normpath(ip): | |
| pairs.append((ip, mp)) | |
| return pairs or None | |
| def _pair_by_glob(split_dir: str) -> List[Tuple[str, str]]: | |
| img_dir = os.path.join(split_dir, "images") | |
| msk_dir = os.path.join(split_dir, "masks") | |
| imgs = sorted(glob(os.path.join(img_dir, "*"))) | |
| pairs = [] | |
| for ip in imgs: | |
| stem = os.path.splitext(os.path.basename(ip))[0] | |
| # mask may share extension or be .png | |
| cands = glob(os.path.join(msk_dir, stem + ".*")) | |
| if not cands: | |
| continue | |
| pairs.append((ip, cands[0])) | |
| return pairs | |
| def detect_in_channels(meta: dict, sample_img: Optional[str]) -> int: | |
| if meta.get("in_channels"): | |
| return int(meta["in_channels"]) | |
| mod = str(meta.get("modality", "")).lower() | |
| for k, v in _MODALITY_CHANNELS.items(): | |
| if k in mod: | |
| return v | |
| if sample_img and os.path.isfile(sample_img): | |
| im = cv2.imread(sample_img, cv2.IMREAD_UNCHANGED) | |
| if im is not None and im.ndim == 3 and im.shape[2] >= 3: | |
| return 3 | |
| return 1 | |
| def detect_num_classes(meta: dict, mask_paths: List[str], dataset: str = "") -> int: | |
| if dataset in _KNOWN_NUM_CLASSES: | |
| return _KNOWN_NUM_CLASSES[dataset] | |
| if meta.get("num_classes"): | |
| return int(meta["num_classes"]) | |
| # unknown dataset: scan ALL masks so a rare class is never missed | |
| vals = set() | |
| for mp in mask_paths: | |
| m = cv2.imread(mp, cv2.IMREAD_GRAYSCALE) | |
| if m is not None: | |
| vals.update(np.unique(m).tolist()) | |
| if not vals: | |
| return 2 | |
| maxv = max(vals) | |
| return int(maxv) + 1 if maxv >= 1 else 2 | |
| class UnifiedSegDataset(Dataset): | |
| def __init__(self, data_root: str, dataset: str, protocol: str, split: str, | |
| transform: Optional[Callable] = None, | |
| in_channels: int = 0, num_classes: int = 0, | |
| synth_dir: str = ""): | |
| self.data_root = data_root | |
| self.dataset = dataset | |
| self.split = split | |
| self.transform = transform | |
| split_dir = os.path.join(data_root, dataset, protocol, split) | |
| if not os.path.isdir(split_dir): | |
| raise FileNotFoundError( | |
| f"split dir not found: {split_dir}\n" | |
| f"(data is prepared separately; see dataset/ scripts)") | |
| manifest = os.path.join(data_root, dataset, "manifest.jsonl") | |
| pairs = _pair_from_manifest(split_dir, manifest) or _pair_by_glob(split_dir) | |
| if not pairs: | |
| raise RuntimeError(f"no (image,mask) pairs found in {split_dir}") | |
| # optionally merge synthetic (image,mask) pairs into the (train) split | |
| if synth_dir and os.path.isdir(synth_dir): | |
| sp = _pair_by_glob(synth_dir if os.path.isdir(os.path.join(synth_dir, "images")) | |
| else os.path.dirname(synth_dir)) | |
| pairs = pairs + sp | |
| self.pairs = pairs | |
| meta = _read_metadata(data_root, dataset) | |
| self.in_channels = in_channels or detect_in_channels(meta, pairs[0][0]) | |
| self.num_classes = num_classes or detect_num_classes(meta, [p[1] for p in pairs], dataset) | |
| def __len__(self) -> int: | |
| return len(self.pairs) | |
| def _load_image(self, path: str) -> np.ndarray: | |
| if self.in_channels == 1: | |
| im = cv2.imread(path, cv2.IMREAD_GRAYSCALE) | |
| if im is None: | |
| raise IOError(f"cannot read image {path}") | |
| return im[:, :, None] # H,W,1 | |
| im = cv2.imread(path, cv2.IMREAD_COLOR) # BGR | |
| if im is None: | |
| raise IOError(f"cannot read image {path}") | |
| return cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # H,W,3 | |
| def __getitem__(self, idx: int): | |
| ip, mp = self.pairs[idx] | |
| image = self._load_image(ip) | |
| mask = cv2.imread(mp, cv2.IMREAD_GRAYSCALE) | |
| if mask is None: | |
| raise IOError(f"cannot read mask {mp}") | |
| mask = mask.astype(np.int64) | |
| if self.transform is not None: | |
| image, mask = self.transform(image, mask) | |
| return {"image": image, "mask": mask, | |
| "name": os.path.splitext(os.path.basename(ip))[0]} | |