Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModel | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| MAX_IMAGES = 4 | |
| def generate_small(color_indexed: bool, color_num: int) -> list: | |
| """Generates a small sprite. | |
| Parameters | |
| ---------- | |
| color_indexed : bool | |
| Whether to use color indexing. | |
| color_num : int | |
| Number of colors in the palette. | |
| Returns | |
| ------- | |
| list | |
| List of PIL images. | |
| """ | |
| # Get the latent dimension | |
| latent_dim = model_small.model.latent_dim | |
| # Initialize the list of images | |
| images_list = [] | |
| # Generate MAX_IMAGES images | |
| for _ in range(MAX_IMAGES): | |
| # Generate a random latent vector | |
| latents = torch.randn((1, latent_dim)) | |
| # Generate the image | |
| with torch.no_grad(): | |
| generated_image = model_small(latents) | |
| # Clamp the image to [0, 1] | |
| generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy() | |
| # Convert the generated image to PIL image | |
| color_image = Image.fromarray( | |
| np.uint8(generated_image[0] * 255).transpose(1, 2, 0), "RGB" | |
| ) | |
| # Convert to color indexed image if needed | |
| if color_indexed: | |
| # Convert using adaptive palette of given color depth | |
| color_image_indexed = color_image.convert( | |
| "P", palette=Image.ADAPTIVE, colors=color_num | |
| ) | |
| # Add the color indexed image to the list | |
| images_list.append(color_image_indexed) | |
| # Add the image to the list | |
| images_list.append(color_image) | |
| return images_list | |
| def generate_med(color_indexed: bool, color_num: int) -> list: | |
| """Generates a medium sprite. | |
| Parameters | |
| ---------- | |
| color_indexed : bool | |
| Whether to use color indexing. | |
| color_num : int | |
| Number of colors in the palette. | |
| Returns | |
| ------- | |
| list | |
| List of PIL images. | |
| """ | |
| # Get the latent dimension | |
| latent_dim = model_med.model.latent_dim | |
| # Initialize the list of images | |
| images_list = [] | |
| # Generate MAX_IMAGES images | |
| for _ in range(MAX_IMAGES): | |
| # Generate a random latent vector | |
| latents = torch.randn((1, latent_dim)) | |
| # Generate the image | |
| with torch.no_grad(): | |
| generated_image = model_med(latents) | |
| # Clamp the image to [0, 1] | |
| generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy() | |
| # Convert the generated image to PIL image | |
| color_image = Image.fromarray( | |
| np.uint8(generated_image[0] * 255).transpose(1, 2, 0), "RGBA" | |
| ) | |
| # Convert to color indexed image if needed | |
| if color_indexed: | |
| # Convert using adaptive palette of given color depth | |
| color_image_indexed = color_image.convert( | |
| "P", palette=Image.ADAPTIVE, colors=color_num | |
| ) | |
| # Add the color indexed image to the list | |
| images_list.append(color_image_indexed) | |
| # Add the image to the list | |
| images_list.append(color_image) | |
| return images_list | |
| # Create the demo interface | |
| demo = gr.Blocks() | |
| # Create the small model | |
| model_small = AutoModel.from_pretrained( | |
| "michaelriedl/MonsterForge-small", trust_remote_code=True | |
| ) | |
| model_small.eval() | |
| # Create the medium model | |
| model_med = AutoModel.from_pretrained( | |
| "michaelriedl/MonsterForge-medium", trust_remote_code=True | |
| ) | |
| model_med.eval() | |
| # Create the interface | |
| with demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; margin: 0 auto;"> | |
| <p style="margin-bottom: 14px; line-height: 23px;"> | |
| Gradio demo for MonsterForge models. This was built with Lightweight GAN using the implementation from <a href='https://github.com/lucidrains/lightweight-gan' target='_blank'>lucidrains</a>. | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("Small Sprite"): | |
| with gr.Column(): | |
| with gr.Row(): | |
| gallery_small = gr.Gallery( | |
| columns=4, | |
| object_fit="scale-down", | |
| ) | |
| with gr.Row(): | |
| color_index_small = gr.Checkbox(label="Color indexed", value=False) | |
| color_num_small = gr.Slider( | |
| minimum=8, | |
| maximum=32, | |
| value=32, | |
| step=4, | |
| label="Number of colors in the palette", | |
| ) | |
| gen_btn_small = gr.Button("Generate") | |
| gen_btn_small.click( | |
| fn=generate_small, | |
| inputs=[color_index_small, color_num_small], | |
| outputs=gallery_small, | |
| ) | |
| with gr.TabItem("Medium Sprite"): | |
| with gr.Column(): | |
| with gr.Row(): | |
| gallery_med = gr.Gallery( | |
| columns=4, | |
| object_fit="scale-down", | |
| ) | |
| with gr.Row(): | |
| color_index_med = gr.Checkbox(label="Color indexed", value=False) | |
| color_num_med = gr.Slider( | |
| minimum=8, | |
| maximum=32, | |
| value=32, | |
| step=4, | |
| label="Number of colors in the palette", | |
| ) | |
| gen_btn_med = gr.Button("Generate") | |
| gen_btn_med.click( | |
| fn=generate_med, | |
| inputs=[color_index_med, color_num_med], | |
| outputs=gallery_med, | |
| ) | |
| gr.HTML( | |
| """ | |
| <div class="footer"> | |
| <div style='text-align: center;'>MonsterForge by <a href='https://michaelriedl.com/' target='_blank'>Michael Riedl</a></div> | |
| </div> | |
| """ | |
| ) | |
| # Launch the interface | |
| demo.launch() | |