--- pipeline_tag: unconditional-image-generation library_name: pytorch tags: - gan - cgan - image-generation - generative-adversarial-network license: mit datasets: - mnist language: - en --- # Conditional GAN — MNIST Digit Generation Conditional Generative Adversarial Network (CGAN) for handwritten digit synthesis, trained on the [MNIST](https://www.google.com/search?q=https://huggingface.co/datasets/mnist) dataset. ## Architecture A standard GAN framework conditioned on class labels, featuring a **Generator** and a **Discriminator** network. | Component | Details | | --- | --- | | Discriminator| Takes a 1×28×28 grayscale image and a class label (0-9). Embeds the label, expands it to 1×28×28, concatenates with the image, and passes through two downsampling `Conv2d` blocks (with `BatchNorm2d` and `LeakyReLU`). Output is flattened and passed through a `Sigmoid` activation (1 output). | | Generator | Takes a 100-dimensional noise vector and a target class label (0-9). Uses embedding layers for the label, concatenates it with noise, and passes through two `ConvTranspose2d` upsampling blocks (with `BatchNorm2d` and `LeakyReLU`). Output uses a `Tanh` activation function (1×28×28). | ## Loss Function The model is trained using the standard Conditional GAN Minimax objective, implemented with Binary Cross Entropy: ``` Loss = min_G max_D V(D, G) = E_x[log D(x|y)] + E_z[log(1 - D(G(z|y)|y))] ``` | Loss | Role | | --- | --- | | Discriminator Loss | Penalises the Discriminator for incorrectly classifying real MNIST images as fake, or generated images as real. | | Generator Loss | Penalises the Generator when the Discriminator successfully identifies its generated images as fake. | ## Training * **Dataset:** 60,000 train (MNIST dataset) * **Input size:** 28×28, normalised to [-1, 1] * **Batch size:** 32 * **Epochs:** 50 * **Optimizer:** Adam * **Hyperparameter search:** Optuna over `learning_rate` {1e-5 to 2e-3}, `beta1` {0.0, 0.9}, and `noise_dim` {50, 100, 128} * **Best weights:** snapshot of the epoch with the lowest Generator loss and a stable Discriminator loss. Best Hyperparameters found: `learning_rate` = 0.000112, `beta1` = 0.037, `noise_dim` = 100. ### Data Preprocessing (train only) **Photometric**: Images are converted to PyTorch tensors and normalised with a mean of `0.5` and standard deviation of `0.5`, scaling the pixel values to the `[-1, 1]` range to match the `Tanh` activation of the Generator. ## Usage ```python import torch from torch import nn from huggingface_hub import hf_hub_download import matplotlib.pyplot as plt class CGAN_Generator(nn.Module): def __init__(self, noise_dim=100, num_classes=10, img_size=28): super().__init__() self.init_size = img_size // 4 self.embedding_dim = 20 self.label_embedding = nn.Embedding(num_classes, self.embedding_dim) self.projection = nn.Sequential( nn.Linear(noise_dim + self.embedding_dim, 128 * self.init_size * self.init_size) ) self.conv_blocks = nn.Sequential( nn.BatchNorm2d(128), nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, inplace=True), nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), nn.Tanh() ) def forward(self, noise, labels): label_embed = self.label_embedding(labels) merged_input = torch.cat((noise, label_embed), dim=1) out = self.projection(merged_input) out = out.view(out.shape[0], 128, self.init_size, self.init_size) img = self.conv_blocks(out) return img model = CGAN_Generator(noise_dim=100) weights_path = hf_hub_download(repo_id="VioletaR/cgan-mnist", filename="mnist_cgan_generator.pth") model.load_state_dict(torch.load(weights_path)) model.eval() target_label = torch.tensor([7]) eval_noise = torch.randn(1, noise_dim) with torch.no_grad(): generated_img = model(eval_noise, target_label) generated_img = (generated_img + 1) / 2.0 img_numpy = generated_img.squeeze().cpu().numpy() plt.imshow(img_numpy, cmap='gray') plt.title(f"Generated Label: {target_label.item()}") plt.axis('off') plt.show() ```