from pathlib import Path import numpy as np import torch from torch.utils.data import Dataset, DataLoader from PIL import Image from sklearn.model_selection import KFold class FGADRDataset(Dataset): """ FGADR Seg-set dataset for diabetic retinopathy lesion segmentation. Expected structure: Seg-set/ ├── DR_Seg_Grading_Label.csv ├── Original_Images/ ├── Microaneurysms_Masks/ ├── Hemohedge_Masks/ ├── HardExudate_Masks/ ├── SoftExudate_Masks/ ├── IRMA_Masks/ └── Neovascularization_Masks/ CSV format, no header: filename,dr_grade Output: image: [3, H, W] label: [6, H, W] grade: scalar long tensor case_id: filename stem split: "train" = all folds except selected fold "val" = selected fold "all" = full dataset Notes: If a lesion-specific mask file is absent, it is treated as an empty all-zero mask, meaning no incidence of that lesion class. """ lesion_dirs = { "microaneurysm": "Microaneurysms_Masks", "hemorrhage": "Hemohedge_Masks", "hard_exudate": "HardExudate_Masks", "soft_exudate": "SoftExudate_Masks", "irma": "IRMA_Masks", "neovascularization": "Neovascularization_Masks", } def __init__( self, root, split="train", fold=0, n_folds=5, seed=42, transform=None, csv_name="DR_Seg_Grading_Label.csv", image_dir_name="Original_Images", mask_suffix="", ): self.root = Path(root) self.split = split self.fold = fold self.n_folds = n_folds self.seed = seed self.transform = transform self.csv_path = self.root / csv_name self.image_dir = self.root / image_dir_name self.mask_suffix = mask_suffix if split not in ["train", "val", "all"]: raise ValueError("split must be one of: 'train', 'val', 'all'") if not (0 <= fold < n_folds): raise ValueError(f"fold must be in [0, {n_folds - 1}], got {fold}") if not self.image_dir.exists(): raise FileNotFoundError(f"Image directory not found: {self.image_dir}") if not self.csv_path.exists(): raise FileNotFoundError(f"CSV file not found: {self.csv_path}") self.class_names = list(self.lesion_dirs.keys()) for dirname in self.lesion_dirs.values(): mask_dir = self.root / dirname if not mask_dir.exists(): raise FileNotFoundError(f"Mask directory not found: {mask_dir}") all_samples = self._read_csv() if len(all_samples) == 0: raise RuntimeError(f"No samples found in {self.csv_path}") if split == "all": self.samples = all_samples else: kfold = KFold( n_splits=n_folds, shuffle=True, random_state=seed, ) splits = list(kfold.split(all_samples)) train_indices, val_indices = splits[fold] if split == "train": self.samples = [all_samples[i] for i in train_indices] else: self.samples = [all_samples[i] for i in val_indices] def _read_csv(self): samples = [] with open(self.csv_path, "r") as f: for line in f: line = line.strip() if not line: continue parts = line.split(",") if len(parts) < 2: continue filename = parts[0].strip() grade = int(parts[1].strip()) image_path = self.image_dir / filename if not image_path.exists(): raise FileNotFoundError(f"Image not found: {image_path}") samples.append( { "filename": filename, "case_id": Path(filename).stem, "image_path": image_path, "grade": grade, } ) return samples def __len__(self): return len(self.samples) def _load_image(self, path): image = Image.open(path).convert("RGB") return np.array(image) def _load_mask(self, path, shape): if path.exists(): mask = Image.open(path).convert("L") mask = np.array(mask) else: mask = np.zeros(shape, dtype=np.uint8) return mask def _get_mask_path(self, lesion_name, filename): mask_dir = self.root / self.lesion_dirs[lesion_name] if self.mask_suffix: stem = Path(filename).stem suffix = Path(filename).suffix filename = f"{stem}{self.mask_suffix}{suffix}" return mask_dir / filename def __getitem__(self, idx): sample_info = self.samples[idx] filename = sample_info["filename"] image_path = sample_info["image_path"] case_id = sample_info["case_id"] grade = sample_info["grade"] image = self._load_image(image_path) h, w = image.shape[:2] masks = [] mask_paths = {} for lesion_name in self.class_names: mask_path = self._get_mask_path(lesion_name, filename) mask = self._load_mask(mask_path, shape=(h, w)) masks.append(mask) mask_paths[lesion_name] = str(mask_path) if self.transform is not None: transformed = self.transform( image=image, masks=masks, ) image = transformed["image"] masks = transformed["masks"] masks = [ m.float() if isinstance(m, torch.Tensor) else torch.from_numpy(m).float() for m in masks ] label = torch.stack(masks, dim=0) else: image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 label = torch.stack( [torch.from_numpy(m).float() for m in masks], dim=0, ) label = (label > 0).float() return { "image": image, "label": label, "grade": torch.tensor(grade, dtype=torch.long), "case_id": case_id, "filename": filename, "image_path": str(image_path), "mask_paths": mask_paths, } if __name__ == "__main__": import matplotlib.pyplot as plt from tqdm import tqdm try: from augmentations import get_train_transforms, IMAGENET_MEAN, IMAGENET_STD except ImportError: import sys project_root = Path(__file__).resolve().parents[1] sys.path.append(str(project_root)) from augmentations import get_train_transforms, IMAGENET_MEAN, IMAGENET_STD root = "/data/MIDS/datasets/retina/FGADR/Seg-set" image_size = 512 dataset = FGADRDataset( root=root, split="train", fold=0, n_folds=5, seed=42, transform=get_train_transforms(image_size=image_size), ) print("\nChecking all FGADR files...") missing_images = 0 absent_masks = 0 for sample in tqdm(dataset.samples, desc="Checking files"): filename = sample["filename"] if not sample["image_path"].exists(): print(f"Missing image: {sample['image_path']}") missing_images += 1 for lesion_name in dataset.class_names: mask_path = dataset._get_mask_path(lesion_name, filename) if not mask_path.exists(): absent_masks += 1 print("File check complete.") print(f"Missing images: {missing_images}") print(f"Absent lesion masks treated as empty: {absent_masks}") loader = DataLoader( dataset, batch_size=4, shuffle=True, num_workers=0, ) batch = next(iter(loader)) print("\nSmoke test batch:") print("Number of samples:", len(dataset)) print("Split:", dataset.split) print("Fold:", dataset.fold) print("Number of folds:", dataset.n_folds) print("Class names:", dataset.class_names) print("Batch keys:", batch.keys()) print("Image shape:", batch["image"].shape) print("Label shape:", batch["label"].shape) print("Grade shape:", batch["grade"].shape) print("Label min/max:", batch["label"].min().item(), batch["label"].max().item()) print("Case IDs:", batch["case_id"]) image = batch["image"][0].cpu() label = batch["label"][0].cpu() grade = batch["grade"][0].item() mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1) std = torch.tensor(IMAGENET_STD).view(3, 1, 1) image_vis = image * std + mean image_vis = image_vis.clamp(0, 1) image_vis = image_vis.permute(1, 2, 0).numpy() combined_mask = (label.sum(dim=0) > 0).float().numpy() fig, axes = plt.subplots(2, 5, figsize=(20, 8)) axes = axes.flatten() axes[0].imshow(image_vis) axes[0].set_title(f"Image | Grade {grade}") axes[0].axis("off") axes[1].imshow(combined_mask, cmap="gray") axes[1].set_title("Any Lesion") axes[1].axis("off") axes[2].imshow(image_vis) axes[2].imshow(combined_mask, cmap="Reds", alpha=0.45) axes[2].set_title("Overlay") axes[2].axis("off") for ax in axes[3:]: ax.axis("off") for i, class_name in enumerate(dataset.class_names): ax = axes[i + 3] ax.imshow(label[i].numpy(), cmap="gray") ax.set_title(class_name) ax.axis("off") plt.tight_layout() plt.show()