| --- |
| pipeline_tag: unconditional-image-generation |
| library_name: pytorch |
| tags: |
| - gan |
| - cgan |
| - image-generation |
| - generative-adversarial-network |
| license: mit |
| datasets: |
| - mnist |
| language: |
| - en |
| --- |
| |
| # Conditional GAN — MNIST Digit Generation |
|
|
| Conditional Generative Adversarial Network (CGAN) for handwritten digit synthesis, trained on the [MNIST](https://www.google.com/search?q=https://huggingface.co/datasets/mnist) dataset. |
|
|
| ## Architecture |
|
|
| A standard GAN framework conditioned on class labels, featuring a **Generator** and a **Discriminator** network. |
|
|
| | Component | Details | |
| | --- | --- | |
| | Discriminator| Takes a 1×28×28 grayscale image and a class label (0-9). Embeds the label, expands it to 1×28×28, concatenates with the image, and passes through two downsampling `Conv2d` blocks (with `BatchNorm2d` and `LeakyReLU`). Output is flattened and passed through a `Sigmoid` activation (1 output). | |
| | Generator | Takes a 100-dimensional noise vector and a target class label (0-9). Uses embedding layers for the label, concatenates it with noise, and passes through two `ConvTranspose2d` upsampling blocks (with `BatchNorm2d` and `LeakyReLU`). Output uses a `Tanh` activation function (1×28×28). | |
|
|
| ## Loss Function |
|
|
| The model is trained using the standard Conditional GAN Minimax objective, implemented with Binary Cross Entropy: |
|
|
| ``` |
| Loss = min_G max_D V(D, G) = E_x[log D(x|y)] + E_z[log(1 - D(G(z|y)|y))] |
| ``` |
|
|
| | Loss | Role | |
| | --- | --- | |
| | Discriminator Loss | Penalises the Discriminator for incorrectly classifying real MNIST images as fake, or generated images as real. | |
| | Generator Loss | Penalises the Generator when the Discriminator successfully identifies its generated images as fake. | |
|
|
| ## Training |
|
|
| * **Dataset:** 60,000 train (MNIST dataset) |
| * **Input size:** 28×28, normalised to [-1, 1] |
| * **Batch size:** 32 |
| * **Epochs:** 50 |
| * **Optimizer:** Adam |
| * **Hyperparameter search:** Optuna over `learning_rate` {1e-5 to 2e-3}, `beta1` {0.0, 0.9}, and `noise_dim` {50, 100, 128} |
| * **Best weights:** snapshot of the epoch with the lowest Generator loss and a stable Discriminator loss. Best Hyperparameters found: `learning_rate` = 0.000112, `beta1` = 0.037, `noise_dim` = 100. |
|
|
| ### Data Preprocessing (train only) |
|
|
| **Photometric**: |
| Images are converted to PyTorch tensors and normalised with a mean of `0.5` and standard deviation of `0.5`, scaling the pixel values to the `[-1, 1]` range to match the `Tanh` activation of the Generator. |
|
|
| ## Usage |
|
|
| ```python |
| import torch |
| from torch import nn |
| from huggingface_hub import hf_hub_download |
| import matplotlib.pyplot as plt |
| |
| class CGAN_Generator(nn.Module): |
| def __init__(self, noise_dim=100, num_classes=10, img_size=28): |
| super().__init__() |
| self.init_size = img_size // 4 |
| self.embedding_dim = 20 |
| self.label_embedding = nn.Embedding(num_classes, self.embedding_dim) |
| |
| self.projection = nn.Sequential( |
| nn.Linear(noise_dim + self.embedding_dim, 128 * self.init_size * self.init_size) |
| ) |
| |
| self.conv_blocks = nn.Sequential( |
| nn.BatchNorm2d(128), |
| nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), |
| nn.BatchNorm2d(64), |
| nn.LeakyReLU(0.2, inplace=True), |
| nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), |
| nn.Tanh() |
| ) |
| |
| def forward(self, noise, labels): |
| label_embed = self.label_embedding(labels) |
| merged_input = torch.cat((noise, label_embed), dim=1) |
| out = self.projection(merged_input) |
| out = out.view(out.shape[0], 128, self.init_size, self.init_size) |
| img = self.conv_blocks(out) |
| return img |
| |
| model = CGAN_Generator(noise_dim=100) |
| |
| weights_path = hf_hub_download(repo_id="VioletaR/cgan-mnist", filename="mnist_cgan_generator.pth") |
| model.load_state_dict(torch.load(weights_path)) |
| model.eval() |
| |
| target_label = torch.tensor([7]) |
| eval_noise = torch.randn(1, noise_dim) |
| |
| with torch.no_grad(): |
| generated_img = model(eval_noise, target_label) |
| |
| generated_img = (generated_img + 1) / 2.0 |
| img_numpy = generated_img.squeeze().cpu().numpy() |
| |
| plt.imshow(img_numpy, cmap='gray') |
| plt.title(f"Generated Label: {target_label.item()}") |
| plt.axis('off') |
| plt.show() |
| ``` |