|
|
from diffusers import StableDiffusionPipeline, DiffusionPipeline |
|
|
import torch |
|
|
import gradio as gr |
|
|
|
|
|
import spaces |
|
|
|
|
|
css = """ |
|
|
#img-display-output { |
|
|
max-height: 60vh; |
|
|
} |
|
|
|
|
|
#img-display-output *{ |
|
|
max-height: 60vh; |
|
|
} |
|
|
""" |
|
|
|
|
|
DEVICE = 'cuda' |
|
|
model_id = "Onodofthenorth/SD_PixelArt_SpriteSheet_Generator" |
|
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) |
|
|
pipe.to("cuda") |
|
|
|
|
|
|
|
|
@spaces.GPU(enable_queue=True) |
|
|
def generate_sprite(prompt): |
|
|
|
|
|
|
|
|
images = pipe(prompt).images |
|
|
return images |
|
|
|
|
|
title = "# SD_PixelArt_SpriteSheet_Generator" |
|
|
description = """Pixel Art Sprite Sheet Generator with Stable Diffusion Checkpoint.""" |
|
|
|
|
|
with gr.Blocks(css=css) as API: |
|
|
gr.Markdown(title) |
|
|
gr.Markdown(description) |
|
|
|
|
|
with gr.Column(): |
|
|
inputs=gr.TextArea(label="Prompt", placeholder="Prompt") |
|
|
|
|
|
outputs=gr.Gallery(label="Ouput Images", columns=4, elem_id="img-display-output") |
|
|
generate_btn = gr.Button(value="Generate") |
|
|
generate_btn.click(generate_sprite, inputs=inputs, outputs=outputs, api_name="generate_mesh") |
|
|
|
|
|
API.launch() |