|
|
import os |
|
|
import glob |
|
|
from PIL import Image |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision import transforms |
|
|
import random |
|
|
|
|
|
|
|
|
class SRDataset(Dataset): |
|
|
""" |
|
|
Custom Dataset for Super-Resolution. |
|
|
Loads HR/LR pairs and returns fixed-size patches. |
|
|
""" |
|
|
def __init__(self, hr_dir, lr_dir, scale_factor, patch_size_lr=48, transform=None): |
|
|
""" |
|
|
Args: |
|
|
hr_dir (str): Directory with all HR images. |
|
|
lr_dir (str): Directory with all LR images (corresponding to hr_dir). |
|
|
scale_factor (int): The upscaling factor. |
|
|
patch_size_lr (int): The size (height and width) of the LR patch to crop. |
|
|
HR patch size will be patch_size_lr * scale_factor. |
|
|
transform (callable, optional): Optional transform (e.g., data augmentation like flips). |
|
|
""" |
|
|
super(SRDataset, self).__init__() |
|
|
self.hr_dir = hr_dir |
|
|
self.lr_dir = lr_dir |
|
|
self.scale_factor = scale_factor |
|
|
self.patch_size_lr = patch_size_lr |
|
|
self.patch_size_hr = patch_size_lr * scale_factor |
|
|
self.transform = transform |
|
|
|
|
|
|
|
|
self.lr_image_files = sorted( |
|
|
glob.glob(os.path.join(lr_dir, '*.png')) + |
|
|
glob.glob(os.path.join(lr_dir, '*.jpg')) + |
|
|
glob.glob(os.path.join(lr_dir, '*.jpeg')) |
|
|
) |
|
|
|
|
|
if not self.lr_image_files: |
|
|
raise FileNotFoundError(f"No images found in LR directory: {lr_dir}. Check path and image extensions.") |
|
|
|
|
|
|
|
|
if self.lr_image_files: |
|
|
|
|
|
pass |
|
|
|
|
|
print(f"Found {len(self.lr_image_files)} image pairs in HR='{hr_dir}', LR='{lr_dir}'") |
|
|
print(f"Using LR patch size: {self.patch_size_lr}x{self.patch_size_lr}, HR patch size: {self.patch_size_hr}x{self.patch_size_hr}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.lr_image_files) |
|
|
|
|
|
@staticmethod |
|
|
def get_patch(lr_img, hr_img, patch_size_lr, scale_factor): |
|
|
""" |
|
|
Randomly crops corresponding patches from LR and HR images. |
|
|
|
|
|
Args: |
|
|
lr_img (PIL.Image): Low-resolution image. |
|
|
hr_img (PIL.Image): High-resolution image. |
|
|
patch_size_lr (int): The desired height/width of the LR patch. |
|
|
scale_factor (int): The upscaling factor. |
|
|
|
|
|
Returns: |
|
|
tuple: (lr_patch, hr_patch) PIL.Image objects. |
|
|
""" |
|
|
lr_w, lr_h = lr_img.size |
|
|
hr_w, hr_h = hr_img.size |
|
|
patch_size_hr = patch_size_lr * scale_factor |
|
|
|
|
|
|
|
|
if hr_w != lr_w * scale_factor or hr_h != lr_h * scale_factor: |
|
|
|
|
|
|
|
|
|
|
|
hr_img = hr_img.resize((lr_w * scale_factor, lr_h * scale_factor), resample=Image.BICUBIC) |
|
|
|
|
|
|
|
|
|
|
|
if lr_w < patch_size_lr or lr_h < patch_size_lr: |
|
|
|
|
|
|
|
|
lr_img = lr_img.resize((max(lr_w, patch_size_lr), max(lr_h, patch_size_lr)), resample=Image.BICUBIC) |
|
|
hr_img = hr_img.resize((lr_img.width * scale_factor, lr_img.height * scale_factor), resample=Image.BICUBIC) |
|
|
lr_w, lr_h = lr_img.size |
|
|
|
|
|
|
|
|
lr_x = random.randrange(0, lr_w - patch_size_lr + 1) |
|
|
lr_y = random.randrange(0, lr_h - patch_size_lr + 1) |
|
|
|
|
|
|
|
|
hr_x = lr_x * scale_factor |
|
|
hr_y = lr_y * scale_factor |
|
|
|
|
|
|
|
|
|
|
|
lr_patch = lr_img.crop((lr_x, lr_y, lr_x + patch_size_lr, lr_y + patch_size_lr)) |
|
|
hr_patch = hr_img.crop((hr_x, hr_y, hr_x + patch_size_hr, hr_y + patch_size_hr)) |
|
|
|
|
|
return lr_patch, hr_patch |
|
|
|
|
|
@staticmethod |
|
|
def augment_patch(lr_patch, hr_patch): |
|
|
"""Applies simple random augmentations (flip, rotation).""" |
|
|
|
|
|
if random.random() < 0.5: |
|
|
lr_patch = lr_patch.transpose(Image.FLIP_LEFT_RIGHT) |
|
|
hr_patch = hr_patch.transpose(Image.FLIP_LEFT_RIGHT) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return lr_patch, hr_patch |
|
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
lr_path = self.lr_image_files[idx] |
|
|
try: |
|
|
lr_img = Image.open(lr_path).convert('RGB') |
|
|
except Exception as e: |
|
|
print(f"Error opening LR image {lr_path}: {e}") |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
base_name = os.path.basename(lr_path) |
|
|
hr_path = os.path.join(self.hr_dir, base_name) |
|
|
|
|
|
|
|
|
if not os.path.exists(hr_path): |
|
|
base, ext = os.path.splitext(base_name) |
|
|
if f'x{self.scale_factor}' in base: |
|
|
hr_name = base.replace(f'x{self.scale_factor}', '') + ext |
|
|
hr_path_alt = os.path.join(self.hr_dir, hr_name) |
|
|
if os.path.exists(hr_path_alt): |
|
|
hr_path = hr_path_alt |
|
|
else: |
|
|
print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}") |
|
|
return None |
|
|
else: |
|
|
print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}") |
|
|
return None |
|
|
|
|
|
try: |
|
|
hr_img = Image.open(hr_path).convert('RGB') |
|
|
except Exception as e: |
|
|
print(f"Error opening HR image {hr_path}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
lr_patch, hr_patch = self.get_patch(lr_img, hr_img, self.patch_size_lr, self.scale_factor) |
|
|
except ValueError as e: |
|
|
print(f"Error getting patch for {lr_path} (maybe image is smaller than patch size?): {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
lr_patch, hr_patch = self.augment_patch(lr_patch, hr_patch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.transform: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
to_tensor = transforms.ToTensor() |
|
|
lr_tensor = to_tensor(lr_patch) |
|
|
hr_tensor = to_tensor(hr_patch) |
|
|
|
|
|
|
|
|
return {'lr': lr_tensor, 'hr': hr_tensor} |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
print("--- Testing SRDataset with Patching ---") |
|
|
hr_data_dir = './datasets/DIV2K/HR_extracted/DIV2K_train_HR' |
|
|
lr_data_dir = './datasets/DIV2K/DIV2K_train_LR_bicubic/X4' |
|
|
scale = 4 |
|
|
lr_patch_size = 48 |
|
|
|
|
|
if not os.path.isdir(hr_data_dir): print(f"ERROR: HR dir not found: '{hr_data_dir}'") |
|
|
if not os.path.isdir(lr_data_dir): print(f"ERROR: LR dir not found: '{lr_data_dir}'") |
|
|
|
|
|
try: |
|
|
dataset = SRDataset(hr_dir=hr_data_dir, lr_dir=lr_data_dir, |
|
|
scale_factor=scale, patch_size_lr=lr_patch_size) |
|
|
|
|
|
if len(dataset) > 0: |
|
|
print(f"\nSuccessfully loaded dataset with {len(dataset)} image pairs.") |
|
|
|
|
|
|
|
|
print("\n--- Testing __getitem__ ---") |
|
|
num_test_items = 5 |
|
|
for i in range(min(num_test_items, len(dataset))): |
|
|
item = dataset[i] |
|
|
if item is None: |
|
|
print(f"Item {i}: Returned None (Error occurred)") |
|
|
continue |
|
|
|
|
|
lr_p = item['lr'] |
|
|
hr_p = item['hr'] |
|
|
print(f"Item {i}: LR Patch Shape={lr_p.shape}, HR Patch Shape={hr_p.shape}") |
|
|
|
|
|
|
|
|
expected_hr_shape = (3, lr_patch_size * scale, lr_patch_size * scale) |
|
|
if lr_p.shape != (3, lr_patch_size, lr_patch_size) or hr_p.shape != expected_hr_shape: |
|
|
print(f" WARNING: Shape mismatch! LR={lr_p.shape}, HR={hr_p.shape}, Expected HR={expected_hr_shape}") |
|
|
|
|
|
|
|
|
print("\n--- Testing DataLoader with Patches ---") |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
|
def collate_fn_filter_none(batch): |
|
|
batch = list(filter(lambda x: x is not None, batch)) |
|
|
if not batch: |
|
|
return None |
|
|
|
|
|
return torch.utils.data.dataloader.default_collate(batch) |
|
|
|
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, |
|
|
num_workers=0, collate_fn=collate_fn_filter_none) |
|
|
|
|
|
num_test_batches = 3 |
|
|
batch_count = 0 |
|
|
for batch in dataloader: |
|
|
if batch_count >= num_test_batches: |
|
|
break |
|
|
if batch is None: |
|
|
print(f"Skipping an entirely problematic batch.") |
|
|
continue |
|
|
|
|
|
lr_batch = batch['lr'] |
|
|
hr_batch = batch['hr'] |
|
|
print(f"Batch {batch_count}: LR Batch Shape={lr_batch.shape}, HR Batch Shape={hr_batch.shape}") |
|
|
batch_count += 1 |
|
|
|
|
|
if batch_count > 0: |
|
|
print("DataLoader test with patches successful.") |
|
|
else: |
|
|
print("DataLoader test: Could not retrieve any valid batches.") |
|
|
|
|
|
else: |
|
|
print("\nDataset loaded but is empty.") |
|
|
|
|
|
except FileNotFoundError as e: |
|
|
print(f"\nERROR initializing dataset: {e}") |
|
|
except Exception as e: |
|
|
print(f"\nAn unexpected error occurred during dataset testing: {e}") |
|
|
|
|
|
print("\n--- SRDataset Test Finished ---") |