import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader from models import Generator, Discriminator import os # Hyperparameters latent_dim = 100 batch_size = 64 n_epochs = 200 lr = 0.0002 beta1 = 0.5 # Create directory for saving images os.makedirs('images', exist_ok=True) # Configure data loader transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Initialize generator and discriminator generator = Generator(latent_dim=latent_dim) discriminator = Discriminator() # Loss function adversarial_loss = nn.BCELoss() # Optimizers g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999)) d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999)) # Check if CUDA is available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') generator.to(device) discriminator.to(device) adversarial_loss.to(device) print(f'Starting training on {device}...') # Training loop for epoch in range(n_epochs): for i, (real_imgs, _) in enumerate(dataloader): batch_size = real_imgs.shape[0] # Ground truths valid = torch.ones(batch_size, 1).to(device) fake = torch.zeros(batch_size, 1).to(device) # Configure input real_imgs = real_imgs.to(device) # ----------------- # Train Generator # ----------------- g_optimizer.zero_grad() # Sample noise as generator input z = torch.randn(batch_size, latent_dim).to(device) # Generate a batch of images gen_imgs = generator(z) # Loss measures generator's ability to fool the discriminator g_loss = adversarial_loss(discriminator(gen_imgs), valid) g_loss.backward() g_optimizer.step() # --------------------- # Train Discriminator # --------------------- d_optimizer.zero_grad() # Measure discriminator's ability to classify real from generated samples real_loss = adversarial_loss(discriminator(real_imgs), valid) fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() d_optimizer.step() if i % 100 == 0: print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] ' f'[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]') # Save generated images at the end of each epoch if epoch % 10 == 0: with torch.no_grad(): z = torch.randn(16, latent_dim).to(device) gen_imgs = generator(z) torch.save(gen_imgs, f'images/epoch_{epoch}.pt') print('Training finished!')