| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import gc |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torchvision.utils as vutils |
| from PIL import Image |
| import gradio as gr |
|
|
| |
|
|
| class DCGANGenerator(nn.Module): |
| def __init__(self, latent_dim=100, features_g=64, num_channels=3): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.ConvTranspose2d(latent_dim, features_g * 8, 4, 1, 0, bias=False), |
| nn.BatchNorm2d(features_g * 8), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g * 4), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g * 2), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False), |
| nn.Tanh(), |
| ) |
|
|
| def forward(self, z): |
| return self.net(z) |
|
|
|
|
| class WGANGenerator(nn.Module): |
| def __init__(self, latent_dim=100, features_g=64, num_channels=3): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.ConvTranspose2d(latent_dim, features_g * 8, 4, 1, 0, bias=False), |
| nn.BatchNorm2d(features_g * 8), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g * 4), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g * 2), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g), |
| nn.ReLU(True), |
| nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False), |
| nn.Tanh(), |
| ) |
|
|
| def forward(self, z): |
| return self.net(z) |
|
|
|
|
| |
|
|
| LATENT_DIM = 100 |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Running on: {device}") |
|
|
| dcgan_gen = DCGANGenerator(LATENT_DIM).to(device) |
| wgan_gen = WGANGenerator(LATENT_DIM).to(device) |
|
|
| DCGAN_WEIGHTS = "dcgan_G_final.pt" |
| WGAN_WEIGHTS = "wgan_G_final.pt" |
|
|
| def load_weights(): |
| """Load weights if available; otherwise use random init (demo fallback).""" |
| if os.path.exists(DCGAN_WEIGHTS): |
| state = torch.load(DCGAN_WEIGHTS, map_location=device) |
| |
| state = {k.replace("module.", ""): v for k, v in state.items()} |
| dcgan_gen.load_state_dict(state) |
| print("β DCGAN weights loaded.") |
| else: |
| print("β DCGAN weights not found β using random init.") |
|
|
| if os.path.exists(WGAN_WEIGHTS): |
| state = torch.load(WGAN_WEIGHTS, map_location=device) |
| state = {k.replace("module.", ""): v for k, v in state.items()} |
| wgan_gen.load_state_dict(state) |
| print("β WGAN-GP weights loaded.") |
| else: |
| print("β WGAN-GP weights not found β using random init.") |
|
|
| dcgan_gen.eval() |
| wgan_gen.eval() |
|
|
| load_weights() |
|
|
|
|
| |
|
|
| def tensor_to_pil_grid(tensor_batch, nrow=4): |
| """Convert a (B,3,H,W) tensor in [-1,1] to a PIL image grid.""" |
| grid = vutils.make_grid(tensor_batch, nrow=nrow, normalize=True, value_range=(-1, 1)) |
| np_img = grid.permute(1, 2, 0).numpy() |
| np_img = (np_img * 255).clip(0, 255).astype(np.uint8) |
| return Image.fromarray(np_img) |
|
|
|
|
| @torch.no_grad() |
| def generate_comparison(n_images: int, seed: int): |
| """ |
| Core generation function. |
| Returns two PIL images: DCGAN grid and WGAN-GP grid. |
| """ |
| n_images = max(1, min(n_images, 16)) |
| torch.manual_seed(seed) |
| z = torch.randn(n_images, LATENT_DIM, 1, 1, device=device) |
|
|
| dcgan_imgs = dcgan_gen(z).cpu() |
| wgan_imgs = wgan_gen(z).cpu() |
|
|
| nrow = 4 if n_images >= 4 else n_images |
| pil_dcgan = tensor_to_pil_grid(dcgan_imgs, nrow=nrow) |
| pil_wgan = tensor_to_pil_grid(wgan_imgs, nrow=nrow) |
|
|
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return pil_dcgan, pil_wgan |
|
|
|
|
| @torch.no_grad() |
| def generate_single(model_choice: str, n_images: int, seed: int): |
| """ |
| Returns a single model's output as a PIL grid + a short description. |
| """ |
| n_images = max(1, min(n_images, 16)) |
| torch.manual_seed(seed) |
| z = torch.randn(n_images, LATENT_DIM, 1, 1, device=device) |
|
|
| gen = dcgan_gen if model_choice == "DCGAN" else wgan_gen |
| imgs = gen(z).cpu() |
| nrow = 4 if n_images >= 4 else n_images |
| pil_out = tensor_to_pil_grid(imgs, nrow=nrow) |
|
|
| desc = { |
| "DCGAN": ("Binary Cross Entropy loss. Faster to train but prone to mode collapse " |
| "β may generate repetitive or blurry samples."), |
| "WGAN-GP": ("Wasserstein loss + Gradient Penalty. More stable training, " |
| "better sample diversity, and less mode collapse."), |
| }[model_choice] |
|
|
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return pil_out, desc |
|
|
|
|
| |
|
|
| with gr.Blocks( |
| title="DCGAN vs WGAN-GP | Anime Face Generator", |
| theme=gr.themes.Soft(), |
| ) as demo: |
|
|
| gr.Markdown( |
| """ |
| # π¨ DCGAN vs WGAN-GP β Anime Face Generator |
| **AI4009 Generative AI | Assignment 3 β Question 1** |
| |
| Generate anime faces using two GAN variants and compare output diversity. |
| Both models were trained on the [Anime Faces](https://www.kaggle.com/datasets/soumikrakshit/anime-faces) |
| dataset (64Γ64, normalised to [-1, 1]). |
| |
| | Model | Loss | Key Property | |
| |-------|------|--------------| |
| | DCGAN | Binary Cross-Entropy | Baseline β fast but unstable | |
| | WGAN-GP | Wasserstein + Gradient Penalty | Stable, diverse, mode-collapse-resistant | |
| """ |
| ) |
|
|
| with gr.Tabs(): |
|
|
| |
| with gr.TabItem("π Compare Both Models"): |
| gr.Markdown("### Generate the same latent noise through both models") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| n_img_compare = gr.Slider(1, 16, value=8, step=1, |
| label="Number of Images") |
| seed_compare = gr.Slider(0, 9999, value=42, step=1, |
| label="Random Seed") |
| btn_compare = gr.Button("π Generate & Compare", variant="primary") |
|
|
| with gr.Row(): |
| out_dcgan = gr.Image(label="DCGAN Output", type="pil") |
| out_wgan = gr.Image(label="WGAN-GP Output", type="pil") |
|
|
| btn_compare.click( |
| fn=generate_comparison, |
| inputs=[n_img_compare, seed_compare], |
| outputs=[out_dcgan, out_wgan], |
| ) |
|
|
| gr.Examples( |
| examples=[[8, 42], [16, 123], [4, 777], [16, 2024]], |
| inputs=[n_img_compare, seed_compare], |
| outputs=[out_dcgan, out_wgan], |
| fn=generate_comparison, |
| cache_examples=False, |
| ) |
|
|
| |
| with gr.TabItem("π Explore Single Model"): |
| gr.Markdown("### Explore a specific model in detail") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| model_choice = gr.Radio(["DCGAN", "WGAN-GP"], value="WGAN-GP", |
| label="Select Model") |
| n_img_single = gr.Slider(1, 16, value=8, step=1, |
| label="Number of Images") |
| seed_single = gr.Slider(0, 9999, value=0, step=1, |
| label="Random Seed") |
| btn_single = gr.Button("Generate", variant="primary") |
|
|
| with gr.Row(): |
| single_out = gr.Image(label="Generated Images", type="pil", scale=2) |
| single_desc = gr.Textbox(label="Model Description", lines=4, scale=1) |
|
|
| btn_single.click( |
| fn=generate_single, |
| inputs=[model_choice, n_img_single, seed_single], |
| outputs=[single_out, single_desc], |
| ) |
|
|
| |
| with gr.TabItem("βΉοΈ About"): |
| gr.Markdown( |
| """ |
| ## Model Details |
| |
| ### DCGAN (Deep Convolutional GAN) |
| - **Generator**: 5 ConvTranspose2d layers, BatchNorm, ReLU, Tanh output |
| - **Discriminator**: 5 Conv2d layers, LeakyReLU, Sigmoid output |
| - **Loss**: Binary Cross-Entropy |
| - **Known weakness**: Mode collapse β the generator may learn to produce |
| only a few "safe" outputs that fool the discriminator. |
| |
| ### WGAN-GP (Wasserstein GAN with Gradient Penalty) |
| - **Generator**: Same architecture as DCGAN |
| - **Critic**: Same structure but uses InstanceNorm and **no Sigmoid** β |
| outputs raw Wasserstein scores instead of probabilities |
| - **Loss**: Wasserstein distance + Gradient Penalty (Ξ»=10) |
| - **Training**: 5 critic updates per generator step |
| - **Advantage**: The Wasserstein distance provides meaningful gradients even |
| when distributions don't overlap β eliminates mode collapse. |
| |
| ### Training Setup |
| - Dataset: Anime Faces 64Γ64 |
| - Optimizer: Adam (lr=0.0002, Ξ²=(0.5, 0.999)) |
| - Mixed precision (torch.cuda.amp) |
| - Platform: Kaggle T4 x2 Dual GPU |
| """ |
| ) |
|
|
| gr.Markdown( |
| "<center>Built for AI4009 GenAI Assignment 3 Β· " |
| "Model trained on Kaggle Β· Deployed on HuggingFace Spaces</center>" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|