Spaces:
Running
Running
| from pathlib import Path | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image | |
| import torchvision.transforms.functional as TF | |
| class DRIVEDataset(Dataset): | |
| """ | |
| PyTorch Dataset for the DRIVE retinal vessel segmentation dataset. | |
| Expected structure: | |
| DRIVE/ | |
| ├── training/ | |
| │ ├── images/ | |
| │ ├── 1st_manual/ | |
| │ └── mask/ | |
| └── test/ | |
| ├── images/ | |
| └── mask/ | |
| For training split: | |
| image: 21_training.tif | |
| vessel mask: 21_manual1.gif | |
| FOV mask: 21_training_mask.gif | |
| For test split: | |
| image: 01_test.tif | |
| FOV mask: 01_test_mask.gif | |
| no vessel mask is included in the provided tree | |
| """ | |
| def __init__( | |
| self, | |
| root, | |
| split="training", | |
| image_size=None, | |
| return_fov=True, | |
| transform=None, | |
| ): | |
| self.root = Path(root) | |
| self.split = split | |
| self.image_size = image_size | |
| self.return_fov = return_fov | |
| self.transform = transform | |
| if split not in ["training", "test"]: | |
| raise ValueError("split must be either 'training' or 'test'") | |
| self.split_dir = self.root / split | |
| self.image_dir = self.split_dir / "images" | |
| self.fov_dir = self.split_dir / "mask" | |
| if not self.image_dir.exists(): | |
| raise FileNotFoundError(f"Image directory not found: {self.image_dir}") | |
| self.image_paths = sorted(self.image_dir.glob("*.tif")) | |
| if len(self.image_paths) == 0: | |
| raise RuntimeError(f"No .tif images found in {self.image_dir}") | |
| if split == "training": | |
| self.label_dir = self.split_dir / "1st_manual" | |
| if not self.label_dir.exists(): | |
| raise FileNotFoundError(f"Label directory not found: {self.label_dir}") | |
| else: | |
| self.label_dir = None | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def _get_case_id(self, image_path): | |
| """ | |
| Examples: | |
| 21_training.tif -> 21 | |
| 01_test.tif -> 01 | |
| """ | |
| return image_path.stem.split("_")[0] | |
| def _load_image(self, path): | |
| image = Image.open(path).convert("RGB") | |
| return image | |
| def _load_mask(self, path): | |
| mask = Image.open(path).convert("L") | |
| return mask | |
| def _resize_if_needed(self, image, label=None, fov=None): | |
| if self.image_size is None: | |
| return image, label, fov | |
| size = self.image_size | |
| if isinstance(size, int): | |
| size = (size, size) | |
| image = TF.resize(image, size, interpolation=TF.InterpolationMode.BILINEAR) | |
| if label is not None: | |
| label = TF.resize(label, size, interpolation=TF.InterpolationMode.NEAREST) | |
| if fov is not None: | |
| fov = TF.resize(fov, size, interpolation=TF.InterpolationMode.NEAREST) | |
| return image, label, fov | |
| def __getitem__(self, idx): | |
| image_path = self.image_paths[idx] | |
| case_id = self._get_case_id(image_path) | |
| image = self._load_image(image_path) | |
| if self.split == "training": | |
| label_path = self.label_dir / f"{case_id}_manual1.gif" | |
| label = self._load_mask(label_path) | |
| else: | |
| label = None | |
| fov_path = self.fov_dir / f"{case_id}_{self.split}_mask.gif" | |
| fov = self._load_mask(fov_path) | |
| image, label, fov = self._resize_if_needed(image, label, fov) | |
| if self.transform is not None: | |
| image, label, fov = self.transform(image, label, fov) | |
| image = TF.to_tensor(image) | |
| sample = { | |
| "image": image, | |
| "case_id": case_id, | |
| } | |
| if label is not None: | |
| label = TF.to_tensor(label) | |
| label = (label > 0.5).float() | |
| sample["label"] = label | |
| if self.return_fov: | |
| fov = TF.to_tensor(fov) | |
| fov = (fov > 0.5).float() | |
| sample["fov"] = fov | |
| return sample | |
| if __name__ == "__main__": | |
| import matplotlib.pyplot as plt | |
| root = "/data/MIDS/datasets/retina/DRIVE" | |
| dataset = DRIVEDataset( | |
| root=root, | |
| split="training", | |
| image_size=512, | |
| return_fov=True, | |
| ) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=4, | |
| shuffle=True, | |
| num_workers=0, | |
| ) | |
| batch = next(iter(loader)) | |
| print("Number of samples:", len(dataset)) | |
| print("Batch keys:", batch.keys()) | |
| print("Image shape:", batch["image"].shape) | |
| if "label" in batch: | |
| print("Label shape:", batch["label"].shape) | |
| print("Label min/max:", batch["label"].min().item(), batch["label"].max().item()) | |
| if "fov" in batch: | |
| print("FOV shape:", batch["fov"].shape) | |
| print("FOV min/max:", batch["fov"].min().item(), batch["fov"].max().item()) | |
| print("Case IDs:", batch["case_id"]) | |
| # ------------------------- | |
| # Matplotlib visualization | |
| # ------------------------- | |
| image = batch["image"][0] # [3, H, W] | |
| label = batch.get("label", None) | |
| fov = batch.get("fov", None) | |
| image_np = image.permute(1, 2, 0).cpu().numpy() | |
| fig, axes = plt.subplots(1, 4, figsize=(16, 4)) | |
| axes[0].imshow(image_np) | |
| axes[0].set_title("Image") | |
| axes[0].axis("off") | |
| if label is not None: | |
| label_np = label[0, 0].cpu().numpy() | |
| axes[1].imshow(label_np, cmap="gray") | |
| axes[1].set_title("Vessel Label") | |
| axes[1].axis("off") | |
| axes[2].imshow(image_np) | |
| axes[2].imshow(label_np, cmap="Reds", alpha=0.45) | |
| axes[2].set_title("Image + Vessel Overlay") | |
| axes[2].axis("off") | |
| else: | |
| axes[1].axis("off") | |
| axes[2].axis("off") | |
| if fov is not None: | |
| fov_np = fov[0, 0].cpu().numpy() | |
| axes[3].imshow(image_np) | |
| axes[3].imshow(fov_np, cmap="gray", alpha=0.25) | |
| axes[3].set_title("Image + FOV Overlay") | |
| axes[3].axis("off") | |
| else: | |
| axes[3].axis("off") | |
| plt.tight_layout() | |
| plt.show() |