from pathlib import Path import numpy as np import torch from torch.utils.data import Dataset, DataLoader from PIL import Image class FIVESDataset(Dataset): """ PyTorch Dataset for FIVES retinal vessel segmentation. Expected structure: FIVES_dataset/ ├── train/ │ ├── Original/ │ └── Ground truth/ └── test/ ├── Original/ └── Ground truth/ Each image in Original/ should have a matching vessel mask with the same filename in Ground truth/. Output sample: { "image": Tensor [3, H, W], "label": Tensor [1, H, W], "case_id": str, "image_path": str, "label_path": str, } If transform is provided, it should be an Albumentations transform. """ def __init__( self, root, split="train", transform=None, image_dir_name="Original", label_dir_name="Ground truth", ): self.root = Path(root) self.split = split self.transform = transform if split not in ["train", "test"]: raise ValueError("split must be either 'train' or 'test'") self.split_dir = self.root / split self.image_dir = self.split_dir / image_dir_name self.label_dir = self.split_dir / label_dir_name if not self.image_dir.exists(): raise FileNotFoundError(f"Image directory not found: {self.image_dir}") if not self.label_dir.exists(): raise FileNotFoundError(f"Label directory not found: {self.label_dir}") self.image_paths = sorted( [ p for p in self.image_dir.glob("*.png") if not p.name.startswith(".") and p.name.lower() != "thumbs.db" ] ) if len(self.image_paths) == 0: raise RuntimeError(f"No PNG images found in {self.image_dir}") self.samples = [] for image_path in self.image_paths: label_path = self.label_dir / image_path.name if not label_path.exists(): raise FileNotFoundError( f"Missing label for image:\n" f"image: {image_path}\n" f"label: {label_path}" ) self.samples.append( { "image_path": image_path, "label_path": label_path, "case_id": image_path.stem, } ) def __len__(self): return len(self.samples) def _load_image(self, path): image = Image.open(path).convert("RGB") return np.array(image) def _load_mask(self, path): mask = Image.open(path).convert("L") return np.array(mask) def __getitem__(self, idx): sample_info = self.samples[idx] image_path = sample_info["image_path"] label_path = sample_info["label_path"] case_id = sample_info["case_id"] image = self._load_image(image_path) label = self._load_mask(label_path) if self.transform is not None: transformed = self.transform( image=image, mask=label, ) image = transformed["image"] label = transformed["mask"] # Albumentations ToTensorV2 converts image to [3, H, W], # but mask remains [H, W], so add channel dimension. if isinstance(label, torch.Tensor): label = label.float().unsqueeze(0) else: label = torch.from_numpy(label).float().unsqueeze(0) else: image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 label = torch.from_numpy(label).float().unsqueeze(0) # Convert vessel mask to binary {0, 1} label = (label > 0).float() return { "image": image, "label": label, "case_id": case_id, "image_path": str(image_path), "label_path": str(label_path), } if __name__ == "__main__": import matplotlib.pyplot as plt try: from augmentations import get_train_transforms, get_val_transforms except ImportError: import sys project_root = Path(__file__).resolve().parents[1] sys.path.append(str(project_root)) from augmentations import get_train_transforms, get_val_transforms root = "/data/MIDS/datasets/retina/FIVES_dataset" image_size = 512 dataset = FIVESDataset( root=root, split="train", transform=get_train_transforms(image_size=image_size), ) loader = DataLoader( dataset, batch_size=4, shuffle=True, num_workers=0, ) batch = next(iter(loader)) print("Number of samples:", len(dataset)) print("Batch keys:", batch.keys()) print("Image shape:", batch["image"].shape) print("Label shape:", batch["label"].shape) print("Label min/max:", batch["label"].min().item(), batch["label"].max().item()) print("Case IDs:", batch["case_id"]) # ------------------------- # Matplotlib visualization # ------------------------- image = batch["image"][0] label = batch["label"][0, 0] # Undo ImageNet normalization for visualization. mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) image_vis = image.cpu() * std + mean image_vis = image_vis.clamp(0, 1) image_vis = image_vis.permute(1, 2, 0).numpy() label_vis = label.cpu().numpy() fig, axes = plt.subplots(1, 3, figsize=(12, 4)) axes[0].imshow(image_vis) axes[0].set_title("Image") axes[0].axis("off") axes[1].imshow(label_vis, cmap="gray") axes[1].set_title("Vessel Label") axes[1].axis("off") axes[2].imshow(image_vis) axes[2].imshow(label_vis, cmap="Reds", alpha=0.45) axes[2].set_title("Overlay") axes[2].axis("off") plt.tight_layout() plt.show()