File size: 3,051 Bytes
cd18d6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from models import Generator, Discriminator
import os

# Hyperparameters
latent_dim = 100
batch_size = 64
n_epochs = 200
lr = 0.0002
beta1 = 0.5

# Create directory for saving images
os.makedirs('images', exist_ok=True)

# Configure data loader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize generator and discriminator
generator = Generator(latent_dim=latent_dim)
discriminator = Discriminator()

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator.to(device)
discriminator.to(device)
adversarial_loss.to(device)

print(f'Starting training on {device}...')

# Training loop
for epoch in range(n_epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        batch_size = real_imgs.shape[0]
        
        # Ground truths
        valid = torch.ones(batch_size, 1).to(device)
        fake = torch.zeros(batch_size, 1).to(device)
        
        # Configure input
        real_imgs = real_imgs.to(device)
        
        # -----------------
        #  Train Generator
        # -----------------
        g_optimizer.zero_grad()
        
        # Sample noise as generator input
        z = torch.randn(batch_size, latent_dim).to(device)
        
        # Generate a batch of images
        gen_imgs = generator(z)
        
        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        
        g_loss.backward()
        g_optimizer.step()
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        d_optimizer.zero_grad()
        
        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        
        d_loss.backward()
        d_optimizer.step()
        
        if i % 100 == 0:
            print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] '
                  f'[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]')

    # Save generated images at the end of each epoch
    if epoch % 10 == 0:
        with torch.no_grad():
            z = torch.randn(16, latent_dim).to(device)
            gen_imgs = generator(z)
            torch.save(gen_imgs, f'images/epoch_{epoch}.pt')

print('Training finished!')