Conditional Generative Adversarial Nets
Paper
β’
1411.1784
β’
Published
β’
1
A PyTorch implementation of Generative Adversarial Networks (GAN) and Conditional GANs (cGAN) trained on FashionMNIST for generating fashion item images.
This project implements two types of generative models:
Both models generate 28x28 grayscale images of fashion items.
Dataset: FashionMNIST (via torchvision)
| Split | Images |
|---|---|
| Train | 60,000 |
| Test | 10,000 |
| Total | 70,000 |
| Index | Category |
|---|---|
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |
Input: z ~ N(0,1) of dimension 100
Linear(100 β 256*7*7) + BatchNorm + ReLU
β
Reshape to (256, 7, 7)
β
ConvTranspose2d(256 β 128, k=4, s=2, p=1) + BatchNorm + ReLU β (128, 14, 14)
β
ConvTranspose2d(128 β 64, k=4, s=2, p=1) + BatchNorm + ReLU β (64, 28, 28)
β
Conv2d(64 β 1, k=3, s=1, p=1) + Tanh β (1, 28, 28)
Output: Image in [-1, 1]
Input: Image (1, 28, 28)
Conv2d(1 β 64, k=4, s=2, p=1) + LeakyReLU(0.2) β (64, 14, 14)
β
Conv2d(64 β 128, k=4, s=2, p=1) + BatchNorm + LeakyReLU(0.2) β (128, 7, 7)
β
Flatten β Linear(128*7*7 β 1) + Sigmoid
Output: Probability [0, 1] (real vs fake)
Input: z ~ N(0,1) of dimension 100 + class label (0-9)
Label Embedding(10 β 50)
β
Concatenate [z, embedding] β (150,)
β
Linear(150 β 256*7*7) + BatchNorm + ReLU
β
[Same architecture as standard Generator]
Output: Class-conditioned image in [-1, 1]
Input: Image (1, 28, 28) + class label (0-9)
Label Embedding(10 β 28*28)
β
Reshape embedding to (1, 28, 28)
β
Concatenate [image, label_map] β (2, 28, 28)
β
[Modified Discriminator with 2 input channels]
Output: Probability [0, 1] (real vs fake for given class)
| Parameter | Value |
|---|---|
| Latent Dimension | 100 |
| Batch Size | 64 |
| Learning Rate (G) | 0.0002 |
| Learning Rate (D) | 0.0002 |
| Optimizer | Adam (Ξ²1=0.5, Ξ²2=0.999) |
| Loss Function | Binary Cross-Entropy |
| Image Normalization | [-1, 1] |
| Label Embedding Dim | 50 (cGAN) |
pip install torch torchvision numpy matplotlib tqdm
import torch
# Load generator
G = Generator(latent_dim=100, channels=1)
G.load_state_dict(torch.load('generator.pth'))
G.eval()
# Generate images
with torch.no_grad():
z = torch.randn(16, 100) # 16 random images
fake_images = G(z)
# Denormalize: [-1, 1] β [0, 1]
fake_images = (fake_images + 1) / 2
import torch
# Load conditional generator
cG = ConditionalGenerator(latent_dim=100, num_classes=10, embedding_dim=50)
cG.load_state_dict(torch.load('conditional_generator.pth'))
cG.eval()
# Generate 8 sneakers (class 7)
with torch.no_grad():
z = torch.randn(8, 100)
labels = torch.full((8,), 7, dtype=torch.long) # 7 = Sneaker
fake_sneakers = cG(z, labels)
# Denormalize
fake_sneakers = (fake_sneakers + 1) / 2
def interpolate_latent(G, z1, z2, steps=10):
"""Smoothly interpolate between two latent vectors"""
images = []
for alpha in torch.linspace(0, 1, steps):
z = (1 - alpha) * z1 + alpha * z2
with torch.no_grad():
img = G(z.unsqueeze(0))
images.append(img)
return torch.cat(images)
def interpolate_between_classes(cG, class1, class2, steps=10):
"""Interpolate between two fashion categories"""
z = torch.randn(1, 100) # Fixed noise
emb1 = cG.label_embedding(torch.tensor([class1]))
emb2 = cG.label_embedding(torch.tensor([class2]))
images = []
for alpha in torch.linspace(0, 1, steps):
emb = (1 - alpha) * emb1 + alpha * emb2
# Manual forward with interpolated embedding
# ...
return images
torch>=1.9.0
torchvision>=0.10.0
numpy>=1.20.0
matplotlib>=3.4.0
tqdm>=4.60.0
seaborn>=0.11.0
@article{goodfellow2014generative,
title={Generative adversarial nets},
author={Goodfellow, Ian and others},
journal={NeurIPS},
year={2014}
}
@article{mirza2014conditional,
title={Conditional generative adversarial nets},
author={Mirza, Mehdi and Osindero, Simon},
journal={arXiv:1411.1784},
year={2014}
}
@online{fashionmnist,
author={Xiao, Han and Rasul, Kashif and Vollgraf, Roland},
title={Fashion-MNIST},
year={2017},
url={https://github.com/zalandoresearch/fashion-mnist}
}
MIT License