import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader from models_conv import ConvGenerator, ConvDiscriminator import os from torch.utils.tensorboard import SummaryWriter # Hyperparameters latent_dim = 100 batch_size = 64 n_epochs = 200 lr = 0.00005 n_critic = 5 clip_value = 0.01 # Create directories os.makedirs('images', exist_ok=True) os.makedirs('checkpoints', exist_ok=True) # Initialize tensorboard writer = SummaryWriter('runs/wgan_training') # Configure data loader transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Initialize generator and discriminator generator = ConvGenerator(latent_dim=latent_dim) discriminator = ConvDiscriminator() # Optimizers g_optimizer = optim.RMSprop(generator.parameters(), lr=lr) d_optimizer = optim.RMSprop(discriminator.parameters(), lr=lr) # Check if CUDA is available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') generator.to(device) discriminator.to(device) print(f'Starting training on {device}...') # Training loop for epoch in range(n_epochs): for i, (real_imgs, _) in enumerate(dataloader): real_imgs = real_imgs.to(device) # --------------------- # Train Discriminator # --------------------- d_optimizer.zero_grad() # Sample noise as generator input z = torch.randn(real_imgs.size(0), latent_dim).to(device) # Generate a batch of images fake_imgs = generator(z).detach() # Compute discriminator loss d_loss = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs)) d_loss.backward() d_optimizer.step() # Clip weights of discriminator for p in discriminator.parameters(): p.data.clamp_(-clip_value, clip_value) # Train the generator every n_critic iterations if i % n_critic == 0: # ----------------- # Train Generator # ----------------- g_optimizer.zero_grad() # Generate a batch of images gen_imgs = generator(z) # Adversarial loss g_loss = -torch.mean(discriminator(gen_imgs)) g_loss.backward() g_optimizer.step() if i % 100 == 0: print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] ' f'[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]') # Log losses to tensorboard writer.add_scalar('D_loss', d_loss.item(), epoch * len(dataloader) + i) writer.add_scalar('G_loss', g_loss.item(), epoch * len(dataloader) + i) # Save checkpoints if epoch % 10 == 0: torch.save({ 'epoch': epoch, 'generator_state_dict': generator.state_dict(), 'discriminator_state_dict': discriminator.state_dict(), 'g_optimizer_state_dict': g_optimizer.state_dict(), 'd_optimizer_state_dict': d_optimizer.state_dict(), }, f'checkpoints/wgan_checkpoint_epoch_{epoch}.pt') # Save sample images with torch.no_grad(): z = torch.randn(16, latent_dim).to(device) gen_imgs = generator(z) for j, img in enumerate(gen_imgs): writer.add_image(f'generated_image_{j}', img, epoch) print('Training finished!') writer.close()