File size: 12,107 Bytes
fd5c0a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import os
import glob
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import random # Needed for random cropping

# --- Updated SRDataset Class ---
class SRDataset(Dataset):
    """
    Custom Dataset for Super-Resolution.
    Loads HR/LR pairs and returns fixed-size patches.
    """
    def __init__(self, hr_dir, lr_dir, scale_factor, patch_size_lr=48, transform=None):
        """
        Args:
            hr_dir (str): Directory with all HR images.
            lr_dir (str): Directory with all LR images (corresponding to hr_dir).
            scale_factor (int): The upscaling factor.
            patch_size_lr (int): The size (height and width) of the LR patch to crop.
                                 HR patch size will be patch_size_lr * scale_factor.
            transform (callable, optional): Optional transform (e.g., data augmentation like flips).
        """
        super(SRDataset, self).__init__() # Call parent constructor
        self.hr_dir = hr_dir
        self.lr_dir = lr_dir
        self.scale_factor = scale_factor
        self.patch_size_lr = patch_size_lr
        self.patch_size_hr = patch_size_lr * scale_factor
        self.transform = transform

        # Find all image files (png, jpg, jpeg) in the LR directory
        self.lr_image_files = sorted(
            glob.glob(os.path.join(lr_dir, '*.png')) +
            glob.glob(os.path.join(lr_dir, '*.jpg')) +
            glob.glob(os.path.join(lr_dir, '*.jpeg'))
        )

        if not self.lr_image_files:
            raise FileNotFoundError(f"No images found in LR directory: {lr_dir}. Check path and image extensions.")

        # --- (Optional Verification Step - can be kept or removed) ---
        if self.lr_image_files:
            # ... (verification code from previous version can go here if desired) ...
            pass

        print(f"Found {len(self.lr_image_files)} image pairs in HR='{hr_dir}', LR='{lr_dir}'")
        print(f"Using LR patch size: {self.patch_size_lr}x{self.patch_size_lr}, HR patch size: {self.patch_size_hr}x{self.patch_size_hr}")

    def __len__(self):
        return len(self.lr_image_files)

    @staticmethod
    def get_patch(lr_img, hr_img, patch_size_lr, scale_factor):
        """
        Randomly crops corresponding patches from LR and HR images.

        Args:
            lr_img (PIL.Image): Low-resolution image.
            hr_img (PIL.Image): High-resolution image.
            patch_size_lr (int): The desired height/width of the LR patch.
            scale_factor (int): The upscaling factor.

        Returns:
            tuple: (lr_patch, hr_patch) PIL.Image objects.
        """
        lr_w, lr_h = lr_img.size
        hr_w, hr_h = hr_img.size
        patch_size_hr = patch_size_lr * scale_factor

        # Ensure HR image dimensions are consistent with LR and scale factor
        if hr_w != lr_w * scale_factor or hr_h != lr_h * scale_factor:
            # Simple fallback: resize HR image to expected size if mismatch occurs
            # This might happen with imperfect downscaling or odd original dimensions
            # print(f"Warning: HR/LR size mismatch ({hr_img.size} vs {lr_img.size} * {scale_factor}). Resizing HR image.")
            hr_img = hr_img.resize((lr_w * scale_factor, lr_h * scale_factor), resample=Image.BICUBIC)

        # Choose random top-left corner for LR patch
        # Ensure the patch fits within the image boundaries
        if lr_w < patch_size_lr or lr_h < patch_size_lr:
             # If LR image is smaller than patch size, resize LR and corresponding HR region
             # This ensures __getitem__ always returns tensors of the target patch size
             lr_img = lr_img.resize((max(lr_w, patch_size_lr), max(lr_h, patch_size_lr)), resample=Image.BICUBIC)
             hr_img = hr_img.resize((lr_img.width * scale_factor, lr_img.height * scale_factor), resample=Image.BICUBIC)
             lr_w, lr_h = lr_img.size # Update dimensions


        lr_x = random.randrange(0, lr_w - patch_size_lr + 1)
        lr_y = random.randrange(0, lr_h - patch_size_lr + 1)

        # Calculate corresponding top-left corner for HR patch
        hr_x = lr_x * scale_factor
        hr_y = lr_y * scale_factor

        # Crop patches
        # PIL crop format is (left, upper, right, lower)
        lr_patch = lr_img.crop((lr_x, lr_y, lr_x + patch_size_lr, lr_y + patch_size_lr))
        hr_patch = hr_img.crop((hr_x, hr_y, hr_x + patch_size_hr, hr_y + patch_size_hr))

        return lr_patch, hr_patch

    @staticmethod
    def augment_patch(lr_patch, hr_patch):
         """Applies simple random augmentations (flip, rotation)."""
         # Random horizontal flip
         if random.random() < 0.5:
             lr_patch = lr_patch.transpose(Image.FLIP_LEFT_RIGHT)
             hr_patch = hr_patch.transpose(Image.FLIP_LEFT_RIGHT)

         # Random vertical flip (less common, can sometimes be excluded)
         # if random.random() < 0.5:
         #     lr_patch = lr_patch.transpose(Image.FLIP_TOP_BOTTOM)
         #     hr_patch = hr_patch.transpose(Image.FLIP_TOP_BOTTOM)

         # Random 90-degree rotation
         # rot_choice = random.choice([0, 1, 2, 3]) # 0: 0 deg, 1: 90 deg, 2: 180 deg, 3: 270 deg
         # if rot_choice != 0:
         #      lr_patch = lr_patch.rotate(90 * rot_choice, expand=True) # expand=True might change size if not square
         #      hr_patch = hr_patch.rotate(90 * rot_choice, expand=True)

         return lr_patch, hr_patch


    def __getitem__(self, idx):
        # Get the full LR image path
        lr_path = self.lr_image_files[idx]
        try:
            lr_img = Image.open(lr_path).convert('RGB')
        except Exception as e:
            print(f"Error opening LR image {lr_path}: {e}")
            # Decide how to handle: return None, raise error, or return dummy
            # Returning None requires careful handling in the DataLoader collate_fn or training loop
            return None # Let collate_fn handle this potentially

        # Construct the corresponding full HR image path
        base_name = os.path.basename(lr_path)
        hr_path = os.path.join(self.hr_dir, base_name)

        # Handle potential alternative HR filenames
        if not os.path.exists(hr_path):
            base, ext = os.path.splitext(base_name)
            if f'x{self.scale_factor}' in base:
                hr_name = base.replace(f'x{self.scale_factor}', '') + ext
                hr_path_alt = os.path.join(self.hr_dir, hr_name)
                if os.path.exists(hr_path_alt):
                    hr_path = hr_path_alt
                else:
                    print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}")
                    return None # Indicate error
            else:
                 print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}")
                 return None # Indicate error

        try:
            hr_img = Image.open(hr_path).convert('RGB')
        except Exception as e:
            print(f"Error opening HR image {hr_path}: {e}")
            return None # Indicate error


        # --- Get Corresponding Patches ---
        try:
            lr_patch, hr_patch = self.get_patch(lr_img, hr_img, self.patch_size_lr, self.scale_factor)
        except ValueError as e: # Catch randrange error if patch size > image size after potential resize
             print(f"Error getting patch for {lr_path} (maybe image is smaller than patch size?): {e}")
             return None


        # --- Apply Augmentations (Optional) ---
        lr_patch, hr_patch = self.augment_patch(lr_patch, hr_patch)


        # --- Apply Custom Transform if provided ---
        # (Currently we pass None, but this is where you'd integrate albumentations etc.)
        if self.transform:
            # A typical transform might operate on numpy arrays
            # lr_np = np.array(lr_patch)
            # hr_np = np.array(hr_patch)
            # transformed = self.transform(image=lr_np, mask=hr_np) # Example syntax
            # lr_patch = Image.fromarray(transformed['image'])
            # hr_patch = Image.fromarray(transformed['mask'])
            pass # Placeholder


        # --- Convert Patches to Tensors ---
        to_tensor = transforms.ToTensor() # Converts PIL image (HWC) [0, 255] to Tensor (CHW) [0.0, 1.0]
        lr_tensor = to_tensor(lr_patch)
        hr_tensor = to_tensor(hr_patch)


        return {'lr': lr_tensor, 'hr': hr_tensor}

# --- Example Usage (for testing the definition) ---
if __name__ == '__main__':
    print("--- Testing SRDataset with Patching ---")
    hr_data_dir = './datasets/DIV2K/HR_extracted/DIV2K_train_HR' # Modify if needed
    lr_data_dir = './datasets/DIV2K/DIV2K_train_LR_bicubic/X4' # Modify if needed
    scale = 4
    lr_patch_size = 48 # Common LR patch size for SR tasks

    if not os.path.isdir(hr_data_dir): print(f"ERROR: HR dir not found: '{hr_data_dir}'")
    if not os.path.isdir(lr_data_dir): print(f"ERROR: LR dir not found: '{lr_data_dir}'")

    try:
        dataset = SRDataset(hr_dir=hr_data_dir, lr_dir=lr_data_dir,
                            scale_factor=scale, patch_size_lr=lr_patch_size)

        if len(dataset) > 0:
            print(f"\nSuccessfully loaded dataset with {len(dataset)} image pairs.")

            # Test getting a single item (patch pair)
            print("\n--- Testing __getitem__ ---")
            num_test_items = 5
            for i in range(min(num_test_items, len(dataset))):
                 item = dataset[i]
                 if item is None:
                      print(f"Item {i}: Returned None (Error occurred)")
                      continue

                 lr_p = item['lr']
                 hr_p = item['hr']
                 print(f"Item {i}: LR Patch Shape={lr_p.shape}, HR Patch Shape={hr_p.shape}")

                 # Verify shapes
                 expected_hr_shape = (3, lr_patch_size * scale, lr_patch_size * scale)
                 if lr_p.shape != (3, lr_patch_size, lr_patch_size) or hr_p.shape != expected_hr_shape:
                      print(f"  WARNING: Shape mismatch! LR={lr_p.shape}, HR={hr_p.shape}, Expected HR={expected_hr_shape}")

            # Test DataLoader with a simple collate function that filters Nones
            print("\n--- Testing DataLoader with Patches ---")
            from torch.utils.data import DataLoader

            # Define a collate_fn that filters out None values returned by __getitem__
            def collate_fn_filter_none(batch):
                batch = list(filter(lambda x: x is not None, batch))
                if not batch: # If all items in the batch failed
                    return None
                # Use default collate on the filtered batch
                return torch.utils.data.dataloader.default_collate(batch)

            # Use batch_size=4 for testing
            dataloader = DataLoader(dataset, batch_size=4, shuffle=True,
                                    num_workers=0, collate_fn=collate_fn_filter_none)

            num_test_batches = 3
            batch_count = 0
            for batch in dataloader:
                 if batch_count >= num_test_batches:
                      break
                 if batch is None:
                      print(f"Skipping an entirely problematic batch.")
                      continue

                 lr_batch = batch['lr']
                 hr_batch = batch['hr']
                 print(f"Batch {batch_count}: LR Batch Shape={lr_batch.shape}, HR Batch Shape={hr_batch.shape}")
                 batch_count += 1

            if batch_count > 0:
                 print("DataLoader test with patches successful.")
            else:
                 print("DataLoader test: Could not retrieve any valid batches.")

        else:
            print("\nDataset loaded but is empty.")

    except FileNotFoundError as e:
         print(f"\nERROR initializing dataset: {e}")
    except Exception as e:
         print(f"\nAn unexpected error occurred during dataset testing: {e}")

    print("\n--- SRDataset Test Finished ---")