File size: 12,107 Bytes
fd5c0a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
import os
import glob
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import random # Needed for random cropping
# --- Updated SRDataset Class ---
class SRDataset(Dataset):
"""
Custom Dataset for Super-Resolution.
Loads HR/LR pairs and returns fixed-size patches.
"""
def __init__(self, hr_dir, lr_dir, scale_factor, patch_size_lr=48, transform=None):
"""
Args:
hr_dir (str): Directory with all HR images.
lr_dir (str): Directory with all LR images (corresponding to hr_dir).
scale_factor (int): The upscaling factor.
patch_size_lr (int): The size (height and width) of the LR patch to crop.
HR patch size will be patch_size_lr * scale_factor.
transform (callable, optional): Optional transform (e.g., data augmentation like flips).
"""
super(SRDataset, self).__init__() # Call parent constructor
self.hr_dir = hr_dir
self.lr_dir = lr_dir
self.scale_factor = scale_factor
self.patch_size_lr = patch_size_lr
self.patch_size_hr = patch_size_lr * scale_factor
self.transform = transform
# Find all image files (png, jpg, jpeg) in the LR directory
self.lr_image_files = sorted(
glob.glob(os.path.join(lr_dir, '*.png')) +
glob.glob(os.path.join(lr_dir, '*.jpg')) +
glob.glob(os.path.join(lr_dir, '*.jpeg'))
)
if not self.lr_image_files:
raise FileNotFoundError(f"No images found in LR directory: {lr_dir}. Check path and image extensions.")
# --- (Optional Verification Step - can be kept or removed) ---
if self.lr_image_files:
# ... (verification code from previous version can go here if desired) ...
pass
print(f"Found {len(self.lr_image_files)} image pairs in HR='{hr_dir}', LR='{lr_dir}'")
print(f"Using LR patch size: {self.patch_size_lr}x{self.patch_size_lr}, HR patch size: {self.patch_size_hr}x{self.patch_size_hr}")
def __len__(self):
return len(self.lr_image_files)
@staticmethod
def get_patch(lr_img, hr_img, patch_size_lr, scale_factor):
"""
Randomly crops corresponding patches from LR and HR images.
Args:
lr_img (PIL.Image): Low-resolution image.
hr_img (PIL.Image): High-resolution image.
patch_size_lr (int): The desired height/width of the LR patch.
scale_factor (int): The upscaling factor.
Returns:
tuple: (lr_patch, hr_patch) PIL.Image objects.
"""
lr_w, lr_h = lr_img.size
hr_w, hr_h = hr_img.size
patch_size_hr = patch_size_lr * scale_factor
# Ensure HR image dimensions are consistent with LR and scale factor
if hr_w != lr_w * scale_factor or hr_h != lr_h * scale_factor:
# Simple fallback: resize HR image to expected size if mismatch occurs
# This might happen with imperfect downscaling or odd original dimensions
# print(f"Warning: HR/LR size mismatch ({hr_img.size} vs {lr_img.size} * {scale_factor}). Resizing HR image.")
hr_img = hr_img.resize((lr_w * scale_factor, lr_h * scale_factor), resample=Image.BICUBIC)
# Choose random top-left corner for LR patch
# Ensure the patch fits within the image boundaries
if lr_w < patch_size_lr or lr_h < patch_size_lr:
# If LR image is smaller than patch size, resize LR and corresponding HR region
# This ensures __getitem__ always returns tensors of the target patch size
lr_img = lr_img.resize((max(lr_w, patch_size_lr), max(lr_h, patch_size_lr)), resample=Image.BICUBIC)
hr_img = hr_img.resize((lr_img.width * scale_factor, lr_img.height * scale_factor), resample=Image.BICUBIC)
lr_w, lr_h = lr_img.size # Update dimensions
lr_x = random.randrange(0, lr_w - patch_size_lr + 1)
lr_y = random.randrange(0, lr_h - patch_size_lr + 1)
# Calculate corresponding top-left corner for HR patch
hr_x = lr_x * scale_factor
hr_y = lr_y * scale_factor
# Crop patches
# PIL crop format is (left, upper, right, lower)
lr_patch = lr_img.crop((lr_x, lr_y, lr_x + patch_size_lr, lr_y + patch_size_lr))
hr_patch = hr_img.crop((hr_x, hr_y, hr_x + patch_size_hr, hr_y + patch_size_hr))
return lr_patch, hr_patch
@staticmethod
def augment_patch(lr_patch, hr_patch):
"""Applies simple random augmentations (flip, rotation)."""
# Random horizontal flip
if random.random() < 0.5:
lr_patch = lr_patch.transpose(Image.FLIP_LEFT_RIGHT)
hr_patch = hr_patch.transpose(Image.FLIP_LEFT_RIGHT)
# Random vertical flip (less common, can sometimes be excluded)
# if random.random() < 0.5:
# lr_patch = lr_patch.transpose(Image.FLIP_TOP_BOTTOM)
# hr_patch = hr_patch.transpose(Image.FLIP_TOP_BOTTOM)
# Random 90-degree rotation
# rot_choice = random.choice([0, 1, 2, 3]) # 0: 0 deg, 1: 90 deg, 2: 180 deg, 3: 270 deg
# if rot_choice != 0:
# lr_patch = lr_patch.rotate(90 * rot_choice, expand=True) # expand=True might change size if not square
# hr_patch = hr_patch.rotate(90 * rot_choice, expand=True)
return lr_patch, hr_patch
def __getitem__(self, idx):
# Get the full LR image path
lr_path = self.lr_image_files[idx]
try:
lr_img = Image.open(lr_path).convert('RGB')
except Exception as e:
print(f"Error opening LR image {lr_path}: {e}")
# Decide how to handle: return None, raise error, or return dummy
# Returning None requires careful handling in the DataLoader collate_fn or training loop
return None # Let collate_fn handle this potentially
# Construct the corresponding full HR image path
base_name = os.path.basename(lr_path)
hr_path = os.path.join(self.hr_dir, base_name)
# Handle potential alternative HR filenames
if not os.path.exists(hr_path):
base, ext = os.path.splitext(base_name)
if f'x{self.scale_factor}' in base:
hr_name = base.replace(f'x{self.scale_factor}', '') + ext
hr_path_alt = os.path.join(self.hr_dir, hr_name)
if os.path.exists(hr_path_alt):
hr_path = hr_path_alt
else:
print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}")
return None # Indicate error
else:
print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}")
return None # Indicate error
try:
hr_img = Image.open(hr_path).convert('RGB')
except Exception as e:
print(f"Error opening HR image {hr_path}: {e}")
return None # Indicate error
# --- Get Corresponding Patches ---
try:
lr_patch, hr_patch = self.get_patch(lr_img, hr_img, self.patch_size_lr, self.scale_factor)
except ValueError as e: # Catch randrange error if patch size > image size after potential resize
print(f"Error getting patch for {lr_path} (maybe image is smaller than patch size?): {e}")
return None
# --- Apply Augmentations (Optional) ---
lr_patch, hr_patch = self.augment_patch(lr_patch, hr_patch)
# --- Apply Custom Transform if provided ---
# (Currently we pass None, but this is where you'd integrate albumentations etc.)
if self.transform:
# A typical transform might operate on numpy arrays
# lr_np = np.array(lr_patch)
# hr_np = np.array(hr_patch)
# transformed = self.transform(image=lr_np, mask=hr_np) # Example syntax
# lr_patch = Image.fromarray(transformed['image'])
# hr_patch = Image.fromarray(transformed['mask'])
pass # Placeholder
# --- Convert Patches to Tensors ---
to_tensor = transforms.ToTensor() # Converts PIL image (HWC) [0, 255] to Tensor (CHW) [0.0, 1.0]
lr_tensor = to_tensor(lr_patch)
hr_tensor = to_tensor(hr_patch)
return {'lr': lr_tensor, 'hr': hr_tensor}
# --- Example Usage (for testing the definition) ---
if __name__ == '__main__':
print("--- Testing SRDataset with Patching ---")
hr_data_dir = './datasets/DIV2K/HR_extracted/DIV2K_train_HR' # Modify if needed
lr_data_dir = './datasets/DIV2K/DIV2K_train_LR_bicubic/X4' # Modify if needed
scale = 4
lr_patch_size = 48 # Common LR patch size for SR tasks
if not os.path.isdir(hr_data_dir): print(f"ERROR: HR dir not found: '{hr_data_dir}'")
if not os.path.isdir(lr_data_dir): print(f"ERROR: LR dir not found: '{lr_data_dir}'")
try:
dataset = SRDataset(hr_dir=hr_data_dir, lr_dir=lr_data_dir,
scale_factor=scale, patch_size_lr=lr_patch_size)
if len(dataset) > 0:
print(f"\nSuccessfully loaded dataset with {len(dataset)} image pairs.")
# Test getting a single item (patch pair)
print("\n--- Testing __getitem__ ---")
num_test_items = 5
for i in range(min(num_test_items, len(dataset))):
item = dataset[i]
if item is None:
print(f"Item {i}: Returned None (Error occurred)")
continue
lr_p = item['lr']
hr_p = item['hr']
print(f"Item {i}: LR Patch Shape={lr_p.shape}, HR Patch Shape={hr_p.shape}")
# Verify shapes
expected_hr_shape = (3, lr_patch_size * scale, lr_patch_size * scale)
if lr_p.shape != (3, lr_patch_size, lr_patch_size) or hr_p.shape != expected_hr_shape:
print(f" WARNING: Shape mismatch! LR={lr_p.shape}, HR={hr_p.shape}, Expected HR={expected_hr_shape}")
# Test DataLoader with a simple collate function that filters Nones
print("\n--- Testing DataLoader with Patches ---")
from torch.utils.data import DataLoader
# Define a collate_fn that filters out None values returned by __getitem__
def collate_fn_filter_none(batch):
batch = list(filter(lambda x: x is not None, batch))
if not batch: # If all items in the batch failed
return None
# Use default collate on the filtered batch
return torch.utils.data.dataloader.default_collate(batch)
# Use batch_size=4 for testing
dataloader = DataLoader(dataset, batch_size=4, shuffle=True,
num_workers=0, collate_fn=collate_fn_filter_none)
num_test_batches = 3
batch_count = 0
for batch in dataloader:
if batch_count >= num_test_batches:
break
if batch is None:
print(f"Skipping an entirely problematic batch.")
continue
lr_batch = batch['lr']
hr_batch = batch['hr']
print(f"Batch {batch_count}: LR Batch Shape={lr_batch.shape}, HR Batch Shape={hr_batch.shape}")
batch_count += 1
if batch_count > 0:
print("DataLoader test with patches successful.")
else:
print("DataLoader test: Could not retrieve any valid batches.")
else:
print("\nDataset loaded but is empty.")
except FileNotFoundError as e:
print(f"\nERROR initializing dataset: {e}")
except Exception as e:
print(f"\nAn unexpected error occurred during dataset testing: {e}")
print("\n--- SRDataset Test Finished ---") |