|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torchvision import datasets, transforms |
|
|
from torch.utils.data import DataLoader |
|
|
from models_conv import ConvGenerator, ConvDiscriminator |
|
|
import os |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
|
|
|
latent_dim = 100 |
|
|
batch_size = 64 |
|
|
n_epochs = 200 |
|
|
lr = 0.00005 |
|
|
n_critic = 5 |
|
|
clip_value = 0.01 |
|
|
|
|
|
|
|
|
os.makedirs('images', exist_ok=True) |
|
|
os.makedirs('checkpoints', exist_ok=True) |
|
|
|
|
|
|
|
|
writer = SummaryWriter('runs/wgan_training') |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5], [0.5]) |
|
|
]) |
|
|
|
|
|
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) |
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
|
|
|
generator = ConvGenerator(latent_dim=latent_dim) |
|
|
discriminator = ConvDiscriminator() |
|
|
|
|
|
|
|
|
g_optimizer = optim.RMSprop(generator.parameters(), lr=lr) |
|
|
d_optimizer = optim.RMSprop(discriminator.parameters(), lr=lr) |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
generator.to(device) |
|
|
discriminator.to(device) |
|
|
|
|
|
print(f'Starting training on {device}...') |
|
|
|
|
|
|
|
|
for epoch in range(n_epochs): |
|
|
for i, (real_imgs, _) in enumerate(dataloader): |
|
|
real_imgs = real_imgs.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
d_optimizer.zero_grad() |
|
|
|
|
|
|
|
|
z = torch.randn(real_imgs.size(0), latent_dim).to(device) |
|
|
|
|
|
|
|
|
fake_imgs = generator(z).detach() |
|
|
|
|
|
|
|
|
d_loss = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs)) |
|
|
|
|
|
d_loss.backward() |
|
|
d_optimizer.step() |
|
|
|
|
|
|
|
|
for p in discriminator.parameters(): |
|
|
p.data.clamp_(-clip_value, clip_value) |
|
|
|
|
|
|
|
|
if i % n_critic == 0: |
|
|
|
|
|
|
|
|
|
|
|
g_optimizer.zero_grad() |
|
|
|
|
|
|
|
|
gen_imgs = generator(z) |
|
|
|
|
|
|
|
|
g_loss = -torch.mean(discriminator(gen_imgs)) |
|
|
|
|
|
g_loss.backward() |
|
|
g_optimizer.step() |
|
|
|
|
|
if i % 100 == 0: |
|
|
print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] ' |
|
|
f'[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]') |
|
|
|
|
|
|
|
|
writer.add_scalar('D_loss', d_loss.item(), epoch * len(dataloader) + i) |
|
|
writer.add_scalar('G_loss', g_loss.item(), epoch * len(dataloader) + i) |
|
|
|
|
|
|
|
|
if epoch % 10 == 0: |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'generator_state_dict': generator.state_dict(), |
|
|
'discriminator_state_dict': discriminator.state_dict(), |
|
|
'g_optimizer_state_dict': g_optimizer.state_dict(), |
|
|
'd_optimizer_state_dict': d_optimizer.state_dict(), |
|
|
}, f'checkpoints/wgan_checkpoint_epoch_{epoch}.pt') |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
z = torch.randn(16, latent_dim).to(device) |
|
|
gen_imgs = generator(z) |
|
|
for j, img in enumerate(gen_imgs): |
|
|
writer.add_image(f'generated_image_{j}', img, epoch) |
|
|
|
|
|
print('Training finished!') |
|
|
writer.close() |