File size: 3,244 Bytes
1d329e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2af8e30
1d329e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52d8991
1d329e0
 
3b0be4a
1d329e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1ec0d0
3b0be4a
a1ec0d0
3b0be4a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import gradio as gr

# ------------------------
# CONFIG
# ------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NOISE_DIM = 100
IMAGE_SIZE = 64   # Must match your GAN training
DEFAULT_NUM_SAMPLES = 16
GENERATOR_PATH = "generator_final.pth"
LOGO_PATH = "kyo.png"

# ------------------------
# Generator Definition
# ------------------------
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(NOISE_DIM, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 3 * IMAGE_SIZE * IMAGE_SIZE),
            nn.Tanh()
        )

    def forward(self, x):
        return self.net(x).view(-1, 3, IMAGE_SIZE, IMAGE_SIZE)

# ------------------------
# Load trained generator
# ------------------------
def load_generator():
    if not os.path.exists(GENERATOR_PATH):
        raise FileNotFoundError(
            f"Checkpoint not found: {GENERATOR_PATH}. Make sure it exists in your Hugging Face Space."
        )
    model = Generator().to(DEVICE)
    state_dict = torch.load(GENERATOR_PATH, map_location=DEVICE)
    model.load_state_dict(state_dict)
    model.eval()
    return model

G = load_generator()

# ------------------------
# Generate graffiti images
# ------------------------
def generate_graffiti(num_samples, seed):
    num_samples = int(num_samples)
    seed = int(seed)

    if num_samples < 1 or num_samples > 32:
        raise gr.Error("Number of samples must be between 1 and 32.")

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    noise = torch.randn(num_samples, NOISE_DIM, device=DEVICE)

    with torch.no_grad():
        fake_images = G(noise).cpu()

    fake_images = fake_images * 0.5 + 0.5
    fake_images = fake_images.clamp(0, 1)

    nrow = min(4, num_samples)
    grid = make_grid(fake_images, nrow=nrow)
    grid_img = grid.permute(1, 2, 0).numpy()

    example_img = fake_images[0].permute(1, 2, 0).numpy()

    return grid_img, example_img

# ------------------------
# Gradio UI
# ------------------------
with gr.Blocks(title="KYO") as demo:
    if os.path.exists(LOGO_PATH):
        gr.Image(value=LOGO_PATH, show_label=False, width=180)

    gr.Markdown("# KYO")
    gr.Markdown("Generate graffiti-style images using PyTorch GAN.")

    with gr.Row():
        num_samples = gr.Slider(
            minimum=1,
            maximum=32,
            step=1,
            value=DEFAULT_NUM_SAMPLES,
            label="Number of images"
        )
        seed = gr.Number(value=42, precision=0, label="Random seed")

    generate_btn = gr.Button("Generate")

    with gr.Row():
        grid_output = gr.Image(label="Generated Grid")
        sample_output = gr.Image(label="First Sample")

    generate_btn.click(
        fn=generate_graffiti,
        inputs=[num_samples, seed],
        outputs=[grid_output, sample_output]
    )

# ------------------------
# Launch app
# ------------------------
if __name__ == "__main__":
    demo.launch()