# RUN python train.py --epochs 2 --batch_size 2 --subset 10 --num_workers 0 --cpu --patch_size 48 import torch import torch.optim as optim import torch.nn as nn from torch.utils.data import DataLoader import os import argparse from tqdm import tqdm import time # Import custom modules from dataset import SRDataset # Make sure dataset.py is in the same directory from models import Generator, Discriminator # Make sure models.py is in the same directory from loss import PerceptualLoss # Make sure loss.py is in the same directory def train(args): # --- 1. Setup --- device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu") print(f"Using device: {device}") # Create directories for saving models and potentially logs/outputs os.makedirs(args.save_dir, exist_ok=True) # --- 2. Data --- print("Loading dataset...") # Note: args.hr_dir and args.lr_dir are assumed to be valid paths by this point # due to checks in the __main__ block try: train_dataset = SRDataset(hr_dir=args.hr_dir, lr_dir=args.lr_dir, scale_factor=args.scale, patch_size_lr=args.patch_size) except FileNotFoundError as e: print(f"Error creating dataset: {e}") print("Please ensure the specified HR and LR directories contain correctly named image files.") exit(1) except Exception as e: print(f"An unexpected error occurred while creating the dataset: {e}") exit(1) # Use a smaller subset for initial testing on CPU if needed if args.subset > 0 and args.subset < len(train_dataset): print(f"Using a subset of {args.subset} images for training.") indices = torch.randperm(len(train_dataset))[:args.subset] train_dataset = torch.utils.data.Subset(train_dataset, indices) elif args.subset >= len(train_dataset) and len(train_dataset) > 0 : print(f"Subset size ({args.subset}) is >= dataset size ({len(train_dataset)}). Using full dataset.") if len(train_dataset) == 0: print(f"Error: Dataset is empty after attempting to load. Please check HR dir '{args.hr_dir}' and LR dir '{args.lr_dir}'") return train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, # Set to 0 if you encounter issues on Windows/macOS pin_memory=True if device == 'cuda' else False # pin_memory only useful for GPU ) print(f"Dataset loaded: {len(train_dataset)} training images.") print(f"Dataloader: {len(train_loader)} batches per epoch.") # --- 3. Models --- print("Initializing models...") generator = Generator(scale_factor=args.scale, num_features=args.gen_features, num_res_blocks=args.gen_blocks).to(device) discriminator = Discriminator(in_channels=3, # Assuming RGB input for discriminator num_features_start=args.disc_features, num_blocks=args.disc_blocks).to(device) print(f"Generator params: {sum(p.numel() for p in generator.parameters()):,}") print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}") # --- 4. Loss Functions --- print("Initializing loss functions...") # Content Loss (Pixel-wise) - L1 is common for SR content_loss_criterion = nn.L1Loss().to(device) # Adversarial Loss - Measures how well G fools D and D identifies fakes adversarial_loss_criterion = nn.BCEWithLogitsLoss().to(device) # More stable than BCELoss + Sigmoid # Perceptual Loss (VGG-based) try: perceptual_loss_criterion = PerceptualLoss(device=device, use_l1=True) # Using L1 feature distance except Exception as e: print(f"Error initializing Perceptual Loss (check VGG weights download/torchvision install): {e}") exit(1) # --- 5. Optimizers --- print("Initializing optimizers...") optimizer_g = optim.Adam(generator.parameters(), lr=args.lr_gen, betas=(0.9, 0.999)) optimizer_d = optim.Adam(discriminator.parameters(), lr=args.lr_disc, betas=(0.9, 0.999)) # --- Optional: Learning Rate Scheduler --- # Example: scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=args.lr_decay_step, gamma=0.5) # Example: scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.5) # --- 6. Training Loop --- print("\n--- Starting Training ---") start_time = time.time() for epoch in range(1, args.epochs + 1): generator.train() # Set generator to training mode discriminator.train() # Set discriminator to training mode epoch_loss_g = 0.0 epoch_loss_d = 0.0 epoch_start_time = time.time() progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=True) # leave=True to keep bar after epoch for batch_idx, batch in enumerate(progress_bar): # Ensure batch is valid (dataset loader might return None on error in __getitem__) if batch is None: print(f"Warning: Skipping problematic batch at index {batch_idx}") continue try: lr_images = batch['lr'].to(device) # Low-resolution images hr_images = batch['hr'].to(device) # High-resolution (ground truth) images except KeyError as e: print(f"Error accessing batch data: {e}. Check SRDataset's __getitem__ return format.") continue # Skip this batch # Create labels for adversarial loss # Real labels = 1, Fake labels = 0 # Add some noise or use soft labels (e.g., 0.9 instead of 1.0) can sometimes help stabilize GAN training real_labels = torch.ones((hr_images.size(0), 1)).to(device) fake_labels = torch.zeros((hr_images.size(0), 1)).to(device) # --------------------- # Train Discriminator # --------------------- optimizer_d.zero_grad() # Generate fake HR images # Use torch.no_grad() for generator forward pass when only training discriminator with torch.no_grad(): fake_sr_images = generator(lr_images) # No need to detach() if already in no_grad context # Loss for real images real_logits = discriminator(hr_images) loss_d_real = adversarial_loss_criterion(real_logits, real_labels) # Loss for fake images fake_logits = discriminator(fake_sr_images) # Use the generated fakes loss_d_fake = adversarial_loss_criterion(fake_logits, fake_labels) # Total discriminator loss loss_d = (loss_d_real + loss_d_fake) / 2 # Backpropagate and update Discriminator loss_d.backward() # Optional: Gradient clipping for Discriminator (can help stability) # torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0) optimizer_d.step() # ----------------- # Train Generator # (Typically done less frequently than discriminator, e.g., every k steps, # but for simplicity here we do it every step) # ----------------- optimizer_g.zero_grad() # Generate fake HR images (this time track gradients for G) generated_sr_images = generator(lr_images) # --- Calculate Generator Losses --- # 1. Content Loss (e.g., L1 distance between generated and real HR) loss_content = content_loss_criterion(generated_sr_images, hr_images) # 2. Perceptual Loss (VGG feature distance) loss_perceptual = perceptual_loss_criterion(generated_sr_images, hr_images) # 3. Adversarial Loss (how well G fools D) # We want the discriminator to output 'real' (1) for the generated images # Pass generated images through the discriminator (ensure D is not in no_grad context here) generated_logits = discriminator(generated_sr_images) loss_adversarial = adversarial_loss_criterion(generated_logits, real_labels) # Use real_labels! # --- Combine Generator Losses --- # Weights control the balance between pixel accuracy, perceptual quality, and realism loss_g = (args.lambda_content * loss_content + args.lambda_percep * loss_perceptual + args.lambda_adv * loss_adversarial) # Backpropagate and update Generator loss_g.backward() # Optional: Gradient clipping for Generator # torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0) optimizer_g.step() # --- Update running losses and progress bar --- epoch_loss_g += loss_g.item() epoch_loss_d += loss_d.item() progress_bar.set_postfix({ 'Loss G': f"{loss_g.item():.4f}", 'Loss D': f"{loss_d.item():.4f}", # Optional: Show individual components of G loss # 'L_Cont': f"{loss_content.item():.4f}", # 'L_Perc': f"{loss_perceptual.item():.4f}", # 'L_Adv': f"{loss_adversarial.item():.4f}" }) # --- End of Epoch --- avg_loss_g = epoch_loss_g / len(train_loader) if len(train_loader) > 0 else 0 avg_loss_d = epoch_loss_d / len(train_loader) if len(train_loader) > 0 else 0 epoch_time = time.time() - epoch_start_time # Optional: Update learning rate schedulers # scheduler_g.step() # scheduler_d.step() # current_lr_g = optimizer_g.param_groups[0]['lr'] print(f"\nEpoch {epoch}/{args.epochs} | Time: {epoch_time:.2f}s | Avg Loss G: {avg_loss_g:.4f} | Avg Loss D: {avg_loss_d:.4f}") # --- Save Checkpoint --- if epoch % args.save_interval == 0 or epoch == args.epochs: gen_path = os.path.join(args.save_dir, f"generator_epoch_{epoch}.pth") disc_path = os.path.join(args.save_dir, f"discriminator_epoch_{epoch}.pth") try: torch.save(generator.state_dict(), gen_path) torch.save(discriminator.state_dict(), disc_path) print(f"Checkpoint saved for epoch {epoch} to '{args.save_dir}'") except Exception as e: print(f"Error saving checkpoint for epoch {epoch}: {e}") # --- End of Training --- total_time = time.time() - start_time print(f"\n--- Training Finished ---") print(f"Total time: {total_time // 3600:.0f}h {(total_time % 3600) // 60:.0f}m {total_time % 60:.2f}s") if __name__ == '__main__': parser = argparse.ArgumentParser(description='Train SRGAN Model') # --- Data Args --- parser.add_argument('--hr_dir', type=str, default='./datasets/DIV2K/HR_extracted/DIV2K_train_HR', help='Path to high-resolution training images') parser.add_argument('--lr_dir', type=str, default=None, # Default to None, will be auto-set help='Path to low-resolution training images (auto-set if None)') parser.add_argument('--scale', type=int, default=4, help='Upscaling factor') parser.add_argument('--batch_size', type=int, default=16, help='Training batch size (reduce for CPU/low VRAM)') parser.add_argument('--subset', type=int, default=0, help='Use only N images for debugging (0 to use all)') parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for DataLoader (set to 0 for Mac/Windows usually)') parser.add_argument('--patch_size', type=int, default=48, help='Size (height/width) of LR patches for training') # NEW ARGUMENT # --- Model Args --- parser.add_argument('--gen_features', type=int, default=64, help='Number of features in Generator') parser.add_argument('--gen_blocks', type=int, default=16, help='Number of residual blocks in Generator (reduce for faster training/less memory)') parser.add_argument('--disc_features', type=int, default=64, help='Number of starting features in Discriminator') parser.add_argument('--disc_blocks', type=int, default=3, help='Number of conv blocks in Discriminator') # --- Training Args --- parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs') parser.add_argument('--lr_gen', type=float, default=1e-4, help='Learning rate for Generator') parser.add_argument('--lr_disc', type=float, default=1e-4, help='Learning rate for Discriminator') parser.add_argument('--lambda_content', type=float, default=0.01, help='Weight for Content Loss (L1)') # SRGAN paper uses 1e-2 for L1/MSE when combined with VGG parser.add_argument('--lambda_percep', type=float, default=1.0, help='Weight for Perceptual Loss') # SRGAN paper uses 1.0 parser.add_argument('--lambda_adv', type=float, default=0.001, help='Weight for Adversarial Loss') # SRGAN paper uses 1e-3 # --- Other Args --- parser.add_argument('--save_dir', type=str, default='checkpoints', help='Directory to save model checkpoints') parser.add_argument('--save_interval', type=int, default=10, help='Save checkpoint every N epochs') parser.add_argument('--cpu', action='store_true', help='Force training on CPU') # parser.add_argument('--load_checkpoint', type=str, default=None, help='Path to checkpoint file to resume training') # Example for adding resume functionality args = parser.parse_args() # --- Set and Validate Directories --- # Auto-set LR directory based on scale IF it wasn't provided via command line if args.lr_dir is None: args.lr_dir = f'./datasets/DIV2K/DIV2K_train_LR_bicubic/X{args.scale}' print(f"LR directory not provided, automatically setting based on scale {args.scale} to: {args.lr_dir}") # Validate HR directory if not os.path.isdir(args.hr_dir): print(f"\nERROR: High-Resolution directory not found at '{args.hr_dir}'") print("Please ensure the directory exists or provide the correct path using --hr_dir.") exit(1) # Exit if the directory is invalid # Validate LR directory if not os.path.isdir(args.lr_dir): print(f"\nERROR: Low-Resolution directory not found at '{args.lr_dir}'") print(f"Please ensure the directory exists (check scale factor {args.scale}?) or provide the correct path using --lr_dir.") exit(1) # Exit if the directory is invalid print("\n--- Training Configuration ---") # Print configuration cleanly config_dict = vars(args) # Calculate terminal width for better formatting (optional) try: term_width = os.get_terminal_size().columns except OSError: term_width = 80 # Default if terminal size unavailable print("-" * term_width) for key, value in config_dict.items(): print(f"{key:<25}: {value}") # Format for alignment print("-" * term_width) # Start the training process train(args)