lolzysiu commited on
Commit
c971098
·
verified ·
1 Parent(s): 90a5888

Create train_wgan.py

Browse files
Files changed (1) hide show
  1. train_wgan.py +117 -0
train_wgan.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import datasets, transforms
5
+ from torch.utils.data import DataLoader
6
+ from models_conv import ConvGenerator, ConvDiscriminator
7
+ import os
8
+ from torch.utils.tensorboard import SummaryWriter
9
+
10
+ # Hyperparameters
11
+ latent_dim = 100
12
+ batch_size = 64
13
+ n_epochs = 200
14
+ lr = 0.00005
15
+ n_critic = 5
16
+ clip_value = 0.01
17
+
18
+ # Create directories
19
+ os.makedirs('images', exist_ok=True)
20
+ os.makedirs('checkpoints', exist_ok=True)
21
+
22
+ # Initialize tensorboard
23
+ writer = SummaryWriter('runs/wgan_training')
24
+
25
+ # Configure data loader
26
+ transform = transforms.Compose([
27
+ transforms.ToTensor(),
28
+ transforms.Normalize([0.5], [0.5])
29
+ ])
30
+
31
+ dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
32
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
33
+
34
+ # Initialize generator and discriminator
35
+ generator = ConvGenerator(latent_dim=latent_dim)
36
+ discriminator = ConvDiscriminator()
37
+
38
+ # Optimizers
39
+ g_optimizer = optim.RMSprop(generator.parameters(), lr=lr)
40
+ d_optimizer = optim.RMSprop(discriminator.parameters(), lr=lr)
41
+
42
+ # Check if CUDA is available
43
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
+ generator.to(device)
45
+ discriminator.to(device)
46
+
47
+ print(f'Starting training on {device}...')
48
+
49
+ # Training loop
50
+ for epoch in range(n_epochs):
51
+ for i, (real_imgs, _) in enumerate(dataloader):
52
+ real_imgs = real_imgs.to(device)
53
+
54
+ # ---------------------
55
+ # Train Discriminator
56
+ # ---------------------
57
+ d_optimizer.zero_grad()
58
+
59
+ # Sample noise as generator input
60
+ z = torch.randn(real_imgs.size(0), latent_dim).to(device)
61
+
62
+ # Generate a batch of images
63
+ fake_imgs = generator(z).detach()
64
+
65
+ # Compute discriminator loss
66
+ d_loss = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
67
+
68
+ d_loss.backward()
69
+ d_optimizer.step()
70
+
71
+ # Clip weights of discriminator
72
+ for p in discriminator.parameters():
73
+ p.data.clamp_(-clip_value, clip_value)
74
+
75
+ # Train the generator every n_critic iterations
76
+ if i % n_critic == 0:
77
+ # -----------------
78
+ # Train Generator
79
+ # -----------------
80
+ g_optimizer.zero_grad()
81
+
82
+ # Generate a batch of images
83
+ gen_imgs = generator(z)
84
+
85
+ # Adversarial loss
86
+ g_loss = -torch.mean(discriminator(gen_imgs))
87
+
88
+ g_loss.backward()
89
+ g_optimizer.step()
90
+
91
+ if i % 100 == 0:
92
+ print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] '
93
+ f'[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]')
94
+
95
+ # Log losses to tensorboard
96
+ writer.add_scalar('D_loss', d_loss.item(), epoch * len(dataloader) + i)
97
+ writer.add_scalar('G_loss', g_loss.item(), epoch * len(dataloader) + i)
98
+
99
+ # Save checkpoints
100
+ if epoch % 10 == 0:
101
+ torch.save({
102
+ 'epoch': epoch,
103
+ 'generator_state_dict': generator.state_dict(),
104
+ 'discriminator_state_dict': discriminator.state_dict(),
105
+ 'g_optimizer_state_dict': g_optimizer.state_dict(),
106
+ 'd_optimizer_state_dict': d_optimizer.state_dict(),
107
+ }, f'checkpoints/wgan_checkpoint_epoch_{epoch}.pt')
108
+
109
+ # Save sample images
110
+ with torch.no_grad():
111
+ z = torch.randn(16, latent_dim).to(device)
112
+ gen_imgs = generator(z)
113
+ for j, img in enumerate(gen_imgs):
114
+ writer.add_image(f'generated_image_{j}', img, epoch)
115
+
116
+ print('Training finished!')
117
+ writer.close()