violetar's picture
Update README.md
96672ae verified
---
pipeline_tag: unconditional-image-generation
library_name: pytorch
tags:
- gan
- cgan
- image-generation
- generative-adversarial-network
license: mit
datasets:
- mnist
language:
- en
---
# Conditional GAN — MNIST Digit Generation
Conditional Generative Adversarial Network (CGAN) for handwritten digit synthesis, trained on the [MNIST](https://www.google.com/search?q=https://huggingface.co/datasets/mnist) dataset.
## Architecture
A standard GAN framework conditioned on class labels, featuring a **Generator** and a **Discriminator** network.
| Component | Details |
| --- | --- |
| Discriminator| Takes a 1×28×28 grayscale image and a class label (0-9). Embeds the label, expands it to 1×28×28, concatenates with the image, and passes through two downsampling `Conv2d` blocks (with `BatchNorm2d` and `LeakyReLU`). Output is flattened and passed through a `Sigmoid` activation (1 output). |
| Generator | Takes a 100-dimensional noise vector and a target class label (0-9). Uses embedding layers for the label, concatenates it with noise, and passes through two `ConvTranspose2d` upsampling blocks (with `BatchNorm2d` and `LeakyReLU`). Output uses a `Tanh` activation function (1×28×28). |
## Loss Function
The model is trained using the standard Conditional GAN Minimax objective, implemented with Binary Cross Entropy:
```
Loss = min_G max_D V(D, G) = E_x[log D(x|y)] + E_z[log(1 - D(G(z|y)|y))]
```
| Loss | Role |
| --- | --- |
| Discriminator Loss | Penalises the Discriminator for incorrectly classifying real MNIST images as fake, or generated images as real. |
| Generator Loss | Penalises the Generator when the Discriminator successfully identifies its generated images as fake. |
## Training
* **Dataset:** 60,000 train (MNIST dataset)
* **Input size:** 28×28, normalised to [-1, 1]
* **Batch size:** 32
* **Epochs:** 50
* **Optimizer:** Adam
* **Hyperparameter search:** Optuna over `learning_rate` {1e-5 to 2e-3}, `beta1` {0.0, 0.9}, and `noise_dim` {50, 100, 128}
* **Best weights:** snapshot of the epoch with the lowest Generator loss and a stable Discriminator loss. Best Hyperparameters found: `learning_rate` = 0.000112, `beta1` = 0.037, `noise_dim` = 100.
### Data Preprocessing (train only)
**Photometric**:
Images are converted to PyTorch tensors and normalised with a mean of `0.5` and standard deviation of `0.5`, scaling the pixel values to the `[-1, 1]` range to match the `Tanh` activation of the Generator.
## Usage
```python
import torch
from torch import nn
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
class CGAN_Generator(nn.Module):
def __init__(self, noise_dim=100, num_classes=10, img_size=28):
super().__init__()
self.init_size = img_size // 4
self.embedding_dim = 20
self.label_embedding = nn.Embedding(num_classes, self.embedding_dim)
self.projection = nn.Sequential(
nn.Linear(noise_dim + self.embedding_dim, 128 * self.init_size * self.init_size)
)
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, noise, labels):
label_embed = self.label_embedding(labels)
merged_input = torch.cat((noise, label_embed), dim=1)
out = self.projection(merged_input)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
model = CGAN_Generator(noise_dim=100)
weights_path = hf_hub_download(repo_id="VioletaR/cgan-mnist", filename="mnist_cgan_generator.pth")
model.load_state_dict(torch.load(weights_path))
model.eval()
target_label = torch.tensor([7])
eval_noise = torch.randn(1, noise_dim)
with torch.no_grad():
generated_img = model(eval_noise, target_label)
generated_img = (generated_img + 1) / 2.0
img_numpy = generated_img.squeeze().cpu().numpy()
plt.imshow(img_numpy, cmap='gray')
plt.title(f"Generated Label: {target_label.item()}")
plt.axis('off')
plt.show()
```