File size: 3,051 Bytes
cd18d6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from models import Generator, Discriminator
import os
# Hyperparameters
latent_dim = 100
batch_size = 64
n_epochs = 200
lr = 0.0002
beta1 = 0.5
# Create directory for saving images
os.makedirs('images', exist_ok=True)
# Configure data loader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Initialize generator and discriminator
generator = Generator(latent_dim=latent_dim)
discriminator = Discriminator()
# Loss function
adversarial_loss = nn.BCELoss()
# Optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator.to(device)
discriminator.to(device)
adversarial_loss.to(device)
print(f'Starting training on {device}...')
# Training loop
for epoch in range(n_epochs):
for i, (real_imgs, _) in enumerate(dataloader):
batch_size = real_imgs.shape[0]
# Ground truths
valid = torch.ones(batch_size, 1).to(device)
fake = torch.zeros(batch_size, 1).to(device)
# Configure input
real_imgs = real_imgs.to(device)
# -----------------
# Train Generator
# -----------------
g_optimizer.zero_grad()
# Sample noise as generator input
z = torch.randn(batch_size, latent_dim).to(device)
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
g_optimizer.step()
# ---------------------
# Train Discriminator
# ---------------------
d_optimizer.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
d_optimizer.step()
if i % 100 == 0:
print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] '
f'[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]')
# Save generated images at the end of each epoch
if epoch % 10 == 0:
with torch.no_grad():
z = torch.randn(16, latent_dim).to(device)
gen_imgs = generator(z)
torch.save(gen_imgs, f'images/epoch_{epoch}.pt')
print('Training finished!') |