import os import glob from PIL import Image import torch from torch.utils.data import Dataset from torchvision import transforms import random # Needed for random cropping # --- Updated SRDataset Class --- class SRDataset(Dataset): """ Custom Dataset for Super-Resolution. Loads HR/LR pairs and returns fixed-size patches. """ def __init__(self, hr_dir, lr_dir, scale_factor, patch_size_lr=48, transform=None): """ Args: hr_dir (str): Directory with all HR images. lr_dir (str): Directory with all LR images (corresponding to hr_dir). scale_factor (int): The upscaling factor. patch_size_lr (int): The size (height and width) of the LR patch to crop. HR patch size will be patch_size_lr * scale_factor. transform (callable, optional): Optional transform (e.g., data augmentation like flips). """ super(SRDataset, self).__init__() # Call parent constructor self.hr_dir = hr_dir self.lr_dir = lr_dir self.scale_factor = scale_factor self.patch_size_lr = patch_size_lr self.patch_size_hr = patch_size_lr * scale_factor self.transform = transform # Find all image files (png, jpg, jpeg) in the LR directory self.lr_image_files = sorted( glob.glob(os.path.join(lr_dir, '*.png')) + glob.glob(os.path.join(lr_dir, '*.jpg')) + glob.glob(os.path.join(lr_dir, '*.jpeg')) ) if not self.lr_image_files: raise FileNotFoundError(f"No images found in LR directory: {lr_dir}. Check path and image extensions.") # --- (Optional Verification Step - can be kept or removed) --- if self.lr_image_files: # ... (verification code from previous version can go here if desired) ... pass print(f"Found {len(self.lr_image_files)} image pairs in HR='{hr_dir}', LR='{lr_dir}'") print(f"Using LR patch size: {self.patch_size_lr}x{self.patch_size_lr}, HR patch size: {self.patch_size_hr}x{self.patch_size_hr}") def __len__(self): return len(self.lr_image_files) @staticmethod def get_patch(lr_img, hr_img, patch_size_lr, scale_factor): """ Randomly crops corresponding patches from LR and HR images. Args: lr_img (PIL.Image): Low-resolution image. hr_img (PIL.Image): High-resolution image. patch_size_lr (int): The desired height/width of the LR patch. scale_factor (int): The upscaling factor. Returns: tuple: (lr_patch, hr_patch) PIL.Image objects. """ lr_w, lr_h = lr_img.size hr_w, hr_h = hr_img.size patch_size_hr = patch_size_lr * scale_factor # Ensure HR image dimensions are consistent with LR and scale factor if hr_w != lr_w * scale_factor or hr_h != lr_h * scale_factor: # Simple fallback: resize HR image to expected size if mismatch occurs # This might happen with imperfect downscaling or odd original dimensions # print(f"Warning: HR/LR size mismatch ({hr_img.size} vs {lr_img.size} * {scale_factor}). Resizing HR image.") hr_img = hr_img.resize((lr_w * scale_factor, lr_h * scale_factor), resample=Image.BICUBIC) # Choose random top-left corner for LR patch # Ensure the patch fits within the image boundaries if lr_w < patch_size_lr or lr_h < patch_size_lr: # If LR image is smaller than patch size, resize LR and corresponding HR region # This ensures __getitem__ always returns tensors of the target patch size lr_img = lr_img.resize((max(lr_w, patch_size_lr), max(lr_h, patch_size_lr)), resample=Image.BICUBIC) hr_img = hr_img.resize((lr_img.width * scale_factor, lr_img.height * scale_factor), resample=Image.BICUBIC) lr_w, lr_h = lr_img.size # Update dimensions lr_x = random.randrange(0, lr_w - patch_size_lr + 1) lr_y = random.randrange(0, lr_h - patch_size_lr + 1) # Calculate corresponding top-left corner for HR patch hr_x = lr_x * scale_factor hr_y = lr_y * scale_factor # Crop patches # PIL crop format is (left, upper, right, lower) lr_patch = lr_img.crop((lr_x, lr_y, lr_x + patch_size_lr, lr_y + patch_size_lr)) hr_patch = hr_img.crop((hr_x, hr_y, hr_x + patch_size_hr, hr_y + patch_size_hr)) return lr_patch, hr_patch @staticmethod def augment_patch(lr_patch, hr_patch): """Applies simple random augmentations (flip, rotation).""" # Random horizontal flip if random.random() < 0.5: lr_patch = lr_patch.transpose(Image.FLIP_LEFT_RIGHT) hr_patch = hr_patch.transpose(Image.FLIP_LEFT_RIGHT) # Random vertical flip (less common, can sometimes be excluded) # if random.random() < 0.5: # lr_patch = lr_patch.transpose(Image.FLIP_TOP_BOTTOM) # hr_patch = hr_patch.transpose(Image.FLIP_TOP_BOTTOM) # Random 90-degree rotation # rot_choice = random.choice([0, 1, 2, 3]) # 0: 0 deg, 1: 90 deg, 2: 180 deg, 3: 270 deg # if rot_choice != 0: # lr_patch = lr_patch.rotate(90 * rot_choice, expand=True) # expand=True might change size if not square # hr_patch = hr_patch.rotate(90 * rot_choice, expand=True) return lr_patch, hr_patch def __getitem__(self, idx): # Get the full LR image path lr_path = self.lr_image_files[idx] try: lr_img = Image.open(lr_path).convert('RGB') except Exception as e: print(f"Error opening LR image {lr_path}: {e}") # Decide how to handle: return None, raise error, or return dummy # Returning None requires careful handling in the DataLoader collate_fn or training loop return None # Let collate_fn handle this potentially # Construct the corresponding full HR image path base_name = os.path.basename(lr_path) hr_path = os.path.join(self.hr_dir, base_name) # Handle potential alternative HR filenames if not os.path.exists(hr_path): base, ext = os.path.splitext(base_name) if f'x{self.scale_factor}' in base: hr_name = base.replace(f'x{self.scale_factor}', '') + ext hr_path_alt = os.path.join(self.hr_dir, hr_name) if os.path.exists(hr_path_alt): hr_path = hr_path_alt else: print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}") return None # Indicate error else: print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}") return None # Indicate error try: hr_img = Image.open(hr_path).convert('RGB') except Exception as e: print(f"Error opening HR image {hr_path}: {e}") return None # Indicate error # --- Get Corresponding Patches --- try: lr_patch, hr_patch = self.get_patch(lr_img, hr_img, self.patch_size_lr, self.scale_factor) except ValueError as e: # Catch randrange error if patch size > image size after potential resize print(f"Error getting patch for {lr_path} (maybe image is smaller than patch size?): {e}") return None # --- Apply Augmentations (Optional) --- lr_patch, hr_patch = self.augment_patch(lr_patch, hr_patch) # --- Apply Custom Transform if provided --- # (Currently we pass None, but this is where you'd integrate albumentations etc.) if self.transform: # A typical transform might operate on numpy arrays # lr_np = np.array(lr_patch) # hr_np = np.array(hr_patch) # transformed = self.transform(image=lr_np, mask=hr_np) # Example syntax # lr_patch = Image.fromarray(transformed['image']) # hr_patch = Image.fromarray(transformed['mask']) pass # Placeholder # --- Convert Patches to Tensors --- to_tensor = transforms.ToTensor() # Converts PIL image (HWC) [0, 255] to Tensor (CHW) [0.0, 1.0] lr_tensor = to_tensor(lr_patch) hr_tensor = to_tensor(hr_patch) return {'lr': lr_tensor, 'hr': hr_tensor} # --- Example Usage (for testing the definition) --- if __name__ == '__main__': print("--- Testing SRDataset with Patching ---") hr_data_dir = './datasets/DIV2K/HR_extracted/DIV2K_train_HR' # Modify if needed lr_data_dir = './datasets/DIV2K/DIV2K_train_LR_bicubic/X4' # Modify if needed scale = 4 lr_patch_size = 48 # Common LR patch size for SR tasks if not os.path.isdir(hr_data_dir): print(f"ERROR: HR dir not found: '{hr_data_dir}'") if not os.path.isdir(lr_data_dir): print(f"ERROR: LR dir not found: '{lr_data_dir}'") try: dataset = SRDataset(hr_dir=hr_data_dir, lr_dir=lr_data_dir, scale_factor=scale, patch_size_lr=lr_patch_size) if len(dataset) > 0: print(f"\nSuccessfully loaded dataset with {len(dataset)} image pairs.") # Test getting a single item (patch pair) print("\n--- Testing __getitem__ ---") num_test_items = 5 for i in range(min(num_test_items, len(dataset))): item = dataset[i] if item is None: print(f"Item {i}: Returned None (Error occurred)") continue lr_p = item['lr'] hr_p = item['hr'] print(f"Item {i}: LR Patch Shape={lr_p.shape}, HR Patch Shape={hr_p.shape}") # Verify shapes expected_hr_shape = (3, lr_patch_size * scale, lr_patch_size * scale) if lr_p.shape != (3, lr_patch_size, lr_patch_size) or hr_p.shape != expected_hr_shape: print(f" WARNING: Shape mismatch! LR={lr_p.shape}, HR={hr_p.shape}, Expected HR={expected_hr_shape}") # Test DataLoader with a simple collate function that filters Nones print("\n--- Testing DataLoader with Patches ---") from torch.utils.data import DataLoader # Define a collate_fn that filters out None values returned by __getitem__ def collate_fn_filter_none(batch): batch = list(filter(lambda x: x is not None, batch)) if not batch: # If all items in the batch failed return None # Use default collate on the filtered batch return torch.utils.data.dataloader.default_collate(batch) # Use batch_size=4 for testing dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, collate_fn=collate_fn_filter_none) num_test_batches = 3 batch_count = 0 for batch in dataloader: if batch_count >= num_test_batches: break if batch is None: print(f"Skipping an entirely problematic batch.") continue lr_batch = batch['lr'] hr_batch = batch['hr'] print(f"Batch {batch_count}: LR Batch Shape={lr_batch.shape}, HR Batch Shape={hr_batch.shape}") batch_count += 1 if batch_count > 0: print("DataLoader test with patches successful.") else: print("DataLoader test: Could not retrieve any valid batches.") else: print("\nDataset loaded but is empty.") except FileNotFoundError as e: print(f"\nERROR initializing dataset: {e}") except Exception as e: print(f"\nAn unexpected error occurred during dataset testing: {e}") print("\n--- SRDataset Test Finished ---")