GAN & Conditional GAN for Fashion Image Generation

A PyTorch implementation of Generative Adversarial Networks (GAN) and Conditional GANs (cGAN) trained on FashionMNIST for generating fashion item images.

Model Description

This project implements two types of generative models:

  1. Standard GAN: Generates random fashion item images from noise
  2. Conditional GAN (cGAN): Generates specific fashion item categories on demand

Both models generate 28x28 grayscale images of fashion items.

Intended Uses

  • Educational: Learning about GANs and generative models
  • Research: Experimenting with latent space exploration
  • Creative: Generating synthetic fashion item images
  • Data Augmentation: Creating additional training samples

Training Data

Dataset: FashionMNIST (via torchvision)

Split Images
Train 60,000
Test 10,000
Total 70,000

Fashion Categories (10 classes)

Index Category
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

Model Architecture

Standard GAN

Generator

Input: z ~ N(0,1) of dimension 100

Linear(100 β†’ 256*7*7) + BatchNorm + ReLU
    ↓
Reshape to (256, 7, 7)
    ↓
ConvTranspose2d(256 β†’ 128, k=4, s=2, p=1) + BatchNorm + ReLU  β†’ (128, 14, 14)
    ↓
ConvTranspose2d(128 β†’ 64, k=4, s=2, p=1) + BatchNorm + ReLU   β†’ (64, 28, 28)
    ↓
Conv2d(64 β†’ 1, k=3, s=1, p=1) + Tanh                          β†’ (1, 28, 28)

Output: Image in [-1, 1]

Discriminator

Input: Image (1, 28, 28)

Conv2d(1 β†’ 64, k=4, s=2, p=1) + LeakyReLU(0.2)      β†’ (64, 14, 14)
    ↓
Conv2d(64 β†’ 128, k=4, s=2, p=1) + BatchNorm + LeakyReLU(0.2)  β†’ (128, 7, 7)
    ↓
Flatten β†’ Linear(128*7*7 β†’ 1) + Sigmoid

Output: Probability [0, 1] (real vs fake)

Conditional GAN (cGAN)

Conditional Generator

Input: z ~ N(0,1) of dimension 100 + class label (0-9)

Label Embedding(10 β†’ 50)
    ↓
Concatenate [z, embedding] β†’ (150,)
    ↓
Linear(150 β†’ 256*7*7) + BatchNorm + ReLU
    ↓
[Same architecture as standard Generator]

Output: Class-conditioned image in [-1, 1]

Conditional Discriminator

Input: Image (1, 28, 28) + class label (0-9)

Label Embedding(10 β†’ 28*28)
    ↓
Reshape embedding to (1, 28, 28)
    ↓
Concatenate [image, label_map] β†’ (2, 28, 28)
    ↓
[Modified Discriminator with 2 input channels]

Output: Probability [0, 1] (real vs fake for given class)

Training Configuration

Parameter Value
Latent Dimension 100
Batch Size 64
Learning Rate (G) 0.0002
Learning Rate (D) 0.0002
Optimizer Adam (Ξ²1=0.5, Ξ²2=0.999)
Loss Function Binary Cross-Entropy
Image Normalization [-1, 1]
Label Embedding Dim 50 (cGAN)

Usage

Installation

pip install torch torchvision numpy matplotlib tqdm

Generate Random Images (Standard GAN)

import torch

# Load generator
G = Generator(latent_dim=100, channels=1)
G.load_state_dict(torch.load('generator.pth'))
G.eval()

# Generate images
with torch.no_grad():
    z = torch.randn(16, 100)  # 16 random images
    fake_images = G(z)

# Denormalize: [-1, 1] β†’ [0, 1]
fake_images = (fake_images + 1) / 2

Generate Specific Category (cGAN)

import torch

# Load conditional generator
cG = ConditionalGenerator(latent_dim=100, num_classes=10, embedding_dim=50)
cG.load_state_dict(torch.load('conditional_generator.pth'))
cG.eval()

# Generate 8 sneakers (class 7)
with torch.no_grad():
    z = torch.randn(8, 100)
    labels = torch.full((8,), 7, dtype=torch.long)  # 7 = Sneaker
    fake_sneakers = cG(z, labels)

# Denormalize
fake_sneakers = (fake_sneakers + 1) / 2

Latent Space Interpolation

def interpolate_latent(G, z1, z2, steps=10):
    """Smoothly interpolate between two latent vectors"""
    images = []
    for alpha in torch.linspace(0, 1, steps):
        z = (1 - alpha) * z1 + alpha * z2
        with torch.no_grad():
            img = G(z.unsqueeze(0))
        images.append(img)
    return torch.cat(images)

Class Interpolation (cGAN)

def interpolate_between_classes(cG, class1, class2, steps=10):
    """Interpolate between two fashion categories"""
    z = torch.randn(1, 100)  # Fixed noise

    emb1 = cG.label_embedding(torch.tensor([class1]))
    emb2 = cG.label_embedding(torch.tensor([class2]))

    images = []
    for alpha in torch.linspace(0, 1, steps):
        emb = (1 - alpha) * emb1 + alpha * emb2
        # Manual forward with interpolated embedding
        # ...
    return images

Features

  • Image Generation: Create synthetic fashion images
  • Class Control: Generate specific categories with cGAN
  • Latent Exploration: Interpolate smoothly in latent space
  • Dimension Analysis: Understand what each latent dimension controls
  • Training Visualization: Monitor GAN training progress

Limitations

  • Low Resolution: Only generates 28x28 grayscale images
  • Limited Dataset: Fashion items only, no real-world complexity
  • Mode Collapse Risk: Standard GAN training challenges
  • Quality: Not suitable for production-quality image generation

Technical Specifications

Dependencies

torch>=1.9.0
torchvision>=0.10.0
numpy>=1.20.0
matplotlib>=3.4.0
tqdm>=4.60.0
seaborn>=0.11.0

Hardware

  • Training: GPU recommended (CUDA)
  • Training Time: ~30 min on GPU, ~2-3 hours on CPU
  • Inference: CPU sufficient (< 50ms per batch of 64)
  • Memory: ~2GB GPU memory

Citation

@article{goodfellow2014generative,
  title={Generative adversarial nets},
  author={Goodfellow, Ian and others},
  journal={NeurIPS},
  year={2014}
}

@article{mirza2014conditional,
  title={Conditional generative adversarial nets},
  author={Mirza, Mehdi and Osindero, Simon},
  journal={arXiv:1411.1784},
  year={2014}
}

@online{fashionmnist,
  author={Xiao, Han and Rasul, Kashif and Vollgraf, Roland},
  title={Fashion-MNIST},
  year={2017},
  url={https://github.com/zalandoresearch/fashion-mnist}
}

License

MIT License

Acknowledgments

  • Original GAN paper by Goodfellow et al.
  • Conditional GAN paper by Mirza & Osindero
  • FashionMNIST by Zalando Research
  • PyTorch team for the deep learning framework
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train A0lgk/GAN_cGAN

Paper for A0lgk/GAN_cGAN