Upload 13 files
Browse files- checkpoints/discriminator_epoch_10.pth +3 -0
- checkpoints/discriminator_epoch_15.pth +3 -0
- checkpoints/discriminator_epoch_2.pth +3 -0
- checkpoints/generator_epoch_10.pth +3 -0
- checkpoints/generator_epoch_15.pth +3 -0
- checkpoints/generator_epoch_2.pth +3 -0
- dataset.py +273 -0
- loss.py +138 -0
- models.py +181 -0
- prep.py +252 -0
- saved_models/generator_x4_f64_b8_untrained.onnx +3 -0
- saved_models/generator_x4_f64_b8_untrained.pth +3 -0
- train.py +307 -0
checkpoints/discriminator_epoch_10.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a93bbd6522431f63abe8c6821a17efba7e8b7751314a69db2caf2a14e3bda5e
|
| 3 |
+
size 1106902
|
checkpoints/discriminator_epoch_15.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5af7468f085e6e6d8b8058a0d629151243199d0a994be72f1f9e270d241e77f
|
| 3 |
+
size 1106902
|
checkpoints/discriminator_epoch_2.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58d206e029ec8e09a3bda2034cbd1a1170848b5cdcc0861f280890186aa3043c
|
| 3 |
+
size 1106807
|
checkpoints/generator_epoch_10.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a8feff031f9337d18f659075b7a0db41f19d782f4367c026a4cc374f8de2232
|
| 3 |
+
size 6096658
|
checkpoints/generator_epoch_15.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce25de423b699eef2ffcc51585b049a95efae370520c640834e892a391070654
|
| 3 |
+
size 6096658
|
checkpoints/generator_epoch_2.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:95f21a754252b0d804fe63660e802f2c7fe435d599d8ba431f58e420c704947d
|
| 3 |
+
size 6096580
|
dataset.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
import random # Needed for random cropping
|
| 8 |
+
|
| 9 |
+
# --- Updated SRDataset Class ---
|
| 10 |
+
class SRDataset(Dataset):
|
| 11 |
+
"""
|
| 12 |
+
Custom Dataset for Super-Resolution.
|
| 13 |
+
Loads HR/LR pairs and returns fixed-size patches.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, hr_dir, lr_dir, scale_factor, patch_size_lr=48, transform=None):
|
| 16 |
+
"""
|
| 17 |
+
Args:
|
| 18 |
+
hr_dir (str): Directory with all HR images.
|
| 19 |
+
lr_dir (str): Directory with all LR images (corresponding to hr_dir).
|
| 20 |
+
scale_factor (int): The upscaling factor.
|
| 21 |
+
patch_size_lr (int): The size (height and width) of the LR patch to crop.
|
| 22 |
+
HR patch size will be patch_size_lr * scale_factor.
|
| 23 |
+
transform (callable, optional): Optional transform (e.g., data augmentation like flips).
|
| 24 |
+
"""
|
| 25 |
+
super(SRDataset, self).__init__() # Call parent constructor
|
| 26 |
+
self.hr_dir = hr_dir
|
| 27 |
+
self.lr_dir = lr_dir
|
| 28 |
+
self.scale_factor = scale_factor
|
| 29 |
+
self.patch_size_lr = patch_size_lr
|
| 30 |
+
self.patch_size_hr = patch_size_lr * scale_factor
|
| 31 |
+
self.transform = transform
|
| 32 |
+
|
| 33 |
+
# Find all image files (png, jpg, jpeg) in the LR directory
|
| 34 |
+
self.lr_image_files = sorted(
|
| 35 |
+
glob.glob(os.path.join(lr_dir, '*.png')) +
|
| 36 |
+
glob.glob(os.path.join(lr_dir, '*.jpg')) +
|
| 37 |
+
glob.glob(os.path.join(lr_dir, '*.jpeg'))
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
if not self.lr_image_files:
|
| 41 |
+
raise FileNotFoundError(f"No images found in LR directory: {lr_dir}. Check path and image extensions.")
|
| 42 |
+
|
| 43 |
+
# --- (Optional Verification Step - can be kept or removed) ---
|
| 44 |
+
if self.lr_image_files:
|
| 45 |
+
# ... (verification code from previous version can go here if desired) ...
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
print(f"Found {len(self.lr_image_files)} image pairs in HR='{hr_dir}', LR='{lr_dir}'")
|
| 49 |
+
print(f"Using LR patch size: {self.patch_size_lr}x{self.patch_size_lr}, HR patch size: {self.patch_size_hr}x{self.patch_size_hr}")
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self.lr_image_files)
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def get_patch(lr_img, hr_img, patch_size_lr, scale_factor):
|
| 56 |
+
"""
|
| 57 |
+
Randomly crops corresponding patches from LR and HR images.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
lr_img (PIL.Image): Low-resolution image.
|
| 61 |
+
hr_img (PIL.Image): High-resolution image.
|
| 62 |
+
patch_size_lr (int): The desired height/width of the LR patch.
|
| 63 |
+
scale_factor (int): The upscaling factor.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
tuple: (lr_patch, hr_patch) PIL.Image objects.
|
| 67 |
+
"""
|
| 68 |
+
lr_w, lr_h = lr_img.size
|
| 69 |
+
hr_w, hr_h = hr_img.size
|
| 70 |
+
patch_size_hr = patch_size_lr * scale_factor
|
| 71 |
+
|
| 72 |
+
# Ensure HR image dimensions are consistent with LR and scale factor
|
| 73 |
+
if hr_w != lr_w * scale_factor or hr_h != lr_h * scale_factor:
|
| 74 |
+
# Simple fallback: resize HR image to expected size if mismatch occurs
|
| 75 |
+
# This might happen with imperfect downscaling or odd original dimensions
|
| 76 |
+
# print(f"Warning: HR/LR size mismatch ({hr_img.size} vs {lr_img.size} * {scale_factor}). Resizing HR image.")
|
| 77 |
+
hr_img = hr_img.resize((lr_w * scale_factor, lr_h * scale_factor), resample=Image.BICUBIC)
|
| 78 |
+
|
| 79 |
+
# Choose random top-left corner for LR patch
|
| 80 |
+
# Ensure the patch fits within the image boundaries
|
| 81 |
+
if lr_w < patch_size_lr or lr_h < patch_size_lr:
|
| 82 |
+
# If LR image is smaller than patch size, resize LR and corresponding HR region
|
| 83 |
+
# This ensures __getitem__ always returns tensors of the target patch size
|
| 84 |
+
lr_img = lr_img.resize((max(lr_w, patch_size_lr), max(lr_h, patch_size_lr)), resample=Image.BICUBIC)
|
| 85 |
+
hr_img = hr_img.resize((lr_img.width * scale_factor, lr_img.height * scale_factor), resample=Image.BICUBIC)
|
| 86 |
+
lr_w, lr_h = lr_img.size # Update dimensions
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
lr_x = random.randrange(0, lr_w - patch_size_lr + 1)
|
| 90 |
+
lr_y = random.randrange(0, lr_h - patch_size_lr + 1)
|
| 91 |
+
|
| 92 |
+
# Calculate corresponding top-left corner for HR patch
|
| 93 |
+
hr_x = lr_x * scale_factor
|
| 94 |
+
hr_y = lr_y * scale_factor
|
| 95 |
+
|
| 96 |
+
# Crop patches
|
| 97 |
+
# PIL crop format is (left, upper, right, lower)
|
| 98 |
+
lr_patch = lr_img.crop((lr_x, lr_y, lr_x + patch_size_lr, lr_y + patch_size_lr))
|
| 99 |
+
hr_patch = hr_img.crop((hr_x, hr_y, hr_x + patch_size_hr, hr_y + patch_size_hr))
|
| 100 |
+
|
| 101 |
+
return lr_patch, hr_patch
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def augment_patch(lr_patch, hr_patch):
|
| 105 |
+
"""Applies simple random augmentations (flip, rotation)."""
|
| 106 |
+
# Random horizontal flip
|
| 107 |
+
if random.random() < 0.5:
|
| 108 |
+
lr_patch = lr_patch.transpose(Image.FLIP_LEFT_RIGHT)
|
| 109 |
+
hr_patch = hr_patch.transpose(Image.FLIP_LEFT_RIGHT)
|
| 110 |
+
|
| 111 |
+
# Random vertical flip (less common, can sometimes be excluded)
|
| 112 |
+
# if random.random() < 0.5:
|
| 113 |
+
# lr_patch = lr_patch.transpose(Image.FLIP_TOP_BOTTOM)
|
| 114 |
+
# hr_patch = hr_patch.transpose(Image.FLIP_TOP_BOTTOM)
|
| 115 |
+
|
| 116 |
+
# Random 90-degree rotation
|
| 117 |
+
# rot_choice = random.choice([0, 1, 2, 3]) # 0: 0 deg, 1: 90 deg, 2: 180 deg, 3: 270 deg
|
| 118 |
+
# if rot_choice != 0:
|
| 119 |
+
# lr_patch = lr_patch.rotate(90 * rot_choice, expand=True) # expand=True might change size if not square
|
| 120 |
+
# hr_patch = hr_patch.rotate(90 * rot_choice, expand=True)
|
| 121 |
+
|
| 122 |
+
return lr_patch, hr_patch
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def __getitem__(self, idx):
|
| 126 |
+
# Get the full LR image path
|
| 127 |
+
lr_path = self.lr_image_files[idx]
|
| 128 |
+
try:
|
| 129 |
+
lr_img = Image.open(lr_path).convert('RGB')
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"Error opening LR image {lr_path}: {e}")
|
| 132 |
+
# Decide how to handle: return None, raise error, or return dummy
|
| 133 |
+
# Returning None requires careful handling in the DataLoader collate_fn or training loop
|
| 134 |
+
return None # Let collate_fn handle this potentially
|
| 135 |
+
|
| 136 |
+
# Construct the corresponding full HR image path
|
| 137 |
+
base_name = os.path.basename(lr_path)
|
| 138 |
+
hr_path = os.path.join(self.hr_dir, base_name)
|
| 139 |
+
|
| 140 |
+
# Handle potential alternative HR filenames
|
| 141 |
+
if not os.path.exists(hr_path):
|
| 142 |
+
base, ext = os.path.splitext(base_name)
|
| 143 |
+
if f'x{self.scale_factor}' in base:
|
| 144 |
+
hr_name = base.replace(f'x{self.scale_factor}', '') + ext
|
| 145 |
+
hr_path_alt = os.path.join(self.hr_dir, hr_name)
|
| 146 |
+
if os.path.exists(hr_path_alt):
|
| 147 |
+
hr_path = hr_path_alt
|
| 148 |
+
else:
|
| 149 |
+
print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}")
|
| 150 |
+
return None # Indicate error
|
| 151 |
+
else:
|
| 152 |
+
print(f"ERROR in __getitem__: Cannot find corresponding HR for LR: {lr_path}")
|
| 153 |
+
return None # Indicate error
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
hr_img = Image.open(hr_path).convert('RGB')
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error opening HR image {hr_path}: {e}")
|
| 159 |
+
return None # Indicate error
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# --- Get Corresponding Patches ---
|
| 163 |
+
try:
|
| 164 |
+
lr_patch, hr_patch = self.get_patch(lr_img, hr_img, self.patch_size_lr, self.scale_factor)
|
| 165 |
+
except ValueError as e: # Catch randrange error if patch size > image size after potential resize
|
| 166 |
+
print(f"Error getting patch for {lr_path} (maybe image is smaller than patch size?): {e}")
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# --- Apply Augmentations (Optional) ---
|
| 171 |
+
lr_patch, hr_patch = self.augment_patch(lr_patch, hr_patch)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# --- Apply Custom Transform if provided ---
|
| 175 |
+
# (Currently we pass None, but this is where you'd integrate albumentations etc.)
|
| 176 |
+
if self.transform:
|
| 177 |
+
# A typical transform might operate on numpy arrays
|
| 178 |
+
# lr_np = np.array(lr_patch)
|
| 179 |
+
# hr_np = np.array(hr_patch)
|
| 180 |
+
# transformed = self.transform(image=lr_np, mask=hr_np) # Example syntax
|
| 181 |
+
# lr_patch = Image.fromarray(transformed['image'])
|
| 182 |
+
# hr_patch = Image.fromarray(transformed['mask'])
|
| 183 |
+
pass # Placeholder
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# --- Convert Patches to Tensors ---
|
| 187 |
+
to_tensor = transforms.ToTensor() # Converts PIL image (HWC) [0, 255] to Tensor (CHW) [0.0, 1.0]
|
| 188 |
+
lr_tensor = to_tensor(lr_patch)
|
| 189 |
+
hr_tensor = to_tensor(hr_patch)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
return {'lr': lr_tensor, 'hr': hr_tensor}
|
| 193 |
+
|
| 194 |
+
# --- Example Usage (for testing the definition) ---
|
| 195 |
+
if __name__ == '__main__':
|
| 196 |
+
print("--- Testing SRDataset with Patching ---")
|
| 197 |
+
hr_data_dir = './datasets/DIV2K/HR_extracted/DIV2K_train_HR' # Modify if needed
|
| 198 |
+
lr_data_dir = './datasets/DIV2K/DIV2K_train_LR_bicubic/X4' # Modify if needed
|
| 199 |
+
scale = 4
|
| 200 |
+
lr_patch_size = 48 # Common LR patch size for SR tasks
|
| 201 |
+
|
| 202 |
+
if not os.path.isdir(hr_data_dir): print(f"ERROR: HR dir not found: '{hr_data_dir}'")
|
| 203 |
+
if not os.path.isdir(lr_data_dir): print(f"ERROR: LR dir not found: '{lr_data_dir}'")
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
dataset = SRDataset(hr_dir=hr_data_dir, lr_dir=lr_data_dir,
|
| 207 |
+
scale_factor=scale, patch_size_lr=lr_patch_size)
|
| 208 |
+
|
| 209 |
+
if len(dataset) > 0:
|
| 210 |
+
print(f"\nSuccessfully loaded dataset with {len(dataset)} image pairs.")
|
| 211 |
+
|
| 212 |
+
# Test getting a single item (patch pair)
|
| 213 |
+
print("\n--- Testing __getitem__ ---")
|
| 214 |
+
num_test_items = 5
|
| 215 |
+
for i in range(min(num_test_items, len(dataset))):
|
| 216 |
+
item = dataset[i]
|
| 217 |
+
if item is None:
|
| 218 |
+
print(f"Item {i}: Returned None (Error occurred)")
|
| 219 |
+
continue
|
| 220 |
+
|
| 221 |
+
lr_p = item['lr']
|
| 222 |
+
hr_p = item['hr']
|
| 223 |
+
print(f"Item {i}: LR Patch Shape={lr_p.shape}, HR Patch Shape={hr_p.shape}")
|
| 224 |
+
|
| 225 |
+
# Verify shapes
|
| 226 |
+
expected_hr_shape = (3, lr_patch_size * scale, lr_patch_size * scale)
|
| 227 |
+
if lr_p.shape != (3, lr_patch_size, lr_patch_size) or hr_p.shape != expected_hr_shape:
|
| 228 |
+
print(f" WARNING: Shape mismatch! LR={lr_p.shape}, HR={hr_p.shape}, Expected HR={expected_hr_shape}")
|
| 229 |
+
|
| 230 |
+
# Test DataLoader with a simple collate function that filters Nones
|
| 231 |
+
print("\n--- Testing DataLoader with Patches ---")
|
| 232 |
+
from torch.utils.data import DataLoader
|
| 233 |
+
|
| 234 |
+
# Define a collate_fn that filters out None values returned by __getitem__
|
| 235 |
+
def collate_fn_filter_none(batch):
|
| 236 |
+
batch = list(filter(lambda x: x is not None, batch))
|
| 237 |
+
if not batch: # If all items in the batch failed
|
| 238 |
+
return None
|
| 239 |
+
# Use default collate on the filtered batch
|
| 240 |
+
return torch.utils.data.dataloader.default_collate(batch)
|
| 241 |
+
|
| 242 |
+
# Use batch_size=4 for testing
|
| 243 |
+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True,
|
| 244 |
+
num_workers=0, collate_fn=collate_fn_filter_none)
|
| 245 |
+
|
| 246 |
+
num_test_batches = 3
|
| 247 |
+
batch_count = 0
|
| 248 |
+
for batch in dataloader:
|
| 249 |
+
if batch_count >= num_test_batches:
|
| 250 |
+
break
|
| 251 |
+
if batch is None:
|
| 252 |
+
print(f"Skipping an entirely problematic batch.")
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
lr_batch = batch['lr']
|
| 256 |
+
hr_batch = batch['hr']
|
| 257 |
+
print(f"Batch {batch_count}: LR Batch Shape={lr_batch.shape}, HR Batch Shape={hr_batch.shape}")
|
| 258 |
+
batch_count += 1
|
| 259 |
+
|
| 260 |
+
if batch_count > 0:
|
| 261 |
+
print("DataLoader test with patches successful.")
|
| 262 |
+
else:
|
| 263 |
+
print("DataLoader test: Could not retrieve any valid batches.")
|
| 264 |
+
|
| 265 |
+
else:
|
| 266 |
+
print("\nDataset loaded but is empty.")
|
| 267 |
+
|
| 268 |
+
except FileNotFoundError as e:
|
| 269 |
+
print(f"\nERROR initializing dataset: {e}")
|
| 270 |
+
except Exception as e:
|
| 271 |
+
print(f"\nAn unexpected error occurred during dataset testing: {e}")
|
| 272 |
+
|
| 273 |
+
print("\n--- SRDataset Test Finished ---")
|
loss.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision.models import vgg19, VGG19_Weights
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
|
| 7 |
+
class PerceptualLoss(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Calculates the VGG perceptual loss.
|
| 10 |
+
|
| 11 |
+
Uses features from the VGG19 network pretrained on ImageNet.
|
| 12 |
+
Compares features from specific layers for the generated and target images.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, feature_layers=None, use_l1=True, device='cpu'):
|
| 15 |
+
"""
|
| 16 |
+
Args:
|
| 17 |
+
feature_layers (list of int, optional): Indices of VGG19 feature layers to use.
|
| 18 |
+
Defaults correspond to layers before pool1, pool2, pool3, pool4.
|
| 19 |
+
Specifically: relu1_1, relu2_1, relu3_1, relu4_1 in many implementations.
|
| 20 |
+
VGG19 structure: layer indices relate to `features` module.
|
| 21 |
+
use_l1 (bool): If True, use L1 loss between features. If False, use L2 (MSE) loss.
|
| 22 |
+
device (str): 'cuda' or 'cpu'.
|
| 23 |
+
"""
|
| 24 |
+
super(PerceptualLoss, self).__init__()
|
| 25 |
+
|
| 26 |
+
# Load pre-trained VGG19 model
|
| 27 |
+
# Ensure you have torchvision installed: pip install torchvision
|
| 28 |
+
try:
|
| 29 |
+
# Recommended way with modern torchvision
|
| 30 |
+
weights = VGG19_Weights.IMAGENET1K_V1
|
| 31 |
+
self.vgg = vgg19(weights=weights).features
|
| 32 |
+
self.preprocess = weights.transforms() # Get the preprocessing expected by the model
|
| 33 |
+
except AttributeError:
|
| 34 |
+
# Fallback for older torchvision versions (might require manual weight download if not cached)
|
| 35 |
+
print("Warning: Using older torchvision VGG19 loading method. Consider upgrading torchvision.")
|
| 36 |
+
self.vgg = vgg19(pretrained=True).features
|
| 37 |
+
# Define standard ImageNet normalization manually if transform isn't available
|
| 38 |
+
self.preprocess = transforms.Compose([
|
| 39 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 40 |
+
])
|
| 41 |
+
|
| 42 |
+
self.vgg.eval() # Set VGG to evaluation mode
|
| 43 |
+
for param in self.vgg.parameters():
|
| 44 |
+
param.requires_grad = False # Freeze VGG parameters
|
| 45 |
+
|
| 46 |
+
self.vgg = self.vgg.to(device)
|
| 47 |
+
self.device = device
|
| 48 |
+
|
| 49 |
+
# Define the layers to extract features from
|
| 50 |
+
# Common choices are layers before max pooling
|
| 51 |
+
# VGG19 features structure indices:
|
| 52 |
+
# ReLU1_1: 1, ReLU2_1: 6, ReLU3_1: 11, ReLU4_1: 20, ReLU5_1: 29 (Sometimes ReLU5 used too)
|
| 53 |
+
if feature_layers is None:
|
| 54 |
+
# These indices correspond to the output of Conv layers before MaxPool
|
| 55 |
+
# Specifically: conv1_1(0), conv2_1(5), conv3_1(10), conv4_1(19), conv5_1(28)
|
| 56 |
+
# Often the ReLU output right after is used: 1, 6, 11, 20, 29
|
| 57 |
+
self.feature_layers = {1, 6, 11, 20} # Using ReLU outputs before pooling layers 1-4
|
| 58 |
+
# Alternative common set often cited as relu5_4 (index 35 or 36 depending on source):
|
| 59 |
+
# self.feature_layers = {35} # Or use a specific high-level layer
|
| 60 |
+
else:
|
| 61 |
+
self.feature_layers = set(feature_layers)
|
| 62 |
+
|
| 63 |
+
self.loss_fn = nn.L1Loss() if use_l1 else nn.MSELoss()
|
| 64 |
+
|
| 65 |
+
print(f"PerceptualLoss: Using VGG19 features from layers: {sorted(list(self.feature_layers))}")
|
| 66 |
+
print(f"PerceptualLoss: Using {'L1' if use_l1 else 'L2'} distance.")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def forward(self, generated, target):
|
| 70 |
+
"""
|
| 71 |
+
Compute the perceptual loss.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
generated (torch.Tensor): The generated image tensor (B, C, H, W). Values [0, 1].
|
| 75 |
+
target (torch.Tensor): The target (ground truth) image tensor (B, C, H, W). Values [0, 1].
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
torch.Tensor: The calculated perceptual loss.
|
| 79 |
+
"""
|
| 80 |
+
# Ensure inputs are on the correct device
|
| 81 |
+
generated = generated.to(self.device)
|
| 82 |
+
target = target.to(self.device)
|
| 83 |
+
|
| 84 |
+
# Preprocess images for VGG
|
| 85 |
+
# VGG expects inputs normalized based on ImageNet stats
|
| 86 |
+
# The transform might handle dtype and range, but let's be explicit
|
| 87 |
+
generated_norm = self.preprocess(generated)
|
| 88 |
+
target_norm = self.preprocess(target)
|
| 89 |
+
|
| 90 |
+
# Extract features
|
| 91 |
+
loss = 0.0
|
| 92 |
+
current_layer_idx = 0
|
| 93 |
+
max_needed_layer = max(self.feature_layers) if self.feature_layers else 0
|
| 94 |
+
|
| 95 |
+
# Iterate through VGG layers, extracting features only from specified layers
|
| 96 |
+
for layer in self.vgg:
|
| 97 |
+
# Compute features for both images up to the current layer
|
| 98 |
+
generated_norm = layer(generated_norm)
|
| 99 |
+
target_norm = layer(target_norm)
|
| 100 |
+
|
| 101 |
+
# If the current layer index is one we want to use for loss calculation
|
| 102 |
+
if current_layer_idx in self.feature_layers:
|
| 103 |
+
loss += self.loss_fn(generated_norm, target_norm)
|
| 104 |
+
|
| 105 |
+
# Stop iterating if we've passed the last needed layer
|
| 106 |
+
if current_layer_idx >= max_needed_layer:
|
| 107 |
+
break
|
| 108 |
+
|
| 109 |
+
current_layer_idx += 1
|
| 110 |
+
|
| 111 |
+
return loss
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# --- Example Usage (for testing the definition) ---
|
| 115 |
+
if __name__ == '__main__':
|
| 116 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 117 |
+
print(f"Using device: {device}")
|
| 118 |
+
|
| 119 |
+
# Create dummy images (Batch Size, Channels, Height, Width)
|
| 120 |
+
# Note: Images should be in the range [0, 1] for standard transforms
|
| 121 |
+
dummy_generated = torch.rand(2, 3, 96, 96).to(device) # Example size (must match target)
|
| 122 |
+
dummy_target = torch.rand(2, 3, 96, 96).to(device)
|
| 123 |
+
|
| 124 |
+
# Instantiate the loss function
|
| 125 |
+
# Default layers: {1, 6, 11, 20} (Relu1_1, Relu2_1, Relu3_1, Relu4_1 outputs)
|
| 126 |
+
perceptual_loss_l1 = PerceptualLoss(device=device, use_l1=True)
|
| 127 |
+
# Example with different layers and L2 loss
|
| 128 |
+
# perceptual_loss_l2 = PerceptualLoss(feature_layers={35}, device=device, use_l1=False)
|
| 129 |
+
|
| 130 |
+
# Calculate loss
|
| 131 |
+
loss_val_l1 = perceptual_loss_l1(dummy_generated, dummy_target)
|
| 132 |
+
# loss_val_l2 = perceptual_loss_l2(dummy_generated, dummy_target)
|
| 133 |
+
|
| 134 |
+
print(f"\nCalculated Perceptual Loss (L1, default layers): {loss_val_l1.item()}")
|
| 135 |
+
# print(f"Calculated Perceptual Loss (L2, layer 35): {loss_val_l2.item()}")
|
| 136 |
+
|
| 137 |
+
assert loss_val_l1.item() >= 0, "Loss should be non-negative"
|
| 138 |
+
print("\nPerceptualLoss definition test successful!")
|
models.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# --- ResidualBlock, Upsampler, and Generator classes remain the same ---
|
| 7 |
+
class ResidualBlock(nn.Module):
|
| 8 |
+
def __init__(self, num_features, kernel_size=3, bn=False, act=nn.ReLU(True), res_scale=1.0):
|
| 9 |
+
super(ResidualBlock, self).__init__()
|
| 10 |
+
padding = kernel_size // 2
|
| 11 |
+
m = []
|
| 12 |
+
m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding))
|
| 13 |
+
if bn: m.append(nn.BatchNorm2d(num_features))
|
| 14 |
+
m.append(act)
|
| 15 |
+
m.append(nn.Conv2d(num_features, num_features, kernel_size, padding=padding))
|
| 16 |
+
if bn: m.append(nn.BatchNorm2d(num_features))
|
| 17 |
+
self.body = nn.Sequential(*m)
|
| 18 |
+
self.res_scale = res_scale
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
res = self.body(x).mul(self.res_scale)
|
| 21 |
+
res += x
|
| 22 |
+
return res
|
| 23 |
+
|
| 24 |
+
class Upsampler(nn.Module):
|
| 25 |
+
def __init__(self, scale_factor, num_features, act=nn.ReLU(True)):
|
| 26 |
+
super(Upsampler, self).__init__()
|
| 27 |
+
m = []
|
| 28 |
+
m.append(nn.Conv2d(num_features, num_features * (scale_factor ** 2), kernel_size=3, padding=1))
|
| 29 |
+
m.append(nn.PixelShuffle(scale_factor))
|
| 30 |
+
if act: m.append(act)
|
| 31 |
+
self.body = nn.Sequential(*m)
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
return self.body(x)
|
| 34 |
+
|
| 35 |
+
class Generator(nn.Module):
|
| 36 |
+
def __init__(self, scale_factor=4, in_channels=3, out_channels=3, num_features=64, num_res_blocks=16, res_scale=1.0):
|
| 37 |
+
super(Generator, self).__init__()
|
| 38 |
+
self.scale_factor = scale_factor
|
| 39 |
+
act = nn.ReLU(True)
|
| 40 |
+
self.head = nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1)
|
| 41 |
+
res_blocks = [ResidualBlock(num_features, kernel_size=3, act=act, res_scale=res_scale) for _ in range(num_res_blocks)]
|
| 42 |
+
res_blocks.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1))
|
| 43 |
+
self.body = nn.Sequential(*res_blocks)
|
| 44 |
+
m_tail = []
|
| 45 |
+
if (scale_factor & (scale_factor - 1)) == 0:
|
| 46 |
+
for _ in range(int(math.log2(scale_factor))):
|
| 47 |
+
m_tail.append(Upsampler(scale_factor=2, num_features=num_features, act=None))
|
| 48 |
+
elif scale_factor == 3:
|
| 49 |
+
m_tail.append(Upsampler(scale_factor=3, num_features=num_features, act=None))
|
| 50 |
+
else:
|
| 51 |
+
raise NotImplementedError(f"Scale factor {scale_factor} not directly supported by this simple upsampler.")
|
| 52 |
+
self.tail = nn.Sequential(*m_tail)
|
| 53 |
+
self.final_conv = nn.Conv2d(num_features, out_channels, kernel_size=3, padding=1)
|
| 54 |
+
|
| 55 |
+
def forward(self, lr_img):
|
| 56 |
+
x = self.head(lr_img)
|
| 57 |
+
res = self.body(x)
|
| 58 |
+
res += x
|
| 59 |
+
x = self.tail(res)
|
| 60 |
+
x = self.final_conv(x)
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
# +++ NEW Discriminator Class +++
|
| 64 |
+
class Discriminator(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
Simple CNN Discriminator Network (PatchGAN style is common but this is simpler).
|
| 67 |
+
Takes an image (real HR or generated SR) and outputs a single logit.
|
| 68 |
+
"""
|
| 69 |
+
def __init__(self, in_channels=3, num_features_start=64, num_blocks=4):
|
| 70 |
+
super(Discriminator, self).__init__()
|
| 71 |
+
|
| 72 |
+
# Initial block
|
| 73 |
+
layers = [
|
| 74 |
+
nn.Conv2d(in_channels, num_features_start, kernel_size=3, stride=1, padding=1),
|
| 75 |
+
nn.LeakyReLU(0.2, inplace=True)
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
current_features = num_features_start
|
| 79 |
+
for i in range(num_blocks):
|
| 80 |
+
stride = 1 if i % 2 == 0 else 2 # Downsample every other block
|
| 81 |
+
next_features = current_features * 2 if stride == 2 else current_features
|
| 82 |
+
layers.extend([
|
| 83 |
+
nn.Conv2d(current_features, next_features, kernel_size=3, stride=stride, padding=1),
|
| 84 |
+
nn.BatchNorm2d(next_features), # BatchNorm is common in discriminators
|
| 85 |
+
nn.LeakyReLU(0.2, inplace=True)
|
| 86 |
+
])
|
| 87 |
+
current_features = next_features
|
| 88 |
+
|
| 89 |
+
self.features = nn.Sequential(*layers)
|
| 90 |
+
|
| 91 |
+
# Classifier part - adjust input features based on final conv output size
|
| 92 |
+
# We need to know the output size of the feature extractor to define the Linear layer.
|
| 93 |
+
# Using AdaptiveAvgPool2d makes it independent of the input image size.
|
| 94 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 95 |
+
self.classifier = nn.Sequential(
|
| 96 |
+
nn.Linear(current_features, 100), # Example intermediate size
|
| 97 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 98 |
+
nn.Linear(100, 1) # Output a single logit (no sigmoid here)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def forward(self, img):
|
| 102 |
+
"""
|
| 103 |
+
Args:
|
| 104 |
+
img (torch.Tensor): Input image tensor (B, C, H, W), either real HR or fake SR.
|
| 105 |
+
Returns:
|
| 106 |
+
torch.Tensor: Output logits (B, 1). Higher values -> more likely "real".
|
| 107 |
+
"""
|
| 108 |
+
batch_size = img.size(0)
|
| 109 |
+
features = self.features(img)
|
| 110 |
+
pooled = self.avgpool(features)
|
| 111 |
+
# Flatten the output of avgpool for the linear layer
|
| 112 |
+
pooled = pooled.view(batch_size, -1)
|
| 113 |
+
output = self.classifier(pooled)
|
| 114 |
+
return output
|
| 115 |
+
|
| 116 |
+
# --- Main block for testing and saving ---
|
| 117 |
+
if __name__ == '__main__':
|
| 118 |
+
# --- Generator Test (as before) ---
|
| 119 |
+
SCALE = 4
|
| 120 |
+
GEN_FEATURES = 64
|
| 121 |
+
GEN_RES_BLOCKS = 8
|
| 122 |
+
save_dir = "saved_models"
|
| 123 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 124 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 125 |
+
print(f"Using device: {device}")
|
| 126 |
+
|
| 127 |
+
# Dummy LR input for Generator
|
| 128 |
+
gen_batch_size = 1
|
| 129 |
+
lr_height = 32
|
| 130 |
+
lr_width = 32
|
| 131 |
+
in_channels = 3
|
| 132 |
+
dummy_lr = torch.randn(gen_batch_size, in_channels, lr_height, lr_width).to(device)
|
| 133 |
+
print(f"Dummy LR input shape (Generator): {dummy_lr.shape}")
|
| 134 |
+
|
| 135 |
+
generator = Generator(scale_factor=SCALE, num_features=GEN_FEATURES, num_res_blocks=GEN_RES_BLOCKS).to(device)
|
| 136 |
+
generator.eval()
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
output_sr = generator(dummy_lr)
|
| 139 |
+
print(f"Output SR shape (Generator): {output_sr.shape}")
|
| 140 |
+
# ... (rest of generator verification and saving code remains here) ...
|
| 141 |
+
print("\nGenerator definition test successful!")
|
| 142 |
+
num_params_gen = sum(p.numel() for p in generator.parameters() if p.requires_grad)
|
| 143 |
+
print(f"Generator - Number of trainable parameters: {num_params_gen:,}")
|
| 144 |
+
# ... (Saving code as before) ...
|
| 145 |
+
|
| 146 |
+
print("\n--- Testing Discriminator ---")
|
| 147 |
+
# --- Discriminator Test ---
|
| 148 |
+
DISC_FEATURES = 64 # Starting features for discriminator
|
| 149 |
+
DISC_BLOCKS = 3 # Number of conv blocks in discriminator
|
| 150 |
+
|
| 151 |
+
# Dummy HR/SR input for Discriminator (must match Generator's output size)
|
| 152 |
+
disc_batch_size = 4 # Can be different from generator test batch size
|
| 153 |
+
hr_height = output_sr.shape[2] # Use the calculated HR height
|
| 154 |
+
hr_width = output_sr.shape[3] # Use the calculated HR width
|
| 155 |
+
dummy_hr = torch.randn(disc_batch_size, in_channels, hr_height, hr_width).to(device)
|
| 156 |
+
print(f"Dummy HR/SR input shape (Discriminator): {dummy_hr.shape}")
|
| 157 |
+
|
| 158 |
+
# Instantiate the Discriminator
|
| 159 |
+
discriminator = Discriminator(in_channels=in_channels,
|
| 160 |
+
num_features_start=DISC_FEATURES,
|
| 161 |
+
num_blocks=DISC_BLOCKS).to(device)
|
| 162 |
+
discriminator.eval() # Set to evaluation mode for testing
|
| 163 |
+
|
| 164 |
+
# print(discriminator) # Optional: Print structure
|
| 165 |
+
|
| 166 |
+
# Perform a forward pass
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
output_logits = discriminator(dummy_hr)
|
| 169 |
+
|
| 170 |
+
print(f"Output Logits shape (Discriminator): {output_logits.shape}")
|
| 171 |
+
|
| 172 |
+
# Verify output shape
|
| 173 |
+
expected_disc_shape = (disc_batch_size, 1)
|
| 174 |
+
assert output_logits.shape == expected_disc_shape, \
|
| 175 |
+
f"Discriminator output shape mismatch! Expected {expected_disc_shape}, got {output_logits.shape}"
|
| 176 |
+
|
| 177 |
+
print("Discriminator definition test successful!")
|
| 178 |
+
|
| 179 |
+
# Optional: Count parameters
|
| 180 |
+
num_params_disc = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
|
| 181 |
+
print(f"Discriminator - Number of trainable parameters: {num_params_disc:,}")
|
prep.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import zipfile
|
| 4 |
+
import requests
|
| 5 |
+
import argparse
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
# --- Helper Functions ---
|
| 10 |
+
|
| 11 |
+
def download_file(url, dest_path, chunk_size=8192):
|
| 12 |
+
"""Downloads a file from a URL to a destination path with progress bar."""
|
| 13 |
+
try:
|
| 14 |
+
response = requests.get(url, stream=True, timeout=30) # Added timeout
|
| 15 |
+
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
| 16 |
+
|
| 17 |
+
total_size = int(response.headers.get('content-length', 0))
|
| 18 |
+
|
| 19 |
+
print(f"Downloading {os.path.basename(dest_path)} ({total_size / (1024*1024):.2f} MB)...")
|
| 20 |
+
with open(dest_path, 'wb') as f, tqdm(
|
| 21 |
+
desc=os.path.basename(dest_path),
|
| 22 |
+
total=total_size,
|
| 23 |
+
unit='iB',
|
| 24 |
+
unit_scale=True,
|
| 25 |
+
unit_divisor=1024,
|
| 26 |
+
) as bar:
|
| 27 |
+
for chunk in response.iter_content(chunk_size=chunk_size):
|
| 28 |
+
size = f.write(chunk)
|
| 29 |
+
bar.update(size)
|
| 30 |
+
print(f"Download complete: {dest_path}")
|
| 31 |
+
return True
|
| 32 |
+
|
| 33 |
+
except requests.exceptions.RequestException as e:
|
| 34 |
+
print(f"Error downloading {url}: {e}")
|
| 35 |
+
# Clean up partially downloaded file if it exists
|
| 36 |
+
if os.path.exists(dest_path):
|
| 37 |
+
os.remove(dest_path)
|
| 38 |
+
return False
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"An unexpected error occurred during download: {e}")
|
| 41 |
+
if os.path.exists(dest_path):
|
| 42 |
+
os.remove(dest_path)
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def unzip_file(zip_path, extract_to):
|
| 47 |
+
"""Unzips a file to a specified directory."""
|
| 48 |
+
print(f"Extracting {os.path.basename(zip_path)} to {extract_to}...")
|
| 49 |
+
try:
|
| 50 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 51 |
+
# You could add a progress bar here for large zips if needed
|
| 52 |
+
# using zip_ref.infolist() and iterating extraction, but
|
| 53 |
+
# extractall is usually efficient enough.
|
| 54 |
+
zip_ref.extractall(extract_to)
|
| 55 |
+
print("Extraction complete.")
|
| 56 |
+
return True
|
| 57 |
+
except zipfile.BadZipFile:
|
| 58 |
+
print(f"Error: Invalid or corrupted zip file: {zip_path}")
|
| 59 |
+
return False
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"An error occurred during extraction: {e}")
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
def find_image_dir(base_path, expected_subdir_suffix='_HR'):
|
| 65 |
+
"""
|
| 66 |
+
Tries to find the actual directory containing images after extraction.
|
| 67 |
+
Handles cases where unzip creates an extra top-level folder.
|
| 68 |
+
"""
|
| 69 |
+
# Check if images are directly in base_path
|
| 70 |
+
if glob.glob(os.path.join(base_path, '*.png')) or \
|
| 71 |
+
glob.glob(os.path.join(base_path, '*.jpg')) or \
|
| 72 |
+
glob.glob(os.path.join(base_path, '*.jpeg')):
|
| 73 |
+
return base_path
|
| 74 |
+
|
| 75 |
+
# Check common pattern: base_path/DatasetName_HR/
|
| 76 |
+
potential_dirs = [d for d in glob.glob(os.path.join(base_path, '*')) if os.path.isdir(d)]
|
| 77 |
+
if len(potential_dirs) == 1:
|
| 78 |
+
subdir = potential_dirs[0]
|
| 79 |
+
# Check if this subdir contains images or ends with the expected suffix
|
| 80 |
+
if subdir.endswith(expected_subdir_suffix) or \
|
| 81 |
+
glob.glob(os.path.join(subdir, '*.png')) or \
|
| 82 |
+
glob.glob(os.path.join(subdir, '*.jpg')) or \
|
| 83 |
+
glob.glob(os.path.join(subdir, '*.jpeg')):
|
| 84 |
+
print(f"Found image directory: {subdir}")
|
| 85 |
+
return subdir
|
| 86 |
+
|
| 87 |
+
# Fallback if specific pattern not found, maybe it's still just base_path
|
| 88 |
+
print(f"Warning: Could not definitively locate image subdirectory in {base_path}. Assuming images are directly within or in a single nested folder.")
|
| 89 |
+
# If we found exactly one directory, return that, otherwise return the original path
|
| 90 |
+
return potential_dirs[0] if len(potential_dirs) == 1 else base_path
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def downsample_images(hr_dir, lr_dir, scale_factor):
|
| 94 |
+
"""Downsamples HR images using bicubic interpolation."""
|
| 95 |
+
if not os.path.exists(lr_dir):
|
| 96 |
+
os.makedirs(lr_dir)
|
| 97 |
+
print(f"Created LR directory: {lr_dir}")
|
| 98 |
+
|
| 99 |
+
hr_images = glob.glob(os.path.join(hr_dir, '*.png')) + \
|
| 100 |
+
glob.glob(os.path.join(hr_dir, '*.jpg')) + \
|
| 101 |
+
glob.glob(os.path.join(hr_dir, '*.jpeg'))
|
| 102 |
+
|
| 103 |
+
if not hr_images:
|
| 104 |
+
print(f"Error: No images found in the determined HR directory: {hr_dir}")
|
| 105 |
+
return False
|
| 106 |
+
|
| 107 |
+
print(f"Found {len(hr_images)} HR images in {hr_dir}. Starting downsampling (x{scale_factor})...")
|
| 108 |
+
|
| 109 |
+
processed_count = 0
|
| 110 |
+
for hr_path in tqdm(hr_images, desc=f"Downsampling x{scale_factor}"):
|
| 111 |
+
try:
|
| 112 |
+
hr_img = Image.open(hr_path).convert('RGB') # Ensure RGB
|
| 113 |
+
hr_width, hr_height = hr_img.size
|
| 114 |
+
|
| 115 |
+
lr_width = hr_width // scale_factor
|
| 116 |
+
lr_height = hr_height // scale_factor
|
| 117 |
+
|
| 118 |
+
if lr_width == 0 or lr_height == 0:
|
| 119 |
+
print(f"\nWarning: Image {os.path.basename(hr_path)} is too small ({hr_width}x{hr_height}) for scale factor {scale_factor}. Skipping.")
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
lr_img = hr_img.resize((lr_width, lr_height), resample=Image.BICUBIC)
|
| 123 |
+
|
| 124 |
+
base_name = os.path.basename(hr_path)
|
| 125 |
+
lr_save_path = os.path.join(lr_dir, base_name)
|
| 126 |
+
lr_img.save(lr_save_path)
|
| 127 |
+
processed_count += 1
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f"\nError processing {hr_path}: {e}")
|
| 131 |
+
|
| 132 |
+
print(f"Downsampling complete. Processed {processed_count}/{len(hr_images)} images.")
|
| 133 |
+
return processed_count > 0
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# --- Main Execution ---
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
parser = argparse.ArgumentParser(description="Download and prepare dataset for Super-Resolution.")
|
| 140 |
+
parser.add_argument('--url', type=str, default='https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip', help='URL of the dataset zip file (default: DIV2K Train HR).')
|
| 141 |
+
parser.add_argument('--base_dir', type=str, default='./datasets', help='Base directory to store datasets.')
|
| 142 |
+
parser.add_argument('--dataset_name', type=str, default='DIV2K', help='Name for the dataset folder.')
|
| 143 |
+
parser.add_argument('--scale', type=int, default=4, help='Downsampling scale factor (e.g., 4 for x4).')
|
| 144 |
+
parser.add_argument('--force', action='store_true', help='Force redownload and reprocessing even if data exists.')
|
| 145 |
+
|
| 146 |
+
args = parser.parse_args()
|
| 147 |
+
|
| 148 |
+
# --- Define Paths ---
|
| 149 |
+
dataset_base_path = os.path.join(args.base_dir, args.dataset_name)
|
| 150 |
+
zip_filename = os.path.basename(args.url)
|
| 151 |
+
zip_save_path = os.path.join(dataset_base_path, zip_filename)
|
| 152 |
+
hr_extract_base = os.path.join(dataset_base_path, 'HR_extracted') # Temp extraction location
|
| 153 |
+
# We will determine the *actual* HR image dir after extraction
|
| 154 |
+
lr_save_dir = os.path.join(dataset_base_path, f'DIV2K_train_LR_bicubic/X{args.scale}') # Following previous convention
|
| 155 |
+
|
| 156 |
+
print(f"--- Configuration ---")
|
| 157 |
+
print(f"Dataset URL: {args.url}")
|
| 158 |
+
print(f"Base Directory: {args.base_dir}")
|
| 159 |
+
print(f"Dataset Name: {args.dataset_name}")
|
| 160 |
+
print(f"Target Scale: x{args.scale}")
|
| 161 |
+
print(f"Zip Save Path: {zip_save_path}")
|
| 162 |
+
print(f"Initial Extract Path: {hr_extract_base}")
|
| 163 |
+
print(f"LR Save Path: {lr_save_dir}")
|
| 164 |
+
print(f"Force Re-run: {args.force}")
|
| 165 |
+
print(f"--------------------")
|
| 166 |
+
|
| 167 |
+
# --- Create Base Directory ---
|
| 168 |
+
os.makedirs(dataset_base_path, exist_ok=True)
|
| 169 |
+
|
| 170 |
+
# --- Step 1: Download ---
|
| 171 |
+
hr_dir_exists = os.path.isdir(hr_extract_base) # Check if base extraction dir exists
|
| 172 |
+
download_needed = not os.path.exists(zip_save_path) or args.force
|
| 173 |
+
|
| 174 |
+
if download_needed:
|
| 175 |
+
if args.force and os.path.exists(zip_save_path):
|
| 176 |
+
print("Force enabled: Removing existing zip file...")
|
| 177 |
+
os.remove(zip_save_path)
|
| 178 |
+
if not download_file(args.url, zip_save_path):
|
| 179 |
+
print("Exiting due to download failure.")
|
| 180 |
+
exit(1)
|
| 181 |
+
elif hr_dir_exists: # If zip exists and hr dir exists, assume download & unzip ok unless forced
|
| 182 |
+
print("Zip file already exists. Skipping download (use --force to override).")
|
| 183 |
+
else: # Zip exists but HR dir doesn't - need to unzip
|
| 184 |
+
print("Zip file found, but extraction directory missing. Will proceed to unzip.")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# --- Step 2: Unzip ---
|
| 188 |
+
# Check if the *potential* content directory already exists. Be a bit lenient here.
|
| 189 |
+
# A more robust check would be to look inside the zip first or check for specific files.
|
| 190 |
+
unzip_needed = not hr_dir_exists or args.force
|
| 191 |
+
|
| 192 |
+
actual_hr_dir = None # Will store the path to the actual images
|
| 193 |
+
|
| 194 |
+
if unzip_needed:
|
| 195 |
+
if args.force and hr_dir_exists:
|
| 196 |
+
print("Force enabled: Removing existing extraction directory...")
|
| 197 |
+
import shutil
|
| 198 |
+
shutil.rmtree(hr_extract_base) # Careful! Removes directory and contents
|
| 199 |
+
|
| 200 |
+
if not os.path.exists(zip_save_path):
|
| 201 |
+
print("Error: Zip file not found, cannot unzip. Please check download step or path.")
|
| 202 |
+
exit(1)
|
| 203 |
+
|
| 204 |
+
os.makedirs(hr_extract_base, exist_ok=True) # Ensure extraction target exists
|
| 205 |
+
if not unzip_file(zip_save_path, hr_extract_base):
|
| 206 |
+
print("Exiting due to extraction failure.")
|
| 207 |
+
exit(1)
|
| 208 |
+
# Find the actual directory containing images post-extraction
|
| 209 |
+
actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR') # e.g., DIV2K_HR
|
| 210 |
+
if not actual_hr_dir or not (glob.glob(os.path.join(actual_hr_dir, '*.png')) or glob.glob(os.path.join(actual_hr_dir, '*.jpg'))):
|
| 211 |
+
print(f"Error: Could not locate the directory with HR images within {hr_extract_base} after extraction.")
|
| 212 |
+
exit(1)
|
| 213 |
+
print(f"Located HR images in: {actual_hr_dir}")
|
| 214 |
+
|
| 215 |
+
else:
|
| 216 |
+
print("HR extraction directory already exists. Skipping unzip (use --force to override).")
|
| 217 |
+
# Try to find the HR dir even if we skipped unzipping
|
| 218 |
+
actual_hr_dir = find_image_dir(hr_extract_base, expected_subdir_suffix=f'{args.dataset_name}_HR')
|
| 219 |
+
if not actual_hr_dir:
|
| 220 |
+
print(f"Error: Could not locate the directory with HR images within existing {hr_extract_base}.")
|
| 221 |
+
exit(1)
|
| 222 |
+
print(f"Using existing HR images from: {actual_hr_dir}")
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# --- Step 3: Process (Downsample) ---
|
| 226 |
+
lr_dir_exists_and_populated = os.path.isdir(lr_save_dir) and len(os.listdir(lr_save_dir)) > 0
|
| 227 |
+
processing_needed = not lr_dir_exists_and_populated or args.force
|
| 228 |
+
|
| 229 |
+
if processing_needed:
|
| 230 |
+
if args.force and lr_dir_exists_and_populated:
|
| 231 |
+
print("Force enabled: Removing existing LR directory...")
|
| 232 |
+
import shutil
|
| 233 |
+
shutil.rmtree(lr_save_dir) # Careful!
|
| 234 |
+
|
| 235 |
+
if not actual_hr_dir:
|
| 236 |
+
print("Error: Cannot proceed with downsampling, HR image directory not determined.")
|
| 237 |
+
exit(1)
|
| 238 |
+
|
| 239 |
+
if not downsample_images(actual_hr_dir, lr_save_dir, args.scale):
|
| 240 |
+
print("Downsampling process failed or produced no images.")
|
| 241 |
+
# Optionally exit here depending on desired behavior
|
| 242 |
+
# exit(1)
|
| 243 |
+
else:
|
| 244 |
+
print("Downsampling finished successfully.")
|
| 245 |
+
else:
|
| 246 |
+
print("LR directory already exists and is populated. Skipping downsampling (use --force to override).")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
print("\n--- Script Finished ---")
|
| 250 |
+
print(f"HR images should be available in/under: {actual_hr_dir}")
|
| 251 |
+
print(f"LR images (x{args.scale}) should be available in: {lr_save_dir}")
|
| 252 |
+
print("You can now use these directories with the SRDataset class.")
|
saved_models/generator_x4_f64_b8_untrained.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2ea5ec7bb7ec436c504f98cf3380a7b2258bf3730cf4ae726b838dd7df52d0b1
|
| 3 |
+
size 3717459
|
saved_models/generator_x4_f64_b8_untrained.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de3bc0ab6790bede102d5a40fd5122bbff83e05b22331f9cc983eb76aace56db
|
| 3 |
+
size 3722508
|
train.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RUN python train.py --epochs 2 --batch_size 2 --subset 10 --num_workers 0 --cpu --patch_size 48
|
| 2 |
+
import torch
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import os
|
| 7 |
+
import argparse
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
# Import custom modules
|
| 12 |
+
from dataset import SRDataset # Make sure dataset.py is in the same directory
|
| 13 |
+
from models import Generator, Discriminator # Make sure models.py is in the same directory
|
| 14 |
+
from loss import PerceptualLoss # Make sure loss.py is in the same directory
|
| 15 |
+
|
| 16 |
+
def train(args):
|
| 17 |
+
# --- 1. Setup ---
|
| 18 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
|
| 19 |
+
print(f"Using device: {device}")
|
| 20 |
+
|
| 21 |
+
# Create directories for saving models and potentially logs/outputs
|
| 22 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 23 |
+
|
| 24 |
+
# --- 2. Data ---
|
| 25 |
+
print("Loading dataset...")
|
| 26 |
+
# Note: args.hr_dir and args.lr_dir are assumed to be valid paths by this point
|
| 27 |
+
# due to checks in the __main__ block
|
| 28 |
+
try:
|
| 29 |
+
train_dataset = SRDataset(hr_dir=args.hr_dir, lr_dir=args.lr_dir, scale_factor=args.scale, patch_size_lr=args.patch_size)
|
| 30 |
+
except FileNotFoundError as e:
|
| 31 |
+
print(f"Error creating dataset: {e}")
|
| 32 |
+
print("Please ensure the specified HR and LR directories contain correctly named image files.")
|
| 33 |
+
exit(1)
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(f"An unexpected error occurred while creating the dataset: {e}")
|
| 36 |
+
exit(1)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Use a smaller subset for initial testing on CPU if needed
|
| 40 |
+
if args.subset > 0 and args.subset < len(train_dataset):
|
| 41 |
+
print(f"Using a subset of {args.subset} images for training.")
|
| 42 |
+
indices = torch.randperm(len(train_dataset))[:args.subset]
|
| 43 |
+
train_dataset = torch.utils.data.Subset(train_dataset, indices)
|
| 44 |
+
elif args.subset >= len(train_dataset) and len(train_dataset) > 0 :
|
| 45 |
+
print(f"Subset size ({args.subset}) is >= dataset size ({len(train_dataset)}). Using full dataset.")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if len(train_dataset) == 0:
|
| 49 |
+
print(f"Error: Dataset is empty after attempting to load. Please check HR dir '{args.hr_dir}' and LR dir '{args.lr_dir}'")
|
| 50 |
+
return
|
| 51 |
+
|
| 52 |
+
train_loader = DataLoader(
|
| 53 |
+
train_dataset,
|
| 54 |
+
batch_size=args.batch_size,
|
| 55 |
+
shuffle=True,
|
| 56 |
+
num_workers=args.num_workers, # Set to 0 if you encounter issues on Windows/macOS
|
| 57 |
+
pin_memory=True if device == 'cuda' else False # pin_memory only useful for GPU
|
| 58 |
+
)
|
| 59 |
+
print(f"Dataset loaded: {len(train_dataset)} training images.")
|
| 60 |
+
print(f"Dataloader: {len(train_loader)} batches per epoch.")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# --- 3. Models ---
|
| 64 |
+
print("Initializing models...")
|
| 65 |
+
generator = Generator(scale_factor=args.scale,
|
| 66 |
+
num_features=args.gen_features,
|
| 67 |
+
num_res_blocks=args.gen_blocks).to(device)
|
| 68 |
+
|
| 69 |
+
discriminator = Discriminator(in_channels=3, # Assuming RGB input for discriminator
|
| 70 |
+
num_features_start=args.disc_features,
|
| 71 |
+
num_blocks=args.disc_blocks).to(device)
|
| 72 |
+
|
| 73 |
+
print(f"Generator params: {sum(p.numel() for p in generator.parameters()):,}")
|
| 74 |
+
print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}")
|
| 75 |
+
|
| 76 |
+
# --- 4. Loss Functions ---
|
| 77 |
+
print("Initializing loss functions...")
|
| 78 |
+
# Content Loss (Pixel-wise) - L1 is common for SR
|
| 79 |
+
content_loss_criterion = nn.L1Loss().to(device)
|
| 80 |
+
|
| 81 |
+
# Adversarial Loss - Measures how well G fools D and D identifies fakes
|
| 82 |
+
adversarial_loss_criterion = nn.BCEWithLogitsLoss().to(device) # More stable than BCELoss + Sigmoid
|
| 83 |
+
|
| 84 |
+
# Perceptual Loss (VGG-based)
|
| 85 |
+
try:
|
| 86 |
+
perceptual_loss_criterion = PerceptualLoss(device=device, use_l1=True) # Using L1 feature distance
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Error initializing Perceptual Loss (check VGG weights download/torchvision install): {e}")
|
| 89 |
+
exit(1)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# --- 5. Optimizers ---
|
| 93 |
+
print("Initializing optimizers...")
|
| 94 |
+
optimizer_g = optim.Adam(generator.parameters(), lr=args.lr_gen, betas=(0.9, 0.999))
|
| 95 |
+
optimizer_d = optim.Adam(discriminator.parameters(), lr=args.lr_disc, betas=(0.9, 0.999))
|
| 96 |
+
|
| 97 |
+
# --- Optional: Learning Rate Scheduler ---
|
| 98 |
+
# Example: scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=args.lr_decay_step, gamma=0.5)
|
| 99 |
+
# Example: scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.5)
|
| 100 |
+
|
| 101 |
+
# --- 6. Training Loop ---
|
| 102 |
+
print("\n--- Starting Training ---")
|
| 103 |
+
start_time = time.time()
|
| 104 |
+
|
| 105 |
+
for epoch in range(1, args.epochs + 1):
|
| 106 |
+
generator.train() # Set generator to training mode
|
| 107 |
+
discriminator.train() # Set discriminator to training mode
|
| 108 |
+
epoch_loss_g = 0.0
|
| 109 |
+
epoch_loss_d = 0.0
|
| 110 |
+
epoch_start_time = time.time()
|
| 111 |
+
|
| 112 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=True) # leave=True to keep bar after epoch
|
| 113 |
+
|
| 114 |
+
for batch_idx, batch in enumerate(progress_bar):
|
| 115 |
+
# Ensure batch is valid (dataset loader might return None on error in __getitem__)
|
| 116 |
+
if batch is None:
|
| 117 |
+
print(f"Warning: Skipping problematic batch at index {batch_idx}")
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
lr_images = batch['lr'].to(device) # Low-resolution images
|
| 122 |
+
hr_images = batch['hr'].to(device) # High-resolution (ground truth) images
|
| 123 |
+
except KeyError as e:
|
| 124 |
+
print(f"Error accessing batch data: {e}. Check SRDataset's __getitem__ return format.")
|
| 125 |
+
continue # Skip this batch
|
| 126 |
+
|
| 127 |
+
# Create labels for adversarial loss
|
| 128 |
+
# Real labels = 1, Fake labels = 0
|
| 129 |
+
# Add some noise or use soft labels (e.g., 0.9 instead of 1.0) can sometimes help stabilize GAN training
|
| 130 |
+
real_labels = torch.ones((hr_images.size(0), 1)).to(device)
|
| 131 |
+
fake_labels = torch.zeros((hr_images.size(0), 1)).to(device)
|
| 132 |
+
|
| 133 |
+
# ---------------------
|
| 134 |
+
# Train Discriminator
|
| 135 |
+
# ---------------------
|
| 136 |
+
optimizer_d.zero_grad()
|
| 137 |
+
|
| 138 |
+
# Generate fake HR images
|
| 139 |
+
# Use torch.no_grad() for generator forward pass when only training discriminator
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
fake_sr_images = generator(lr_images) # No need to detach() if already in no_grad context
|
| 142 |
+
|
| 143 |
+
# Loss for real images
|
| 144 |
+
real_logits = discriminator(hr_images)
|
| 145 |
+
loss_d_real = adversarial_loss_criterion(real_logits, real_labels)
|
| 146 |
+
|
| 147 |
+
# Loss for fake images
|
| 148 |
+
fake_logits = discriminator(fake_sr_images) # Use the generated fakes
|
| 149 |
+
loss_d_fake = adversarial_loss_criterion(fake_logits, fake_labels)
|
| 150 |
+
|
| 151 |
+
# Total discriminator loss
|
| 152 |
+
loss_d = (loss_d_real + loss_d_fake) / 2
|
| 153 |
+
|
| 154 |
+
# Backpropagate and update Discriminator
|
| 155 |
+
loss_d.backward()
|
| 156 |
+
# Optional: Gradient clipping for Discriminator (can help stability)
|
| 157 |
+
# torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
|
| 158 |
+
optimizer_d.step()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# -----------------
|
| 162 |
+
# Train Generator
|
| 163 |
+
# (Typically done less frequently than discriminator, e.g., every k steps,
|
| 164 |
+
# but for simplicity here we do it every step)
|
| 165 |
+
# -----------------
|
| 166 |
+
optimizer_g.zero_grad()
|
| 167 |
+
|
| 168 |
+
# Generate fake HR images (this time track gradients for G)
|
| 169 |
+
generated_sr_images = generator(lr_images)
|
| 170 |
+
|
| 171 |
+
# --- Calculate Generator Losses ---
|
| 172 |
+
# 1. Content Loss (e.g., L1 distance between generated and real HR)
|
| 173 |
+
loss_content = content_loss_criterion(generated_sr_images, hr_images)
|
| 174 |
+
|
| 175 |
+
# 2. Perceptual Loss (VGG feature distance)
|
| 176 |
+
loss_perceptual = perceptual_loss_criterion(generated_sr_images, hr_images)
|
| 177 |
+
|
| 178 |
+
# 3. Adversarial Loss (how well G fools D)
|
| 179 |
+
# We want the discriminator to output 'real' (1) for the generated images
|
| 180 |
+
# Pass generated images through the discriminator (ensure D is not in no_grad context here)
|
| 181 |
+
generated_logits = discriminator(generated_sr_images)
|
| 182 |
+
loss_adversarial = adversarial_loss_criterion(generated_logits, real_labels) # Use real_labels!
|
| 183 |
+
|
| 184 |
+
# --- Combine Generator Losses ---
|
| 185 |
+
# Weights control the balance between pixel accuracy, perceptual quality, and realism
|
| 186 |
+
loss_g = (args.lambda_content * loss_content +
|
| 187 |
+
args.lambda_percep * loss_perceptual +
|
| 188 |
+
args.lambda_adv * loss_adversarial)
|
| 189 |
+
|
| 190 |
+
# Backpropagate and update Generator
|
| 191 |
+
loss_g.backward()
|
| 192 |
+
# Optional: Gradient clipping for Generator
|
| 193 |
+
# torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
|
| 194 |
+
optimizer_g.step()
|
| 195 |
+
|
| 196 |
+
# --- Update running losses and progress bar ---
|
| 197 |
+
epoch_loss_g += loss_g.item()
|
| 198 |
+
epoch_loss_d += loss_d.item()
|
| 199 |
+
progress_bar.set_postfix({
|
| 200 |
+
'Loss G': f"{loss_g.item():.4f}",
|
| 201 |
+
'Loss D': f"{loss_d.item():.4f}",
|
| 202 |
+
# Optional: Show individual components of G loss
|
| 203 |
+
# 'L_Cont': f"{loss_content.item():.4f}",
|
| 204 |
+
# 'L_Perc': f"{loss_perceptual.item():.4f}",
|
| 205 |
+
# 'L_Adv': f"{loss_adversarial.item():.4f}"
|
| 206 |
+
})
|
| 207 |
+
|
| 208 |
+
# --- End of Epoch ---
|
| 209 |
+
avg_loss_g = epoch_loss_g / len(train_loader) if len(train_loader) > 0 else 0
|
| 210 |
+
avg_loss_d = epoch_loss_d / len(train_loader) if len(train_loader) > 0 else 0
|
| 211 |
+
epoch_time = time.time() - epoch_start_time
|
| 212 |
+
|
| 213 |
+
# Optional: Update learning rate schedulers
|
| 214 |
+
# scheduler_g.step()
|
| 215 |
+
# scheduler_d.step()
|
| 216 |
+
# current_lr_g = optimizer_g.param_groups[0]['lr']
|
| 217 |
+
|
| 218 |
+
print(f"\nEpoch {epoch}/{args.epochs} | Time: {epoch_time:.2f}s | Avg Loss G: {avg_loss_g:.4f} | Avg Loss D: {avg_loss_d:.4f}")
|
| 219 |
+
|
| 220 |
+
# --- Save Checkpoint ---
|
| 221 |
+
if epoch % args.save_interval == 0 or epoch == args.epochs:
|
| 222 |
+
gen_path = os.path.join(args.save_dir, f"generator_epoch_{epoch}.pth")
|
| 223 |
+
disc_path = os.path.join(args.save_dir, f"discriminator_epoch_{epoch}.pth")
|
| 224 |
+
try:
|
| 225 |
+
torch.save(generator.state_dict(), gen_path)
|
| 226 |
+
torch.save(discriminator.state_dict(), disc_path)
|
| 227 |
+
print(f"Checkpoint saved for epoch {epoch} to '{args.save_dir}'")
|
| 228 |
+
except Exception as e:
|
| 229 |
+
print(f"Error saving checkpoint for epoch {epoch}: {e}")
|
| 230 |
+
|
| 231 |
+
# --- End of Training ---
|
| 232 |
+
total_time = time.time() - start_time
|
| 233 |
+
print(f"\n--- Training Finished ---")
|
| 234 |
+
print(f"Total time: {total_time // 3600:.0f}h {(total_time % 3600) // 60:.0f}m {total_time % 60:.2f}s")
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
if __name__ == '__main__':
|
| 238 |
+
parser = argparse.ArgumentParser(description='Train SRGAN Model')
|
| 239 |
+
|
| 240 |
+
# --- Data Args ---
|
| 241 |
+
parser.add_argument('--hr_dir', type=str,
|
| 242 |
+
default='./datasets/DIV2K/HR_extracted/DIV2K_train_HR',
|
| 243 |
+
help='Path to high-resolution training images')
|
| 244 |
+
parser.add_argument('--lr_dir', type=str, default=None, # Default to None, will be auto-set
|
| 245 |
+
help='Path to low-resolution training images (auto-set if None)')
|
| 246 |
+
parser.add_argument('--scale', type=int, default=4, help='Upscaling factor')
|
| 247 |
+
parser.add_argument('--batch_size', type=int, default=16, help='Training batch size (reduce for CPU/low VRAM)')
|
| 248 |
+
parser.add_argument('--subset', type=int, default=0, help='Use only N images for debugging (0 to use all)')
|
| 249 |
+
parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for DataLoader (set to 0 for Mac/Windows usually)')
|
| 250 |
+
parser.add_argument('--patch_size', type=int, default=48, help='Size (height/width) of LR patches for training') # NEW ARGUMENT
|
| 251 |
+
|
| 252 |
+
# --- Model Args ---
|
| 253 |
+
parser.add_argument('--gen_features', type=int, default=64, help='Number of features in Generator')
|
| 254 |
+
parser.add_argument('--gen_blocks', type=int, default=16, help='Number of residual blocks in Generator (reduce for faster training/less memory)')
|
| 255 |
+
parser.add_argument('--disc_features', type=int, default=64, help='Number of starting features in Discriminator')
|
| 256 |
+
parser.add_argument('--disc_blocks', type=int, default=3, help='Number of conv blocks in Discriminator')
|
| 257 |
+
|
| 258 |
+
# --- Training Args ---
|
| 259 |
+
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
|
| 260 |
+
parser.add_argument('--lr_gen', type=float, default=1e-4, help='Learning rate for Generator')
|
| 261 |
+
parser.add_argument('--lr_disc', type=float, default=1e-4, help='Learning rate for Discriminator')
|
| 262 |
+
parser.add_argument('--lambda_content', type=float, default=0.01, help='Weight for Content Loss (L1)') # SRGAN paper uses 1e-2 for L1/MSE when combined with VGG
|
| 263 |
+
parser.add_argument('--lambda_percep', type=float, default=1.0, help='Weight for Perceptual Loss') # SRGAN paper uses 1.0
|
| 264 |
+
parser.add_argument('--lambda_adv', type=float, default=0.001, help='Weight for Adversarial Loss') # SRGAN paper uses 1e-3
|
| 265 |
+
|
| 266 |
+
# --- Other Args ---
|
| 267 |
+
parser.add_argument('--save_dir', type=str, default='checkpoints', help='Directory to save model checkpoints')
|
| 268 |
+
parser.add_argument('--save_interval', type=int, default=10, help='Save checkpoint every N epochs')
|
| 269 |
+
parser.add_argument('--cpu', action='store_true', help='Force training on CPU')
|
| 270 |
+
# parser.add_argument('--load_checkpoint', type=str, default=None, help='Path to checkpoint file to resume training') # Example for adding resume functionality
|
| 271 |
+
|
| 272 |
+
args = parser.parse_args()
|
| 273 |
+
|
| 274 |
+
# --- Set and Validate Directories ---
|
| 275 |
+
# Auto-set LR directory based on scale IF it wasn't provided via command line
|
| 276 |
+
if args.lr_dir is None:
|
| 277 |
+
args.lr_dir = f'./datasets/DIV2K/DIV2K_train_LR_bicubic/X{args.scale}'
|
| 278 |
+
print(f"LR directory not provided, automatically setting based on scale {args.scale} to: {args.lr_dir}")
|
| 279 |
+
|
| 280 |
+
# Validate HR directory
|
| 281 |
+
if not os.path.isdir(args.hr_dir):
|
| 282 |
+
print(f"\nERROR: High-Resolution directory not found at '{args.hr_dir}'")
|
| 283 |
+
print("Please ensure the directory exists or provide the correct path using --hr_dir.")
|
| 284 |
+
exit(1) # Exit if the directory is invalid
|
| 285 |
+
# Validate LR directory
|
| 286 |
+
if not os.path.isdir(args.lr_dir):
|
| 287 |
+
print(f"\nERROR: Low-Resolution directory not found at '{args.lr_dir}'")
|
| 288 |
+
print(f"Please ensure the directory exists (check scale factor {args.scale}?) or provide the correct path using --lr_dir.")
|
| 289 |
+
exit(1) # Exit if the directory is invalid
|
| 290 |
+
|
| 291 |
+
print("\n--- Training Configuration ---")
|
| 292 |
+
# Print configuration cleanly
|
| 293 |
+
config_dict = vars(args)
|
| 294 |
+
# Calculate terminal width for better formatting (optional)
|
| 295 |
+
try:
|
| 296 |
+
term_width = os.get_terminal_size().columns
|
| 297 |
+
except OSError:
|
| 298 |
+
term_width = 80 # Default if terminal size unavailable
|
| 299 |
+
|
| 300 |
+
print("-" * term_width)
|
| 301 |
+
for key, value in config_dict.items():
|
| 302 |
+
print(f"{key:<25}: {value}") # Format for alignment
|
| 303 |
+
print("-" * term_width)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# Start the training process
|
| 307 |
+
train(args)
|