| """ |
| PyTorch dataset for LIDC-IDRI flat format (image + majority-vote mask). |
| """ |
| import os |
| import glob |
| import numpy as np |
| from PIL import Image |
| import torch |
| from torch.utils.data import Dataset |
| import torchvision.transforms.functional as TF |
| import random |
|
|
|
|
| class LIDCFlatDataset(Dataset): |
| """Dataset for flat directory structure with matched image/mask pairs.""" |
| |
| def __init__(self, root_dir, augment=False, img_size=128): |
| """ |
| Args: |
| root_dir: Directory containing 'images/' and 'masks/' subdirectories |
| augment: Whether to apply data augmentation |
| img_size: Target image size (images should already be this size) |
| """ |
| self.root_dir = root_dir |
| self.augment = augment |
| self.img_size = img_size |
| |
| self.image_dir = os.path.join(root_dir, "images") |
| self.mask_dir = os.path.join(root_dir, "masks") |
| |
| self.image_files = sorted(glob.glob(os.path.join(self.image_dir, "*.png"))) |
| |
| if len(self.image_files) == 0: |
| raise RuntimeError(f"No images found in {self.image_dir}") |
| |
| print(f"Dataset: {len(self.image_files)} samples from {root_dir}") |
| |
| def __len__(self): |
| return len(self.image_files) |
| |
| def __getitem__(self, idx): |
| img_path = self.image_files[idx] |
| basename = os.path.basename(img_path) |
| mask_path = os.path.join(self.mask_dir, basename) |
| |
| |
| image = Image.open(img_path).convert("L") |
| mask = Image.open(mask_path).convert("L") |
| |
| |
| if self.augment: |
| |
| if random.random() > 0.5: |
| image = TF.hflip(image) |
| mask = TF.hflip(mask) |
| |
| |
| if random.random() > 0.5: |
| image = TF.vflip(image) |
| mask = TF.vflip(mask) |
| |
| |
| angle = random.uniform(-15, 15) |
| image = TF.rotate(image, angle, fill=0) |
| mask = TF.rotate(mask, angle, fill=0) |
| |
| |
| image = TF.to_tensor(image) |
| mask = TF.to_tensor(mask) |
| mask = (mask > 0.5).float() |
| |
| |
| sample_id = os.path.splitext(basename)[0] |
| |
| return image, mask, sample_id |
| |
| def get_sample_id(self, idx): |
| """Get sample ID without loading the image.""" |
| return os.path.splitext(os.path.basename(self.image_files[idx]))[0] |
|
|
|
|
| class LIDCTestDataset(Dataset): |
| """Test dataset - loads only images (no masks needed for prediction).""" |
| |
| def __init__(self, root_dir, img_size=128): |
| self.root_dir = root_dir |
| self.img_size = img_size |
| self.image_dir = os.path.join(root_dir, "images") |
| self.image_files = sorted(glob.glob(os.path.join(self.image_dir, "*.png"))) |
| |
| if len(self.image_files) == 0: |
| raise RuntimeError(f"No images found in {self.image_dir}") |
| |
| print(f"Test dataset: {len(self.image_files)} samples from {root_dir}") |
| |
| def __len__(self): |
| return len(self.image_files) |
| |
| def __getitem__(self, idx): |
| img_path = self.image_files[idx] |
| basename = os.path.basename(img_path) |
| |
| image = Image.open(img_path).convert("L") |
| image = TF.to_tensor(image) |
| |
| sample_id = os.path.splitext(basename)[0] |
| return image, sample_id |
|
|