CFPVesselSeg / datasets /FGADR.py
farrell236's picture
add src
e99a83c
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.model_selection import KFold
class FGADRDataset(Dataset):
"""
FGADR Seg-set dataset for diabetic retinopathy lesion segmentation.
Expected structure:
Seg-set/
├── DR_Seg_Grading_Label.csv
├── Original_Images/
├── Microaneurysms_Masks/
├── Hemohedge_Masks/
├── HardExudate_Masks/
├── SoftExudate_Masks/
├── IRMA_Masks/
└── Neovascularization_Masks/
CSV format, no header:
filename,dr_grade
Output:
image: [3, H, W]
label: [6, H, W]
grade: scalar long tensor
case_id: filename stem
split:
"train" = all folds except selected fold
"val" = selected fold
"all" = full dataset
Notes:
If a lesion-specific mask file is absent, it is treated as an empty
all-zero mask, meaning no incidence of that lesion class.
"""
lesion_dirs = {
"microaneurysm": "Microaneurysms_Masks",
"hemorrhage": "Hemohedge_Masks",
"hard_exudate": "HardExudate_Masks",
"soft_exudate": "SoftExudate_Masks",
"irma": "IRMA_Masks",
"neovascularization": "Neovascularization_Masks",
}
def __init__(
self,
root,
split="train",
fold=0,
n_folds=5,
seed=42,
transform=None,
csv_name="DR_Seg_Grading_Label.csv",
image_dir_name="Original_Images",
mask_suffix="",
):
self.root = Path(root)
self.split = split
self.fold = fold
self.n_folds = n_folds
self.seed = seed
self.transform = transform
self.csv_path = self.root / csv_name
self.image_dir = self.root / image_dir_name
self.mask_suffix = mask_suffix
if split not in ["train", "val", "all"]:
raise ValueError("split must be one of: 'train', 'val', 'all'")
if not (0 <= fold < n_folds):
raise ValueError(f"fold must be in [0, {n_folds - 1}], got {fold}")
if not self.image_dir.exists():
raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
if not self.csv_path.exists():
raise FileNotFoundError(f"CSV file not found: {self.csv_path}")
self.class_names = list(self.lesion_dirs.keys())
for dirname in self.lesion_dirs.values():
mask_dir = self.root / dirname
if not mask_dir.exists():
raise FileNotFoundError(f"Mask directory not found: {mask_dir}")
all_samples = self._read_csv()
if len(all_samples) == 0:
raise RuntimeError(f"No samples found in {self.csv_path}")
if split == "all":
self.samples = all_samples
else:
kfold = KFold(
n_splits=n_folds,
shuffle=True,
random_state=seed,
)
splits = list(kfold.split(all_samples))
train_indices, val_indices = splits[fold]
if split == "train":
self.samples = [all_samples[i] for i in train_indices]
else:
self.samples = [all_samples[i] for i in val_indices]
def _read_csv(self):
samples = []
with open(self.csv_path, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split(",")
if len(parts) < 2:
continue
filename = parts[0].strip()
grade = int(parts[1].strip())
image_path = self.image_dir / filename
if not image_path.exists():
raise FileNotFoundError(f"Image not found: {image_path}")
samples.append(
{
"filename": filename,
"case_id": Path(filename).stem,
"image_path": image_path,
"grade": grade,
}
)
return samples
def __len__(self):
return len(self.samples)
def _load_image(self, path):
image = Image.open(path).convert("RGB")
return np.array(image)
def _load_mask(self, path, shape):
if path.exists():
mask = Image.open(path).convert("L")
mask = np.array(mask)
else:
mask = np.zeros(shape, dtype=np.uint8)
return mask
def _get_mask_path(self, lesion_name, filename):
mask_dir = self.root / self.lesion_dirs[lesion_name]
if self.mask_suffix:
stem = Path(filename).stem
suffix = Path(filename).suffix
filename = f"{stem}{self.mask_suffix}{suffix}"
return mask_dir / filename
def __getitem__(self, idx):
sample_info = self.samples[idx]
filename = sample_info["filename"]
image_path = sample_info["image_path"]
case_id = sample_info["case_id"]
grade = sample_info["grade"]
image = self._load_image(image_path)
h, w = image.shape[:2]
masks = []
mask_paths = {}
for lesion_name in self.class_names:
mask_path = self._get_mask_path(lesion_name, filename)
mask = self._load_mask(mask_path, shape=(h, w))
masks.append(mask)
mask_paths[lesion_name] = str(mask_path)
if self.transform is not None:
transformed = self.transform(
image=image,
masks=masks,
)
image = transformed["image"]
masks = transformed["masks"]
masks = [
m.float() if isinstance(m, torch.Tensor) else torch.from_numpy(m).float()
for m in masks
]
label = torch.stack(masks, dim=0)
else:
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
label = torch.stack(
[torch.from_numpy(m).float() for m in masks],
dim=0,
)
label = (label > 0).float()
return {
"image": image,
"label": label,
"grade": torch.tensor(grade, dtype=torch.long),
"case_id": case_id,
"filename": filename,
"image_path": str(image_path),
"mask_paths": mask_paths,
}
if __name__ == "__main__":
import matplotlib.pyplot as plt
from tqdm import tqdm
try:
from augmentations import get_train_transforms, IMAGENET_MEAN, IMAGENET_STD
except ImportError:
import sys
project_root = Path(__file__).resolve().parents[1]
sys.path.append(str(project_root))
from augmentations import get_train_transforms, IMAGENET_MEAN, IMAGENET_STD
root = "/data/MIDS/datasets/retina/FGADR/Seg-set"
image_size = 512
dataset = FGADRDataset(
root=root,
split="train",
fold=0,
n_folds=5,
seed=42,
transform=get_train_transforms(image_size=image_size),
)
print("\nChecking all FGADR files...")
missing_images = 0
absent_masks = 0
for sample in tqdm(dataset.samples, desc="Checking files"):
filename = sample["filename"]
if not sample["image_path"].exists():
print(f"Missing image: {sample['image_path']}")
missing_images += 1
for lesion_name in dataset.class_names:
mask_path = dataset._get_mask_path(lesion_name, filename)
if not mask_path.exists():
absent_masks += 1
print("File check complete.")
print(f"Missing images: {missing_images}")
print(f"Absent lesion masks treated as empty: {absent_masks}")
loader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
num_workers=0,
)
batch = next(iter(loader))
print("\nSmoke test batch:")
print("Number of samples:", len(dataset))
print("Split:", dataset.split)
print("Fold:", dataset.fold)
print("Number of folds:", dataset.n_folds)
print("Class names:", dataset.class_names)
print("Batch keys:", batch.keys())
print("Image shape:", batch["image"].shape)
print("Label shape:", batch["label"].shape)
print("Grade shape:", batch["grade"].shape)
print("Label min/max:", batch["label"].min().item(), batch["label"].max().item())
print("Case IDs:", batch["case_id"])
image = batch["image"][0].cpu()
label = batch["label"][0].cpu()
grade = batch["grade"][0].item()
mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
image_vis = image * std + mean
image_vis = image_vis.clamp(0, 1)
image_vis = image_vis.permute(1, 2, 0).numpy()
combined_mask = (label.sum(dim=0) > 0).float().numpy()
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()
axes[0].imshow(image_vis)
axes[0].set_title(f"Image | Grade {grade}")
axes[0].axis("off")
axes[1].imshow(combined_mask, cmap="gray")
axes[1].set_title("Any Lesion")
axes[1].axis("off")
axes[2].imshow(image_vis)
axes[2].imshow(combined_mask, cmap="Reds", alpha=0.45)
axes[2].set_title("Overlay")
axes[2].axis("off")
for ax in axes[3:]:
ax.axis("off")
for i, class_name in enumerate(dataset.class_names):
ax = axes[i + 3]
ax.imshow(label[i].numpy(), cmap="gray")
ax.set_title(class_name)
ax.axis("off")
plt.tight_layout()
plt.show()