import os import torch import torch.nn as nn from torchvision.utils import make_grid import gradio as gr # ------------------------ # CONFIG # ------------------------ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" NOISE_DIM = 100 IMAGE_SIZE = 64 # Must match your GAN training DEFAULT_NUM_SAMPLES = 16 GENERATOR_PATH = "generator_final.pth" LOGO_PATH = "kyo.png" # ------------------------ # Generator Definition # ------------------------ class Generator(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(NOISE_DIM, 256), nn.ReLU(True), nn.Linear(256, 512), nn.ReLU(True), nn.Linear(512, 1024), nn.ReLU(True), nn.Linear(1024, 3 * IMAGE_SIZE * IMAGE_SIZE), nn.Tanh() ) def forward(self, x): return self.net(x).view(-1, 3, IMAGE_SIZE, IMAGE_SIZE) # ------------------------ # Load trained generator # ------------------------ def load_generator(): if not os.path.exists(GENERATOR_PATH): raise FileNotFoundError( f"Checkpoint not found: {GENERATOR_PATH}. Make sure it exists in your Hugging Face Space." ) model = Generator().to(DEVICE) state_dict = torch.load(GENERATOR_PATH, map_location=DEVICE) model.load_state_dict(state_dict) model.eval() return model G = load_generator() # ------------------------ # Generate graffiti images # ------------------------ def generate_graffiti(num_samples, seed): num_samples = int(num_samples) seed = int(seed) if num_samples < 1 or num_samples > 32: raise gr.Error("Number of samples must be between 1 and 32.") torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) noise = torch.randn(num_samples, NOISE_DIM, device=DEVICE) with torch.no_grad(): fake_images = G(noise).cpu() fake_images = fake_images * 0.5 + 0.5 fake_images = fake_images.clamp(0, 1) nrow = min(4, num_samples) grid = make_grid(fake_images, nrow=nrow) grid_img = grid.permute(1, 2, 0).numpy() example_img = fake_images[0].permute(1, 2, 0).numpy() return grid_img, example_img # ------------------------ # Gradio UI # ------------------------ with gr.Blocks(title="KYO") as demo: if os.path.exists(LOGO_PATH): gr.Image(value=LOGO_PATH, show_label=False, width=180) gr.Markdown("# KYO") gr.Markdown("Generate graffiti-style images using PyTorch GAN.") with gr.Row(): num_samples = gr.Slider( minimum=1, maximum=32, step=1, value=DEFAULT_NUM_SAMPLES, label="Number of images" ) seed = gr.Number(value=42, precision=0, label="Random seed") generate_btn = gr.Button("Generate") with gr.Row(): grid_output = gr.Image(label="Generated Grid") sample_output = gr.Image(label="First Sample") generate_btn.click( fn=generate_graffiti, inputs=[num_samples, seed], outputs=[grid_output, sample_output] ) # ------------------------ # Launch app # ------------------------ if __name__ == "__main__": demo.launch()