File size: 11,737 Bytes
a015496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# ============================================================
# app.py β€” HuggingFace Spaces Gradio App
# DCGAN vs WGAN-GP: Anime Face Generation
# ============================================================
# Deploy instructions:
#   1. Create a new Space on HuggingFace (SDK: Gradio)
#   2. Upload this app.py and requirements.txt
#   3. Upload dcgan_G_final.pt and wgan_G_final.pt to the Space files
#      (or host them on HF Hub and pull with hf_hub_download)
# ============================================================

import os
import gc
import numpy as np
import torch
import torch.nn as nn
import torchvision.utils as vutils
from PIL import Image
import gradio as gr

# ── Re-define architectures (must match training code exactly) ───────────────

class DCGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, features_g=64, num_channels=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, features_g * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, z):
        return self.net(z)


class WGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, features_g=64, num_channels=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, features_g * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features_g),
            nn.ReLU(True),
            nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, z):
        return self.net(z)


# ── Load models ──────────────────────────────────────────────────────────────

LATENT_DIM = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")

dcgan_gen = DCGANGenerator(LATENT_DIM).to(device)
wgan_gen  = WGANGenerator(LATENT_DIM).to(device)

DCGAN_WEIGHTS = "dcgan_G_final.pt"
WGAN_WEIGHTS  = "wgan_G_final.pt"

def load_weights():
    """Load weights if available; otherwise use random init (demo fallback)."""
    if os.path.exists(DCGAN_WEIGHTS):
        state = torch.load(DCGAN_WEIGHTS, map_location=device)
        # Handle DataParallel prefix if saved from multi-GPU
        state = {k.replace("module.", ""): v for k, v in state.items()}
        dcgan_gen.load_state_dict(state)
        print("βœ” DCGAN weights loaded.")
    else:
        print("⚠  DCGAN weights not found β€” using random init.")

    if os.path.exists(WGAN_WEIGHTS):
        state = torch.load(WGAN_WEIGHTS, map_location=device)
        state = {k.replace("module.", ""): v for k, v in state.items()}
        wgan_gen.load_state_dict(state)
        print("βœ” WGAN-GP weights loaded.")
    else:
        print("⚠  WGAN-GP weights not found β€” using random init.")

    dcgan_gen.eval()
    wgan_gen.eval()

load_weights()


# ── Inference helpers ─────────────────────────────────────────────────────────

def tensor_to_pil_grid(tensor_batch, nrow=4):
    """Convert a (B,3,H,W) tensor in [-1,1] to a PIL image grid."""
    grid = vutils.make_grid(tensor_batch, nrow=nrow, normalize=True, value_range=(-1, 1))
    np_img = grid.permute(1, 2, 0).numpy()           # (H, W, 3)
    np_img = (np_img * 255).clip(0, 255).astype(np.uint8)
    return Image.fromarray(np_img)


@torch.no_grad()
def generate_comparison(n_images: int, seed: int):
    """
    Core generation function.
    Returns two PIL images: DCGAN grid and WGAN-GP grid.
    """
    n_images = max(1, min(n_images, 16))  # clamp to [1, 16]
    torch.manual_seed(seed)
    z = torch.randn(n_images, LATENT_DIM, 1, 1, device=device)

    dcgan_imgs = dcgan_gen(z).cpu()
    wgan_imgs  = wgan_gen(z).cpu()

    nrow = 4 if n_images >= 4 else n_images
    pil_dcgan = tensor_to_pil_grid(dcgan_imgs, nrow=nrow)
    pil_wgan  = tensor_to_pil_grid(wgan_imgs,  nrow=nrow)

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return pil_dcgan, pil_wgan


@torch.no_grad()
def generate_single(model_choice: str, n_images: int, seed: int):
    """
    Returns a single model's output as a PIL grid + a short description.
    """
    n_images = max(1, min(n_images, 16))
    torch.manual_seed(seed)
    z = torch.randn(n_images, LATENT_DIM, 1, 1, device=device)

    gen = dcgan_gen if model_choice == "DCGAN" else wgan_gen
    imgs = gen(z).cpu()
    nrow = 4 if n_images >= 4 else n_images
    pil_out = tensor_to_pil_grid(imgs, nrow=nrow)

    desc = {
        "DCGAN":   ("Binary Cross Entropy loss. Faster to train but prone to mode collapse "
                    "β€” may generate repetitive or blurry samples."),
        "WGAN-GP": ("Wasserstein loss + Gradient Penalty. More stable training, "
                    "better sample diversity, and less mode collapse."),
    }[model_choice]

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return pil_out, desc


# ── Gradio UI ─────────────────────────────────────────────────────────────────

with gr.Blocks(
    title="DCGAN vs WGAN-GP | Anime Face Generator",
    theme=gr.themes.Soft(),
) as demo:

    gr.Markdown(
        """
        # 🎨 DCGAN vs WGAN-GP β€” Anime Face Generator
        **AI4009 Generative AI | Assignment 3 β€” Question 1**

        Generate anime faces using two GAN variants and compare output diversity.
        Both models were trained on the [Anime Faces](https://www.kaggle.com/datasets/soumikrakshit/anime-faces)
        dataset (64Γ—64, normalised to [-1, 1]).

        | Model | Loss | Key Property |
        |-------|------|--------------|
        | DCGAN | Binary Cross-Entropy | Baseline β€” fast but unstable |
        | WGAN-GP | Wasserstein + Gradient Penalty | Stable, diverse, mode-collapse-resistant |
        """
    )

    with gr.Tabs():

        # ── Tab 1: Side-by-side comparison ──────────────────────────────────
        with gr.TabItem("πŸ”„ Compare Both Models"):
            gr.Markdown("### Generate the same latent noise through both models")

            with gr.Row():
                with gr.Column(scale=1):
                    n_img_compare = gr.Slider(1, 16, value=8, step=1,
                                              label="Number of Images")
                    seed_compare = gr.Slider(0, 9999, value=42, step=1,
                                             label="Random Seed")
                    btn_compare = gr.Button("πŸš€ Generate & Compare", variant="primary")

            with gr.Row():
                out_dcgan = gr.Image(label="DCGAN Output", type="pil")
                out_wgan  = gr.Image(label="WGAN-GP Output", type="pil")

            btn_compare.click(
                fn=generate_comparison,
                inputs=[n_img_compare, seed_compare],
                outputs=[out_dcgan, out_wgan],
            )

            gr.Examples(
                examples=[[8, 42], [16, 123], [4, 777], [16, 2024]],
                inputs=[n_img_compare, seed_compare],
                outputs=[out_dcgan, out_wgan],
                fn=generate_comparison,
                cache_examples=False,
            )

        # ── Tab 2: Single model explorer ────────────────────────────────────
        with gr.TabItem("πŸ” Explore Single Model"):
            gr.Markdown("### Explore a specific model in detail")

            with gr.Row():
                with gr.Column(scale=1):
                    model_choice = gr.Radio(["DCGAN", "WGAN-GP"], value="WGAN-GP",
                                            label="Select Model")
                    n_img_single = gr.Slider(1, 16, value=8, step=1,
                                             label="Number of Images")
                    seed_single  = gr.Slider(0, 9999, value=0, step=1,
                                             label="Random Seed")
                    btn_single   = gr.Button("Generate", variant="primary")

            with gr.Row():
                single_out  = gr.Image(label="Generated Images", type="pil", scale=2)
                single_desc = gr.Textbox(label="Model Description", lines=4, scale=1)

            btn_single.click(
                fn=generate_single,
                inputs=[model_choice, n_img_single, seed_single],
                outputs=[single_out, single_desc],
            )

        # ── Tab 3: About ─────────────────────────────────────────────────────
        with gr.TabItem("ℹ️ About"):
            gr.Markdown(
                """
                ## Model Details

                ### DCGAN (Deep Convolutional GAN)
                - **Generator**: 5 ConvTranspose2d layers, BatchNorm, ReLU, Tanh output
                - **Discriminator**: 5 Conv2d layers, LeakyReLU, Sigmoid output
                - **Loss**: Binary Cross-Entropy
                - **Known weakness**: Mode collapse β€” the generator may learn to produce
                  only a few "safe" outputs that fool the discriminator.

                ### WGAN-GP (Wasserstein GAN with Gradient Penalty)
                - **Generator**: Same architecture as DCGAN
                - **Critic**: Same structure but uses InstanceNorm and **no Sigmoid** β€”
                  outputs raw Wasserstein scores instead of probabilities
                - **Loss**: Wasserstein distance + Gradient Penalty (Ξ»=10)
                - **Training**: 5 critic updates per generator step
                - **Advantage**: The Wasserstein distance provides meaningful gradients even
                  when distributions don't overlap β€” eliminates mode collapse.

                ### Training Setup
                - Dataset: Anime Faces 64Γ—64
                - Optimizer: Adam (lr=0.0002, Ξ²=(0.5, 0.999))
                - Mixed precision (torch.cuda.amp)
                - Platform: Kaggle T4 x2 Dual GPU
                """
            )

    gr.Markdown(
        "<center>Built for AI4009 GenAI Assignment 3 Β· "
        "Model trained on Kaggle Β· Deployed on HuggingFace Spaces</center>"
    )


if __name__ == "__main__":
    demo.launch()