File size: 4,270 Bytes
a7902cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96672ae
a7902cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
---
pipeline_tag: unconditional-image-generation
library_name: pytorch
tags:
- gan
- cgan
- image-generation
- generative-adversarial-network
license: mit
datasets:
- mnist
language:
- en
---

# Conditional GAN — MNIST Digit Generation

Conditional Generative Adversarial Network (CGAN) for handwritten digit synthesis, trained on the [MNIST](https://www.google.com/search?q=https://huggingface.co/datasets/mnist) dataset.

## Architecture

A standard GAN framework conditioned on class labels, featuring a **Generator** and a **Discriminator** network.

| Component | Details |
| --- | --- |
| Discriminator| Takes a 1×28×28 grayscale image and a class label (0-9). Embeds the label, expands it to 1×28×28, concatenates with the image, and passes through two downsampling `Conv2d` blocks (with `BatchNorm2d` and `LeakyReLU`). Output is flattened and passed through a `Sigmoid` activation (1 output). |
| Generator | Takes a 100-dimensional noise vector and a target class label (0-9). Uses embedding layers for the label, concatenates it with noise, and passes through two `ConvTranspose2d` upsampling blocks (with `BatchNorm2d` and `LeakyReLU`). Output uses a `Tanh` activation function (1×28×28). |

## Loss Function

The model is trained using the standard Conditional GAN Minimax objective, implemented with Binary Cross Entropy:

```
Loss = min_G max_D V(D, G) = E_x[log D(x|y)] + E_z[log(1 - D(G(z|y)|y))]
```

| Loss | Role |
| --- | --- |
| Discriminator Loss | Penalises the Discriminator for incorrectly classifying real MNIST images as fake, or generated images as real. |
| Generator Loss | Penalises the Generator when the Discriminator successfully identifies its generated images as fake. |

## Training

  * **Dataset:** 60,000 train (MNIST dataset)
  * **Input size:** 28×28, normalised to [-1, 1]
  * **Batch size:** 32
  * **Epochs:** 50
  * **Optimizer:** Adam
  * **Hyperparameter search:** Optuna over `learning_rate` {1e-5 to 2e-3}, `beta1` {0.0, 0.9}, and `noise_dim` {50, 100, 128}
  * **Best weights:** snapshot of the epoch with the lowest Generator loss and a stable Discriminator loss. Best Hyperparameters found: `learning_rate` = 0.000112, `beta1` = 0.037, `noise_dim` = 100.

### Data Preprocessing (train only)

**Photometric**:
Images are converted to PyTorch tensors and normalised with a mean of `0.5` and standard deviation of `0.5`, scaling the pixel values to the `[-1, 1]` range to match the `Tanh` activation of the Generator.

## Usage

```python
import torch
from torch import nn
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt

class CGAN_Generator(nn.Module):
    def __init__(self, noise_dim=100, num_classes=10, img_size=28):
        super().__init__()
        self.init_size = img_size // 4
        self.embedding_dim = 20
        self.label_embedding = nn.Embedding(num_classes, self.embedding_dim)
        
        self.projection = nn.Sequential(
            nn.Linear(noise_dim + self.embedding_dim, 128 * self.init_size * self.init_size)
        )
        
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        label_embed = self.label_embedding(labels)
        merged_input = torch.cat((noise, label_embed), dim=1)
        out = self.projection(merged_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

model = CGAN_Generator(noise_dim=100)

weights_path = hf_hub_download(repo_id="VioletaR/cgan-mnist", filename="mnist_cgan_generator.pth")
model.load_state_dict(torch.load(weights_path))
model.eval()

target_label = torch.tensor([7]) 
eval_noise = torch.randn(1, noise_dim)

with torch.no_grad():
    generated_img = model(eval_noise, target_label)

generated_img = (generated_img + 1) / 2.0
img_numpy = generated_img.squeeze().cpu().numpy()

plt.imshow(img_numpy, cmap='gray')
plt.title(f"Generated Label: {target_label.item()}")
plt.axis('off')
plt.show()
```