|
|
|
|
|
import torch |
|
|
import torch.optim as optim |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
import os |
|
|
import argparse |
|
|
from tqdm import tqdm |
|
|
import time |
|
|
|
|
|
|
|
|
from dataset import SRDataset |
|
|
from models import Generator, Discriminator |
|
|
from loss import PerceptualLoss |
|
|
|
|
|
def train(args): |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
os.makedirs(args.save_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
print("Loading dataset...") |
|
|
|
|
|
|
|
|
try: |
|
|
train_dataset = SRDataset(hr_dir=args.hr_dir, lr_dir=args.lr_dir, scale_factor=args.scale, patch_size_lr=args.patch_size) |
|
|
except FileNotFoundError as e: |
|
|
print(f"Error creating dataset: {e}") |
|
|
print("Please ensure the specified HR and LR directories contain correctly named image files.") |
|
|
exit(1) |
|
|
except Exception as e: |
|
|
print(f"An unexpected error occurred while creating the dataset: {e}") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
|
|
|
if args.subset > 0 and args.subset < len(train_dataset): |
|
|
print(f"Using a subset of {args.subset} images for training.") |
|
|
indices = torch.randperm(len(train_dataset))[:args.subset] |
|
|
train_dataset = torch.utils.data.Subset(train_dataset, indices) |
|
|
elif args.subset >= len(train_dataset) and len(train_dataset) > 0 : |
|
|
print(f"Subset size ({args.subset}) is >= dataset size ({len(train_dataset)}). Using full dataset.") |
|
|
|
|
|
|
|
|
if len(train_dataset) == 0: |
|
|
print(f"Error: Dataset is empty after attempting to load. Please check HR dir '{args.hr_dir}' and LR dir '{args.lr_dir}'") |
|
|
return |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=args.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=args.num_workers, |
|
|
pin_memory=True if device == 'cuda' else False |
|
|
) |
|
|
print(f"Dataset loaded: {len(train_dataset)} training images.") |
|
|
print(f"Dataloader: {len(train_loader)} batches per epoch.") |
|
|
|
|
|
|
|
|
|
|
|
print("Initializing models...") |
|
|
generator = Generator(scale_factor=args.scale, |
|
|
num_features=args.gen_features, |
|
|
num_res_blocks=args.gen_blocks).to(device) |
|
|
|
|
|
discriminator = Discriminator(in_channels=3, |
|
|
num_features_start=args.disc_features, |
|
|
num_blocks=args.disc_blocks).to(device) |
|
|
|
|
|
print(f"Generator params: {sum(p.numel() for p in generator.parameters()):,}") |
|
|
print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}") |
|
|
|
|
|
|
|
|
print("Initializing loss functions...") |
|
|
|
|
|
content_loss_criterion = nn.L1Loss().to(device) |
|
|
|
|
|
|
|
|
adversarial_loss_criterion = nn.BCEWithLogitsLoss().to(device) |
|
|
|
|
|
|
|
|
try: |
|
|
perceptual_loss_criterion = PerceptualLoss(device=device, use_l1=True) |
|
|
except Exception as e: |
|
|
print(f"Error initializing Perceptual Loss (check VGG weights download/torchvision install): {e}") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
|
|
|
print("Initializing optimizers...") |
|
|
optimizer_g = optim.Adam(generator.parameters(), lr=args.lr_gen, betas=(0.9, 0.999)) |
|
|
optimizer_d = optim.Adam(discriminator.parameters(), lr=args.lr_disc, betas=(0.9, 0.999)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n--- Starting Training ---") |
|
|
start_time = time.time() |
|
|
|
|
|
for epoch in range(1, args.epochs + 1): |
|
|
generator.train() |
|
|
discriminator.train() |
|
|
epoch_loss_g = 0.0 |
|
|
epoch_loss_d = 0.0 |
|
|
epoch_start_time = time.time() |
|
|
|
|
|
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=True) |
|
|
|
|
|
for batch_idx, batch in enumerate(progress_bar): |
|
|
|
|
|
if batch is None: |
|
|
print(f"Warning: Skipping problematic batch at index {batch_idx}") |
|
|
continue |
|
|
|
|
|
try: |
|
|
lr_images = batch['lr'].to(device) |
|
|
hr_images = batch['hr'].to(device) |
|
|
except KeyError as e: |
|
|
print(f"Error accessing batch data: {e}. Check SRDataset's __getitem__ return format.") |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
real_labels = torch.ones((hr_images.size(0), 1)).to(device) |
|
|
fake_labels = torch.zeros((hr_images.size(0), 1)).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer_d.zero_grad() |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
fake_sr_images = generator(lr_images) |
|
|
|
|
|
|
|
|
real_logits = discriminator(hr_images) |
|
|
loss_d_real = adversarial_loss_criterion(real_logits, real_labels) |
|
|
|
|
|
|
|
|
fake_logits = discriminator(fake_sr_images) |
|
|
loss_d_fake = adversarial_loss_criterion(fake_logits, fake_labels) |
|
|
|
|
|
|
|
|
loss_d = (loss_d_real + loss_d_fake) / 2 |
|
|
|
|
|
|
|
|
loss_d.backward() |
|
|
|
|
|
|
|
|
optimizer_d.step() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer_g.zero_grad() |
|
|
|
|
|
|
|
|
generated_sr_images = generator(lr_images) |
|
|
|
|
|
|
|
|
|
|
|
loss_content = content_loss_criterion(generated_sr_images, hr_images) |
|
|
|
|
|
|
|
|
loss_perceptual = perceptual_loss_criterion(generated_sr_images, hr_images) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generated_logits = discriminator(generated_sr_images) |
|
|
loss_adversarial = adversarial_loss_criterion(generated_logits, real_labels) |
|
|
|
|
|
|
|
|
|
|
|
loss_g = (args.lambda_content * loss_content + |
|
|
args.lambda_percep * loss_perceptual + |
|
|
args.lambda_adv * loss_adversarial) |
|
|
|
|
|
|
|
|
loss_g.backward() |
|
|
|
|
|
|
|
|
optimizer_g.step() |
|
|
|
|
|
|
|
|
epoch_loss_g += loss_g.item() |
|
|
epoch_loss_d += loss_d.item() |
|
|
progress_bar.set_postfix({ |
|
|
'Loss G': f"{loss_g.item():.4f}", |
|
|
'Loss D': f"{loss_d.item():.4f}", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}) |
|
|
|
|
|
|
|
|
avg_loss_g = epoch_loss_g / len(train_loader) if len(train_loader) > 0 else 0 |
|
|
avg_loss_d = epoch_loss_d / len(train_loader) if len(train_loader) > 0 else 0 |
|
|
epoch_time = time.time() - epoch_start_time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nEpoch {epoch}/{args.epochs} | Time: {epoch_time:.2f}s | Avg Loss G: {avg_loss_g:.4f} | Avg Loss D: {avg_loss_d:.4f}") |
|
|
|
|
|
|
|
|
if epoch % args.save_interval == 0 or epoch == args.epochs: |
|
|
gen_path = os.path.join(args.save_dir, f"generator_epoch_{epoch}.pth") |
|
|
disc_path = os.path.join(args.save_dir, f"discriminator_epoch_{epoch}.pth") |
|
|
try: |
|
|
torch.save(generator.state_dict(), gen_path) |
|
|
torch.save(discriminator.state_dict(), disc_path) |
|
|
print(f"Checkpoint saved for epoch {epoch} to '{args.save_dir}'") |
|
|
except Exception as e: |
|
|
print(f"Error saving checkpoint for epoch {epoch}: {e}") |
|
|
|
|
|
|
|
|
total_time = time.time() - start_time |
|
|
print(f"\n--- Training Finished ---") |
|
|
print(f"Total time: {total_time // 3600:.0f}h {(total_time % 3600) // 60:.0f}m {total_time % 60:.2f}s") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser(description='Train SRGAN Model') |
|
|
|
|
|
|
|
|
parser.add_argument('--hr_dir', type=str, |
|
|
default='./datasets/DIV2K/HR_extracted/DIV2K_train_HR', |
|
|
help='Path to high-resolution training images') |
|
|
parser.add_argument('--lr_dir', type=str, default=None, |
|
|
help='Path to low-resolution training images (auto-set if None)') |
|
|
parser.add_argument('--scale', type=int, default=4, help='Upscaling factor') |
|
|
parser.add_argument('--batch_size', type=int, default=16, help='Training batch size (reduce for CPU/low VRAM)') |
|
|
parser.add_argument('--subset', type=int, default=0, help='Use only N images for debugging (0 to use all)') |
|
|
parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for DataLoader (set to 0 for Mac/Windows usually)') |
|
|
parser.add_argument('--patch_size', type=int, default=48, help='Size (height/width) of LR patches for training') |
|
|
|
|
|
|
|
|
parser.add_argument('--gen_features', type=int, default=64, help='Number of features in Generator') |
|
|
parser.add_argument('--gen_blocks', type=int, default=16, help='Number of residual blocks in Generator (reduce for faster training/less memory)') |
|
|
parser.add_argument('--disc_features', type=int, default=64, help='Number of starting features in Discriminator') |
|
|
parser.add_argument('--disc_blocks', type=int, default=3, help='Number of conv blocks in Discriminator') |
|
|
|
|
|
|
|
|
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs') |
|
|
parser.add_argument('--lr_gen', type=float, default=1e-4, help='Learning rate for Generator') |
|
|
parser.add_argument('--lr_disc', type=float, default=1e-4, help='Learning rate for Discriminator') |
|
|
parser.add_argument('--lambda_content', type=float, default=0.01, help='Weight for Content Loss (L1)') |
|
|
parser.add_argument('--lambda_percep', type=float, default=1.0, help='Weight for Perceptual Loss') |
|
|
parser.add_argument('--lambda_adv', type=float, default=0.001, help='Weight for Adversarial Loss') |
|
|
|
|
|
|
|
|
parser.add_argument('--save_dir', type=str, default='checkpoints', help='Directory to save model checkpoints') |
|
|
parser.add_argument('--save_interval', type=int, default=10, help='Save checkpoint every N epochs') |
|
|
parser.add_argument('--cpu', action='store_true', help='Force training on CPU') |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
if args.lr_dir is None: |
|
|
args.lr_dir = f'./datasets/DIV2K/DIV2K_train_LR_bicubic/X{args.scale}' |
|
|
print(f"LR directory not provided, automatically setting based on scale {args.scale} to: {args.lr_dir}") |
|
|
|
|
|
|
|
|
if not os.path.isdir(args.hr_dir): |
|
|
print(f"\nERROR: High-Resolution directory not found at '{args.hr_dir}'") |
|
|
print("Please ensure the directory exists or provide the correct path using --hr_dir.") |
|
|
exit(1) |
|
|
|
|
|
if not os.path.isdir(args.lr_dir): |
|
|
print(f"\nERROR: Low-Resolution directory not found at '{args.lr_dir}'") |
|
|
print(f"Please ensure the directory exists (check scale factor {args.scale}?) or provide the correct path using --lr_dir.") |
|
|
exit(1) |
|
|
|
|
|
print("\n--- Training Configuration ---") |
|
|
|
|
|
config_dict = vars(args) |
|
|
|
|
|
try: |
|
|
term_width = os.get_terminal_size().columns |
|
|
except OSError: |
|
|
term_width = 80 |
|
|
|
|
|
print("-" * term_width) |
|
|
for key, value in config_dict.items(): |
|
|
print(f"{key:<25}: {value}") |
|
|
print("-" * term_width) |
|
|
|
|
|
|
|
|
|
|
|
train(args) |