Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import Optional, Sequence | |
| import pandas as pd | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import Dataset | |
| from sklearn.model_selection import StratifiedGroupKFold | |
| class EyePACSDataset(Dataset): | |
| """ | |
| EyePACS diabetic retinopathy dataset. | |
| Expected structure: | |
| root/ | |
| ├── trainLabels.csv | |
| ├── testLabels.csv | |
| ├── train/ | |
| │ ├── xxx_left.jpeg | |
| │ └── xxx_right.jpeg | |
| └── test/ | |
| ├── xxx_left.jpeg | |
| └── xxx_right.jpeg | |
| Supported splits: | |
| train: | |
| Uses trainLabels.csv only. | |
| Applies fold CV. | |
| Keeps samples where fold != selected fold. | |
| val: | |
| Uses trainLabels.csv only. | |
| Applies fold CV. | |
| Keeps samples where fold == selected fold. | |
| test: | |
| Uses testLabels.csv only. | |
| Uses original test/ image folder. | |
| No fold filtering. | |
| all: | |
| Uses trainLabels.csv + testLabels.csv. | |
| Applies fold CV over the combined labeled pool. | |
| If fold is not None: | |
| keeps samples where fold != selected fold by default if all_mode="train" | |
| keeps samples where fold == selected fold if all_mode="val" | |
| If fold is None: | |
| keeps all combined labeled samples. | |
| Args: | |
| root: | |
| EyePACS root directory. | |
| split: | |
| One of {"train", "val", "test", "all"}. | |
| transform: | |
| Optional image transform. | |
| seed: | |
| Random seed for fold assignment. | |
| fold: | |
| Selected fold index. Required for split="train" and split="val". | |
| Optional for split="all". Ignored for split="test". | |
| n_folds: | |
| Number of folds. | |
| all_mode: | |
| Only used when split="all" and fold is not None. | |
| Options: | |
| "train": keep fold != selected fold | |
| "val": keep fold == selected fold | |
| "all": keep all folds | |
| return_path: | |
| If True, return metadata dictionary. | |
| image_exts: | |
| File extensions to try. | |
| """ | |
| def __init__( | |
| self, | |
| root, | |
| split: str = "train", | |
| transform=None, | |
| seed: int = 42, | |
| fold: Optional[int] = 0, | |
| n_folds: int = 5, | |
| all_mode: str = "all", | |
| return_path: bool = False, | |
| image_exts: Sequence[str] = (".jpeg", ".jpg", ".png"), | |
| ): | |
| self.root = Path(root) | |
| self.split = split | |
| self.transform = transform | |
| self.seed = seed | |
| self.fold = fold | |
| self.n_folds = n_folds | |
| self.all_mode = all_mode | |
| self.return_path = return_path | |
| self.image_exts = tuple(image_exts) | |
| if split not in {"train", "val", "test", "all"}: | |
| raise ValueError( | |
| f"split must be one of {{'train', 'val', 'test', 'all'}}, got {split}" | |
| ) | |
| if all_mode not in {"train", "val", "all"}: | |
| raise ValueError( | |
| f"all_mode must be one of {{'train', 'val', 'all'}}, got {all_mode}" | |
| ) | |
| if split in {"train", "val"} and fold is None: | |
| raise ValueError(f"fold must be provided for split='{split}'") | |
| if fold is not None and not (0 <= fold < n_folds): | |
| raise ValueError(f"fold must be in [0, {n_folds - 1}], got {fold}") | |
| if split == "train": | |
| df = self._load_train_dataframe() | |
| df = self._assign_folds(df) | |
| df = df[df["fold"] != fold].reset_index(drop=True) | |
| elif split == "val": | |
| df = self._load_train_dataframe() | |
| df = self._assign_folds(df) | |
| df = df[df["fold"] == fold].reset_index(drop=True) | |
| elif split == "test": | |
| df = self._load_test_dataframe() | |
| df["fold"] = -1 | |
| elif split == "all": | |
| df = self._load_combined_dataframe() | |
| df = self._assign_folds(df) | |
| if fold is not None: | |
| if all_mode == "train": | |
| df = df[df["fold"] != fold].reset_index(drop=True) | |
| elif all_mode == "val": | |
| df = df[df["fold"] == fold].reset_index(drop=True) | |
| elif all_mode == "all": | |
| df = df.reset_index(drop=True) | |
| else: | |
| df = df.reset_index(drop=True) | |
| self.df = df.reset_index(drop=True) | |
| self.samples = self._build_samples(self.df) | |
| if len(self.samples) == 0: | |
| raise RuntimeError( | |
| f"No images found for split='{split}'. " | |
| f"Check root path, CSV files, folders, and file extensions." | |
| ) | |
| self._print_summary() | |
| def _load_train_dataframe(self) -> pd.DataFrame: | |
| path = self.root / "trainLabels.csv" | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Missing trainLabels.csv: {path}") | |
| df = pd.read_csv(path) | |
| return self._standardize_label_dataframe(df, source="train") | |
| def _load_test_dataframe(self) -> pd.DataFrame: | |
| path = self.root / "testLabels.csv" | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Missing testLabels.csv: {path}") | |
| df = pd.read_csv(path) | |
| return self._standardize_label_dataframe(df, source="test") | |
| def _load_combined_dataframe(self) -> pd.DataFrame: | |
| train_df = self._load_train_dataframe() | |
| test_df = self._load_test_dataframe() | |
| df = pd.concat([train_df, test_df], axis=0, ignore_index=True) | |
| df = df.drop_duplicates(subset=["source", "image"]).reset_index(drop=True) | |
| return df | |
| def _standardize_label_dataframe(df: pd.DataFrame, source: str) -> pd.DataFrame: | |
| """ | |
| Standardize label dataframe to: | |
| image, level, source, patient_id | |
| EyePACS image names usually look like: | |
| 10_left | |
| 10_right | |
| patient_id is extracted as the part before the first underscore. | |
| """ | |
| if "image" not in df.columns: | |
| raise ValueError(f"{source} labels CSV must contain column 'image'") | |
| if "level" not in df.columns: | |
| raise ValueError(f"{source} labels CSV must contain column 'level'") | |
| df = df[["image", "level"]].copy() | |
| df["image"] = df["image"].astype(str) | |
| df["level"] = df["level"].astype(int) | |
| df["source"] = source | |
| df["patient_id"] = df["image"].str.split("_").str[0].astype(str) | |
| return df | |
| def _assign_folds(self, df: pd.DataFrame) -> pd.DataFrame: | |
| """ | |
| Assign stratified group folds. | |
| Grouping: | |
| patient_id | |
| Stratification label: | |
| max DR severity across all images for that patient_id. | |
| This keeps left/right eyes from the same patient in the same fold. | |
| """ | |
| df = df.copy() | |
| patient_df = ( | |
| df.groupby("patient_id", as_index=False) | |
| .agg(patient_level=("level", "max")) | |
| .reset_index(drop=True) | |
| ) | |
| groups = patient_df["patient_id"].values | |
| y = patient_df["patient_level"].values | |
| splitter = StratifiedGroupKFold( | |
| n_splits=self.n_folds, | |
| shuffle=True, | |
| random_state=self.seed, | |
| ) | |
| patient_df["fold"] = -1 | |
| for fold_idx, (_, val_idx) in enumerate( | |
| splitter.split(X=patient_df, y=y, groups=groups) | |
| ): | |
| patient_df.loc[val_idx, "fold"] = fold_idx | |
| if (patient_df["fold"] < 0).any(): | |
| raise RuntimeError("Some patients were not assigned to a fold") | |
| fold_map = dict(zip(patient_df["patient_id"], patient_df["fold"])) | |
| df["fold"] = df["patient_id"].map(fold_map).astype(int) | |
| return df | |
| def _build_samples(self, df: pd.DataFrame): | |
| samples = [] | |
| missing = [] | |
| for _, row in df.iterrows(): | |
| image_id = str(row["image"]) | |
| label = int(row["level"]) | |
| source = str(row["source"]) | |
| patient_id = str(row["patient_id"]) | |
| fold = int(row["fold"]) | |
| image_dir = self.root / source | |
| image_path = self._find_image_path(image_dir, image_id) | |
| if image_path is None: | |
| missing.append((source, image_id)) | |
| continue | |
| samples.append( | |
| { | |
| "image_id": image_id, | |
| "image_path": image_path, | |
| "label": label, | |
| "source": source, | |
| "patient_id": patient_id, | |
| "fold": fold, | |
| } | |
| ) | |
| if len(missing) > 0: | |
| print( | |
| f"[EyePACSDataset] Warning: {len(missing)} images listed in CSV " | |
| f"were not found on disk." | |
| ) | |
| print(f"[EyePACSDataset] First few missing: {missing[:5]}") | |
| return samples | |
| def _find_image_path(self, image_dir: Path, image_id: str): | |
| for ext in self.image_exts: | |
| path = image_dir / f"{image_id}{ext}" | |
| if path.exists(): | |
| return path | |
| return None | |
| def _print_summary(self): | |
| labels = [s["label"] for s in self.samples] | |
| counts = pd.Series(labels).value_counts().sort_index() | |
| print(f"[EyePACSDataset] split={self.split}") | |
| print(f"[EyePACSDataset] root={self.root}") | |
| print(f"[EyePACSDataset] n={len(self.samples)}") | |
| if self.split != "test": | |
| print( | |
| f"[EyePACSDataset] seed={self.seed}, " | |
| f"fold={self.fold}, " | |
| f"n_folds={self.n_folds}" | |
| ) | |
| if self.split == "all": | |
| print(f"[EyePACSDataset] all_mode={self.all_mode}") | |
| print("[EyePACSDataset] source counts:") | |
| source_counts = pd.Series([s["source"] for s in self.samples]).value_counts() | |
| for source, count in source_counts.items(): | |
| print(f" {source}: {int(count)}") | |
| print("[EyePACSDataset] class counts:") | |
| for cls in range(5): | |
| print(f" class {cls}: {int(counts.get(cls, 0))}") | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| sample = self.samples[idx] | |
| image = Image.open(sample["image_path"]).convert("RGB") | |
| if self.transform is not None: | |
| image = self.transform(image) | |
| label = torch.tensor(sample["label"], dtype=torch.long) | |
| if self.return_path: | |
| return { | |
| "image": image, | |
| "label": label, | |
| "image_id": sample["image_id"], | |
| "image_path": str(sample["image_path"]), | |
| "source": sample["source"], | |
| "patient_id": sample["patient_id"], | |
| "fold": sample["fold"], | |
| } | |
| return image, label |