| import numpy as np |
| import torch |
| import torch.nn as nn |
| import gradio as gr |
| from huggingface_hub import hf_hub_download |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| def denorm(x): |
| return (x * 0.5 + 0.5).clamp(0, 1) |
|
|
| class DCGenerator(nn.Module): |
| def __init__(self, z_dim=100, channels=3, f=64): |
| super().__init__() |
| self.net = nn.Sequential( |
| |
| nn.ConvTranspose2d(z_dim, f * 8, 4, 1, 0, bias=False), |
| nn.BatchNorm2d(f * 8), |
| nn.ReLU(True), |
| |
| nn.ConvTranspose2d(f * 8, f * 4, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(f * 4), |
| nn.ReLU(True), |
| |
| nn.ConvTranspose2d(f * 4, f * 2, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(f * 2), |
| nn.ReLU(True), |
| |
| nn.ConvTranspose2d(f * 2, f, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(f), |
| nn.ReLU(True), |
| |
| nn.ConvTranspose2d(f, channels, 4, 2, 1, bias=False), |
| nn.Tanh() |
| |
| ) |
|
|
| def forward(self, z): |
| |
| return self.net(z) |
|
|
| def _strip(sd): |
| return {k.replace("module.", ""): v for k, v in sd.items()} |
|
|
| |
| dc_path = hf_hub_download(repo_id="SotaSF/q1-models", filename="dcgan_final.pt") |
| wg_path = hf_hub_download(repo_id="SotaSF/q1-models", filename="wgangp_final.pt") |
|
|
| @torch.no_grad() |
| def generate_compare(n_samples, seed, z_dim): |
| g1 = DCGenerator(int(z_dim)).to(device) |
| g2 = DCGenerator(int(z_dim)).to(device) |
|
|
| s1 = torch.load(dc_path, map_location=device) |
| s2 = torch.load(wg_path, map_location=device) |
|
|
| g1.load_state_dict(_strip(s1["G"] if "G" in s1 else s1)) |
| g2.load_state_dict(_strip(s2["G"] if "G" in s2 else s2)) |
|
|
| g1.eval() |
| g2.eval() |
|
|
| gen = torch.Generator(device=device).manual_seed(int(seed)) |
| z = torch.randn(int(n_samples), int(z_dim), 1, 1, device=device, generator=gen) |
|
|
| a = denorm(g1(z)).cpu() |
| b = denorm(g2(z)).cpu() |
|
|
| out_a, out_b = [], [] |
| for i in range(int(n_samples)): |
| out_a.append((a[i].permute(1,2,0).numpy()*255).astype(np.uint8)) |
| out_b.append((b[i].permute(1,2,0).numpy()*255).astype(np.uint8)) |
|
|
| return out_a, out_b |
|
|
| with gr.Blocks(title="Q1 DCGAN vs WGAN-GP") as demo: |
| gr.Markdown("# Q1: DCGAN vs WGAN-GP") |
| n = gr.Slider(5, 20, value=10, step=1, label="Samples") |
| s = gr.Number(value=42, label="Seed", precision=0) |
| z = gr.Number(value=100, label="z dim", precision=0) |
| btn = gr.Button("Generate") |
|
|
| g1 = gr.Gallery(label="DCGAN", columns=5) |
| g2 = gr.Gallery(label="WGAN-GP", columns=5) |
|
|
| btn.click(generate_compare, [n, s, z], [g1, g2]) |
|
|
| demo.launch() |
|
|