File size: 15,164 Bytes
fd5c0a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 |
# RUN python train.py --epochs 2 --batch_size 2 --subset 10 --num_workers 0 --cpu --patch_size 48
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import argparse
from tqdm import tqdm
import time
# Import custom modules
from dataset import SRDataset # Make sure dataset.py is in the same directory
from models import Generator, Discriminator # Make sure models.py is in the same directory
from loss import PerceptualLoss # Make sure loss.py is in the same directory
def train(args):
# --- 1. Setup ---
device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
print(f"Using device: {device}")
# Create directories for saving models and potentially logs/outputs
os.makedirs(args.save_dir, exist_ok=True)
# --- 2. Data ---
print("Loading dataset...")
# Note: args.hr_dir and args.lr_dir are assumed to be valid paths by this point
# due to checks in the __main__ block
try:
train_dataset = SRDataset(hr_dir=args.hr_dir, lr_dir=args.lr_dir, scale_factor=args.scale, patch_size_lr=args.patch_size)
except FileNotFoundError as e:
print(f"Error creating dataset: {e}")
print("Please ensure the specified HR and LR directories contain correctly named image files.")
exit(1)
except Exception as e:
print(f"An unexpected error occurred while creating the dataset: {e}")
exit(1)
# Use a smaller subset for initial testing on CPU if needed
if args.subset > 0 and args.subset < len(train_dataset):
print(f"Using a subset of {args.subset} images for training.")
indices = torch.randperm(len(train_dataset))[:args.subset]
train_dataset = torch.utils.data.Subset(train_dataset, indices)
elif args.subset >= len(train_dataset) and len(train_dataset) > 0 :
print(f"Subset size ({args.subset}) is >= dataset size ({len(train_dataset)}). Using full dataset.")
if len(train_dataset) == 0:
print(f"Error: Dataset is empty after attempting to load. Please check HR dir '{args.hr_dir}' and LR dir '{args.lr_dir}'")
return
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers, # Set to 0 if you encounter issues on Windows/macOS
pin_memory=True if device == 'cuda' else False # pin_memory only useful for GPU
)
print(f"Dataset loaded: {len(train_dataset)} training images.")
print(f"Dataloader: {len(train_loader)} batches per epoch.")
# --- 3. Models ---
print("Initializing models...")
generator = Generator(scale_factor=args.scale,
num_features=args.gen_features,
num_res_blocks=args.gen_blocks).to(device)
discriminator = Discriminator(in_channels=3, # Assuming RGB input for discriminator
num_features_start=args.disc_features,
num_blocks=args.disc_blocks).to(device)
print(f"Generator params: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}")
# --- 4. Loss Functions ---
print("Initializing loss functions...")
# Content Loss (Pixel-wise) - L1 is common for SR
content_loss_criterion = nn.L1Loss().to(device)
# Adversarial Loss - Measures how well G fools D and D identifies fakes
adversarial_loss_criterion = nn.BCEWithLogitsLoss().to(device) # More stable than BCELoss + Sigmoid
# Perceptual Loss (VGG-based)
try:
perceptual_loss_criterion = PerceptualLoss(device=device, use_l1=True) # Using L1 feature distance
except Exception as e:
print(f"Error initializing Perceptual Loss (check VGG weights download/torchvision install): {e}")
exit(1)
# --- 5. Optimizers ---
print("Initializing optimizers...")
optimizer_g = optim.Adam(generator.parameters(), lr=args.lr_gen, betas=(0.9, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=args.lr_disc, betas=(0.9, 0.999))
# --- Optional: Learning Rate Scheduler ---
# Example: scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=args.lr_decay_step, gamma=0.5)
# Example: scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.5)
# --- 6. Training Loop ---
print("\n--- Starting Training ---")
start_time = time.time()
for epoch in range(1, args.epochs + 1):
generator.train() # Set generator to training mode
discriminator.train() # Set discriminator to training mode
epoch_loss_g = 0.0
epoch_loss_d = 0.0
epoch_start_time = time.time()
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=True) # leave=True to keep bar after epoch
for batch_idx, batch in enumerate(progress_bar):
# Ensure batch is valid (dataset loader might return None on error in __getitem__)
if batch is None:
print(f"Warning: Skipping problematic batch at index {batch_idx}")
continue
try:
lr_images = batch['lr'].to(device) # Low-resolution images
hr_images = batch['hr'].to(device) # High-resolution (ground truth) images
except KeyError as e:
print(f"Error accessing batch data: {e}. Check SRDataset's __getitem__ return format.")
continue # Skip this batch
# Create labels for adversarial loss
# Real labels = 1, Fake labels = 0
# Add some noise or use soft labels (e.g., 0.9 instead of 1.0) can sometimes help stabilize GAN training
real_labels = torch.ones((hr_images.size(0), 1)).to(device)
fake_labels = torch.zeros((hr_images.size(0), 1)).to(device)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_d.zero_grad()
# Generate fake HR images
# Use torch.no_grad() for generator forward pass when only training discriminator
with torch.no_grad():
fake_sr_images = generator(lr_images) # No need to detach() if already in no_grad context
# Loss for real images
real_logits = discriminator(hr_images)
loss_d_real = adversarial_loss_criterion(real_logits, real_labels)
# Loss for fake images
fake_logits = discriminator(fake_sr_images) # Use the generated fakes
loss_d_fake = adversarial_loss_criterion(fake_logits, fake_labels)
# Total discriminator loss
loss_d = (loss_d_real + loss_d_fake) / 2
# Backpropagate and update Discriminator
loss_d.backward()
# Optional: Gradient clipping for Discriminator (can help stability)
# torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
optimizer_d.step()
# -----------------
# Train Generator
# (Typically done less frequently than discriminator, e.g., every k steps,
# but for simplicity here we do it every step)
# -----------------
optimizer_g.zero_grad()
# Generate fake HR images (this time track gradients for G)
generated_sr_images = generator(lr_images)
# --- Calculate Generator Losses ---
# 1. Content Loss (e.g., L1 distance between generated and real HR)
loss_content = content_loss_criterion(generated_sr_images, hr_images)
# 2. Perceptual Loss (VGG feature distance)
loss_perceptual = perceptual_loss_criterion(generated_sr_images, hr_images)
# 3. Adversarial Loss (how well G fools D)
# We want the discriminator to output 'real' (1) for the generated images
# Pass generated images through the discriminator (ensure D is not in no_grad context here)
generated_logits = discriminator(generated_sr_images)
loss_adversarial = adversarial_loss_criterion(generated_logits, real_labels) # Use real_labels!
# --- Combine Generator Losses ---
# Weights control the balance between pixel accuracy, perceptual quality, and realism
loss_g = (args.lambda_content * loss_content +
args.lambda_percep * loss_perceptual +
args.lambda_adv * loss_adversarial)
# Backpropagate and update Generator
loss_g.backward()
# Optional: Gradient clipping for Generator
# torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
optimizer_g.step()
# --- Update running losses and progress bar ---
epoch_loss_g += loss_g.item()
epoch_loss_d += loss_d.item()
progress_bar.set_postfix({
'Loss G': f"{loss_g.item():.4f}",
'Loss D': f"{loss_d.item():.4f}",
# Optional: Show individual components of G loss
# 'L_Cont': f"{loss_content.item():.4f}",
# 'L_Perc': f"{loss_perceptual.item():.4f}",
# 'L_Adv': f"{loss_adversarial.item():.4f}"
})
# --- End of Epoch ---
avg_loss_g = epoch_loss_g / len(train_loader) if len(train_loader) > 0 else 0
avg_loss_d = epoch_loss_d / len(train_loader) if len(train_loader) > 0 else 0
epoch_time = time.time() - epoch_start_time
# Optional: Update learning rate schedulers
# scheduler_g.step()
# scheduler_d.step()
# current_lr_g = optimizer_g.param_groups[0]['lr']
print(f"\nEpoch {epoch}/{args.epochs} | Time: {epoch_time:.2f}s | Avg Loss G: {avg_loss_g:.4f} | Avg Loss D: {avg_loss_d:.4f}")
# --- Save Checkpoint ---
if epoch % args.save_interval == 0 or epoch == args.epochs:
gen_path = os.path.join(args.save_dir, f"generator_epoch_{epoch}.pth")
disc_path = os.path.join(args.save_dir, f"discriminator_epoch_{epoch}.pth")
try:
torch.save(generator.state_dict(), gen_path)
torch.save(discriminator.state_dict(), disc_path)
print(f"Checkpoint saved for epoch {epoch} to '{args.save_dir}'")
except Exception as e:
print(f"Error saving checkpoint for epoch {epoch}: {e}")
# --- End of Training ---
total_time = time.time() - start_time
print(f"\n--- Training Finished ---")
print(f"Total time: {total_time // 3600:.0f}h {(total_time % 3600) // 60:.0f}m {total_time % 60:.2f}s")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train SRGAN Model')
# --- Data Args ---
parser.add_argument('--hr_dir', type=str,
default='./datasets/DIV2K/HR_extracted/DIV2K_train_HR',
help='Path to high-resolution training images')
parser.add_argument('--lr_dir', type=str, default=None, # Default to None, will be auto-set
help='Path to low-resolution training images (auto-set if None)')
parser.add_argument('--scale', type=int, default=4, help='Upscaling factor')
parser.add_argument('--batch_size', type=int, default=16, help='Training batch size (reduce for CPU/low VRAM)')
parser.add_argument('--subset', type=int, default=0, help='Use only N images for debugging (0 to use all)')
parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for DataLoader (set to 0 for Mac/Windows usually)')
parser.add_argument('--patch_size', type=int, default=48, help='Size (height/width) of LR patches for training') # NEW ARGUMENT
# --- Model Args ---
parser.add_argument('--gen_features', type=int, default=64, help='Number of features in Generator')
parser.add_argument('--gen_blocks', type=int, default=16, help='Number of residual blocks in Generator (reduce for faster training/less memory)')
parser.add_argument('--disc_features', type=int, default=64, help='Number of starting features in Discriminator')
parser.add_argument('--disc_blocks', type=int, default=3, help='Number of conv blocks in Discriminator')
# --- Training Args ---
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
parser.add_argument('--lr_gen', type=float, default=1e-4, help='Learning rate for Generator')
parser.add_argument('--lr_disc', type=float, default=1e-4, help='Learning rate for Discriminator')
parser.add_argument('--lambda_content', type=float, default=0.01, help='Weight for Content Loss (L1)') # SRGAN paper uses 1e-2 for L1/MSE when combined with VGG
parser.add_argument('--lambda_percep', type=float, default=1.0, help='Weight for Perceptual Loss') # SRGAN paper uses 1.0
parser.add_argument('--lambda_adv', type=float, default=0.001, help='Weight for Adversarial Loss') # SRGAN paper uses 1e-3
# --- Other Args ---
parser.add_argument('--save_dir', type=str, default='checkpoints', help='Directory to save model checkpoints')
parser.add_argument('--save_interval', type=int, default=10, help='Save checkpoint every N epochs')
parser.add_argument('--cpu', action='store_true', help='Force training on CPU')
# parser.add_argument('--load_checkpoint', type=str, default=None, help='Path to checkpoint file to resume training') # Example for adding resume functionality
args = parser.parse_args()
# --- Set and Validate Directories ---
# Auto-set LR directory based on scale IF it wasn't provided via command line
if args.lr_dir is None:
args.lr_dir = f'./datasets/DIV2K/DIV2K_train_LR_bicubic/X{args.scale}'
print(f"LR directory not provided, automatically setting based on scale {args.scale} to: {args.lr_dir}")
# Validate HR directory
if not os.path.isdir(args.hr_dir):
print(f"\nERROR: High-Resolution directory not found at '{args.hr_dir}'")
print("Please ensure the directory exists or provide the correct path using --hr_dir.")
exit(1) # Exit if the directory is invalid
# Validate LR directory
if not os.path.isdir(args.lr_dir):
print(f"\nERROR: Low-Resolution directory not found at '{args.lr_dir}'")
print(f"Please ensure the directory exists (check scale factor {args.scale}?) or provide the correct path using --lr_dir.")
exit(1) # Exit if the directory is invalid
print("\n--- Training Configuration ---")
# Print configuration cleanly
config_dict = vars(args)
# Calculate terminal width for better formatting (optional)
try:
term_width = os.get_terminal_size().columns
except OSError:
term_width = 80 # Default if terminal size unavailable
print("-" * term_width)
for key, value in config_dict.items():
print(f"{key:<25}: {value}") # Format for alignment
print("-" * term_width)
# Start the training process
train(args) |