OxO_Image-Repair / dataset.py
Gordon-H's picture
Upload 13 files
fd5c0a6 verified
import os
import glob
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import random # Needed for random cropping
# --- Updated SRDataset Class ---
class SRDataset(Dataset):
"""
Custom Dataset for Super-Resolution.
Loads HR/LR pairs and returns fixed-size patches.
"""
def __init__(self, hr_dir, lr_dir, scale_factor, patch_size_lr=48, transform=None):
"""
Args:
hr_dir (str): Directory with all HR images.
lr_dir (str): Directory with all LR images (corresponding to hr_dir).
scale_factor (int): The upscaling factor.
patch_size_lr (int): The size (height and width) of the LR patch to crop.
HR patch size will be patch_size_lr * scale_factor.
transform (callable, optional): Optional transform (e.g., data augmentation like flips).
"""
super(SRDataset, self).__init__() # Call parent constructor
self.hr_dir = hr_dir
self.lr_dir = lr_dir
self.scale_factor = scale_factor
self.patch_size_lr = patch_size_lr
self.patch_size_hr = patch_size_lr * scale_factor
self.transform = transform
# Find all image files (png, jpg, jpeg) in the LR directory
self.lr_image_files = sorted(
glob.glob(os.path.join(lr_dir, '*.png')) +
glob.glob(os.path.join(lr_dir, '*.jpg')) +
glob.glob(os.path.join(lr_dir, '*.jpeg'))
)
if not self.lr_image_files:
raise FileNotFoundError(f"No images found in LR directory: {lr_dir}. Check path and image extensions.")
# --- (Optional Verification Step - can be kept or removed) ---
if self.lr_image_files:
# ... (verification code from previous version can go here if desired) ...
pass
print(f"Found {len(self.lr_image_files)} image pairs in HR='{hr_dir}', LR='{lr_dir}'")
print(f"Using LR patch size: {self.patch_size_lr}x{self.patch_size_lr}, HR patch size: {self.patch_size_hr}x{self.patch_size_hr}")
def __len__(self):
return len(self.lr_image_files)
@staticmethod
def get_patch(lr_img, hr_img, patch_size_lr, scale_factor):
"""
Randomly crops corresponding patches from LR and HR images.
Args:
lr_img (PIL.Image): Low-resolution image.
hr_img (PIL.Image): High-resolution image.
patch_size_lr (int): The desired height/width of the LR patch.
scale_factor (int): The upscaling factor.
Returns:
tuple: (lr_patch, hr_patch) PIL.Image objects.
"""
lr_w, lr_h = lr_img.size
hr_w, hr_h = hr_img.size
patch_size_hr = patch_size_lr * scale_factor
# Ensure HR image dimensions are consistent with LR and scale factor
if hr_w != lr_w * scale_factor or hr_h != lr_h * scale_factor:
# Simple fallback: resize HR image to expected size if mismatch occurs
# This might happen with imperfect downscaling or odd original dimensions
# print(f"Warning: HR/LR size mismatch ({hr_img.size} vs {lr_img.size} * {scale_factor}). Resizing HR image.")
hr_img = hr_img.resize((lr_w * scale_factor, lr_h * scale_factor), resample=Image.BICUBIC)
# Choose random top-left corner for LR patch
# Ensure the patch fits within the image boundaries
if lr_w < patch_size_lr or lr_h < patch_size_lr:
# If LR image is smaller than patch size, resize LR and corresponding HR region
# This ensures __getitem__ always returns tensors of the target patch size
lr_img = lr_img.resize((max(lr_w, patch_size_lr), max(lr_h, patch_size_lr)), resample=Image.BICUBIC)
hr_img = hr_img.resize((lr_img.width * scale_factor, lr_img.height * scale_factor), resample=Image.BICUBIC)
lr_w, lr_h = lr_img.size # Update dimensions
lr_x = random.randrange(0, lr_w - patch_size_lr + 1)
lr_y = random.randrange(0, lr_h - patch_size_lr + 1)
# Calculate corresponding top-left corner for HR patch
hr_x = lr_x * scale_factor
hr_y = lr_y * scale_factor
# Crop patches
# PIL crop format is (left, upper, right, lower)
lr_patch = lr_img.crop((lr_x, lr_y, lr_x + patch_size_lr, lr_y + patch_size_lr))
hr_patch = hr_img.crop((hr_x, hr_y, hr_x + patch_size_hr, hr_y + patch_size_hr))
return lr_patch, hr_patch
@staticmethod
def augment_patch(lr_patch, hr_patch):
"""Applies simple random augmentations (flip, rotation)."""
# Random horizontal flip
if random.random() < 0.5:
lr_patch = lr_patch.transpose(Image.FLIP_LEFT_RIGHT)
hr_patch = hr_patch.transpose(Image.FLIP_LEFT_RIGHT)
# Random vertical flip (less common, can sometimes be excluded)
# if random.random() < 0.5:
# lr_patch = lr_patch.transpose(Image.FLIP_TOP_BOTTOM)
# hr_patch = hr_patch.transpose(Image.FLIP_TOP_BOTTOM)
# Random 90-degree rotation
# rot_choice = random.choice([0, 1, 2, 3]) # 0: 0 deg, 1: 90 deg, 2: 180 deg, 3: 270 deg
# if rot_choice != 0:
# lr_patch = lr_patch.rotate(90 * rot_choice, expand=True) # expand=True might change size if not square
# hr_patch = hr_patch.rotate(90 * rot_choice, expand=True)
return lr_patch, hr_patch
def __getitem__(self, idx):
# Get the full LR image path
lr_path = self.lr_image_files[idx]
try:
lr_img = Image.open(lr_path).convert('RGB')
except Exception as e:
print(f"Error opening LR image {lr_path}: {e}")
# Decide how to handle: return None, raise error, or return dummy
# Returning None requires careful handling in the DataLoader collate_fn or training loop
return None # Let collate_fn handle this potentially
# Construct the corresponding full HR image path
base_name = os.path.basename(lr_path)
hr_path = os.path.join(self.hr_dir, base_name)
# Handle potential alternative HR filenames
if not os.path.exists(hr_path):
base, ext = os.path.splitext(base_name)
if f'x{self.scale_factor}' in base:
hr_name = base.replace(f'x{self.scale_factor}', '') + ext
hr_path_alt = os.path.join(self.hr_dir, hr_name)
if os.path.exists(hr_path_alt):
hr_path = hr_path_alt
else:
print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}")
return None # Indicate error
else:
print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}")
return None # Indicate error
try:
hr_img = Image.open(hr_path).convert('RGB')
except Exception as e:
print(f"Error opening HR image {hr_path}: {e}")
return None # Indicate error
# --- Get Corresponding Patches ---
try:
lr_patch, hr_patch = self.get_patch(lr_img, hr_img, self.patch_size_lr, self.scale_factor)
except ValueError as e: # Catch randrange error if patch size > image size after potential resize
print(f"Error getting patch for {lr_path} (maybe image is smaller than patch size?): {e}")
return None
# --- Apply Augmentations (Optional) ---
lr_patch, hr_patch = self.augment_patch(lr_patch, hr_patch)
# --- Apply Custom Transform if provided ---
# (Currently we pass None, but this is where you'd integrate albumentations etc.)
if self.transform:
# A typical transform might operate on numpy arrays
# lr_np = np.array(lr_patch)
# hr_np = np.array(hr_patch)
# transformed = self.transform(image=lr_np, mask=hr_np) # Example syntax
# lr_patch = Image.fromarray(transformed['image'])
# hr_patch = Image.fromarray(transformed['mask'])
pass # Placeholder
# --- Convert Patches to Tensors ---
to_tensor = transforms.ToTensor() # Converts PIL image (HWC) [0, 255] to Tensor (CHW) [0.0, 1.0]
lr_tensor = to_tensor(lr_patch)
hr_tensor = to_tensor(hr_patch)
return {'lr': lr_tensor, 'hr': hr_tensor}
# --- Example Usage (for testing the definition) ---
if __name__ == '__main__':
print("--- Testing SRDataset with Patching ---")
hr_data_dir = './datasets/DIV2K/HR_extracted/DIV2K_train_HR' # Modify if needed
lr_data_dir = './datasets/DIV2K/DIV2K_train_LR_bicubic/X4' # Modify if needed
scale = 4
lr_patch_size = 48 # Common LR patch size for SR tasks
if not os.path.isdir(hr_data_dir): print(f"ERROR: HR dir not found: '{hr_data_dir}'")
if not os.path.isdir(lr_data_dir): print(f"ERROR: LR dir not found: '{lr_data_dir}'")
try:
dataset = SRDataset(hr_dir=hr_data_dir, lr_dir=lr_data_dir,
scale_factor=scale, patch_size_lr=lr_patch_size)
if len(dataset) > 0:
print(f"\nSuccessfully loaded dataset with {len(dataset)} image pairs.")
# Test getting a single item (patch pair)
print("\n--- Testing __getitem__ ---")
num_test_items = 5
for i in range(min(num_test_items, len(dataset))):
item = dataset[i]
if item is None:
print(f"Item {i}: Returned None (Error occurred)")
continue
lr_p = item['lr']
hr_p = item['hr']
print(f"Item {i}: LR Patch Shape={lr_p.shape}, HR Patch Shape={hr_p.shape}")
# Verify shapes
expected_hr_shape = (3, lr_patch_size * scale, lr_patch_size * scale)
if lr_p.shape != (3, lr_patch_size, lr_patch_size) or hr_p.shape != expected_hr_shape:
print(f" WARNING: Shape mismatch! LR={lr_p.shape}, HR={hr_p.shape}, Expected HR={expected_hr_shape}")
# Test DataLoader with a simple collate function that filters Nones
print("\n--- Testing DataLoader with Patches ---")
from torch.utils.data import DataLoader
# Define a collate_fn that filters out None values returned by __getitem__
def collate_fn_filter_none(batch):
batch = list(filter(lambda x: x is not None, batch))
if not batch: # If all items in the batch failed
return None
# Use default collate on the filtered batch
return torch.utils.data.dataloader.default_collate(batch)
# Use batch_size=4 for testing
dataloader = DataLoader(dataset, batch_size=4, shuffle=True,
num_workers=0, collate_fn=collate_fn_filter_none)
num_test_batches = 3
batch_count = 0
for batch in dataloader:
if batch_count >= num_test_batches:
break
if batch is None:
print(f"Skipping an entirely problematic batch.")
continue
lr_batch = batch['lr']
hr_batch = batch['hr']
print(f"Batch {batch_count}: LR Batch Shape={lr_batch.shape}, HR Batch Shape={hr_batch.shape}")
batch_count += 1
if batch_count > 0:
print("DataLoader test with patches successful.")
else:
print("DataLoader test: Could not retrieve any valid batches.")
else:
print("\nDataset loaded but is empty.")
except FileNotFoundError as e:
print(f"\nERROR initializing dataset: {e}")
except Exception as e:
print(f"\nAn unexpected error occurred during dataset testing: {e}")
print("\n--- SRDataset Test Finished ---")