from pathlib import Path import torch from torch.utils.data import Dataset, DataLoader from PIL import Image import torchvision.transforms.functional as TF class DRIVEDataset(Dataset): """ PyTorch Dataset for the DRIVE retinal vessel segmentation dataset. Expected structure: DRIVE/ ├── training/ │ ├── images/ │ ├── 1st_manual/ │ └── mask/ └── test/ ├── images/ └── mask/ For training split: image: 21_training.tif vessel mask: 21_manual1.gif FOV mask: 21_training_mask.gif For test split: image: 01_test.tif FOV mask: 01_test_mask.gif no vessel mask is included in the provided tree """ def __init__( self, root, split="training", image_size=None, return_fov=True, transform=None, ): self.root = Path(root) self.split = split self.image_size = image_size self.return_fov = return_fov self.transform = transform if split not in ["training", "test"]: raise ValueError("split must be either 'training' or 'test'") self.split_dir = self.root / split self.image_dir = self.split_dir / "images" self.fov_dir = self.split_dir / "mask" if not self.image_dir.exists(): raise FileNotFoundError(f"Image directory not found: {self.image_dir}") self.image_paths = sorted(self.image_dir.glob("*.tif")) if len(self.image_paths) == 0: raise RuntimeError(f"No .tif images found in {self.image_dir}") if split == "training": self.label_dir = self.split_dir / "1st_manual" if not self.label_dir.exists(): raise FileNotFoundError(f"Label directory not found: {self.label_dir}") else: self.label_dir = None def __len__(self): return len(self.image_paths) def _get_case_id(self, image_path): """ Examples: 21_training.tif -> 21 01_test.tif -> 01 """ return image_path.stem.split("_")[0] def _load_image(self, path): image = Image.open(path).convert("RGB") return image def _load_mask(self, path): mask = Image.open(path).convert("L") return mask def _resize_if_needed(self, image, label=None, fov=None): if self.image_size is None: return image, label, fov size = self.image_size if isinstance(size, int): size = (size, size) image = TF.resize(image, size, interpolation=TF.InterpolationMode.BILINEAR) if label is not None: label = TF.resize(label, size, interpolation=TF.InterpolationMode.NEAREST) if fov is not None: fov = TF.resize(fov, size, interpolation=TF.InterpolationMode.NEAREST) return image, label, fov def __getitem__(self, idx): image_path = self.image_paths[idx] case_id = self._get_case_id(image_path) image = self._load_image(image_path) if self.split == "training": label_path = self.label_dir / f"{case_id}_manual1.gif" label = self._load_mask(label_path) else: label = None fov_path = self.fov_dir / f"{case_id}_{self.split}_mask.gif" fov = self._load_mask(fov_path) image, label, fov = self._resize_if_needed(image, label, fov) if self.transform is not None: image, label, fov = self.transform(image, label, fov) image = TF.to_tensor(image) sample = { "image": image, "case_id": case_id, } if label is not None: label = TF.to_tensor(label) label = (label > 0.5).float() sample["label"] = label if self.return_fov: fov = TF.to_tensor(fov) fov = (fov > 0.5).float() sample["fov"] = fov return sample if __name__ == "__main__": import matplotlib.pyplot as plt root = "/data/MIDS/datasets/retina/DRIVE" dataset = DRIVEDataset( root=root, split="training", image_size=512, return_fov=True, ) loader = DataLoader( dataset, batch_size=4, shuffle=True, num_workers=0, ) batch = next(iter(loader)) print("Number of samples:", len(dataset)) print("Batch keys:", batch.keys()) print("Image shape:", batch["image"].shape) if "label" in batch: print("Label shape:", batch["label"].shape) print("Label min/max:", batch["label"].min().item(), batch["label"].max().item()) if "fov" in batch: print("FOV shape:", batch["fov"].shape) print("FOV min/max:", batch["fov"].min().item(), batch["fov"].max().item()) print("Case IDs:", batch["case_id"]) # ------------------------- # Matplotlib visualization # ------------------------- image = batch["image"][0] # [3, H, W] label = batch.get("label", None) fov = batch.get("fov", None) image_np = image.permute(1, 2, 0).cpu().numpy() fig, axes = plt.subplots(1, 4, figsize=(16, 4)) axes[0].imshow(image_np) axes[0].set_title("Image") axes[0].axis("off") if label is not None: label_np = label[0, 0].cpu().numpy() axes[1].imshow(label_np, cmap="gray") axes[1].set_title("Vessel Label") axes[1].axis("off") axes[2].imshow(image_np) axes[2].imshow(label_np, cmap="Reds", alpha=0.45) axes[2].set_title("Image + Vessel Overlay") axes[2].axis("off") else: axes[1].axis("off") axes[2].axis("off") if fov is not None: fov_np = fov[0, 0].cpu().numpy() axes[3].imshow(image_np) axes[3].imshow(fov_np, cmap="gray", alpha=0.25) axes[3].set_title("Image + FOV Overlay") axes[3].axis("off") else: axes[3].axis("off") plt.tight_layout() plt.show()