metadata
license: mit
language:
- en
library_name: pytorch
tags:
- gan
- generative-adversarial-network
- conditional-gan
- cgan
- image-generation
- fashion-mnist
- deep-learning
- computer-vision
datasets:
- fashion_mnist
pipeline_tag: image-to-image
GAN & Conditional GAN for Fashion Image Generation
A PyTorch implementation of Generative Adversarial Networks (GAN) and Conditional GANs (cGAN) trained on FashionMNIST for generating fashion item images.
Model Description
This project implements two types of generative models:
- Standard GAN: Generates random fashion item images from noise
- Conditional GAN (cGAN): Generates specific fashion item categories on demand
Both models generate 28x28 grayscale images of fashion items.
Intended Uses
- Educational: Learning about GANs and generative models
- Research: Experimenting with latent space exploration
- Creative: Generating synthetic fashion item images
- Data Augmentation: Creating additional training samples
Training Data
Dataset: FashionMNIST (via torchvision)
| Split | Images |
|---|---|
| Train | 60,000 |
| Test | 10,000 |
| Total | 70,000 |
Fashion Categories (10 classes)
| Index | Category |
|---|---|
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |
Model Architecture
Standard GAN
Generator
Input: z ~ N(0,1) of dimension 100
Linear(100 β 256*7*7) + BatchNorm + ReLU
β
Reshape to (256, 7, 7)
β
ConvTranspose2d(256 β 128, k=4, s=2, p=1) + BatchNorm + ReLU β (128, 14, 14)
β
ConvTranspose2d(128 β 64, k=4, s=2, p=1) + BatchNorm + ReLU β (64, 28, 28)
β
Conv2d(64 β 1, k=3, s=1, p=1) + Tanh β (1, 28, 28)
Output: Image in [-1, 1]
Discriminator
Input: Image (1, 28, 28)
Conv2d(1 β 64, k=4, s=2, p=1) + LeakyReLU(0.2) β (64, 14, 14)
β
Conv2d(64 β 128, k=4, s=2, p=1) + BatchNorm + LeakyReLU(0.2) β (128, 7, 7)
β
Flatten β Linear(128*7*7 β 1) + Sigmoid
Output: Probability [0, 1] (real vs fake)
Conditional GAN (cGAN)
Conditional Generator
Input: z ~ N(0,1) of dimension 100 + class label (0-9)
Label Embedding(10 β 50)
β
Concatenate [z, embedding] β (150,)
β
Linear(150 β 256*7*7) + BatchNorm + ReLU
β
[Same architecture as standard Generator]
Output: Class-conditioned image in [-1, 1]
Conditional Discriminator
Input: Image (1, 28, 28) + class label (0-9)
Label Embedding(10 β 28*28)
β
Reshape embedding to (1, 28, 28)
β
Concatenate [image, label_map] β (2, 28, 28)
β
[Modified Discriminator with 2 input channels]
Output: Probability [0, 1] (real vs fake for given class)
Training Configuration
| Parameter | Value |
|---|---|
| Latent Dimension | 100 |
| Batch Size | 64 |
| Learning Rate (G) | 0.0002 |
| Learning Rate (D) | 0.0002 |
| Optimizer | Adam (Ξ²1=0.5, Ξ²2=0.999) |
| Loss Function | Binary Cross-Entropy |
| Image Normalization | [-1, 1] |
| Label Embedding Dim | 50 (cGAN) |
Usage
Installation
pip install torch torchvision numpy matplotlib tqdm
Generate Random Images (Standard GAN)
import torch
# Load generator
G = Generator(latent_dim=100, channels=1)
G.load_state_dict(torch.load('generator.pth'))
G.eval()
# Generate images
with torch.no_grad():
z = torch.randn(16, 100) # 16 random images
fake_images = G(z)
# Denormalize: [-1, 1] β [0, 1]
fake_images = (fake_images + 1) / 2
Generate Specific Category (cGAN)
import torch
# Load conditional generator
cG = ConditionalGenerator(latent_dim=100, num_classes=10, embedding_dim=50)
cG.load_state_dict(torch.load('conditional_generator.pth'))
cG.eval()
# Generate 8 sneakers (class 7)
with torch.no_grad():
z = torch.randn(8, 100)
labels = torch.full((8,), 7, dtype=torch.long) # 7 = Sneaker
fake_sneakers = cG(z, labels)
# Denormalize
fake_sneakers = (fake_sneakers + 1) / 2
Latent Space Interpolation
def interpolate_latent(G, z1, z2, steps=10):
"""Smoothly interpolate between two latent vectors"""
images = []
for alpha in torch.linspace(0, 1, steps):
z = (1 - alpha) * z1 + alpha * z2
with torch.no_grad():
img = G(z.unsqueeze(0))
images.append(img)
return torch.cat(images)
Class Interpolation (cGAN)
def interpolate_between_classes(cG, class1, class2, steps=10):
"""Interpolate between two fashion categories"""
z = torch.randn(1, 100) # Fixed noise
emb1 = cG.label_embedding(torch.tensor([class1]))
emb2 = cG.label_embedding(torch.tensor([class2]))
images = []
for alpha in torch.linspace(0, 1, steps):
emb = (1 - alpha) * emb1 + alpha * emb2
# Manual forward with interpolated embedding
# ...
return images
Features
- Image Generation: Create synthetic fashion images
- Class Control: Generate specific categories with cGAN
- Latent Exploration: Interpolate smoothly in latent space
- Dimension Analysis: Understand what each latent dimension controls
- Training Visualization: Monitor GAN training progress
Limitations
- Low Resolution: Only generates 28x28 grayscale images
- Limited Dataset: Fashion items only, no real-world complexity
- Mode Collapse Risk: Standard GAN training challenges
- Quality: Not suitable for production-quality image generation
Technical Specifications
Dependencies
torch>=1.9.0
torchvision>=0.10.0
numpy>=1.20.0
matplotlib>=3.4.0
tqdm>=4.60.0
seaborn>=0.11.0
Hardware
- Training: GPU recommended (CUDA)
- Training Time: ~30 min on GPU, ~2-3 hours on CPU
- Inference: CPU sufficient (< 50ms per batch of 64)
- Memory: ~2GB GPU memory
Citation
@article{goodfellow2014generative,
title={Generative adversarial nets},
author={Goodfellow, Ian and others},
journal={NeurIPS},
year={2014}
}
@article{mirza2014conditional,
title={Conditional generative adversarial nets},
author={Mirza, Mehdi and Osindero, Simon},
journal={arXiv:1411.1784},
year={2014}
}
@online{fashionmnist,
author={Xiao, Han and Rasul, Kashif and Vollgraf, Roland},
title={Fashion-MNIST},
year={2017},
url={https://github.com/zalandoresearch/fashion-mnist}
}
License
MIT License
Acknowledgments
- Original GAN paper by Goodfellow et al.
- Conditional GAN paper by Mirza & Osindero
- FashionMNIST by Zalando Research
- PyTorch team for the deep learning framework