File size: 1,482 Bytes
78598be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch, os, gradio as gr, numpy as np
from torchvision import utils, transforms
from progan_modules import Generator

CHECKPOINT_DIR = "./model"
Z_DIM, CHANNEL_SIZE = 128, 128
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
FIXED_STEP  = 6
FIXED_ALPHA = 0.0

g_running = Generator(CHANNEL_SIZE, Z_DIM, pixel_norm=False, tanh=False).to(DEVICE)
g_running.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, "g.model"), map_location=DEVICE))
g_running.eval()

to_pil = transforms.ToPILImage()

@torch.inference_mode()
def sample_images(n_images: int = 50, seed: int | None = None):
    if seed is not None and seed >= 0:
        torch.manual_seed(seed); np.random.seed(seed)
    else:
        torch.seed()

    z = torch.randn(n_images, Z_DIM, device=DEVICE)
    imgs = g_running(z, step=FIXED_STEP, alpha=FIXED_ALPHA).cpu()

    grid = utils.make_grid(imgs, nrow=10, normalize=True, value_range=(-1, 1))
    return to_pil(grid)

demo = gr.Interface(
    fn=sample_images,
    inputs=[
        gr.Slider(1, 200, value=50, step=10, label="Jumlah Gambar (kelipatan 10)"),
        gr.Number(value=-1, precision=0, label="Seed (‑1 = acak)"),
    ],
    outputs=gr.Image(type="pil", label="Grid Hasil"),
    title="Progressive Growing Generative Adversarial Network",
    description="contoh implementasi PGGAN untuk dataset jerawat",
    allow_flagging="never",
)

if __name__ == "__main__":
    demo.queue()
    demo.launch(show_api=False, share=True)