File size: 3,244 Bytes
1d329e0 2af8e30 1d329e0 52d8991 1d329e0 3b0be4a 1d329e0 a1ec0d0 3b0be4a a1ec0d0 3b0be4a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | import os
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import gradio as gr
# ------------------------
# CONFIG
# ------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NOISE_DIM = 100
IMAGE_SIZE = 64 # Must match your GAN training
DEFAULT_NUM_SAMPLES = 16
GENERATOR_PATH = "generator_final.pth"
LOGO_PATH = "kyo.png"
# ------------------------
# Generator Definition
# ------------------------
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(NOISE_DIM, 256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 3 * IMAGE_SIZE * IMAGE_SIZE),
nn.Tanh()
)
def forward(self, x):
return self.net(x).view(-1, 3, IMAGE_SIZE, IMAGE_SIZE)
# ------------------------
# Load trained generator
# ------------------------
def load_generator():
if not os.path.exists(GENERATOR_PATH):
raise FileNotFoundError(
f"Checkpoint not found: {GENERATOR_PATH}. Make sure it exists in your Hugging Face Space."
)
model = Generator().to(DEVICE)
state_dict = torch.load(GENERATOR_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()
return model
G = load_generator()
# ------------------------
# Generate graffiti images
# ------------------------
def generate_graffiti(num_samples, seed):
num_samples = int(num_samples)
seed = int(seed)
if num_samples < 1 or num_samples > 32:
raise gr.Error("Number of samples must be between 1 and 32.")
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
noise = torch.randn(num_samples, NOISE_DIM, device=DEVICE)
with torch.no_grad():
fake_images = G(noise).cpu()
fake_images = fake_images * 0.5 + 0.5
fake_images = fake_images.clamp(0, 1)
nrow = min(4, num_samples)
grid = make_grid(fake_images, nrow=nrow)
grid_img = grid.permute(1, 2, 0).numpy()
example_img = fake_images[0].permute(1, 2, 0).numpy()
return grid_img, example_img
# ------------------------
# Gradio UI
# ------------------------
with gr.Blocks(title="KYO") as demo:
if os.path.exists(LOGO_PATH):
gr.Image(value=LOGO_PATH, show_label=False, width=180)
gr.Markdown("# KYO")
gr.Markdown("Generate graffiti-style images using PyTorch GAN.")
with gr.Row():
num_samples = gr.Slider(
minimum=1,
maximum=32,
step=1,
value=DEFAULT_NUM_SAMPLES,
label="Number of images"
)
seed = gr.Number(value=42, precision=0, label="Random seed")
generate_btn = gr.Button("Generate")
with gr.Row():
grid_output = gr.Image(label="Generated Grid")
sample_output = gr.Image(label="First Sample")
generate_btn.click(
fn=generate_graffiti,
inputs=[num_samples, seed],
outputs=[grid_output, sample_output]
)
# ------------------------
# Launch app
# ------------------------
if __name__ == "__main__":
demo.launch()
|