progan / app.py
Fliw
chore(inference) : add gradio model
78598be
import torch, os, gradio as gr, numpy as np
from torchvision import utils, transforms
from progan_modules import Generator
CHECKPOINT_DIR = "./model"
Z_DIM, CHANNEL_SIZE = 128, 128
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
FIXED_STEP = 6
FIXED_ALPHA = 0.0
g_running = Generator(CHANNEL_SIZE, Z_DIM, pixel_norm=False, tanh=False).to(DEVICE)
g_running.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "g.model"), map_location=DEVICE))
g_running.eval()
to_pil = transforms.ToPILImage()
@torch.inference_mode()
def sample_images(n_images: int = 50, seed: int | None = None):
if seed is not None and seed >= 0:
torch.manual_seed(seed); np.random.seed(seed)
else:
torch.seed()
z = torch.randn(n_images, Z_DIM, device=DEVICE)
imgs = g_running(z, step=FIXED_STEP, alpha=FIXED_ALPHA).cpu()
grid = utils.make_grid(imgs, nrow=10, normalize=True, value_range=(-1, 1))
return to_pil(grid)
demo = gr.Interface(
fn=sample_images,
inputs=[
gr.Slider(1, 200, value=50, step=10, label="Jumlah Gambar (kelipatan 10)"),
gr.Number(value=-1, precision=0, label="Seed (‑1 = acak)"),
],
outputs=gr.Image(type="pil", label="Grid Hasil"),
title="Progressive Growing Generative Adversarial Network",
description="contoh implementasi PGGAN untuk dataset jerawat",
allow_flagging="never",
)
if __name__ == "__main__":
demo.queue()
demo.launch(show_api=False, share=True)