hellooooo / train_wgan.py
lolzysiu's picture
Create train_wgan.py
c971098 verified
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from models_conv import ConvGenerator, ConvDiscriminator
import os
from torch.utils.tensorboard import SummaryWriter
# Hyperparameters
latent_dim = 100
batch_size = 64
n_epochs = 200
lr = 0.00005
n_critic = 5
clip_value = 0.01
# Create directories
os.makedirs('images', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
# Initialize tensorboard
writer = SummaryWriter('runs/wgan_training')
# Configure data loader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Initialize generator and discriminator
generator = ConvGenerator(latent_dim=latent_dim)
discriminator = ConvDiscriminator()
# Optimizers
g_optimizer = optim.RMSprop(generator.parameters(), lr=lr)
d_optimizer = optim.RMSprop(discriminator.parameters(), lr=lr)
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator.to(device)
discriminator.to(device)
print(f'Starting training on {device}...')
# Training loop
for epoch in range(n_epochs):
for i, (real_imgs, _) in enumerate(dataloader):
real_imgs = real_imgs.to(device)
# ---------------------
# Train Discriminator
# ---------------------
d_optimizer.zero_grad()
# Sample noise as generator input
z = torch.randn(real_imgs.size(0), latent_dim).to(device)
# Generate a batch of images
fake_imgs = generator(z).detach()
# Compute discriminator loss
d_loss = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
d_loss.backward()
d_optimizer.step()
# Clip weights of discriminator
for p in discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)
# Train the generator every n_critic iterations
if i % n_critic == 0:
# -----------------
# Train Generator
# -----------------
g_optimizer.zero_grad()
# Generate a batch of images
gen_imgs = generator(z)
# Adversarial loss
g_loss = -torch.mean(discriminator(gen_imgs))
g_loss.backward()
g_optimizer.step()
if i % 100 == 0:
print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] '
f'[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]')
# Log losses to tensorboard
writer.add_scalar('D_loss', d_loss.item(), epoch * len(dataloader) + i)
writer.add_scalar('G_loss', g_loss.item(), epoch * len(dataloader) + i)
# Save checkpoints
if epoch % 10 == 0:
torch.save({
'epoch': epoch,
'generator_state_dict': generator.state_dict(),
'discriminator_state_dict': discriminator.state_dict(),
'g_optimizer_state_dict': g_optimizer.state_dict(),
'd_optimizer_state_dict': d_optimizer.state_dict(),
}, f'checkpoints/wgan_checkpoint_epoch_{epoch}.pt')
# Save sample images
with torch.no_grad():
z = torch.randn(16, latent_dim).to(device)
gen_imgs = generator(z)
for j, img in enumerate(gen_imgs):
writer.add_image(f'generated_image_{j}', img, epoch)
print('Training finished!')
writer.close()