OxO_Image-Repair / train.py
Gordon-H's picture
Upload 13 files
fd5c0a6 verified
# RUN python train.py --epochs 2 --batch_size 2 --subset 10 --num_workers 0 --cpu --patch_size 48
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import argparse
from tqdm import tqdm
import time
# Import custom modules
from dataset import SRDataset # Make sure dataset.py is in the same directory
from models import Generator, Discriminator # Make sure models.py is in the same directory
from loss import PerceptualLoss # Make sure loss.py is in the same directory
def train(args):
# --- 1. Setup ---
device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
print(f"Using device: {device}")
# Create directories for saving models and potentially logs/outputs
os.makedirs(args.save_dir, exist_ok=True)
# --- 2. Data ---
print("Loading dataset...")
# Note: args.hr_dir and args.lr_dir are assumed to be valid paths by this point
# due to checks in the __main__ block
try:
train_dataset = SRDataset(hr_dir=args.hr_dir, lr_dir=args.lr_dir, scale_factor=args.scale, patch_size_lr=args.patch_size)
except FileNotFoundError as e:
print(f"Error creating dataset: {e}")
print("Please ensure the specified HR and LR directories contain correctly named image files.")
exit(1)
except Exception as e:
print(f"An unexpected error occurred while creating the dataset: {e}")
exit(1)
# Use a smaller subset for initial testing on CPU if needed
if args.subset > 0 and args.subset < len(train_dataset):
print(f"Using a subset of {args.subset} images for training.")
indices = torch.randperm(len(train_dataset))[:args.subset]
train_dataset = torch.utils.data.Subset(train_dataset, indices)
elif args.subset >= len(train_dataset) and len(train_dataset) > 0 :
print(f"Subset size ({args.subset}) is >= dataset size ({len(train_dataset)}). Using full dataset.")
if len(train_dataset) == 0:
print(f"Error: Dataset is empty after attempting to load. Please check HR dir '{args.hr_dir}' and LR dir '{args.lr_dir}'")
return
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers, # Set to 0 if you encounter issues on Windows/macOS
pin_memory=True if device == 'cuda' else False # pin_memory only useful for GPU
)
print(f"Dataset loaded: {len(train_dataset)} training images.")
print(f"Dataloader: {len(train_loader)} batches per epoch.")
# --- 3. Models ---
print("Initializing models...")
generator = Generator(scale_factor=args.scale,
num_features=args.gen_features,
num_res_blocks=args.gen_blocks).to(device)
discriminator = Discriminator(in_channels=3, # Assuming RGB input for discriminator
num_features_start=args.disc_features,
num_blocks=args.disc_blocks).to(device)
print(f"Generator params: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}")
# --- 4. Loss Functions ---
print("Initializing loss functions...")
# Content Loss (Pixel-wise) - L1 is common for SR
content_loss_criterion = nn.L1Loss().to(device)
# Adversarial Loss - Measures how well G fools D and D identifies fakes
adversarial_loss_criterion = nn.BCEWithLogitsLoss().to(device) # More stable than BCELoss + Sigmoid
# Perceptual Loss (VGG-based)
try:
perceptual_loss_criterion = PerceptualLoss(device=device, use_l1=True) # Using L1 feature distance
except Exception as e:
print(f"Error initializing Perceptual Loss (check VGG weights download/torchvision install): {e}")
exit(1)
# --- 5. Optimizers ---
print("Initializing optimizers...")
optimizer_g = optim.Adam(generator.parameters(), lr=args.lr_gen, betas=(0.9, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=args.lr_disc, betas=(0.9, 0.999))
# --- Optional: Learning Rate Scheduler ---
# Example: scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=args.lr_decay_step, gamma=0.5)
# Example: scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.5)
# --- 6. Training Loop ---
print("\n--- Starting Training ---")
start_time = time.time()
for epoch in range(1, args.epochs + 1):
generator.train() # Set generator to training mode
discriminator.train() # Set discriminator to training mode
epoch_loss_g = 0.0
epoch_loss_d = 0.0
epoch_start_time = time.time()
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=True) # leave=True to keep bar after epoch
for batch_idx, batch in enumerate(progress_bar):
# Ensure batch is valid (dataset loader might return None on error in __getitem__)
if batch is None:
print(f"Warning: Skipping problematic batch at index {batch_idx}")
continue
try:
lr_images = batch['lr'].to(device) # Low-resolution images
hr_images = batch['hr'].to(device) # High-resolution (ground truth) images
except KeyError as e:
print(f"Error accessing batch data: {e}. Check SRDataset's __getitem__ return format.")
continue # Skip this batch
# Create labels for adversarial loss
# Real labels = 1, Fake labels = 0
# Add some noise or use soft labels (e.g., 0.9 instead of 1.0) can sometimes help stabilize GAN training
real_labels = torch.ones((hr_images.size(0), 1)).to(device)
fake_labels = torch.zeros((hr_images.size(0), 1)).to(device)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_d.zero_grad()
# Generate fake HR images
# Use torch.no_grad() for generator forward pass when only training discriminator
with torch.no_grad():
fake_sr_images = generator(lr_images) # No need to detach() if already in no_grad context
# Loss for real images
real_logits = discriminator(hr_images)
loss_d_real = adversarial_loss_criterion(real_logits, real_labels)
# Loss for fake images
fake_logits = discriminator(fake_sr_images) # Use the generated fakes
loss_d_fake = adversarial_loss_criterion(fake_logits, fake_labels)
# Total discriminator loss
loss_d = (loss_d_real + loss_d_fake) / 2
# Backpropagate and update Discriminator
loss_d.backward()
# Optional: Gradient clipping for Discriminator (can help stability)
# torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
optimizer_d.step()
# -----------------
# Train Generator
# (Typically done less frequently than discriminator, e.g., every k steps,
# but for simplicity here we do it every step)
# -----------------
optimizer_g.zero_grad()
# Generate fake HR images (this time track gradients for G)
generated_sr_images = generator(lr_images)
# --- Calculate Generator Losses ---
# 1. Content Loss (e.g., L1 distance between generated and real HR)
loss_content = content_loss_criterion(generated_sr_images, hr_images)
# 2. Perceptual Loss (VGG feature distance)
loss_perceptual = perceptual_loss_criterion(generated_sr_images, hr_images)
# 3. Adversarial Loss (how well G fools D)
# We want the discriminator to output 'real' (1) for the generated images
# Pass generated images through the discriminator (ensure D is not in no_grad context here)
generated_logits = discriminator(generated_sr_images)
loss_adversarial = adversarial_loss_criterion(generated_logits, real_labels) # Use real_labels!
# --- Combine Generator Losses ---
# Weights control the balance between pixel accuracy, perceptual quality, and realism
loss_g = (args.lambda_content * loss_content +
args.lambda_percep * loss_perceptual +
args.lambda_adv * loss_adversarial)
# Backpropagate and update Generator
loss_g.backward()
# Optional: Gradient clipping for Generator
# torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
optimizer_g.step()
# --- Update running losses and progress bar ---
epoch_loss_g += loss_g.item()
epoch_loss_d += loss_d.item()
progress_bar.set_postfix({
'Loss G': f"{loss_g.item():.4f}",
'Loss D': f"{loss_d.item():.4f}",
# Optional: Show individual components of G loss
# 'L_Cont': f"{loss_content.item():.4f}",
# 'L_Perc': f"{loss_perceptual.item():.4f}",
# 'L_Adv': f"{loss_adversarial.item():.4f}"
})
# --- End of Epoch ---
avg_loss_g = epoch_loss_g / len(train_loader) if len(train_loader) > 0 else 0
avg_loss_d = epoch_loss_d / len(train_loader) if len(train_loader) > 0 else 0
epoch_time = time.time() - epoch_start_time
# Optional: Update learning rate schedulers
# scheduler_g.step()
# scheduler_d.step()
# current_lr_g = optimizer_g.param_groups[0]['lr']
print(f"\nEpoch {epoch}/{args.epochs} | Time: {epoch_time:.2f}s | Avg Loss G: {avg_loss_g:.4f} | Avg Loss D: {avg_loss_d:.4f}")
# --- Save Checkpoint ---
if epoch % args.save_interval == 0 or epoch == args.epochs:
gen_path = os.path.join(args.save_dir, f"generator_epoch_{epoch}.pth")
disc_path = os.path.join(args.save_dir, f"discriminator_epoch_{epoch}.pth")
try:
torch.save(generator.state_dict(), gen_path)
torch.save(discriminator.state_dict(), disc_path)
print(f"Checkpoint saved for epoch {epoch} to '{args.save_dir}'")
except Exception as e:
print(f"Error saving checkpoint for epoch {epoch}: {e}")
# --- End of Training ---
total_time = time.time() - start_time
print(f"\n--- Training Finished ---")
print(f"Total time: {total_time // 3600:.0f}h {(total_time % 3600) // 60:.0f}m {total_time % 60:.2f}s")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train SRGAN Model')
# --- Data Args ---
parser.add_argument('--hr_dir', type=str,
default='./datasets/DIV2K/HR_extracted/DIV2K_train_HR',
help='Path to high-resolution training images')
parser.add_argument('--lr_dir', type=str, default=None, # Default to None, will be auto-set
help='Path to low-resolution training images (auto-set if None)')
parser.add_argument('--scale', type=int, default=4, help='Upscaling factor')
parser.add_argument('--batch_size', type=int, default=16, help='Training batch size (reduce for CPU/low VRAM)')
parser.add_argument('--subset', type=int, default=0, help='Use only N images for debugging (0 to use all)')
parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for DataLoader (set to 0 for Mac/Windows usually)')
parser.add_argument('--patch_size', type=int, default=48, help='Size (height/width) of LR patches for training') # NEW ARGUMENT
# --- Model Args ---
parser.add_argument('--gen_features', type=int, default=64, help='Number of features in Generator')
parser.add_argument('--gen_blocks', type=int, default=16, help='Number of residual blocks in Generator (reduce for faster training/less memory)')
parser.add_argument('--disc_features', type=int, default=64, help='Number of starting features in Discriminator')
parser.add_argument('--disc_blocks', type=int, default=3, help='Number of conv blocks in Discriminator')
# --- Training Args ---
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
parser.add_argument('--lr_gen', type=float, default=1e-4, help='Learning rate for Generator')
parser.add_argument('--lr_disc', type=float, default=1e-4, help='Learning rate for Discriminator')
parser.add_argument('--lambda_content', type=float, default=0.01, help='Weight for Content Loss (L1)') # SRGAN paper uses 1e-2 for L1/MSE when combined with VGG
parser.add_argument('--lambda_percep', type=float, default=1.0, help='Weight for Perceptual Loss') # SRGAN paper uses 1.0
parser.add_argument('--lambda_adv', type=float, default=0.001, help='Weight for Adversarial Loss') # SRGAN paper uses 1e-3
# --- Other Args ---
parser.add_argument('--save_dir', type=str, default='checkpoints', help='Directory to save model checkpoints')
parser.add_argument('--save_interval', type=int, default=10, help='Save checkpoint every N epochs')
parser.add_argument('--cpu', action='store_true', help='Force training on CPU')
# parser.add_argument('--load_checkpoint', type=str, default=None, help='Path to checkpoint file to resume training') # Example for adding resume functionality
args = parser.parse_args()
# --- Set and Validate Directories ---
# Auto-set LR directory based on scale IF it wasn't provided via command line
if args.lr_dir is None:
args.lr_dir = f'./datasets/DIV2K/DIV2K_train_LR_bicubic/X{args.scale}'
print(f"LR directory not provided, automatically setting based on scale {args.scale} to: {args.lr_dir}")
# Validate HR directory
if not os.path.isdir(args.hr_dir):
print(f"\nERROR: High-Resolution directory not found at '{args.hr_dir}'")
print("Please ensure the directory exists or provide the correct path using --hr_dir.")
exit(1) # Exit if the directory is invalid
# Validate LR directory
if not os.path.isdir(args.lr_dir):
print(f"\nERROR: Low-Resolution directory not found at '{args.lr_dir}'")
print(f"Please ensure the directory exists (check scale factor {args.scale}?) or provide the correct path using --lr_dir.")
exit(1) # Exit if the directory is invalid
print("\n--- Training Configuration ---")
# Print configuration cleanly
config_dict = vars(args)
# Calculate terminal width for better formatting (optional)
try:
term_width = os.get_terminal_size().columns
except OSError:
term_width = 80 # Default if terminal size unavailable
print("-" * term_width)
for key, value in config_dict.items():
print(f"{key:<25}: {value}") # Format for alignment
print("-" * term_width)
# Start the training process
train(args)