from pathlib import Path from typing import Optional, Sequence import pandas as pd from PIL import Image import torch from torch.utils.data import Dataset from sklearn.model_selection import StratifiedGroupKFold class EyePACSDataset(Dataset): """ EyePACS diabetic retinopathy dataset. Expected structure: root/ ├── trainLabels.csv ├── testLabels.csv ├── train/ │ ├── xxx_left.jpeg │ └── xxx_right.jpeg └── test/ ├── xxx_left.jpeg └── xxx_right.jpeg Supported splits: train: Uses trainLabels.csv only. Applies fold CV. Keeps samples where fold != selected fold. val: Uses trainLabels.csv only. Applies fold CV. Keeps samples where fold == selected fold. test: Uses testLabels.csv only. Uses original test/ image folder. No fold filtering. all: Uses trainLabels.csv + testLabels.csv. Applies fold CV over the combined labeled pool. If fold is not None: keeps samples where fold != selected fold by default if all_mode="train" keeps samples where fold == selected fold if all_mode="val" If fold is None: keeps all combined labeled samples. Args: root: EyePACS root directory. split: One of {"train", "val", "test", "all"}. transform: Optional image transform. seed: Random seed for fold assignment. fold: Selected fold index. Required for split="train" and split="val". Optional for split="all". Ignored for split="test". n_folds: Number of folds. all_mode: Only used when split="all" and fold is not None. Options: "train": keep fold != selected fold "val": keep fold == selected fold "all": keep all folds return_path: If True, return metadata dictionary. image_exts: File extensions to try. """ def __init__( self, root, split: str = "train", transform=None, seed: int = 42, fold: Optional[int] = 0, n_folds: int = 5, all_mode: str = "all", return_path: bool = False, image_exts: Sequence[str] = (".jpeg", ".jpg", ".png"), ): self.root = Path(root) self.split = split self.transform = transform self.seed = seed self.fold = fold self.n_folds = n_folds self.all_mode = all_mode self.return_path = return_path self.image_exts = tuple(image_exts) if split not in {"train", "val", "test", "all"}: raise ValueError( f"split must be one of {{'train', 'val', 'test', 'all'}}, got {split}" ) if all_mode not in {"train", "val", "all"}: raise ValueError( f"all_mode must be one of {{'train', 'val', 'all'}}, got {all_mode}" ) if split in {"train", "val"} and fold is None: raise ValueError(f"fold must be provided for split='{split}'") if fold is not None and not (0 <= fold < n_folds): raise ValueError(f"fold must be in [0, {n_folds - 1}], got {fold}") if split == "train": df = self._load_train_dataframe() df = self._assign_folds(df) df = df[df["fold"] != fold].reset_index(drop=True) elif split == "val": df = self._load_train_dataframe() df = self._assign_folds(df) df = df[df["fold"] == fold].reset_index(drop=True) elif split == "test": df = self._load_test_dataframe() df["fold"] = -1 elif split == "all": df = self._load_combined_dataframe() df = self._assign_folds(df) if fold is not None: if all_mode == "train": df = df[df["fold"] != fold].reset_index(drop=True) elif all_mode == "val": df = df[df["fold"] == fold].reset_index(drop=True) elif all_mode == "all": df = df.reset_index(drop=True) else: df = df.reset_index(drop=True) self.df = df.reset_index(drop=True) self.samples = self._build_samples(self.df) if len(self.samples) == 0: raise RuntimeError( f"No images found for split='{split}'. " f"Check root path, CSV files, folders, and file extensions." ) self._print_summary() def _load_train_dataframe(self) -> pd.DataFrame: path = self.root / "trainLabels.csv" if not path.exists(): raise FileNotFoundError(f"Missing trainLabels.csv: {path}") df = pd.read_csv(path) return self._standardize_label_dataframe(df, source="train") def _load_test_dataframe(self) -> pd.DataFrame: path = self.root / "testLabels.csv" if not path.exists(): raise FileNotFoundError(f"Missing testLabels.csv: {path}") df = pd.read_csv(path) return self._standardize_label_dataframe(df, source="test") def _load_combined_dataframe(self) -> pd.DataFrame: train_df = self._load_train_dataframe() test_df = self._load_test_dataframe() df = pd.concat([train_df, test_df], axis=0, ignore_index=True) df = df.drop_duplicates(subset=["source", "image"]).reset_index(drop=True) return df @staticmethod def _standardize_label_dataframe(df: pd.DataFrame, source: str) -> pd.DataFrame: """ Standardize label dataframe to: image, level, source, patient_id EyePACS image names usually look like: 10_left 10_right patient_id is extracted as the part before the first underscore. """ if "image" not in df.columns: raise ValueError(f"{source} labels CSV must contain column 'image'") if "level" not in df.columns: raise ValueError(f"{source} labels CSV must contain column 'level'") df = df[["image", "level"]].copy() df["image"] = df["image"].astype(str) df["level"] = df["level"].astype(int) df["source"] = source df["patient_id"] = df["image"].str.split("_").str[0].astype(str) return df def _assign_folds(self, df: pd.DataFrame) -> pd.DataFrame: """ Assign stratified group folds. Grouping: patient_id Stratification label: max DR severity across all images for that patient_id. This keeps left/right eyes from the same patient in the same fold. """ df = df.copy() patient_df = ( df.groupby("patient_id", as_index=False) .agg(patient_level=("level", "max")) .reset_index(drop=True) ) groups = patient_df["patient_id"].values y = patient_df["patient_level"].values splitter = StratifiedGroupKFold( n_splits=self.n_folds, shuffle=True, random_state=self.seed, ) patient_df["fold"] = -1 for fold_idx, (_, val_idx) in enumerate( splitter.split(X=patient_df, y=y, groups=groups) ): patient_df.loc[val_idx, "fold"] = fold_idx if (patient_df["fold"] < 0).any(): raise RuntimeError("Some patients were not assigned to a fold") fold_map = dict(zip(patient_df["patient_id"], patient_df["fold"])) df["fold"] = df["patient_id"].map(fold_map).astype(int) return df def _build_samples(self, df: pd.DataFrame): samples = [] missing = [] for _, row in df.iterrows(): image_id = str(row["image"]) label = int(row["level"]) source = str(row["source"]) patient_id = str(row["patient_id"]) fold = int(row["fold"]) image_dir = self.root / source image_path = self._find_image_path(image_dir, image_id) if image_path is None: missing.append((source, image_id)) continue samples.append( { "image_id": image_id, "image_path": image_path, "label": label, "source": source, "patient_id": patient_id, "fold": fold, } ) if len(missing) > 0: print( f"[EyePACSDataset] Warning: {len(missing)} images listed in CSV " f"were not found on disk." ) print(f"[EyePACSDataset] First few missing: {missing[:5]}") return samples def _find_image_path(self, image_dir: Path, image_id: str): for ext in self.image_exts: path = image_dir / f"{image_id}{ext}" if path.exists(): return path return None def _print_summary(self): labels = [s["label"] for s in self.samples] counts = pd.Series(labels).value_counts().sort_index() print(f"[EyePACSDataset] split={self.split}") print(f"[EyePACSDataset] root={self.root}") print(f"[EyePACSDataset] n={len(self.samples)}") if self.split != "test": print( f"[EyePACSDataset] seed={self.seed}, " f"fold={self.fold}, " f"n_folds={self.n_folds}" ) if self.split == "all": print(f"[EyePACSDataset] all_mode={self.all_mode}") print("[EyePACSDataset] source counts:") source_counts = pd.Series([s["source"] for s in self.samples]).value_counts() for source, count in source_counts.items(): print(f" {source}: {int(count)}") print("[EyePACSDataset] class counts:") for cls in range(5): print(f" class {cls}: {int(counts.get(cls, 0))}") def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] image = Image.open(sample["image_path"]).convert("RGB") if self.transform is not None: image = self.transform(image) label = torch.tensor(sample["label"], dtype=torch.long) if self.return_path: return { "image": image, "label": label, "image_id": sample["image_id"], "image_path": str(sample["image_path"]), "source": sample["source"], "patient_id": sample["patient_id"], "fold": sample["fold"], } return image, label