Spaces:
Running
Running
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image | |
| from sklearn.model_selection import KFold | |
| class FGADRDataset(Dataset): | |
| """ | |
| FGADR Seg-set dataset for diabetic retinopathy lesion segmentation. | |
| Expected structure: | |
| Seg-set/ | |
| ├── DR_Seg_Grading_Label.csv | |
| ├── Original_Images/ | |
| ├── Microaneurysms_Masks/ | |
| ├── Hemohedge_Masks/ | |
| ├── HardExudate_Masks/ | |
| ├── SoftExudate_Masks/ | |
| ├── IRMA_Masks/ | |
| └── Neovascularization_Masks/ | |
| CSV format, no header: | |
| filename,dr_grade | |
| Output: | |
| image: [3, H, W] | |
| label: [6, H, W] | |
| grade: scalar long tensor | |
| case_id: filename stem | |
| split: | |
| "train" = all folds except selected fold | |
| "val" = selected fold | |
| "all" = full dataset | |
| Notes: | |
| If a lesion-specific mask file is absent, it is treated as an empty | |
| all-zero mask, meaning no incidence of that lesion class. | |
| """ | |
| lesion_dirs = { | |
| "microaneurysm": "Microaneurysms_Masks", | |
| "hemorrhage": "Hemohedge_Masks", | |
| "hard_exudate": "HardExudate_Masks", | |
| "soft_exudate": "SoftExudate_Masks", | |
| "irma": "IRMA_Masks", | |
| "neovascularization": "Neovascularization_Masks", | |
| } | |
| def __init__( | |
| self, | |
| root, | |
| split="train", | |
| fold=0, | |
| n_folds=5, | |
| seed=42, | |
| transform=None, | |
| csv_name="DR_Seg_Grading_Label.csv", | |
| image_dir_name="Original_Images", | |
| mask_suffix="", | |
| ): | |
| self.root = Path(root) | |
| self.split = split | |
| self.fold = fold | |
| self.n_folds = n_folds | |
| self.seed = seed | |
| self.transform = transform | |
| self.csv_path = self.root / csv_name | |
| self.image_dir = self.root / image_dir_name | |
| self.mask_suffix = mask_suffix | |
| if split not in ["train", "val", "all"]: | |
| raise ValueError("split must be one of: 'train', 'val', 'all'") | |
| if not (0 <= fold < n_folds): | |
| raise ValueError(f"fold must be in [0, {n_folds - 1}], got {fold}") | |
| if not self.image_dir.exists(): | |
| raise FileNotFoundError(f"Image directory not found: {self.image_dir}") | |
| if not self.csv_path.exists(): | |
| raise FileNotFoundError(f"CSV file not found: {self.csv_path}") | |
| self.class_names = list(self.lesion_dirs.keys()) | |
| for dirname in self.lesion_dirs.values(): | |
| mask_dir = self.root / dirname | |
| if not mask_dir.exists(): | |
| raise FileNotFoundError(f"Mask directory not found: {mask_dir}") | |
| all_samples = self._read_csv() | |
| if len(all_samples) == 0: | |
| raise RuntimeError(f"No samples found in {self.csv_path}") | |
| if split == "all": | |
| self.samples = all_samples | |
| else: | |
| kfold = KFold( | |
| n_splits=n_folds, | |
| shuffle=True, | |
| random_state=seed, | |
| ) | |
| splits = list(kfold.split(all_samples)) | |
| train_indices, val_indices = splits[fold] | |
| if split == "train": | |
| self.samples = [all_samples[i] for i in train_indices] | |
| else: | |
| self.samples = [all_samples[i] for i in val_indices] | |
| def _read_csv(self): | |
| samples = [] | |
| with open(self.csv_path, "r") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| parts = line.split(",") | |
| if len(parts) < 2: | |
| continue | |
| filename = parts[0].strip() | |
| grade = int(parts[1].strip()) | |
| image_path = self.image_dir / filename | |
| if not image_path.exists(): | |
| raise FileNotFoundError(f"Image not found: {image_path}") | |
| samples.append( | |
| { | |
| "filename": filename, | |
| "case_id": Path(filename).stem, | |
| "image_path": image_path, | |
| "grade": grade, | |
| } | |
| ) | |
| return samples | |
| def __len__(self): | |
| return len(self.samples) | |
| def _load_image(self, path): | |
| image = Image.open(path).convert("RGB") | |
| return np.array(image) | |
| def _load_mask(self, path, shape): | |
| if path.exists(): | |
| mask = Image.open(path).convert("L") | |
| mask = np.array(mask) | |
| else: | |
| mask = np.zeros(shape, dtype=np.uint8) | |
| return mask | |
| def _get_mask_path(self, lesion_name, filename): | |
| mask_dir = self.root / self.lesion_dirs[lesion_name] | |
| if self.mask_suffix: | |
| stem = Path(filename).stem | |
| suffix = Path(filename).suffix | |
| filename = f"{stem}{self.mask_suffix}{suffix}" | |
| return mask_dir / filename | |
| def __getitem__(self, idx): | |
| sample_info = self.samples[idx] | |
| filename = sample_info["filename"] | |
| image_path = sample_info["image_path"] | |
| case_id = sample_info["case_id"] | |
| grade = sample_info["grade"] | |
| image = self._load_image(image_path) | |
| h, w = image.shape[:2] | |
| masks = [] | |
| mask_paths = {} | |
| for lesion_name in self.class_names: | |
| mask_path = self._get_mask_path(lesion_name, filename) | |
| mask = self._load_mask(mask_path, shape=(h, w)) | |
| masks.append(mask) | |
| mask_paths[lesion_name] = str(mask_path) | |
| if self.transform is not None: | |
| transformed = self.transform( | |
| image=image, | |
| masks=masks, | |
| ) | |
| image = transformed["image"] | |
| masks = transformed["masks"] | |
| masks = [ | |
| m.float() if isinstance(m, torch.Tensor) else torch.from_numpy(m).float() | |
| for m in masks | |
| ] | |
| label = torch.stack(masks, dim=0) | |
| else: | |
| image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 | |
| label = torch.stack( | |
| [torch.from_numpy(m).float() for m in masks], | |
| dim=0, | |
| ) | |
| label = (label > 0).float() | |
| return { | |
| "image": image, | |
| "label": label, | |
| "grade": torch.tensor(grade, dtype=torch.long), | |
| "case_id": case_id, | |
| "filename": filename, | |
| "image_path": str(image_path), | |
| "mask_paths": mask_paths, | |
| } | |
| if __name__ == "__main__": | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| try: | |
| from augmentations import get_train_transforms, IMAGENET_MEAN, IMAGENET_STD | |
| except ImportError: | |
| import sys | |
| project_root = Path(__file__).resolve().parents[1] | |
| sys.path.append(str(project_root)) | |
| from augmentations import get_train_transforms, IMAGENET_MEAN, IMAGENET_STD | |
| root = "/data/MIDS/datasets/retina/FGADR/Seg-set" | |
| image_size = 512 | |
| dataset = FGADRDataset( | |
| root=root, | |
| split="train", | |
| fold=0, | |
| n_folds=5, | |
| seed=42, | |
| transform=get_train_transforms(image_size=image_size), | |
| ) | |
| print("\nChecking all FGADR files...") | |
| missing_images = 0 | |
| absent_masks = 0 | |
| for sample in tqdm(dataset.samples, desc="Checking files"): | |
| filename = sample["filename"] | |
| if not sample["image_path"].exists(): | |
| print(f"Missing image: {sample['image_path']}") | |
| missing_images += 1 | |
| for lesion_name in dataset.class_names: | |
| mask_path = dataset._get_mask_path(lesion_name, filename) | |
| if not mask_path.exists(): | |
| absent_masks += 1 | |
| print("File check complete.") | |
| print(f"Missing images: {missing_images}") | |
| print(f"Absent lesion masks treated as empty: {absent_masks}") | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=4, | |
| shuffle=True, | |
| num_workers=0, | |
| ) | |
| batch = next(iter(loader)) | |
| print("\nSmoke test batch:") | |
| print("Number of samples:", len(dataset)) | |
| print("Split:", dataset.split) | |
| print("Fold:", dataset.fold) | |
| print("Number of folds:", dataset.n_folds) | |
| print("Class names:", dataset.class_names) | |
| print("Batch keys:", batch.keys()) | |
| print("Image shape:", batch["image"].shape) | |
| print("Label shape:", batch["label"].shape) | |
| print("Grade shape:", batch["grade"].shape) | |
| print("Label min/max:", batch["label"].min().item(), batch["label"].max().item()) | |
| print("Case IDs:", batch["case_id"]) | |
| image = batch["image"][0].cpu() | |
| label = batch["label"][0].cpu() | |
| grade = batch["grade"][0].item() | |
| mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1) | |
| std = torch.tensor(IMAGENET_STD).view(3, 1, 1) | |
| image_vis = image * std + mean | |
| image_vis = image_vis.clamp(0, 1) | |
| image_vis = image_vis.permute(1, 2, 0).numpy() | |
| combined_mask = (label.sum(dim=0) > 0).float().numpy() | |
| fig, axes = plt.subplots(2, 5, figsize=(20, 8)) | |
| axes = axes.flatten() | |
| axes[0].imshow(image_vis) | |
| axes[0].set_title(f"Image | Grade {grade}") | |
| axes[0].axis("off") | |
| axes[1].imshow(combined_mask, cmap="gray") | |
| axes[1].set_title("Any Lesion") | |
| axes[1].axis("off") | |
| axes[2].imshow(image_vis) | |
| axes[2].imshow(combined_mask, cmap="Reds", alpha=0.45) | |
| axes[2].set_title("Overlay") | |
| axes[2].axis("off") | |
| for ax in axes[3:]: | |
| ax.axis("off") | |
| for i, class_name in enumerate(dataset.class_names): | |
| ax = axes[i + 3] | |
| ax.imshow(label[i].numpy(), cmap="gray") | |
| ax.set_title(class_name) | |
| ax.axis("off") | |
| plt.tight_layout() | |
| plt.show() |