|
|
import torch, os, gradio as gr, numpy as np |
|
|
from torchvision import utils, transforms |
|
|
from progan_modules import Generator |
|
|
|
|
|
CHECKPOINT_DIR = "./model" |
|
|
Z_DIM, CHANNEL_SIZE = 128, 128 |
|
|
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
FIXED_STEP = 6 |
|
|
FIXED_ALPHA = 0.0 |
|
|
|
|
|
g_running = Generator(CHANNEL_SIZE, Z_DIM, pixel_norm=False, tanh=False).to(DEVICE) |
|
|
g_running.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "g.model"), map_location=DEVICE)) |
|
|
g_running.eval() |
|
|
|
|
|
to_pil = transforms.ToPILImage() |
|
|
|
|
|
@torch.inference_mode() |
|
|
def sample_images(n_images: int = 50, seed: int | None = None): |
|
|
if seed is not None and seed >= 0: |
|
|
torch.manual_seed(seed); np.random.seed(seed) |
|
|
else: |
|
|
torch.seed() |
|
|
|
|
|
z = torch.randn(n_images, Z_DIM, device=DEVICE) |
|
|
imgs = g_running(z, step=FIXED_STEP, alpha=FIXED_ALPHA).cpu() |
|
|
|
|
|
grid = utils.make_grid(imgs, nrow=10, normalize=True, value_range=(-1, 1)) |
|
|
return to_pil(grid) |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=sample_images, |
|
|
inputs=[ |
|
|
gr.Slider(1, 200, value=50, step=10, label="Jumlah Gambar (kelipatan 10)"), |
|
|
gr.Number(value=-1, precision=0, label="Seed (‑1 = acak)"), |
|
|
], |
|
|
outputs=gr.Image(type="pil", label="Grid Hasil"), |
|
|
title="Progressive Growing Generative Adversarial Network", |
|
|
description="contoh implementasi PGGAN untuk dataset jerawat", |
|
|
allow_flagging="never", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue() |
|
|
demo.launch(show_api=False, share=True) |
|
|
|