""" PyTorch dataset for LIDC-IDRI flat format (image + majority-vote mask). """ import os import glob import numpy as np from PIL import Image import torch from torch.utils.data import Dataset import torchvision.transforms.functional as TF import random class LIDCFlatDataset(Dataset): """Dataset for flat directory structure with matched image/mask pairs.""" def __init__(self, root_dir, augment=False, img_size=128): """ Args: root_dir: Directory containing 'images/' and 'masks/' subdirectories augment: Whether to apply data augmentation img_size: Target image size (images should already be this size) """ self.root_dir = root_dir self.augment = augment self.img_size = img_size self.image_dir = os.path.join(root_dir, "images") self.mask_dir = os.path.join(root_dir, "masks") self.image_files = sorted(glob.glob(os.path.join(self.image_dir, "*.png"))) if len(self.image_files) == 0: raise RuntimeError(f"No images found in {self.image_dir}") print(f"Dataset: {len(self.image_files)} samples from {root_dir}") def __len__(self): return len(self.image_files) def __getitem__(self, idx): img_path = self.image_files[idx] basename = os.path.basename(img_path) mask_path = os.path.join(self.mask_dir, basename) # Load image (grayscale) image = Image.open(img_path).convert("L") mask = Image.open(mask_path).convert("L") # Apply augmentation if self.augment: # Random horizontal flip if random.random() > 0.5: image = TF.hflip(image) mask = TF.hflip(mask) # Random vertical flip if random.random() > 0.5: image = TF.vflip(image) mask = TF.vflip(mask) # Random rotation (±15 degrees) angle = random.uniform(-15, 15) image = TF.rotate(image, angle, fill=0) mask = TF.rotate(mask, angle, fill=0) # Convert to tensors image = TF.to_tensor(image) # [1, H, W], range [0, 1] mask = TF.to_tensor(mask) # [1, H, W], range [0, 1] mask = (mask > 0.5).float() # Binarize # Get sample ID for evaluation sample_id = os.path.splitext(basename)[0] return image, mask, sample_id def get_sample_id(self, idx): """Get sample ID without loading the image.""" return os.path.splitext(os.path.basename(self.image_files[idx]))[0] class LIDCTestDataset(Dataset): """Test dataset - loads only images (no masks needed for prediction).""" def __init__(self, root_dir, img_size=128): self.root_dir = root_dir self.img_size = img_size self.image_dir = os.path.join(root_dir, "images") self.image_files = sorted(glob.glob(os.path.join(self.image_dir, "*.png"))) if len(self.image_files) == 0: raise RuntimeError(f"No images found in {self.image_dir}") print(f"Test dataset: {len(self.image_files)} samples from {root_dir}") def __len__(self): return len(self.image_files) def __getitem__(self, idx): img_path = self.image_files[idx] basename = os.path.basename(img_path) image = Image.open(img_path).convert("L") image = TF.to_tensor(image) sample_id = os.path.splitext(basename)[0] return image, sample_id