CFPVesselSeg / datasets /FIVES.py
farrell236's picture
add src
e99a83c
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
class FIVESDataset(Dataset):
"""
PyTorch Dataset for FIVES retinal vessel segmentation.
Expected structure:
FIVES_dataset/
├── train/
│ ├── Original/
│ └── Ground truth/
└── test/
├── Original/
└── Ground truth/
Each image in Original/ should have a matching vessel mask
with the same filename in Ground truth/.
Output sample:
{
"image": Tensor [3, H, W],
"label": Tensor [1, H, W],
"case_id": str,
"image_path": str,
"label_path": str,
}
If transform is provided, it should be an Albumentations transform.
"""
def __init__(
self,
root,
split="train",
transform=None,
image_dir_name="Original",
label_dir_name="Ground truth",
):
self.root = Path(root)
self.split = split
self.transform = transform
if split not in ["train", "test"]:
raise ValueError("split must be either 'train' or 'test'")
self.split_dir = self.root / split
self.image_dir = self.split_dir / image_dir_name
self.label_dir = self.split_dir / label_dir_name
if not self.image_dir.exists():
raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
if not self.label_dir.exists():
raise FileNotFoundError(f"Label directory not found: {self.label_dir}")
self.image_paths = sorted(
[
p for p in self.image_dir.glob("*.png")
if not p.name.startswith(".") and p.name.lower() != "thumbs.db"
]
)
if len(self.image_paths) == 0:
raise RuntimeError(f"No PNG images found in {self.image_dir}")
self.samples = []
for image_path in self.image_paths:
label_path = self.label_dir / image_path.name
if not label_path.exists():
raise FileNotFoundError(
f"Missing label for image:\n"
f"image: {image_path}\n"
f"label: {label_path}"
)
self.samples.append(
{
"image_path": image_path,
"label_path": label_path,
"case_id": image_path.stem,
}
)
def __len__(self):
return len(self.samples)
def _load_image(self, path):
image = Image.open(path).convert("RGB")
return np.array(image)
def _load_mask(self, path):
mask = Image.open(path).convert("L")
return np.array(mask)
def __getitem__(self, idx):
sample_info = self.samples[idx]
image_path = sample_info["image_path"]
label_path = sample_info["label_path"]
case_id = sample_info["case_id"]
image = self._load_image(image_path)
label = self._load_mask(label_path)
if self.transform is not None:
transformed = self.transform(
image=image,
mask=label,
)
image = transformed["image"]
label = transformed["mask"]
# Albumentations ToTensorV2 converts image to [3, H, W],
# but mask remains [H, W], so add channel dimension.
if isinstance(label, torch.Tensor):
label = label.float().unsqueeze(0)
else:
label = torch.from_numpy(label).float().unsqueeze(0)
else:
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
label = torch.from_numpy(label).float().unsqueeze(0)
# Convert vessel mask to binary {0, 1}
label = (label > 0).float()
return {
"image": image,
"label": label,
"case_id": case_id,
"image_path": str(image_path),
"label_path": str(label_path),
}
if __name__ == "__main__":
import matplotlib.pyplot as plt
try:
from augmentations import get_train_transforms, get_val_transforms
except ImportError:
import sys
project_root = Path(__file__).resolve().parents[1]
sys.path.append(str(project_root))
from augmentations import get_train_transforms, get_val_transforms
root = "/data/MIDS/datasets/retina/FIVES_dataset"
image_size = 512
dataset = FIVESDataset(
root=root,
split="train",
transform=get_train_transforms(image_size=image_size),
)
loader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
num_workers=0,
)
batch = next(iter(loader))
print("Number of samples:", len(dataset))
print("Batch keys:", batch.keys())
print("Image shape:", batch["image"].shape)
print("Label shape:", batch["label"].shape)
print("Label min/max:", batch["label"].min().item(), batch["label"].max().item())
print("Case IDs:", batch["case_id"])
# -------------------------
# Matplotlib visualization
# -------------------------
image = batch["image"][0]
label = batch["label"][0, 0]
# Undo ImageNet normalization for visualization.
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
image_vis = image.cpu() * std + mean
image_vis = image_vis.clamp(0, 1)
image_vis = image_vis.permute(1, 2, 0).numpy()
label_vis = label.cpu().numpy()
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(image_vis)
axes[0].set_title("Image")
axes[0].axis("off")
axes[1].imshow(label_vis, cmap="gray")
axes[1].set_title("Vessel Label")
axes[1].axis("off")
axes[2].imshow(image_vis)
axes[2].imshow(label_vis, cmap="Reds", alpha=0.45)
axes[2].set_title("Overlay")
axes[2].axis("off")
plt.tight_layout()
plt.show()