| import os |
| import torch |
| import torch.nn as nn |
| from torchvision.utils import make_grid |
| import gradio as gr |
|
|
| |
| |
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| NOISE_DIM = 100 |
| IMAGE_SIZE = 64 |
| DEFAULT_NUM_SAMPLES = 16 |
| GENERATOR_PATH = "generator_final.pth" |
| LOGO_PATH = "kyo.png" |
|
|
| |
| |
| |
| class Generator(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(NOISE_DIM, 256), |
| nn.ReLU(True), |
| nn.Linear(256, 512), |
| nn.ReLU(True), |
| nn.Linear(512, 1024), |
| nn.ReLU(True), |
| nn.Linear(1024, 3 * IMAGE_SIZE * IMAGE_SIZE), |
| nn.Tanh() |
| ) |
|
|
| def forward(self, x): |
| return self.net(x).view(-1, 3, IMAGE_SIZE, IMAGE_SIZE) |
|
|
| |
| |
| |
| def load_generator(): |
| if not os.path.exists(GENERATOR_PATH): |
| raise FileNotFoundError( |
| f"Checkpoint not found: {GENERATOR_PATH}. Make sure it exists in your Hugging Face Space." |
| ) |
| model = Generator().to(DEVICE) |
| state_dict = torch.load(GENERATOR_PATH, map_location=DEVICE) |
| model.load_state_dict(state_dict) |
| model.eval() |
| return model |
|
|
| G = load_generator() |
|
|
| |
| |
| |
| def generate_graffiti(num_samples, seed): |
| num_samples = int(num_samples) |
| seed = int(seed) |
|
|
| if num_samples < 1 or num_samples > 32: |
| raise gr.Error("Number of samples must be between 1 and 32.") |
|
|
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| noise = torch.randn(num_samples, NOISE_DIM, device=DEVICE) |
|
|
| with torch.no_grad(): |
| fake_images = G(noise).cpu() |
|
|
| fake_images = fake_images * 0.5 + 0.5 |
| fake_images = fake_images.clamp(0, 1) |
|
|
| nrow = min(4, num_samples) |
| grid = make_grid(fake_images, nrow=nrow) |
| grid_img = grid.permute(1, 2, 0).numpy() |
|
|
| example_img = fake_images[0].permute(1, 2, 0).numpy() |
|
|
| return grid_img, example_img |
|
|
| |
| |
| |
| with gr.Blocks(title="KYO") as demo: |
| if os.path.exists(LOGO_PATH): |
| gr.Image(value=LOGO_PATH, show_label=False, width=180) |
|
|
| gr.Markdown("# KYO") |
| gr.Markdown("Generate graffiti-style images using PyTorch GAN.") |
|
|
| with gr.Row(): |
| num_samples = gr.Slider( |
| minimum=1, |
| maximum=32, |
| step=1, |
| value=DEFAULT_NUM_SAMPLES, |
| label="Number of images" |
| ) |
| seed = gr.Number(value=42, precision=0, label="Random seed") |
|
|
| generate_btn = gr.Button("Generate") |
|
|
| with gr.Row(): |
| grid_output = gr.Image(label="Generated Grid") |
| sample_output = gr.Image(label="First Sample") |
|
|
| generate_btn.click( |
| fn=generate_graffiti, |
| inputs=[num_samples, seed], |
| outputs=[grid_output, sample_output] |
| ) |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| demo.launch() |
|
|