hellooooo / train.py
lolzysiu's picture
Create train.py
cd18d6b verified
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from models import Generator, Discriminator
import os
# Hyperparameters
latent_dim = 100
batch_size = 64
n_epochs = 200
lr = 0.0002
beta1 = 0.5
# Create directory for saving images
os.makedirs('images', exist_ok=True)
# Configure data loader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Initialize generator and discriminator
generator = Generator(latent_dim=latent_dim)
discriminator = Discriminator()
# Loss function
adversarial_loss = nn.BCELoss()
# Optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator.to(device)
discriminator.to(device)
adversarial_loss.to(device)
print(f'Starting training on {device}...')
# Training loop
for epoch in range(n_epochs):
for i, (real_imgs, _) in enumerate(dataloader):
batch_size = real_imgs.shape[0]
# Ground truths
valid = torch.ones(batch_size, 1).to(device)
fake = torch.zeros(batch_size, 1).to(device)
# Configure input
real_imgs = real_imgs.to(device)
# -----------------
# Train Generator
# -----------------
g_optimizer.zero_grad()
# Sample noise as generator input
z = torch.randn(batch_size, latent_dim).to(device)
# Generate a batch of images
gen_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
g_optimizer.step()
# ---------------------
# Train Discriminator
# ---------------------
d_optimizer.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
d_optimizer.step()
if i % 100 == 0:
print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] '
f'[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]')
# Save generated images at the end of each epoch
if epoch % 10 == 0:
with torch.no_grad():
z = torch.randn(16, latent_dim).to(device)
gen_imgs = generator(z)
torch.save(gen_imgs, f'images/epoch_{epoch}.pt')
print('Training finished!')