lolzysiu commited on
Commit
cd18d6b
·
verified ·
1 Parent(s): c971098

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +100 -0
train.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import datasets, transforms
5
+ from torch.utils.data import DataLoader
6
+ from models import Generator, Discriminator
7
+ import os
8
+
9
+ # Hyperparameters
10
+ latent_dim = 100
11
+ batch_size = 64
12
+ n_epochs = 200
13
+ lr = 0.0002
14
+ beta1 = 0.5
15
+
16
+ # Create directory for saving images
17
+ os.makedirs('images', exist_ok=True)
18
+
19
+ # Configure data loader
20
+ transform = transforms.Compose([
21
+ transforms.ToTensor(),
22
+ transforms.Normalize([0.5], [0.5])
23
+ ])
24
+
25
+ dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
26
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
27
+
28
+ # Initialize generator and discriminator
29
+ generator = Generator(latent_dim=latent_dim)
30
+ discriminator = Discriminator()
31
+
32
+ # Loss function
33
+ adversarial_loss = nn.BCELoss()
34
+
35
+ # Optimizers
36
+ g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
37
+ d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
38
+
39
+ # Check if CUDA is available
40
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
+ generator.to(device)
42
+ discriminator.to(device)
43
+ adversarial_loss.to(device)
44
+
45
+ print(f'Starting training on {device}...')
46
+
47
+ # Training loop
48
+ for epoch in range(n_epochs):
49
+ for i, (real_imgs, _) in enumerate(dataloader):
50
+ batch_size = real_imgs.shape[0]
51
+
52
+ # Ground truths
53
+ valid = torch.ones(batch_size, 1).to(device)
54
+ fake = torch.zeros(batch_size, 1).to(device)
55
+
56
+ # Configure input
57
+ real_imgs = real_imgs.to(device)
58
+
59
+ # -----------------
60
+ # Train Generator
61
+ # -----------------
62
+ g_optimizer.zero_grad()
63
+
64
+ # Sample noise as generator input
65
+ z = torch.randn(batch_size, latent_dim).to(device)
66
+
67
+ # Generate a batch of images
68
+ gen_imgs = generator(z)
69
+
70
+ # Loss measures generator's ability to fool the discriminator
71
+ g_loss = adversarial_loss(discriminator(gen_imgs), valid)
72
+
73
+ g_loss.backward()
74
+ g_optimizer.step()
75
+
76
+ # ---------------------
77
+ # Train Discriminator
78
+ # ---------------------
79
+ d_optimizer.zero_grad()
80
+
81
+ # Measure discriminator's ability to classify real from generated samples
82
+ real_loss = adversarial_loss(discriminator(real_imgs), valid)
83
+ fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
84
+ d_loss = (real_loss + fake_loss) / 2
85
+
86
+ d_loss.backward()
87
+ d_optimizer.step()
88
+
89
+ if i % 100 == 0:
90
+ print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] '
91
+ f'[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]')
92
+
93
+ # Save generated images at the end of each epoch
94
+ if epoch % 10 == 0:
95
+ with torch.no_grad():
96
+ z = torch.randn(16, latent_dim).to(device)
97
+ gen_imgs = generator(z)
98
+ torch.save(gen_imgs, f'images/epoch_{epoch}.pt')
99
+
100
+ print('Training finished!')