Corelyn commited on
Commit
1d329e0
·
verified ·
1 Parent(s): fd2707d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision.utils import make_grid
5
+ import gradio as gr
6
+
7
+ # ------------------------
8
+ # CONFIG
9
+ # ------------------------
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+ NOISE_DIM = 100
12
+ IMAGE_SIZE = 64 # Must match your GAN training
13
+ DEFAULT_NUM_SAMPLES = 16
14
+ GENERATOR_PATH = "generator_final.pth"
15
+ LOGO_PATH = "kyo.png"
16
+
17
+ # ------------------------
18
+ # Generator Definition
19
+ # ------------------------
20
+ class Generator(nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+ self.net = nn.Sequential(
24
+ nn.Linear(NOISE_DIM, 256),
25
+ nn.ReLU(True),
26
+ nn.Linear(256, 512),
27
+ nn.ReLU(True),
28
+ nn.Linear(512, 1024),
29
+ nn.ReLU(True),
30
+ nn.Linear(1024, 3 * IMAGE_SIZE * IMAGE_SIZE),
31
+ nn.Tanh()
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.net(x).view(-1, 3, IMAGE_SIZE, IMAGE_SIZE)
36
+
37
+ # ------------------------
38
+ # Load trained generator
39
+ # ------------------------
40
+ def load_generator():
41
+ if not os.path.exists(GENERATOR_PATH):
42
+ raise FileNotFoundError(
43
+ f"Checkpoint not found: {GENERATOR_PATH}. "
44
+ "Make sure it exists in your Hugging Face Space."
45
+ )
46
+
47
+ model = Generator().to(DEVICE)
48
+ state_dict = torch.load(GENERATOR_PATH, map_location=DEVICE)
49
+ model.load_state_dict(state_dict)
50
+ model.eval()
51
+ return model
52
+
53
+ G = load_generator()
54
+
55
+ # ------------------------
56
+ # Generate graffiti images
57
+ # ------------------------
58
+ def generate_graffiti(num_samples, seed):
59
+ num_samples = int(num_samples)
60
+ seed = int(seed)
61
+
62
+ if num_samples < 1 or num_samples > 32:
63
+ raise gr.Error("Number of samples must be between 1 and 32.")
64
+
65
+ torch.manual_seed(seed)
66
+ if torch.cuda.is_available():
67
+ torch.cuda.manual_seed_all(seed)
68
+
69
+ noise = torch.randn(num_samples, NOISE_DIM, device=DEVICE)
70
+
71
+ with torch.no_grad():
72
+ fake_images = G(noise).cpu()
73
+
74
+ fake_images = fake_images * 0.5 + 0.5
75
+ fake_images = fake_images.clamp(0, 1)
76
+
77
+ nrow = min(4, num_samples)
78
+ grid = make_grid(fake_images, nrow=nrow)
79
+ grid_img = grid.permute(1, 2, 0).numpy()
80
+
81
+ example_img = fake_images[0].permute(1, 2, 0).numpy()
82
+
83
+ return grid_img, example_img
84
+
85
+ # ------------------------
86
+ # Gradio UI
87
+ # ------------------------
88
+ with gr.Blocks(title="KYO") as demo:
89
+ if os.path.exists(LOGO_PATH):
90
+ gr.Image(LOGO_PATH, show_label=False, show_download_button=False, width=180)
91
+
92
+ gr.Markdown("# KYO")
93
+ gr.Markdown("Generate graffiti-style images using your trained PyTorch GAN.")
94
+
95
+ with gr.Row():
96
+ num_samples = gr.Slider(
97
+ minimum=1,
98
+ maximum=32,
99
+ step=1,
100
+ value=DEFAULT_NUM_SAMPLES,
101
+ label="Number of images"
102
+ )
103
+ seed = gr.Number(value=42, precision=0, label="Random seed")
104
+
105
+ generate_btn = gr.Button("Generate")
106
+
107
+ with gr.Row():
108
+ grid_output = gr.Image(label="Generated Grid")
109
+ sample_output = gr.Image(label="First Sample")
110
+
111
+ generate_btn.click(
112
+ fn=generate_graffiti,
113
+ inputs=[num_samples, seed],
114
+ outputs=[grid_output, sample_output]
115
+ )
116
+
117
+ if __name__ == "__main__":
118
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))