q1 / app.py
SotaSF's picture
Upload folder using huggingface_hub
84edefd verified
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
from huggingface_hub import hf_hub_download
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def denorm(x):
return (x * 0.5 + 0.5).clamp(0, 1)
class DCGenerator(nn.Module):
def __init__(self, z_dim=100, channels=3, f=64):
super().__init__()
self.net = nn.Sequential(
# input is Z
nn.ConvTranspose2d(z_dim, f * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(f * 8),
nn.ReLU(True),
# (f*8) x 4 x 4
nn.ConvTranspose2d(f * 8, f * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(f * 4),
nn.ReLU(True),
# (f*4) x 8 x 8
nn.ConvTranspose2d(f * 4, f * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(f * 2),
nn.ReLU(True),
# (f*2) x 16 x 16
nn.ConvTranspose2d(f * 2, f, 4, 2, 1, bias=False),
nn.BatchNorm2d(f),
nn.ReLU(True),
# (f) x 32 x 32
nn.ConvTranspose2d(f, channels, 4, 2, 1, bias=False),
nn.Tanh()
# (channels) x 64 x 64
)
def forward(self, z):
# print("z shape in generator:", z.shape)
return self.net(z)
def _strip(sd):
return {k.replace("module.", ""): v for k, v in sd.items()}
# Download weights from model repo at startup
dc_path = hf_hub_download(repo_id="SotaSF/q1-models", filename="dcgan_final.pt")
wg_path = hf_hub_download(repo_id="SotaSF/q1-models", filename="wgangp_final.pt")
@torch.no_grad()
def generate_compare(n_samples, seed, z_dim):
g1 = DCGenerator(int(z_dim)).to(device)
g2 = DCGenerator(int(z_dim)).to(device)
s1 = torch.load(dc_path, map_location=device)
s2 = torch.load(wg_path, map_location=device)
g1.load_state_dict(_strip(s1["G"] if "G" in s1 else s1))
g2.load_state_dict(_strip(s2["G"] if "G" in s2 else s2))
g1.eval()
g2.eval()
gen = torch.Generator(device=device).manual_seed(int(seed))
z = torch.randn(int(n_samples), int(z_dim), 1, 1, device=device, generator=gen)
a = denorm(g1(z)).cpu()
b = denorm(g2(z)).cpu()
out_a, out_b = [], []
for i in range(int(n_samples)):
out_a.append((a[i].permute(1,2,0).numpy()*255).astype(np.uint8))
out_b.append((b[i].permute(1,2,0).numpy()*255).astype(np.uint8))
return out_a, out_b
with gr.Blocks(title="Q1 DCGAN vs WGAN-GP") as demo:
gr.Markdown("# Q1: DCGAN vs WGAN-GP")
n = gr.Slider(5, 20, value=10, step=1, label="Samples")
s = gr.Number(value=42, label="Seed", precision=0)
z = gr.Number(value=100, label="z dim", precision=0)
btn = gr.Button("Generate")
g1 = gr.Gallery(label="DCGAN", columns=5)
g2 = gr.Gallery(label="WGAN-GP", columns=5)
btn.click(generate_compare, [n, s, z], [g1, g2])
demo.launch()