Gordon-H commited on
Commit
fd5c0a6
·
verified ·
1 Parent(s): 824c640

Upload 13 files

Browse files
checkpoints/discriminator_epoch_10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a93bbd6522431f63abe8c6821a17efba7e8b7751314a69db2caf2a14e3bda5e
3
+ size 1106902
checkpoints/discriminator_epoch_15.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5af7468f085e6e6d8b8058a0d629151243199d0a994be72f1f9e270d241e77f
3
+ size 1106902
checkpoints/discriminator_epoch_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58d206e029ec8e09a3bda2034cbd1a1170848b5cdcc0861f280890186aa3043c
3
+ size 1106807
checkpoints/generator_epoch_10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a8feff031f9337d18f659075b7a0db41f19d782f4367c026a4cc374f8de2232
3
+ size 6096658
checkpoints/generator_epoch_15.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce25de423b699eef2ffcc51585b049a95efae370520c640834e892a391070654
3
+ size 6096658
checkpoints/generator_epoch_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95f21a754252b0d804fe63660e802f2c7fe435d599d8ba431f58e420c704947d
3
+ size 6096580
dataset.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from PIL import Image
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+ import random # Needed for random cropping
8
+
9
+ # --- Updated SRDataset Class ---
10
+ class SRDataset(Dataset):
11
+ """
12
+ Custom Dataset for Super-Resolution.
13
+ Loads HR/LR pairs and returns fixed-size patches.
14
+ """
15
+ def __init__(self, hr_dir, lr_dir, scale_factor, patch_size_lr=48, transform=None):
16
+ """
17
+ Args:
18
+ hr_dir (str): Directory with all HR images.
19
+ lr_dir (str): Directory with all LR images (corresponding to hr_dir).
20
+ scale_factor (int): The upscaling factor.
21
+ patch_size_lr (int): The size (height and width) of the LR patch to crop.
22
+ HR patch size will be patch_size_lr * scale_factor.
23
+ transform (callable, optional): Optional transform (e.g., data augmentation like flips).
24
+ """
25
+ super(SRDataset, self).__init__() # Call parent constructor
26
+ self.hr_dir = hr_dir
27
+ self.lr_dir = lr_dir
28
+ self.scale_factor = scale_factor
29
+ self.patch_size_lr = patch_size_lr
30
+ self.patch_size_hr = patch_size_lr * scale_factor
31
+ self.transform = transform
32
+
33
+ # Find all image files (png, jpg, jpeg) in the LR directory
34
+ self.lr_image_files = sorted(
35
+ glob.glob(os.path.join(lr_dir, '*.png')) +
36
+ glob.glob(os.path.join(lr_dir, '*.jpg')) +
37
+ glob.glob(os.path.join(lr_dir, '*.jpeg'))
38
+ )
39
+
40
+ if not self.lr_image_files:
41
+ raise FileNotFoundError(f"No images found in LR directory: {lr_dir}. Check path and image extensions.")
42
+
43
+ # --- (Optional Verification Step - can be kept or removed) ---
44
+ if self.lr_image_files:
45
+ # ... (verification code from previous version can go here if desired) ...
46
+ pass
47
+
48
+ print(f"Found {len(self.lr_image_files)} image pairs in HR='{hr_dir}', LR='{lr_dir}'")
49
+ print(f"Using LR patch size: {self.patch_size_lr}x{self.patch_size_lr}, HR patch size: {self.patch_size_hr}x{self.patch_size_hr}")
50
+
51
+ def __len__(self):
52
+ return len(self.lr_image_files)
53
+
54
+ @staticmethod
55
+ def get_patch(lr_img, hr_img, patch_size_lr, scale_factor):
56
+ """
57
+ Randomly crops corresponding patches from LR and HR images.
58
+
59
+ Args:
60
+ lr_img (PIL.Image): Low-resolution image.
61
+ hr_img (PIL.Image): High-resolution image.
62
+ patch_size_lr (int): The desired height/width of the LR patch.
63
+ scale_factor (int): The upscaling factor.
64
+
65
+ Returns:
66
+ tuple: (lr_patch, hr_patch) PIL.Image objects.
67
+ """
68
+ lr_w, lr_h = lr_img.size
69
+ hr_w, hr_h = hr_img.size
70
+ patch_size_hr = patch_size_lr * scale_factor
71
+
72
+ # Ensure HR image dimensions are consistent with LR and scale factor
73
+ if hr_w != lr_w * scale_factor or hr_h != lr_h * scale_factor:
74
+ # Simple fallback: resize HR image to expected size if mismatch occurs
75
+ # This might happen with imperfect downscaling or odd original dimensions
76
+ # print(f"Warning: HR/LR size mismatch ({hr_img.size} vs {lr_img.size} * {scale_factor}). Resizing HR image.")
77
+ hr_img = hr_img.resize((lr_w * scale_factor, lr_h * scale_factor), resample=Image.BICUBIC)
78
+
79
+ # Choose random top-left corner for LR patch
80
+ # Ensure the patch fits within the image boundaries
81
+ if lr_w < patch_size_lr or lr_h < patch_size_lr:
82
+ # If LR image is smaller than patch size, resize LR and corresponding HR region
83
+ # This ensures __getitem__ always returns tensors of the target patch size
84
+ lr_img = lr_img.resize((max(lr_w, patch_size_lr), max(lr_h, patch_size_lr)), resample=Image.BICUBIC)
85
+ hr_img = hr_img.resize((lr_img.width * scale_factor, lr_img.height * scale_factor), resample=Image.BICUBIC)
86
+ lr_w, lr_h = lr_img.size # Update dimensions
87
+
88
+
89
+ lr_x = random.randrange(0, lr_w - patch_size_lr + 1)
90
+ lr_y = random.randrange(0, lr_h - patch_size_lr + 1)
91
+
92
+ # Calculate corresponding top-left corner for HR patch
93
+ hr_x = lr_x * scale_factor
94
+ hr_y = lr_y * scale_factor
95
+
96
+ # Crop patches
97
+ # PIL crop format is (left, upper, right, lower)
98
+ lr_patch = lr_img.crop((lr_x, lr_y, lr_x + patch_size_lr, lr_y + patch_size_lr))
99
+ hr_patch = hr_img.crop((hr_x, hr_y, hr_x + patch_size_hr, hr_y + patch_size_hr))
100
+
101
+ return lr_patch, hr_patch
102
+
103
+ @staticmethod
104
+ def augment_patch(lr_patch, hr_patch):
105
+ """Applies simple random augmentations (flip, rotation)."""
106
+ # Random horizontal flip
107
+ if random.random() < 0.5:
108
+ lr_patch = lr_patch.transpose(Image.FLIP_LEFT_RIGHT)
109
+ hr_patch = hr_patch.transpose(Image.FLIP_LEFT_RIGHT)
110
+
111
+ # Random vertical flip (less common, can sometimes be excluded)
112
+ # if random.random() < 0.5:
113
+ # lr_patch = lr_patch.transpose(Image.FLIP_TOP_BOTTOM)
114
+ # hr_patch = hr_patch.transpose(Image.FLIP_TOP_BOTTOM)
115
+
116
+ # Random 90-degree rotation
117
+ # rot_choice = random.choice([0, 1, 2, 3]) # 0: 0 deg, 1: 90 deg, 2: 180 deg, 3: 270 deg
118
+ # if rot_choice != 0:
119
+ # lr_patch = lr_patch.rotate(90 * rot_choice, expand=True) # expand=True might change size if not square
120
+ # hr_patch = hr_patch.rotate(90 * rot_choice, expand=True)
121
+
122
+ return lr_patch, hr_patch
123
+
124
+
125
+ def __getitem__(self, idx):
126
+ # Get the full LR image path
127
+ lr_path = self.lr_image_files[idx]
128
+ try:
129
+ lr_img = Image.open(lr_path).convert('RGB')
130
+ except Exception as e:
131
+ print(f"Error opening LR image {lr_path}: {e}")
132
+ # Decide how to handle: return None, raise error, or return dummy
133
+ # Returning None requires careful handling in the DataLoader collate_fn or training loop
134
+ return None # Let collate_fn handle this potentially
135
+
136
+ # Construct the corresponding full HR image path
137
+ base_name = os.path.basename(lr_path)
138
+ hr_path = os.path.join(self.hr_dir, base_name)
139
+
140
+ # Handle potential alternative HR filenames
141
+ if not os.path.exists(hr_path):
142
+ base, ext = os.path.splitext(base_name)
143
+ if f'x{self.scale_factor}' in base:
144
+ hr_name = base.replace(f'x{self.scale_factor}', '') + ext
145
+ hr_path_alt = os.path.join(self.hr_dir, hr_name)
146
+ if os.path.exists(hr_path_alt):
147
+ hr_path = hr_path_alt
148
+ else:
149
+ print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}")
150
+ return None # Indicate error
151
+ else:
152
+ print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}")
153
+ return None # Indicate error
154
+
155
+ try:
156
+ hr_img = Image.open(hr_path).convert('RGB')
157
+ except Exception as e:
158
+ print(f"Error opening HR image {hr_path}: {e}")
159
+ return None # Indicate error
160
+
161
+
162
+ # --- Get Corresponding Patches ---
163
+ try:
164
+ lr_patch, hr_patch = self.get_patch(lr_img, hr_img, self.patch_size_lr, self.scale_factor)
165
+ except ValueError as e: # Catch randrange error if patch size > image size after potential resize
166
+ print(f"Error getting patch for {lr_path} (maybe image is smaller than patch size?): {e}")
167
+ return None
168
+
169
+
170
+ # --- Apply Augmentations (Optional) ---
171
+ lr_patch, hr_patch = self.augment_patch(lr_patch, hr_patch)
172
+
173
+
174
+ # --- Apply Custom Transform if provided ---
175
+ # (Currently we pass None, but this is where you'd integrate albumentations etc.)
176
+ if self.transform:
177
+ # A typical transform might operate on numpy arrays
178
+ # lr_np = np.array(lr_patch)
179
+ # hr_np = np.array(hr_patch)
180
+ # transformed = self.transform(image=lr_np, mask=hr_np) # Example syntax
181
+ # lr_patch = Image.fromarray(transformed['image'])
182
+ # hr_patch = Image.fromarray(transformed['mask'])
183
+ pass # Placeholder
184
+
185
+
186
+ # --- Convert Patches to Tensors ---
187
+ to_tensor = transforms.ToTensor() # Converts PIL image (HWC) [0, 255] to Tensor (CHW) [0.0, 1.0]
188
+ lr_tensor = to_tensor(lr_patch)
189
+ hr_tensor = to_tensor(hr_patch)
190
+
191
+
192
+ return {'lr': lr_tensor, 'hr': hr_tensor}
193
+
194
+ # --- Example Usage (for testing the definition) ---
195
+ if __name__ == '__main__':
196
+ print("--- Testing SRDataset with Patching ---")
197
+ hr_data_dir = './datasets/DIV2K/HR_extracted/DIV2K_train_HR' # Modify if needed
198
+ lr_data_dir = './datasets/DIV2K/DIV2K_train_LR_bicubic/X4' # Modify if needed
199
+ scale = 4
200
+ lr_patch_size = 48 # Common LR patch size for SR tasks
201
+
202
+ if not os.path.isdir(hr_data_dir): print(f"ERROR: HR dir not found: '{hr_data_dir}'")
203
+ if not os.path.isdir(lr_data_dir): print(f"ERROR: LR dir not found: '{lr_data_dir}'")
204
+
205
+ try:
206
+ dataset = SRDataset(hr_dir=hr_data_dir, lr_dir=lr_data_dir,
207
+ scale_factor=scale, patch_size_lr=lr_patch_size)
208
+
209
+ if len(dataset) > 0:
210
+ print(f"\nSuccessfully loaded dataset with {len(dataset)} image pairs.")
211
+
212
+ # Test getting a single item (patch pair)
213
+ print("\n--- Testing __getitem__ ---")
214
+ num_test_items = 5
215
+ for i in range(min(num_test_items, len(dataset))):
216
+ item = dataset[i]
217
+ if item is None:
218
+ print(f"Item {i}: Returned None (Error occurred)")
219
+ continue
220
+
221
+ lr_p = item['lr']
222
+ hr_p = item['hr']
223
+ print(f"Item {i}: LR Patch Shape={lr_p.shape}, HR Patch Shape={hr_p.shape}")
224
+
225
+ # Verify shapes
226
+ expected_hr_shape = (3, lr_patch_size * scale, lr_patch_size * scale)
227
+ if lr_p.shape != (3, lr_patch_size, lr_patch_size) or hr_p.shape != expected_hr_shape:
228
+ print(f" WARNING: Shape mismatch! LR={lr_p.shape}, HR={hr_p.shape}, Expected HR={expected_hr_shape}")
229
+
230
+ # Test DataLoader with a simple collate function that filters Nones
231
+ print("\n--- Testing DataLoader with Patches ---")
232
+ from torch.utils.data import DataLoader
233
+
234
+ # Define a collate_fn that filters out None values returned by __getitem__
235
+ def collate_fn_filter_none(batch):
236
+ batch = list(filter(lambda x: x is not None, batch))
237
+ if not batch: # If all items in the batch failed
238
+ return None
239
+ # Use default collate on the filtered batch
240
+ return torch.utils.data.dataloader.default_collate(batch)
241
+
242
+ # Use batch_size=4 for testing
243
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True,
244
+ num_workers=0, collate_fn=collate_fn_filter_none)
245
+
246
+ num_test_batches = 3
247
+ batch_count = 0
248
+ for batch in dataloader:
249
+ if batch_count >= num_test_batches:
250
+ break
251
+ if batch is None:
252
+ print(f"Skipping an entirely problematic batch.")
253
+ continue
254
+
255
+ lr_batch = batch['lr']
256
+ hr_batch = batch['hr']
257
+ print(f"Batch {batch_count}: LR Batch Shape={lr_batch.shape}, HR Batch Shape={hr_batch.shape}")
258
+ batch_count += 1
259
+
260
+ if batch_count > 0:
261
+ print("DataLoader test with patches successful.")
262
+ else:
263
+ print("DataLoader test: Could not retrieve any valid batches.")
264
+
265
+ else:
266
+ print("\nDataset loaded but is empty.")
267
+
268
+ except FileNotFoundError as e:
269
+ print(f"\nERROR initializing dataset: {e}")
270
+ except Exception as e:
271
+ print(f"\nAn unexpected error occurred during dataset testing: {e}")
272
+
273
+ print("\n--- SRDataset Test Finished ---")
loss.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision.models import vgg19, VGG19_Weights
5
+ from torchvision import transforms
6
+
7
+ class PerceptualLoss(nn.Module):
8
+ """
9
+ Calculates the VGG perceptual loss.
10
+
11
+ Uses features from the VGG19 network pretrained on ImageNet.
12
+ Compares features from specific layers for the generated and target images.
13
+ """
14
+ def __init__(self, feature_layers=None, use_l1=True, device='cpu'):
15
+ """
16
+ Args:
17
+ feature_layers (list of int, optional): Indices of VGG19 feature layers to use.
18
+ Defaults correspond to layers before pool1, pool2, pool3, pool4.
19
+ Specifically: relu1_1, relu2_1, relu3_1, relu4_1 in many implementations.
20
+ VGG19 structure: layer indices relate to `features` module.
21
+ use_l1 (bool): If True, use L1 loss between features. If False, use L2 (MSE) loss.
22
+ device (str): 'cuda' or 'cpu'.
23
+ """
24
+ super(PerceptualLoss, self).__init__()
25
+
26
+ # Load pre-trained VGG19 model
27
+ # Ensure you have torchvision installed: pip install torchvision
28
+ try:
29
+ # Recommended way with modern torchvision
30
+ weights = VGG19_Weights.IMAGENET1K_V1
31
+ self.vgg = vgg19(weights=weights).features
32
+ self.preprocess = weights.transforms() # Get the preprocessing expected by the model
33
+ except AttributeError:
34
+ # Fallback for older torchvision versions (might require manual weight download if not cached)
35
+ print("Warning: Using older torchvision VGG19 loading method. Consider upgrading torchvision.")
36
+ self.vgg = vgg19(pretrained=True).features
37
+ # Define standard ImageNet normalization manually if transform isn't available
38
+ self.preprocess = transforms.Compose([
39
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
40
+ ])
41
+
42
+ self.vgg.eval() # Set VGG to evaluation mode
43
+ for param in self.vgg.parameters():
44
+ param.requires_grad = False # Freeze VGG parameters
45
+
46
+ self.vgg = self.vgg.to(device)
47
+ self.device = device
48
+
49
+ # Define the layers to extract features from
50
+ # Common choices are layers before max pooling
51
+ # VGG19 features structure indices:
52
+ # ReLU1_1: 1, ReLU2_1: 6, ReLU3_1: 11, ReLU4_1: 20, ReLU5_1: 29 (Sometimes ReLU5 used too)
53
+ if feature_layers is None:
54
+ # These indices correspond to the output of Conv layers before MaxPool
55
+ # Specifically: conv1_1(0), conv2_1(5), conv3_1(10), conv4_1(19), conv5_1(28)
56
+ # Often the ReLU output right after is used: 1, 6, 11, 20, 29
57
+ self.feature_layers = {1, 6, 11, 20} # Using ReLU outputs before pooling layers 1-4
58
+ # Alternative common set often cited as relu5_4 (index 35 or 36 depending on source):
59
+ # self.feature_layers = {35} # Or use a specific high-level layer
60
+ else:
61
+ self.feature_layers = set(feature_layers)
62
+
63
+ self.loss_fn = nn.L1Loss() if use_l1 else nn.MSELoss()
64
+
65
+ print(f"PerceptualLoss: Using VGG19 features from layers: {sorted(list(self.feature_layers))}")
66
+ print(f"PerceptualLoss: Using {'L1' if use_l1 else 'L2'} distance.")
67
+
68
+
69
+ def forward(self, generated, target):
70
+ """
71
+ Compute the perceptual loss.
72
+
73
+ Args:
74
+ generated (torch.Tensor): The generated image tensor (B, C, H, W). Values [0, 1].
75
+ target (torch.Tensor): The target (ground truth) image tensor (B, C, H, W). Values [0, 1].
76
+
77
+ Returns:
78
+ torch.Tensor: The calculated perceptual loss.
79
+ """
80
+ # Ensure inputs are on the correct device
81
+ generated = generated.to(self.device)
82
+ target = target.to(self.device)
83
+
84
+ # Preprocess images for VGG
85
+ # VGG expects inputs normalized based on ImageNet stats
86
+ # The transform might handle dtype and range, but let's be explicit
87
+ generated_norm = self.preprocess(generated)
88
+ target_norm = self.preprocess(target)
89
+
90
+ # Extract features
91
+ loss = 0.0
92
+ current_layer_idx = 0
93
+ max_needed_layer = max(self.feature_layers) if self.feature_layers else 0
94
+
95
+ # Iterate through VGG layers, extracting features only from specified layers
96
+ for layer in self.vgg:
97
+ # Compute features for both images up to the current layer
98
+ generated_norm = layer(generated_norm)
99
+ target_norm = layer(target_norm)
100
+
101
+ # If the current layer index is one we want to use for loss calculation
102
+ if current_layer_idx in self.feature_layers:
103
+ loss += self.loss_fn(generated_norm, target_norm)
104
+
105
+ # Stop iterating if we've passed the last needed layer
106
+ if current_layer_idx >= max_needed_layer:
107
+ break
108
+
109
+ current_layer_idx += 1
110
+
111
+ return loss
112
+
113
+
114
+ # --- Example Usage (for testing the definition) ---
115
+ if __name__ == '__main__':
116
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117
+ print(f"Using device: {device}")
118
+
119
+ # Create dummy images (Batch Size, Channels, Height, Width)
120
+ # Note: Images should be in the range [0, 1] for standard transforms
121
+ dummy_generated = torch.rand(2, 3, 96, 96).to(device) # Example size (must match target)
122
+ dummy_target = torch.rand(2, 3, 96, 96).to(device)
123
+
124
+ # Instantiate the loss function
125
+ # Default layers: {1, 6, 11, 20} (Relu1_1, Relu2_1, Relu3_1, Relu4_1 outputs)
126
+ perceptual_loss_l1 = PerceptualLoss(device=device, use_l1=True)
127
+ # Example with different layers and L2 loss
128
+ # perceptual_loss_l2 = PerceptualLoss(feature_layers={35}, device=device, use_l1=False)
129
+
130
+ # Calculate loss
131
+ loss_val_l1 = perceptual_loss_l1(dummy_generated, dummy_target)
132
+ # loss_val_l2 = perceptual_loss_l2(dummy_generated, dummy_target)
133
+
134
+ print(f"\nCalculated Perceptual Loss (L1, default layers): {loss_val_l1.item()}")
135
+ # print(f"Calculated Perceptual Loss (L2, layer 35): {loss_val_l2.item()}")
136
+
137
+ assert loss_val_l1.item() >= 0, "Loss should be non-negative"
138
+ print("\nPerceptualLoss definition test successful!")
models.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import os
5
+
6
+ # --- ResidualBlock, Upsampler, and Generator classes remain the same ---
7
+ class ResidualBlock(nn.Module):
8
+ def __init__(self, num_features, kernel_size=3, bn=False, act=nn.ReLU(True), res_scale=1.0):
9
+ super(ResidualBlock, self).__init__()
10
+ padding = kernel_size // 2
11
+ m = []
12
+ m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding))
13
+ if bn: m.append(nn.BatchNorm2d(num_features))
14
+ m.append(act)
15
+ m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding))
16
+ if bn: m.append(nn.BatchNorm2d(num_features))
17
+ self.body = nn.Sequential(*m)
18
+ self.res_scale = res_scale
19
+ def forward(self, x):
20
+ res = self.body(x).mul(self.res_scale)
21
+ res += x
22
+ return res
23
+
24
+ class Upsampler(nn.Module):
25
+ def __init__(self, scale_factor, num_features, act=nn.ReLU(True)):
26
+ super(Upsampler, self).__init__()
27
+ m = []
28
+ m.append(nn.Conv2d(num_features, num_features * (scale_factor ** 2), kernel_size=3, padding=1))
29
+ m.append(nn.PixelShuffle(scale_factor))
30
+ if act: m.append(act)
31
+ self.body = nn.Sequential(*m)
32
+ def forward(self, x):
33
+ return self.body(x)
34
+
35
+ class Generator(nn.Module):
36
+ def __init__(self, scale_factor=4, in_channels=3, out_channels=3, num_features=64, num_res_blocks=16, res_scale=1.0):
37
+ super(Generator, self).__init__()
38
+ self.scale_factor = scale_factor
39
+ act = nn.ReLU(True)
40
+ self.head = nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1)
41
+ res_blocks = [ResidualBlock(num_features, kernel_size=3, act=act, res_scale=res_scale) for _ in range(num_res_blocks)]
42
+ res_blocks.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1))
43
+ self.body = nn.Sequential(*res_blocks)
44
+ m_tail = []
45
+ if (scale_factor & (scale_factor - 1)) == 0:
46
+ for _ in range(int(math.log2(scale_factor))):
47
+ m_tail.append(Upsampler(scale_factor=2, num_features=num_features, act=None))
48
+ elif scale_factor == 3:
49
+ m_tail.append(Upsampler(scale_factor=3, num_features=num_features, act=None))
50
+ else:
51
+ raise NotImplementedError(f"Scale factor {scale_factor} not directly supported by this simple upsampler.")
52
+ self.tail = nn.Sequential(*m_tail)
53
+ self.final_conv = nn.Conv2d(num_features, out_channels, kernel_size=3, padding=1)
54
+
55
+ def forward(self, lr_img):
56
+ x = self.head(lr_img)
57
+ res = self.body(x)
58
+ res += x
59
+ x = self.tail(res)
60
+ x = self.final_conv(x)
61
+ return x
62
+
63
+ # +++ NEW Discriminator Class +++
64
+ class Discriminator(nn.Module):
65
+ """
66
+ Simple CNN Discriminator Network (PatchGAN style is common but this is simpler).
67
+ Takes an image (real HR or generated SR) and outputs a single logit.
68
+ """
69
+ def __init__(self, in_channels=3, num_features_start=64, num_blocks=4):
70
+ super(Discriminator, self).__init__()
71
+
72
+ # Initial block
73
+ layers = [
74
+ nn.Conv2d(in_channels, num_features_start, kernel_size=3, stride=1, padding=1),
75
+ nn.LeakyReLU(0.2, inplace=True)
76
+ ]
77
+
78
+ current_features = num_features_start
79
+ for i in range(num_blocks):
80
+ stride = 1 if i % 2 == 0 else 2 # Downsample every other block
81
+ next_features = current_features * 2 if stride == 2 else current_features
82
+ layers.extend([
83
+ nn.Conv2d(current_features, next_features, kernel_size=3, stride=stride, padding=1),
84
+ nn.BatchNorm2d(next_features), # BatchNorm is common in discriminators
85
+ nn.LeakyReLU(0.2, inplace=True)
86
+ ])
87
+ current_features = next_features
88
+
89
+ self.features = nn.Sequential(*layers)
90
+
91
+ # Classifier part - adjust input features based on final conv output size
92
+ # We need to know the output size of the feature extractor to define the Linear layer.
93
+ # Using AdaptiveAvgPool2d makes it independent of the input image size.
94
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
95
+ self.classifier = nn.Sequential(
96
+ nn.Linear(current_features, 100), # Example intermediate size
97
+ nn.LeakyReLU(0.2, inplace=True),
98
+ nn.Linear(100, 1) # Output a single logit (no sigmoid here)
99
+ )
100
+
101
+ def forward(self, img):
102
+ """
103
+ Args:
104
+ img (torch.Tensor): Input image tensor (B, C, H, W), either real HR or fake SR.
105
+ Returns:
106
+ torch.Tensor: Output logits (B, 1). Higher values -> more likely "real".
107
+ """
108
+ batch_size = img.size(0)
109
+ features = self.features(img)
110
+ pooled = self.avgpool(features)
111
+ # Flatten the output of avgpool for the linear layer
112
+ pooled = pooled.view(batch_size, -1)
113
+ output = self.classifier(pooled)
114
+ return output
115
+
116
+ # --- Main block for testing and saving ---
117
+ if __name__ == '__main__':
118
+ # --- Generator Test (as before) ---
119
+ SCALE = 4
120
+ GEN_FEATURES = 64
121
+ GEN_RES_BLOCKS = 8
122
+ save_dir = "saved_models"
123
+ os.makedirs(save_dir, exist_ok=True)
124
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
125
+ print(f"Using device: {device}")
126
+
127
+ # Dummy LR input for Generator
128
+ gen_batch_size = 1
129
+ lr_height = 32
130
+ lr_width = 32
131
+ in_channels = 3
132
+ dummy_lr = torch.randn(gen_batch_size, in_channels, lr_height, lr_width).to(device)
133
+ print(f"Dummy LR input shape (Generator): {dummy_lr.shape}")
134
+
135
+ generator = Generator(scale_factor=SCALE, num_features=GEN_FEATURES, num_res_blocks=GEN_RES_BLOCKS).to(device)
136
+ generator.eval()
137
+ with torch.no_grad():
138
+ output_sr = generator(dummy_lr)
139
+ print(f"Output SR shape (Generator): {output_sr.shape}")
140
+ # ... (rest of generator verification and saving code remains here) ...
141
+ print("\nGenerator definition test successful!")
142
+ num_params_gen = sum(p.numel() for p in generator.parameters() if p.requires_grad)
143
+ print(f"Generator - Number of trainable parameters: {num_params_gen:,}")
144
+ # ... (Saving code as before) ...
145
+
146
+ print("\n--- Testing Discriminator ---")
147
+ # --- Discriminator Test ---
148
+ DISC_FEATURES = 64 # Starting features for discriminator
149
+ DISC_BLOCKS = 3 # Number of conv blocks in discriminator
150
+
151
+ # Dummy HR/SR input for Discriminator (must match Generator's output size)
152
+ disc_batch_size = 4 # Can be different from generator test batch size
153
+ hr_height = output_sr.shape[2] # Use the calculated HR height
154
+ hr_width = output_sr.shape[3] # Use the calculated HR width
155
+ dummy_hr = torch.randn(disc_batch_size, in_channels, hr_height, hr_width).to(device)
156
+ print(f"Dummy HR/SR input shape (Discriminator): {dummy_hr.shape}")
157
+
158
+ # Instantiate the Discriminator
159
+ discriminator = Discriminator(in_channels=in_channels,
160
+ num_features_start=DISC_FEATURES,
161
+ num_blocks=DISC_BLOCKS).to(device)
162
+ discriminator.eval() # Set to evaluation mode for testing
163
+
164
+ # print(discriminator) # Optional: Print structure
165
+
166
+ # Perform a forward pass
167
+ with torch.no_grad():
168
+ output_logits = discriminator(dummy_hr)
169
+
170
+ print(f"Output Logits shape (Discriminator): {output_logits.shape}")
171
+
172
+ # Verify output shape
173
+ expected_disc_shape = (disc_batch_size, 1)
174
+ assert output_logits.shape == expected_disc_shape, \
175
+ f"Discriminator output shape mismatch! Expected {expected_disc_shape}, got {output_logits.shape}"
176
+
177
+ print("Discriminator definition test successful!")
178
+
179
+ # Optional: Count parameters
180
+ num_params_disc = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
181
+ print(f"Discriminator - Number of trainable parameters: {num_params_disc:,}")
prep.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import zipfile
4
+ import requests
5
+ import argparse
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+
9
+ # --- Helper Functions ---
10
+
11
+ def download_file(url, dest_path, chunk_size=8192):
12
+ """Downloads a file from a URL to a destination path with progress bar."""
13
+ try:
14
+ response = requests.get(url, stream=True, timeout=30) # Added timeout
15
+ response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
16
+
17
+ total_size = int(response.headers.get('content-length', 0))
18
+
19
+ print(f"Downloading {os.path.basename(dest_path)} ({total_size / (1024*1024):.2f} MB)...")
20
+ with open(dest_path, 'wb') as f, tqdm(
21
+ desc=os.path.basename(dest_path),
22
+ total=total_size,
23
+ unit='iB',
24
+ unit_scale=True,
25
+ unit_divisor=1024,
26
+ ) as bar:
27
+ for chunk in response.iter_content(chunk_size=chunk_size):
28
+ size = f.write(chunk)
29
+ bar.update(size)
30
+ print(f"Download complete: {dest_path}")
31
+ return True
32
+
33
+ except requests.exceptions.RequestException as e:
34
+ print(f"Error downloading {url}: {e}")
35
+ # Clean up partially downloaded file if it exists
36
+ if os.path.exists(dest_path):
37
+ os.remove(dest_path)
38
+ return False
39
+ except Exception as e:
40
+ print(f"An unexpected error occurred during download: {e}")
41
+ if os.path.exists(dest_path):
42
+ os.remove(dest_path)
43
+ return False
44
+
45
+
46
+ def unzip_file(zip_path, extract_to):
47
+ """Unzips a file to a specified directory."""
48
+ print(f"Extracting {os.path.basename(zip_path)} to {extract_to}...")
49
+ try:
50
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
51
+ # You could add a progress bar here for large zips if needed
52
+ # using zip_ref.infolist() and iterating extraction, but
53
+ # extractall is usually efficient enough.
54
+ zip_ref.extractall(extract_to)
55
+ print("Extraction complete.")
56
+ return True
57
+ except zipfile.BadZipFile:
58
+ print(f"Error: Invalid or corrupted zip file: {zip_path}")
59
+ return False
60
+ except Exception as e:
61
+ print(f"An error occurred during extraction: {e}")
62
+ return False
63
+
64
+ def find_image_dir(base_path, expected_subdir_suffix='_HR'):
65
+ """
66
+ Tries to find the actual directory containing images after extraction.
67
+ Handles cases where unzip creates an extra top-level folder.
68
+ """
69
+ # Check if images are directly in base_path
70
+ if glob.glob(os.path.join(base_path, '*.png')) or \
71
+ glob.glob(os.path.join(base_path, '*.jpg')) or \
72
+ glob.glob(os.path.join(base_path, '*.jpeg')):
73
+ return base_path
74
+
75
+ # Check common pattern: base_path/DatasetName_HR/
76
+ potential_dirs = [d for d in glob.glob(os.path.join(base_path, '*')) if os.path.isdir(d)]
77
+ if len(potential_dirs) == 1:
78
+ subdir = potential_dirs[0]
79
+ # Check if this subdir contains images or ends with the expected suffix
80
+ if subdir.endswith(expected_subdir_suffix) or \
81
+ glob.glob(os.path.join(subdir, '*.png')) or \
82
+ glob.glob(os.path.join(subdir, '*.jpg')) or \
83
+ glob.glob(os.path.join(subdir, '*.jpeg')):
84
+ print(f"Found image directory: {subdir}")
85
+ return subdir
86
+
87
+ # Fallback if specific pattern not found, maybe it's still just base_path
88
+ print(f"Warning: Could not definitively locate image subdirectory in {base_path}. Assuming images are directly within or in a single nested folder.")
89
+ # If we found exactly one directory, return that, otherwise return the original path
90
+ return potential_dirs[0] if len(potential_dirs) == 1 else base_path
91
+
92
+
93
+ def downsample_images(hr_dir, lr_dir, scale_factor):
94
+ """Downsamples HR images using bicubic interpolation."""
95
+ if not os.path.exists(lr_dir):
96
+ os.makedirs(lr_dir)
97
+ print(f"Created LR directory: {lr_dir}")
98
+
99
+ hr_images = glob.glob(os.path.join(hr_dir, '*.png')) + \
100
+ glob.glob(os.path.join(hr_dir, '*.jpg')) + \
101
+ glob.glob(os.path.join(hr_dir, '*.jpeg'))
102
+
103
+ if not hr_images:
104
+ print(f"Error: No images found in the determined HR directory: {hr_dir}")
105
+ return False
106
+
107
+ print(f"Found {len(hr_images)} HR images in {hr_dir}. Starting downsampling (x{scale_factor})...")
108
+
109
+ processed_count = 0
110
+ for hr_path in tqdm(hr_images, desc=f"Downsampling x{scale_factor}"):
111
+ try:
112
+ hr_img = Image.open(hr_path).convert('RGB') # Ensure RGB
113
+ hr_width, hr_height = hr_img.size
114
+
115
+ lr_width = hr_width // scale_factor
116
+ lr_height = hr_height // scale_factor
117
+
118
+ if lr_width == 0 or lr_height == 0:
119
+ print(f"\nWarning: Image {os.path.basename(hr_path)} is too small ({hr_width}x{hr_height}) for scale factor {scale_factor}. Skipping.")
120
+ continue
121
+
122
+ lr_img = hr_img.resize((lr_width, lr_height), resample=Image.BICUBIC)
123
+
124
+ base_name = os.path.basename(hr_path)
125
+ lr_save_path = os.path.join(lr_dir, base_name)
126
+ lr_img.save(lr_save_path)
127
+ processed_count += 1
128
+
129
+ except Exception as e:
130
+ print(f"\nError processing {hr_path}: {e}")
131
+
132
+ print(f"Downsampling complete. Processed {processed_count}/{len(hr_images)} images.")
133
+ return processed_count > 0
134
+
135
+
136
+ # --- Main Execution ---
137
+
138
+ if __name__ == "__main__":
139
+ parser = argparse.ArgumentParser(description="Download and prepare dataset for Super-Resolution.")
140
+ parser.add_argument('--url', type=str, default='https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip', help='URL of the dataset zip file (default: DIV2K Train HR).')
141
+ parser.add_argument('--base_dir', type=str, default='./datasets', help='Base directory to store datasets.')
142
+ parser.add_argument('--dataset_name', type=str, default='DIV2K', help='Name for the dataset folder.')
143
+ parser.add_argument('--scale', type=int, default=4, help='Downsampling scale factor (e.g., 4 for x4).')
144
+ parser.add_argument('--force', action='store_true', help='Force redownload and reprocessing even if data exists.')
145
+
146
+ args = parser.parse_args()
147
+
148
+ # --- Define Paths ---
149
+ dataset_base_path = os.path.join(args.base_dir, args.dataset_name)
150
+ zip_filename = os.path.basename(args.url)
151
+ zip_save_path = os.path.join(dataset_base_path, zip_filename)
152
+ hr_extract_base = os.path.join(dataset_base_path, 'HR_extracted') # Temp extraction location
153
+ # We will determine the *actual* HR image dir after extraction
154
+ lr_save_dir = os.path.join(dataset_base_path, f'DIV2K_train_LR_bicubic/X{args.scale}') # Following previous convention
155
+
156
+ print(f"--- Configuration ---")
157
+ print(f"Dataset URL: {args.url}")
158
+ print(f"Base Directory: {args.base_dir}")
159
+ print(f"Dataset Name: {args.dataset_name}")
160
+ print(f"Target Scale: x{args.scale}")
161
+ print(f"Zip Save Path: {zip_save_path}")
162
+ print(f"Initial Extract Path: {hr_extract_base}")
163
+ print(f"LR Save Path: {lr_save_dir}")
164
+ print(f"Force Re-run: {args.force}")
165
+ print(f"--------------------")
166
+
167
+ # --- Create Base Directory ---
168
+ os.makedirs(dataset_base_path, exist_ok=True)
169
+
170
+ # --- Step 1: Download ---
171
+ hr_dir_exists = os.path.isdir(hr_extract_base) # Check if base extraction dir exists
172
+ download_needed = not os.path.exists(zip_save_path) or args.force
173
+
174
+ if download_needed:
175
+ if args.force and os.path.exists(zip_save_path):
176
+ print("Force enabled: Removing existing zip file...")
177
+ os.remove(zip_save_path)
178
+ if not download_file(args.url, zip_save_path):
179
+ print("Exiting due to download failure.")
180
+ exit(1)
181
+ elif hr_dir_exists: # If zip exists and hr dir exists, assume download & unzip ok unless forced
182
+ print("Zip file already exists. Skipping download (use --force to override).")
183
+ else: # Zip exists but HR dir doesn't - need to unzip
184
+ print("Zip file found, but extraction directory missing. Will proceed to unzip.")
185
+
186
+
187
+ # --- Step 2: Unzip ---
188
+ # Check if the *potential* content directory already exists. Be a bit lenient here.
189
+ # A more robust check would be to look inside the zip first or check for specific files.
190
+ unzip_needed = not hr_dir_exists or args.force
191
+
192
+ actual_hr_dir = None # Will store the path to the actual images
193
+
194
+ if unzip_needed:
195
+ if args.force and hr_dir_exists:
196
+ print("Force enabled: Removing existing extraction directory...")
197
+ import shutil
198
+ shutil.rmtree(hr_extract_base) # Careful! Removes directory and contents
199
+
200
+ if not os.path.exists(zip_save_path):
201
+ print("Error: Zip file not found, cannot unzip. Please check download step or path.")
202
+ exit(1)
203
+
204
+ os.makedirs(hr_extract_base, exist_ok=True) # Ensure extraction target exists
205
+ if not unzip_file(zip_save_path, hr_extract_base):
206
+ print("Exiting due to extraction failure.")
207
+ exit(1)
208
+ # Find the actual directory containing images post-extraction
209
+ actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR') # e.g., DIV2K_HR
210
+ if not actual_hr_dir or not (glob.glob(os.path.join(actual_hr_dir, '*.png')) or glob.glob(os.path.join(actual_hr_dir, '*.jpg'))):
211
+ print(f"Error: Could not locate the directory with HR images within {hr_extract_base} after extraction.")
212
+ exit(1)
213
+ print(f"Located HR images in: {actual_hr_dir}")
214
+
215
+ else:
216
+ print("HR extraction directory already exists. Skipping unzip (use --force to override).")
217
+ # Try to find the HR dir even if we skipped unzipping
218
+ actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR')
219
+ if not actual_hr_dir:
220
+ print(f"Error: Could not locate the directory with HR images within existing {hr_extract_base}.")
221
+ exit(1)
222
+ print(f"Using existing HR images from: {actual_hr_dir}")
223
+
224
+
225
+ # --- Step 3: Process (Downsample) ---
226
+ lr_dir_exists_and_populated = os.path.isdir(lr_save_dir) and len(os.listdir(lr_save_dir)) > 0
227
+ processing_needed = not lr_dir_exists_and_populated or args.force
228
+
229
+ if processing_needed:
230
+ if args.force and lr_dir_exists_and_populated:
231
+ print("Force enabled: Removing existing LR directory...")
232
+ import shutil
233
+ shutil.rmtree(lr_save_dir) # Careful!
234
+
235
+ if not actual_hr_dir:
236
+ print("Error: Cannot proceed with downsampling, HR image directory not determined.")
237
+ exit(1)
238
+
239
+ if not downsample_images(actual_hr_dir, lr_save_dir, args.scale):
240
+ print("Downsampling process failed or produced no images.")
241
+ # Optionally exit here depending on desired behavior
242
+ # exit(1)
243
+ else:
244
+ print("Downsampling finished successfully.")
245
+ else:
246
+ print("LR directory already exists and is populated. Skipping downsampling (use --force to override).")
247
+
248
+
249
+ print("\n--- Script Finished ---")
250
+ print(f"HR images should be available in/under: {actual_hr_dir}")
251
+ print(f"LR images (x{args.scale}) should be available in: {lr_save_dir}")
252
+ print("You can now use these directories with the SRDataset class.")
saved_models/generator_x4_f64_b8_untrained.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ea5ec7bb7ec436c504f98cf3380a7b2258bf3730cf4ae726b838dd7df52d0b1
3
+ size 3717459
saved_models/generator_x4_f64_b8_untrained.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de3bc0ab6790bede102d5a40fd5122bbff83e05b22331f9cc983eb76aace56db
3
+ size 3722508
train.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RUN python train.py --epochs 2 --batch_size 2 --subset 10 --num_workers 0 --cpu --patch_size 48
2
+ import torch
3
+ import torch.optim as optim
4
+ import torch.nn as nn
5
+ from torch.utils.data import DataLoader
6
+ import os
7
+ import argparse
8
+ from tqdm import tqdm
9
+ import time
10
+
11
+ # Import custom modules
12
+ from dataset import SRDataset # Make sure dataset.py is in the same directory
13
+ from models import Generator, Discriminator # Make sure models.py is in the same directory
14
+ from loss import PerceptualLoss # Make sure loss.py is in the same directory
15
+
16
+ def train(args):
17
+ # --- 1. Setup ---
18
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
19
+ print(f"Using device: {device}")
20
+
21
+ # Create directories for saving models and potentially logs/outputs
22
+ os.makedirs(args.save_dir, exist_ok=True)
23
+
24
+ # --- 2. Data ---
25
+ print("Loading dataset...")
26
+ # Note: args.hr_dir and args.lr_dir are assumed to be valid paths by this point
27
+ # due to checks in the __main__ block
28
+ try:
29
+ train_dataset = SRDataset(hr_dir=args.hr_dir, lr_dir=args.lr_dir, scale_factor=args.scale, patch_size_lr=args.patch_size)
30
+ except FileNotFoundError as e:
31
+ print(f"Error creating dataset: {e}")
32
+ print("Please ensure the specified HR and LR directories contain correctly named image files.")
33
+ exit(1)
34
+ except Exception as e:
35
+ print(f"An unexpected error occurred while creating the dataset: {e}")
36
+ exit(1)
37
+
38
+
39
+ # Use a smaller subset for initial testing on CPU if needed
40
+ if args.subset > 0 and args.subset < len(train_dataset):
41
+ print(f"Using a subset of {args.subset} images for training.")
42
+ indices = torch.randperm(len(train_dataset))[:args.subset]
43
+ train_dataset = torch.utils.data.Subset(train_dataset, indices)
44
+ elif args.subset >= len(train_dataset) and len(train_dataset) > 0 :
45
+ print(f"Subset size ({args.subset}) is >= dataset size ({len(train_dataset)}). Using full dataset.")
46
+
47
+
48
+ if len(train_dataset) == 0:
49
+ print(f"Error: Dataset is empty after attempting to load. Please check HR dir '{args.hr_dir}' and LR dir '{args.lr_dir}'")
50
+ return
51
+
52
+ train_loader = DataLoader(
53
+ train_dataset,
54
+ batch_size=args.batch_size,
55
+ shuffle=True,
56
+ num_workers=args.num_workers, # Set to 0 if you encounter issues on Windows/macOS
57
+ pin_memory=True if device == 'cuda' else False # pin_memory only useful for GPU
58
+ )
59
+ print(f"Dataset loaded: {len(train_dataset)} training images.")
60
+ print(f"Dataloader: {len(train_loader)} batches per epoch.")
61
+
62
+
63
+ # --- 3. Models ---
64
+ print("Initializing models...")
65
+ generator = Generator(scale_factor=args.scale,
66
+ num_features=args.gen_features,
67
+ num_res_blocks=args.gen_blocks).to(device)
68
+
69
+ discriminator = Discriminator(in_channels=3, # Assuming RGB input for discriminator
70
+ num_features_start=args.disc_features,
71
+ num_blocks=args.disc_blocks).to(device)
72
+
73
+ print(f"Generator params: {sum(p.numel() for p in generator.parameters()):,}")
74
+ print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}")
75
+
76
+ # --- 4. Loss Functions ---
77
+ print("Initializing loss functions...")
78
+ # Content Loss (Pixel-wise) - L1 is common for SR
79
+ content_loss_criterion = nn.L1Loss().to(device)
80
+
81
+ # Adversarial Loss - Measures how well G fools D and D identifies fakes
82
+ adversarial_loss_criterion = nn.BCEWithLogitsLoss().to(device) # More stable than BCELoss + Sigmoid
83
+
84
+ # Perceptual Loss (VGG-based)
85
+ try:
86
+ perceptual_loss_criterion = PerceptualLoss(device=device, use_l1=True) # Using L1 feature distance
87
+ except Exception as e:
88
+ print(f"Error initializing Perceptual Loss (check VGG weights download/torchvision install): {e}")
89
+ exit(1)
90
+
91
+
92
+ # --- 5. Optimizers ---
93
+ print("Initializing optimizers...")
94
+ optimizer_g = optim.Adam(generator.parameters(), lr=args.lr_gen, betas=(0.9, 0.999))
95
+ optimizer_d = optim.Adam(discriminator.parameters(), lr=args.lr_disc, betas=(0.9, 0.999))
96
+
97
+ # --- Optional: Learning Rate Scheduler ---
98
+ # Example: scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=args.lr_decay_step, gamma=0.5)
99
+ # Example: scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.5)
100
+
101
+ # --- 6. Training Loop ---
102
+ print("\n--- Starting Training ---")
103
+ start_time = time.time()
104
+
105
+ for epoch in range(1, args.epochs + 1):
106
+ generator.train() # Set generator to training mode
107
+ discriminator.train() # Set discriminator to training mode
108
+ epoch_loss_g = 0.0
109
+ epoch_loss_d = 0.0
110
+ epoch_start_time = time.time()
111
+
112
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=True) # leave=True to keep bar after epoch
113
+
114
+ for batch_idx, batch in enumerate(progress_bar):
115
+ # Ensure batch is valid (dataset loader might return None on error in __getitem__)
116
+ if batch is None:
117
+ print(f"Warning: Skipping problematic batch at index {batch_idx}")
118
+ continue
119
+
120
+ try:
121
+ lr_images = batch['lr'].to(device) # Low-resolution images
122
+ hr_images = batch['hr'].to(device) # High-resolution (ground truth) images
123
+ except KeyError as e:
124
+ print(f"Error accessing batch data: {e}. Check SRDataset's __getitem__ return format.")
125
+ continue # Skip this batch
126
+
127
+ # Create labels for adversarial loss
128
+ # Real labels = 1, Fake labels = 0
129
+ # Add some noise or use soft labels (e.g., 0.9 instead of 1.0) can sometimes help stabilize GAN training
130
+ real_labels = torch.ones((hr_images.size(0), 1)).to(device)
131
+ fake_labels = torch.zeros((hr_images.size(0), 1)).to(device)
132
+
133
+ # ---------------------
134
+ # Train Discriminator
135
+ # ---------------------
136
+ optimizer_d.zero_grad()
137
+
138
+ # Generate fake HR images
139
+ # Use torch.no_grad() for generator forward pass when only training discriminator
140
+ with torch.no_grad():
141
+ fake_sr_images = generator(lr_images) # No need to detach() if already in no_grad context
142
+
143
+ # Loss for real images
144
+ real_logits = discriminator(hr_images)
145
+ loss_d_real = adversarial_loss_criterion(real_logits, real_labels)
146
+
147
+ # Loss for fake images
148
+ fake_logits = discriminator(fake_sr_images) # Use the generated fakes
149
+ loss_d_fake = adversarial_loss_criterion(fake_logits, fake_labels)
150
+
151
+ # Total discriminator loss
152
+ loss_d = (loss_d_real + loss_d_fake) / 2
153
+
154
+ # Backpropagate and update Discriminator
155
+ loss_d.backward()
156
+ # Optional: Gradient clipping for Discriminator (can help stability)
157
+ # torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
158
+ optimizer_d.step()
159
+
160
+
161
+ # -----------------
162
+ # Train Generator
163
+ # (Typically done less frequently than discriminator, e.g., every k steps,
164
+ # but for simplicity here we do it every step)
165
+ # -----------------
166
+ optimizer_g.zero_grad()
167
+
168
+ # Generate fake HR images (this time track gradients for G)
169
+ generated_sr_images = generator(lr_images)
170
+
171
+ # --- Calculate Generator Losses ---
172
+ # 1. Content Loss (e.g., L1 distance between generated and real HR)
173
+ loss_content = content_loss_criterion(generated_sr_images, hr_images)
174
+
175
+ # 2. Perceptual Loss (VGG feature distance)
176
+ loss_perceptual = perceptual_loss_criterion(generated_sr_images, hr_images)
177
+
178
+ # 3. Adversarial Loss (how well G fools D)
179
+ # We want the discriminator to output 'real' (1) for the generated images
180
+ # Pass generated images through the discriminator (ensure D is not in no_grad context here)
181
+ generated_logits = discriminator(generated_sr_images)
182
+ loss_adversarial = adversarial_loss_criterion(generated_logits, real_labels) # Use real_labels!
183
+
184
+ # --- Combine Generator Losses ---
185
+ # Weights control the balance between pixel accuracy, perceptual quality, and realism
186
+ loss_g = (args.lambda_content * loss_content +
187
+ args.lambda_percep * loss_perceptual +
188
+ args.lambda_adv * loss_adversarial)
189
+
190
+ # Backpropagate and update Generator
191
+ loss_g.backward()
192
+ # Optional: Gradient clipping for Generator
193
+ # torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
194
+ optimizer_g.step()
195
+
196
+ # --- Update running losses and progress bar ---
197
+ epoch_loss_g += loss_g.item()
198
+ epoch_loss_d += loss_d.item()
199
+ progress_bar.set_postfix({
200
+ 'Loss G': f"{loss_g.item():.4f}",
201
+ 'Loss D': f"{loss_d.item():.4f}",
202
+ # Optional: Show individual components of G loss
203
+ # 'L_Cont': f"{loss_content.item():.4f}",
204
+ # 'L_Perc': f"{loss_perceptual.item():.4f}",
205
+ # 'L_Adv': f"{loss_adversarial.item():.4f}"
206
+ })
207
+
208
+ # --- End of Epoch ---
209
+ avg_loss_g = epoch_loss_g / len(train_loader) if len(train_loader) > 0 else 0
210
+ avg_loss_d = epoch_loss_d / len(train_loader) if len(train_loader) > 0 else 0
211
+ epoch_time = time.time() - epoch_start_time
212
+
213
+ # Optional: Update learning rate schedulers
214
+ # scheduler_g.step()
215
+ # scheduler_d.step()
216
+ # current_lr_g = optimizer_g.param_groups[0]['lr']
217
+
218
+ print(f"\nEpoch {epoch}/{args.epochs} | Time: {epoch_time:.2f}s | Avg Loss G: {avg_loss_g:.4f} | Avg Loss D: {avg_loss_d:.4f}")
219
+
220
+ # --- Save Checkpoint ---
221
+ if epoch % args.save_interval == 0 or epoch == args.epochs:
222
+ gen_path = os.path.join(args.save_dir, f"generator_epoch_{epoch}.pth")
223
+ disc_path = os.path.join(args.save_dir, f"discriminator_epoch_{epoch}.pth")
224
+ try:
225
+ torch.save(generator.state_dict(), gen_path)
226
+ torch.save(discriminator.state_dict(), disc_path)
227
+ print(f"Checkpoint saved for epoch {epoch} to '{args.save_dir}'")
228
+ except Exception as e:
229
+ print(f"Error saving checkpoint for epoch {epoch}: {e}")
230
+
231
+ # --- End of Training ---
232
+ total_time = time.time() - start_time
233
+ print(f"\n--- Training Finished ---")
234
+ print(f"Total time: {total_time // 3600:.0f}h {(total_time % 3600) // 60:.0f}m {total_time % 60:.2f}s")
235
+
236
+
237
+ if __name__ == '__main__':
238
+ parser = argparse.ArgumentParser(description='Train SRGAN Model')
239
+
240
+ # --- Data Args ---
241
+ parser.add_argument('--hr_dir', type=str,
242
+ default='./datasets/DIV2K/HR_extracted/DIV2K_train_HR',
243
+ help='Path to high-resolution training images')
244
+ parser.add_argument('--lr_dir', type=str, default=None, # Default to None, will be auto-set
245
+ help='Path to low-resolution training images (auto-set if None)')
246
+ parser.add_argument('--scale', type=int, default=4, help='Upscaling factor')
247
+ parser.add_argument('--batch_size', type=int, default=16, help='Training batch size (reduce for CPU/low VRAM)')
248
+ parser.add_argument('--subset', type=int, default=0, help='Use only N images for debugging (0 to use all)')
249
+ parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for DataLoader (set to 0 for Mac/Windows usually)')
250
+ parser.add_argument('--patch_size', type=int, default=48, help='Size (height/width) of LR patches for training') # NEW ARGUMENT
251
+
252
+ # --- Model Args ---
253
+ parser.add_argument('--gen_features', type=int, default=64, help='Number of features in Generator')
254
+ parser.add_argument('--gen_blocks', type=int, default=16, help='Number of residual blocks in Generator (reduce for faster training/less memory)')
255
+ parser.add_argument('--disc_features', type=int, default=64, help='Number of starting features in Discriminator')
256
+ parser.add_argument('--disc_blocks', type=int, default=3, help='Number of conv blocks in Discriminator')
257
+
258
+ # --- Training Args ---
259
+ parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
260
+ parser.add_argument('--lr_gen', type=float, default=1e-4, help='Learning rate for Generator')
261
+ parser.add_argument('--lr_disc', type=float, default=1e-4, help='Learning rate for Discriminator')
262
+ parser.add_argument('--lambda_content', type=float, default=0.01, help='Weight for Content Loss (L1)') # SRGAN paper uses 1e-2 for L1/MSE when combined with VGG
263
+ parser.add_argument('--lambda_percep', type=float, default=1.0, help='Weight for Perceptual Loss') # SRGAN paper uses 1.0
264
+ parser.add_argument('--lambda_adv', type=float, default=0.001, help='Weight for Adversarial Loss') # SRGAN paper uses 1e-3
265
+
266
+ # --- Other Args ---
267
+ parser.add_argument('--save_dir', type=str, default='checkpoints', help='Directory to save model checkpoints')
268
+ parser.add_argument('--save_interval', type=int, default=10, help='Save checkpoint every N epochs')
269
+ parser.add_argument('--cpu', action='store_true', help='Force training on CPU')
270
+ # parser.add_argument('--load_checkpoint', type=str, default=None, help='Path to checkpoint file to resume training') # Example for adding resume functionality
271
+
272
+ args = parser.parse_args()
273
+
274
+ # --- Set and Validate Directories ---
275
+ # Auto-set LR directory based on scale IF it wasn't provided via command line
276
+ if args.lr_dir is None:
277
+ args.lr_dir = f'./datasets/DIV2K/DIV2K_train_LR_bicubic/X{args.scale}'
278
+ print(f"LR directory not provided, automatically setting based on scale {args.scale} to: {args.lr_dir}")
279
+
280
+ # Validate HR directory
281
+ if not os.path.isdir(args.hr_dir):
282
+ print(f"\nERROR: High-Resolution directory not found at '{args.hr_dir}'")
283
+ print("Please ensure the directory exists or provide the correct path using --hr_dir.")
284
+ exit(1) # Exit if the directory is invalid
285
+ # Validate LR directory
286
+ if not os.path.isdir(args.lr_dir):
287
+ print(f"\nERROR: Low-Resolution directory not found at '{args.lr_dir}'")
288
+ print(f"Please ensure the directory exists (check scale factor {args.scale}?) or provide the correct path using --lr_dir.")
289
+ exit(1) # Exit if the directory is invalid
290
+
291
+ print("\n--- Training Configuration ---")
292
+ # Print configuration cleanly
293
+ config_dict = vars(args)
294
+ # Calculate terminal width for better formatting (optional)
295
+ try:
296
+ term_width = os.get_terminal_size().columns
297
+ except OSError:
298
+ term_width = 80 # Default if terminal size unavailable
299
+
300
+ print("-" * term_width)
301
+ for key, value in config_dict.items():
302
+ print(f"{key:<25}: {value}") # Format for alignment
303
+ print("-" * term_width)
304
+
305
+
306
+ # Start the training process
307
+ train(args)