waleed-12's picture
Upload 2 files
a015496 verified
# ============================================================
# app.py β€” HuggingFace Spaces Gradio App
# DCGAN vs WGAN-GP: Anime Face Generation
# ============================================================
# Deploy instructions:
# 1. Create a new Space on HuggingFace (SDK: Gradio)
# 2. Upload this app.py and requirements.txt
# 3. Upload dcgan_G_final.pt and wgan_G_final.pt to the Space files
# (or host them on HF Hub and pull with hf_hub_download)
# ============================================================
import os
import gc
import numpy as np
import torch
import torch.nn as nn
import torchvision.utils as vutils
from PIL import Image
import gradio as gr
# ── Re-define architectures (must match training code exactly) ───────────────
class DCGANGenerator(nn.Module):
def __init__(self, latent_dim=100, features_g=64, num_channels=3):
super().__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(latent_dim, features_g * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g),
nn.ReLU(True),
nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, z):
return self.net(z)
class WGANGenerator(nn.Module):
def __init__(self, latent_dim=100, features_g=64, num_channels=3):
super().__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(latent_dim, features_g * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(True),
nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g),
nn.ReLU(True),
nn.ConvTranspose2d(features_g, num_channels, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, z):
return self.net(z)
# ── Load models ──────────────────────────────────────────────────────────────
LATENT_DIM = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")
dcgan_gen = DCGANGenerator(LATENT_DIM).to(device)
wgan_gen = WGANGenerator(LATENT_DIM).to(device)
DCGAN_WEIGHTS = "dcgan_G_final.pt"
WGAN_WEIGHTS = "wgan_G_final.pt"
def load_weights():
"""Load weights if available; otherwise use random init (demo fallback)."""
if os.path.exists(DCGAN_WEIGHTS):
state = torch.load(DCGAN_WEIGHTS, map_location=device)
# Handle DataParallel prefix if saved from multi-GPU
state = {k.replace("module.", ""): v for k, v in state.items()}
dcgan_gen.load_state_dict(state)
print("βœ” DCGAN weights loaded.")
else:
print("⚠ DCGAN weights not found β€” using random init.")
if os.path.exists(WGAN_WEIGHTS):
state = torch.load(WGAN_WEIGHTS, map_location=device)
state = {k.replace("module.", ""): v for k, v in state.items()}
wgan_gen.load_state_dict(state)
print("βœ” WGAN-GP weights loaded.")
else:
print("⚠ WGAN-GP weights not found β€” using random init.")
dcgan_gen.eval()
wgan_gen.eval()
load_weights()
# ── Inference helpers ─────────────────────────────────────────────────────────
def tensor_to_pil_grid(tensor_batch, nrow=4):
"""Convert a (B,3,H,W) tensor in [-1,1] to a PIL image grid."""
grid = vutils.make_grid(tensor_batch, nrow=nrow, normalize=True, value_range=(-1, 1))
np_img = grid.permute(1, 2, 0).numpy() # (H, W, 3)
np_img = (np_img * 255).clip(0, 255).astype(np.uint8)
return Image.fromarray(np_img)
@torch.no_grad()
def generate_comparison(n_images: int, seed: int):
"""
Core generation function.
Returns two PIL images: DCGAN grid and WGAN-GP grid.
"""
n_images = max(1, min(n_images, 16)) # clamp to [1, 16]
torch.manual_seed(seed)
z = torch.randn(n_images, LATENT_DIM, 1, 1, device=device)
dcgan_imgs = dcgan_gen(z).cpu()
wgan_imgs = wgan_gen(z).cpu()
nrow = 4 if n_images >= 4 else n_images
pil_dcgan = tensor_to_pil_grid(dcgan_imgs, nrow=nrow)
pil_wgan = tensor_to_pil_grid(wgan_imgs, nrow=nrow)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return pil_dcgan, pil_wgan
@torch.no_grad()
def generate_single(model_choice: str, n_images: int, seed: int):
"""
Returns a single model's output as a PIL grid + a short description.
"""
n_images = max(1, min(n_images, 16))
torch.manual_seed(seed)
z = torch.randn(n_images, LATENT_DIM, 1, 1, device=device)
gen = dcgan_gen if model_choice == "DCGAN" else wgan_gen
imgs = gen(z).cpu()
nrow = 4 if n_images >= 4 else n_images
pil_out = tensor_to_pil_grid(imgs, nrow=nrow)
desc = {
"DCGAN": ("Binary Cross Entropy loss. Faster to train but prone to mode collapse "
"β€” may generate repetitive or blurry samples."),
"WGAN-GP": ("Wasserstein loss + Gradient Penalty. More stable training, "
"better sample diversity, and less mode collapse."),
}[model_choice]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return pil_out, desc
# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(
title="DCGAN vs WGAN-GP | Anime Face Generator",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown(
"""
# 🎨 DCGAN vs WGAN-GP β€” Anime Face Generator
**AI4009 Generative AI | Assignment 3 β€” Question 1**
Generate anime faces using two GAN variants and compare output diversity.
Both models were trained on the [Anime Faces](https://www.kaggle.com/datasets/soumikrakshit/anime-faces)
dataset (64Γ—64, normalised to [-1, 1]).
| Model | Loss | Key Property |
|-------|------|--------------|
| DCGAN | Binary Cross-Entropy | Baseline β€” fast but unstable |
| WGAN-GP | Wasserstein + Gradient Penalty | Stable, diverse, mode-collapse-resistant |
"""
)
with gr.Tabs():
# ── Tab 1: Side-by-side comparison ──────────────────────────────────
with gr.TabItem("πŸ”„ Compare Both Models"):
gr.Markdown("### Generate the same latent noise through both models")
with gr.Row():
with gr.Column(scale=1):
n_img_compare = gr.Slider(1, 16, value=8, step=1,
label="Number of Images")
seed_compare = gr.Slider(0, 9999, value=42, step=1,
label="Random Seed")
btn_compare = gr.Button("πŸš€ Generate & Compare", variant="primary")
with gr.Row():
out_dcgan = gr.Image(label="DCGAN Output", type="pil")
out_wgan = gr.Image(label="WGAN-GP Output", type="pil")
btn_compare.click(
fn=generate_comparison,
inputs=[n_img_compare, seed_compare],
outputs=[out_dcgan, out_wgan],
)
gr.Examples(
examples=[[8, 42], [16, 123], [4, 777], [16, 2024]],
inputs=[n_img_compare, seed_compare],
outputs=[out_dcgan, out_wgan],
fn=generate_comparison,
cache_examples=False,
)
# ── Tab 2: Single model explorer ────────────────────────────────────
with gr.TabItem("πŸ” Explore Single Model"):
gr.Markdown("### Explore a specific model in detail")
with gr.Row():
with gr.Column(scale=1):
model_choice = gr.Radio(["DCGAN", "WGAN-GP"], value="WGAN-GP",
label="Select Model")
n_img_single = gr.Slider(1, 16, value=8, step=1,
label="Number of Images")
seed_single = gr.Slider(0, 9999, value=0, step=1,
label="Random Seed")
btn_single = gr.Button("Generate", variant="primary")
with gr.Row():
single_out = gr.Image(label="Generated Images", type="pil", scale=2)
single_desc = gr.Textbox(label="Model Description", lines=4, scale=1)
btn_single.click(
fn=generate_single,
inputs=[model_choice, n_img_single, seed_single],
outputs=[single_out, single_desc],
)
# ── Tab 3: About ─────────────────────────────────────────────────────
with gr.TabItem("ℹ️ About"):
gr.Markdown(
"""
## Model Details
### DCGAN (Deep Convolutional GAN)
- **Generator**: 5 ConvTranspose2d layers, BatchNorm, ReLU, Tanh output
- **Discriminator**: 5 Conv2d layers, LeakyReLU, Sigmoid output
- **Loss**: Binary Cross-Entropy
- **Known weakness**: Mode collapse β€” the generator may learn to produce
only a few "safe" outputs that fool the discriminator.
### WGAN-GP (Wasserstein GAN with Gradient Penalty)
- **Generator**: Same architecture as DCGAN
- **Critic**: Same structure but uses InstanceNorm and **no Sigmoid** β€”
outputs raw Wasserstein scores instead of probabilities
- **Loss**: Wasserstein distance + Gradient Penalty (Ξ»=10)
- **Training**: 5 critic updates per generator step
- **Advantage**: The Wasserstein distance provides meaningful gradients even
when distributions don't overlap β€” eliminates mode collapse.
### Training Setup
- Dataset: Anime Faces 64Γ—64
- Optimizer: Adam (lr=0.0002, Ξ²=(0.5, 0.999))
- Mixed precision (torch.cuda.amp)
- Platform: Kaggle T4 x2 Dual GPU
"""
)
gr.Markdown(
"<center>Built for AI4009 GenAI Assignment 3 Β· "
"Model trained on Kaggle Β· Deployed on HuggingFace Spaces</center>"
)
if __name__ == "__main__":
demo.launch()