Spaces:
Running
Running
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image | |
| class FIVESDataset(Dataset): | |
| """ | |
| PyTorch Dataset for FIVES retinal vessel segmentation. | |
| Expected structure: | |
| FIVES_dataset/ | |
| ├── train/ | |
| │ ├── Original/ | |
| │ └── Ground truth/ | |
| └── test/ | |
| ├── Original/ | |
| └── Ground truth/ | |
| Each image in Original/ should have a matching vessel mask | |
| with the same filename in Ground truth/. | |
| Output sample: | |
| { | |
| "image": Tensor [3, H, W], | |
| "label": Tensor [1, H, W], | |
| "case_id": str, | |
| "image_path": str, | |
| "label_path": str, | |
| } | |
| If transform is provided, it should be an Albumentations transform. | |
| """ | |
| def __init__( | |
| self, | |
| root, | |
| split="train", | |
| transform=None, | |
| image_dir_name="Original", | |
| label_dir_name="Ground truth", | |
| ): | |
| self.root = Path(root) | |
| self.split = split | |
| self.transform = transform | |
| if split not in ["train", "test"]: | |
| raise ValueError("split must be either 'train' or 'test'") | |
| self.split_dir = self.root / split | |
| self.image_dir = self.split_dir / image_dir_name | |
| self.label_dir = self.split_dir / label_dir_name | |
| if not self.image_dir.exists(): | |
| raise FileNotFoundError(f"Image directory not found: {self.image_dir}") | |
| if not self.label_dir.exists(): | |
| raise FileNotFoundError(f"Label directory not found: {self.label_dir}") | |
| self.image_paths = sorted( | |
| [ | |
| p for p in self.image_dir.glob("*.png") | |
| if not p.name.startswith(".") and p.name.lower() != "thumbs.db" | |
| ] | |
| ) | |
| if len(self.image_paths) == 0: | |
| raise RuntimeError(f"No PNG images found in {self.image_dir}") | |
| self.samples = [] | |
| for image_path in self.image_paths: | |
| label_path = self.label_dir / image_path.name | |
| if not label_path.exists(): | |
| raise FileNotFoundError( | |
| f"Missing label for image:\n" | |
| f"image: {image_path}\n" | |
| f"label: {label_path}" | |
| ) | |
| self.samples.append( | |
| { | |
| "image_path": image_path, | |
| "label_path": label_path, | |
| "case_id": image_path.stem, | |
| } | |
| ) | |
| def __len__(self): | |
| return len(self.samples) | |
| def _load_image(self, path): | |
| image = Image.open(path).convert("RGB") | |
| return np.array(image) | |
| def _load_mask(self, path): | |
| mask = Image.open(path).convert("L") | |
| return np.array(mask) | |
| def __getitem__(self, idx): | |
| sample_info = self.samples[idx] | |
| image_path = sample_info["image_path"] | |
| label_path = sample_info["label_path"] | |
| case_id = sample_info["case_id"] | |
| image = self._load_image(image_path) | |
| label = self._load_mask(label_path) | |
| if self.transform is not None: | |
| transformed = self.transform( | |
| image=image, | |
| mask=label, | |
| ) | |
| image = transformed["image"] | |
| label = transformed["mask"] | |
| # Albumentations ToTensorV2 converts image to [3, H, W], | |
| # but mask remains [H, W], so add channel dimension. | |
| if isinstance(label, torch.Tensor): | |
| label = label.float().unsqueeze(0) | |
| else: | |
| label = torch.from_numpy(label).float().unsqueeze(0) | |
| else: | |
| image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 | |
| label = torch.from_numpy(label).float().unsqueeze(0) | |
| # Convert vessel mask to binary {0, 1} | |
| label = (label > 0).float() | |
| return { | |
| "image": image, | |
| "label": label, | |
| "case_id": case_id, | |
| "image_path": str(image_path), | |
| "label_path": str(label_path), | |
| } | |
| if __name__ == "__main__": | |
| import matplotlib.pyplot as plt | |
| try: | |
| from augmentations import get_train_transforms, get_val_transforms | |
| except ImportError: | |
| import sys | |
| project_root = Path(__file__).resolve().parents[1] | |
| sys.path.append(str(project_root)) | |
| from augmentations import get_train_transforms, get_val_transforms | |
| root = "/data/MIDS/datasets/retina/FIVES_dataset" | |
| image_size = 512 | |
| dataset = FIVESDataset( | |
| root=root, | |
| split="train", | |
| transform=get_train_transforms(image_size=image_size), | |
| ) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=4, | |
| shuffle=True, | |
| num_workers=0, | |
| ) | |
| batch = next(iter(loader)) | |
| print("Number of samples:", len(dataset)) | |
| print("Batch keys:", batch.keys()) | |
| print("Image shape:", batch["image"].shape) | |
| print("Label shape:", batch["label"].shape) | |
| print("Label min/max:", batch["label"].min().item(), batch["label"].max().item()) | |
| print("Case IDs:", batch["case_id"]) | |
| # ------------------------- | |
| # Matplotlib visualization | |
| # ------------------------- | |
| image = batch["image"][0] | |
| label = batch["label"][0, 0] | |
| # Undo ImageNet normalization for visualization. | |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
| std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
| image_vis = image.cpu() * std + mean | |
| image_vis = image_vis.clamp(0, 1) | |
| image_vis = image_vis.permute(1, 2, 0).numpy() | |
| label_vis = label.cpu().numpy() | |
| fig, axes = plt.subplots(1, 3, figsize=(12, 4)) | |
| axes[0].imshow(image_vis) | |
| axes[0].set_title("Image") | |
| axes[0].axis("off") | |
| axes[1].imshow(label_vis, cmap="gray") | |
| axes[1].set_title("Vessel Label") | |
| axes[1].axis("off") | |
| axes[2].imshow(image_vis) | |
| axes[2].imshow(label_vis, cmap="Reds", alpha=0.45) | |
| axes[2].set_title("Overlay") | |
| axes[2].axis("off") | |
| plt.tight_layout() | |
| plt.show() |