siddharthdhara17's picture
Upload baselines/dataset.py with huggingface_hub
09967a3 verified
"""
PyTorch dataset for LIDC-IDRI flat format (image + majority-vote mask).
"""
import os
import glob
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
import random
class LIDCFlatDataset(Dataset):
"""Dataset for flat directory structure with matched image/mask pairs."""
def __init__(self, root_dir, augment=False, img_size=128):
"""
Args:
root_dir: Directory containing 'images/' and 'masks/' subdirectories
augment: Whether to apply data augmentation
img_size: Target image size (images should already be this size)
"""
self.root_dir = root_dir
self.augment = augment
self.img_size = img_size
self.image_dir = os.path.join(root_dir, "images")
self.mask_dir = os.path.join(root_dir, "masks")
self.image_files = sorted(glob.glob(os.path.join(self.image_dir, "*.png")))
if len(self.image_files) == 0:
raise RuntimeError(f"No images found in {self.image_dir}")
print(f"Dataset: {len(self.image_files)} samples from {root_dir}")
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = self.image_files[idx]
basename = os.path.basename(img_path)
mask_path = os.path.join(self.mask_dir, basename)
# Load image (grayscale)
image = Image.open(img_path).convert("L")
mask = Image.open(mask_path).convert("L")
# Apply augmentation
if self.augment:
# Random horizontal flip
if random.random() > 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask)
# Random vertical flip
if random.random() > 0.5:
image = TF.vflip(image)
mask = TF.vflip(mask)
# Random rotation (±15 degrees)
angle = random.uniform(-15, 15)
image = TF.rotate(image, angle, fill=0)
mask = TF.rotate(mask, angle, fill=0)
# Convert to tensors
image = TF.to_tensor(image) # [1, H, W], range [0, 1]
mask = TF.to_tensor(mask) # [1, H, W], range [0, 1]
mask = (mask > 0.5).float() # Binarize
# Get sample ID for evaluation
sample_id = os.path.splitext(basename)[0]
return image, mask, sample_id
def get_sample_id(self, idx):
"""Get sample ID without loading the image."""
return os.path.splitext(os.path.basename(self.image_files[idx]))[0]
class LIDCTestDataset(Dataset):
"""Test dataset - loads only images (no masks needed for prediction)."""
def __init__(self, root_dir, img_size=128):
self.root_dir = root_dir
self.img_size = img_size
self.image_dir = os.path.join(root_dir, "images")
self.image_files = sorted(glob.glob(os.path.join(self.image_dir, "*.png")))
if len(self.image_files) == 0:
raise RuntimeError(f"No images found in {self.image_dir}")
print(f"Test dataset: {len(self.image_files)} samples from {root_dir}")
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = self.image_files[idx]
basename = os.path.basename(img_path)
image = Image.open(img_path).convert("L")
image = TF.to_tensor(image)
sample_id = os.path.splitext(basename)[0]
return image, sample_id