|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torchvision import datasets, transforms |
|
|
from torch.utils.data import DataLoader |
|
|
from models import Generator, Discriminator |
|
|
import os |
|
|
|
|
|
|
|
|
latent_dim = 100 |
|
|
batch_size = 64 |
|
|
n_epochs = 200 |
|
|
lr = 0.0002 |
|
|
beta1 = 0.5 |
|
|
|
|
|
|
|
|
os.makedirs('images', exist_ok=True) |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5], [0.5]) |
|
|
]) |
|
|
|
|
|
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) |
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
|
|
|
generator = Generator(latent_dim=latent_dim) |
|
|
discriminator = Discriminator() |
|
|
|
|
|
|
|
|
adversarial_loss = nn.BCELoss() |
|
|
|
|
|
|
|
|
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999)) |
|
|
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999)) |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
generator.to(device) |
|
|
discriminator.to(device) |
|
|
adversarial_loss.to(device) |
|
|
|
|
|
print(f'Starting training on {device}...') |
|
|
|
|
|
|
|
|
for epoch in range(n_epochs): |
|
|
for i, (real_imgs, _) in enumerate(dataloader): |
|
|
batch_size = real_imgs.shape[0] |
|
|
|
|
|
|
|
|
valid = torch.ones(batch_size, 1).to(device) |
|
|
fake = torch.zeros(batch_size, 1).to(device) |
|
|
|
|
|
|
|
|
real_imgs = real_imgs.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
g_optimizer.zero_grad() |
|
|
|
|
|
|
|
|
z = torch.randn(batch_size, latent_dim).to(device) |
|
|
|
|
|
|
|
|
gen_imgs = generator(z) |
|
|
|
|
|
|
|
|
g_loss = adversarial_loss(discriminator(gen_imgs), valid) |
|
|
|
|
|
g_loss.backward() |
|
|
g_optimizer.step() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
d_optimizer.zero_grad() |
|
|
|
|
|
|
|
|
real_loss = adversarial_loss(discriminator(real_imgs), valid) |
|
|
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) |
|
|
d_loss = (real_loss + fake_loss) / 2 |
|
|
|
|
|
d_loss.backward() |
|
|
d_optimizer.step() |
|
|
|
|
|
if i % 100 == 0: |
|
|
print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] ' |
|
|
f'[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]') |
|
|
|
|
|
|
|
|
if epoch % 10 == 0: |
|
|
with torch.no_grad(): |
|
|
z = torch.randn(16, latent_dim).to(device) |
|
|
gen_imgs = generator(z) |
|
|
torch.save(gen_imgs, f'images/epoch_{epoch}.pt') |
|
|
|
|
|
print('Training finished!') |