File size: 3,657 Bytes
09967a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | """
PyTorch dataset for LIDC-IDRI flat format (image + majority-vote mask).
"""
import os
import glob
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
import random
class LIDCFlatDataset(Dataset):
"""Dataset for flat directory structure with matched image/mask pairs."""
def __init__(self, root_dir, augment=False, img_size=128):
"""
Args:
root_dir: Directory containing 'images/' and 'masks/' subdirectories
augment: Whether to apply data augmentation
img_size: Target image size (images should already be this size)
"""
self.root_dir = root_dir
self.augment = augment
self.img_size = img_size
self.image_dir = os.path.join(root_dir, "images")
self.mask_dir = os.path.join(root_dir, "masks")
self.image_files = sorted(glob.glob(os.path.join(self.image_dir, "*.png")))
if len(self.image_files) == 0:
raise RuntimeError(f"No images found in {self.image_dir}")
print(f"Dataset: {len(self.image_files)} samples from {root_dir}")
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = self.image_files[idx]
basename = os.path.basename(img_path)
mask_path = os.path.join(self.mask_dir, basename)
# Load image (grayscale)
image = Image.open(img_path).convert("L")
mask = Image.open(mask_path).convert("L")
# Apply augmentation
if self.augment:
# Random horizontal flip
if random.random() > 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask)
# Random vertical flip
if random.random() > 0.5:
image = TF.vflip(image)
mask = TF.vflip(mask)
# Random rotation (±15 degrees)
angle = random.uniform(-15, 15)
image = TF.rotate(image, angle, fill=0)
mask = TF.rotate(mask, angle, fill=0)
# Convert to tensors
image = TF.to_tensor(image) # [1, H, W], range [0, 1]
mask = TF.to_tensor(mask) # [1, H, W], range [0, 1]
mask = (mask > 0.5).float() # Binarize
# Get sample ID for evaluation
sample_id = os.path.splitext(basename)[0]
return image, mask, sample_id
def get_sample_id(self, idx):
"""Get sample ID without loading the image."""
return os.path.splitext(os.path.basename(self.image_files[idx]))[0]
class LIDCTestDataset(Dataset):
"""Test dataset - loads only images (no masks needed for prediction)."""
def __init__(self, root_dir, img_size=128):
self.root_dir = root_dir
self.img_size = img_size
self.image_dir = os.path.join(root_dir, "images")
self.image_files = sorted(glob.glob(os.path.join(self.image_dir, "*.png")))
if len(self.image_files) == 0:
raise RuntimeError(f"No images found in {self.image_dir}")
print(f"Test dataset: {len(self.image_files)} samples from {root_dir}")
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = self.image_files[idx]
basename = os.path.basename(img_path)
image = Image.open(img_path).convert("L")
image = TF.to_tensor(image)
sample_id = os.path.splitext(basename)[0]
return image, sample_id
|