import numpy as np import torch import torch.nn as nn import gradio as gr from huggingface_hub import hf_hub_download device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def denorm(x): return (x * 0.5 + 0.5).clamp(0, 1) class DCGenerator(nn.Module): def __init__(self, z_dim=100, channels=3, f=64): super().__init__() self.net = nn.Sequential( # input is Z nn.ConvTranspose2d(z_dim, f * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(f * 8), nn.ReLU(True), # (f*8) x 4 x 4 nn.ConvTranspose2d(f * 8, f * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(f * 4), nn.ReLU(True), # (f*4) x 8 x 8 nn.ConvTranspose2d(f * 4, f * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(f * 2), nn.ReLU(True), # (f*2) x 16 x 16 nn.ConvTranspose2d(f * 2, f, 4, 2, 1, bias=False), nn.BatchNorm2d(f), nn.ReLU(True), # (f) x 32 x 32 nn.ConvTranspose2d(f, channels, 4, 2, 1, bias=False), nn.Tanh() # (channels) x 64 x 64 ) def forward(self, z): # print("z shape in generator:", z.shape) return self.net(z) def _strip(sd): return {k.replace("module.", ""): v for k, v in sd.items()} # Download weights from model repo at startup dc_path = hf_hub_download(repo_id="SotaSF/q1-models", filename="dcgan_final.pt") wg_path = hf_hub_download(repo_id="SotaSF/q1-models", filename="wgangp_final.pt") @torch.no_grad() def generate_compare(n_samples, seed, z_dim): g1 = DCGenerator(int(z_dim)).to(device) g2 = DCGenerator(int(z_dim)).to(device) s1 = torch.load(dc_path, map_location=device) s2 = torch.load(wg_path, map_location=device) g1.load_state_dict(_strip(s1["G"] if "G" in s1 else s1)) g2.load_state_dict(_strip(s2["G"] if "G" in s2 else s2)) g1.eval() g2.eval() gen = torch.Generator(device=device).manual_seed(int(seed)) z = torch.randn(int(n_samples), int(z_dim), 1, 1, device=device, generator=gen) a = denorm(g1(z)).cpu() b = denorm(g2(z)).cpu() out_a, out_b = [], [] for i in range(int(n_samples)): out_a.append((a[i].permute(1,2,0).numpy()*255).astype(np.uint8)) out_b.append((b[i].permute(1,2,0).numpy()*255).astype(np.uint8)) return out_a, out_b with gr.Blocks(title="Q1 DCGAN vs WGAN-GP") as demo: gr.Markdown("# Q1: DCGAN vs WGAN-GP") n = gr.Slider(5, 20, value=10, step=1, label="Samples") s = gr.Number(value=42, label="Seed", precision=0) z = gr.Number(value=100, label="z dim", precision=0) btn = gr.Button("Generate") g1 = gr.Gallery(label="DCGAN", columns=5) g2 = gr.Gallery(label="WGAN-GP", columns=5) btn.click(generate_compare, [n, s, z], [g1, g2]) demo.launch()