CFPVesselSeg / datasets /DRIVE.py
farrell236's picture
add src
e99a83c
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms.functional as TF
class DRIVEDataset(Dataset):
"""
PyTorch Dataset for the DRIVE retinal vessel segmentation dataset.
Expected structure:
DRIVE/
├── training/
│ ├── images/
│ ├── 1st_manual/
│ └── mask/
└── test/
├── images/
└── mask/
For training split:
image: 21_training.tif
vessel mask: 21_manual1.gif
FOV mask: 21_training_mask.gif
For test split:
image: 01_test.tif
FOV mask: 01_test_mask.gif
no vessel mask is included in the provided tree
"""
def __init__(
self,
root,
split="training",
image_size=None,
return_fov=True,
transform=None,
):
self.root = Path(root)
self.split = split
self.image_size = image_size
self.return_fov = return_fov
self.transform = transform
if split not in ["training", "test"]:
raise ValueError("split must be either 'training' or 'test'")
self.split_dir = self.root / split
self.image_dir = self.split_dir / "images"
self.fov_dir = self.split_dir / "mask"
if not self.image_dir.exists():
raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
self.image_paths = sorted(self.image_dir.glob("*.tif"))
if len(self.image_paths) == 0:
raise RuntimeError(f"No .tif images found in {self.image_dir}")
if split == "training":
self.label_dir = self.split_dir / "1st_manual"
if not self.label_dir.exists():
raise FileNotFoundError(f"Label directory not found: {self.label_dir}")
else:
self.label_dir = None
def __len__(self):
return len(self.image_paths)
def _get_case_id(self, image_path):
"""
Examples:
21_training.tif -> 21
01_test.tif -> 01
"""
return image_path.stem.split("_")[0]
def _load_image(self, path):
image = Image.open(path).convert("RGB")
return image
def _load_mask(self, path):
mask = Image.open(path).convert("L")
return mask
def _resize_if_needed(self, image, label=None, fov=None):
if self.image_size is None:
return image, label, fov
size = self.image_size
if isinstance(size, int):
size = (size, size)
image = TF.resize(image, size, interpolation=TF.InterpolationMode.BILINEAR)
if label is not None:
label = TF.resize(label, size, interpolation=TF.InterpolationMode.NEAREST)
if fov is not None:
fov = TF.resize(fov, size, interpolation=TF.InterpolationMode.NEAREST)
return image, label, fov
def __getitem__(self, idx):
image_path = self.image_paths[idx]
case_id = self._get_case_id(image_path)
image = self._load_image(image_path)
if self.split == "training":
label_path = self.label_dir / f"{case_id}_manual1.gif"
label = self._load_mask(label_path)
else:
label = None
fov_path = self.fov_dir / f"{case_id}_{self.split}_mask.gif"
fov = self._load_mask(fov_path)
image, label, fov = self._resize_if_needed(image, label, fov)
if self.transform is not None:
image, label, fov = self.transform(image, label, fov)
image = TF.to_tensor(image)
sample = {
"image": image,
"case_id": case_id,
}
if label is not None:
label = TF.to_tensor(label)
label = (label > 0.5).float()
sample["label"] = label
if self.return_fov:
fov = TF.to_tensor(fov)
fov = (fov > 0.5).float()
sample["fov"] = fov
return sample
if __name__ == "__main__":
import matplotlib.pyplot as plt
root = "/data/MIDS/datasets/retina/DRIVE"
dataset = DRIVEDataset(
root=root,
split="training",
image_size=512,
return_fov=True,
)
loader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
num_workers=0,
)
batch = next(iter(loader))
print("Number of samples:", len(dataset))
print("Batch keys:", batch.keys())
print("Image shape:", batch["image"].shape)
if "label" in batch:
print("Label shape:", batch["label"].shape)
print("Label min/max:", batch["label"].min().item(), batch["label"].max().item())
if "fov" in batch:
print("FOV shape:", batch["fov"].shape)
print("FOV min/max:", batch["fov"].min().item(), batch["fov"].max().item())
print("Case IDs:", batch["case_id"])
# -------------------------
# Matplotlib visualization
# -------------------------
image = batch["image"][0] # [3, H, W]
label = batch.get("label", None)
fov = batch.get("fov", None)
image_np = image.permute(1, 2, 0).cpu().numpy()
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
axes[0].imshow(image_np)
axes[0].set_title("Image")
axes[0].axis("off")
if label is not None:
label_np = label[0, 0].cpu().numpy()
axes[1].imshow(label_np, cmap="gray")
axes[1].set_title("Vessel Label")
axes[1].axis("off")
axes[2].imshow(image_np)
axes[2].imshow(label_np, cmap="Reds", alpha=0.45)
axes[2].set_title("Image + Vessel Overlay")
axes[2].axis("off")
else:
axes[1].axis("off")
axes[2].axis("off")
if fov is not None:
fov_np = fov[0, 0].cpu().numpy()
axes[3].imshow(image_np)
axes[3].imshow(fov_np, cmap="gray", alpha=0.25)
axes[3].set_title("Image + FOV Overlay")
axes[3].axis("off")
else:
axes[3].axis("off")
plt.tight_layout()
plt.show()