Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
| MAX_IMAGES = 1 | |
| def generate_images( | |
| type1: str, | |
| type2: str, | |
| hp_num: int, | |
| attack_num: int, | |
| defense_num: int, | |
| sp_attack_num: int, | |
| sp_defense_num: int, | |
| speed_num: int, | |
| ) -> list: | |
| """Generates a sprite based on the input stats. | |
| Parameters | |
| ---------- | |
| Returns | |
| ------- | |
| list | |
| List of PIL images. | |
| """ | |
| # Initalize the images list | |
| images_list = [] | |
| # Calculate the base total | |
| base_total = ( | |
| hp_num + attack_num + defense_num + sp_attack_num + sp_defense_num + speed_num | |
| ) | |
| # Create the text prompt | |
| prompt = f"type1: {type1}, type2: {type2}, base_total: {base_total}, hp: {hp_num}, attack: {attack_num}, defense: {defense_num}, sp_attack: {sp_attack_num}, sp_defense: {sp_defense_num}, speed: {speed_num}" | |
| # Generate the images | |
| for _ in range(MAX_IMAGES): | |
| image = pipe( | |
| prompt, | |
| height=288, | |
| width=288, | |
| num_inference_steps=10, | |
| guidance_scale=7.5, | |
| cross_attention_kwargs={"scale": 1.0}, | |
| ).images[0] | |
| images_list.append(Image.fromarray(np.array(image))) | |
| return images_list | |
| # Create the demo interface | |
| demo = gr.Blocks() | |
| # Set the models to load | |
| model_base = "stabilityai/stable-diffusion-2-base" | |
| lora_model_path = "michaelriedl/MonsterForgeFusion-sd-2-base" | |
| # Create the pipeline | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_base, torch_dtype=torch.float32, use_safetensors=False, local_files_only=False | |
| ) | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| pipe.unet.load_attn_procs(lora_model_path) | |
| # Create the interface | |
| with demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; margin: 0 auto;"> | |
| <p style="margin-bottom: 14px; line-height: 23px;"> | |
| Gradio demo for MonsterForgeFusion models. This was built with LoRA fine-tuning of Stable Diffusion models. | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| gallery = gr.Gallery( | |
| columns=MAX_IMAGES, preview=True, object_fit="scale-down" | |
| ) | |
| with gr.Row(): | |
| type1 = gr.Dropdown( | |
| [ | |
| "bug", | |
| "dark", | |
| "dragon", | |
| "electric", | |
| "fairy", | |
| "fighting", | |
| "fire", | |
| "flying", | |
| "ghost", | |
| "grass", | |
| "ground", | |
| "ice", | |
| "normal", | |
| "poison", | |
| "psychic", | |
| "rock", | |
| "steel", | |
| "water", | |
| ], | |
| value="steel", | |
| label="Type 1", | |
| ) | |
| type2 = gr.Dropdown( | |
| [ | |
| "bug", | |
| "dark", | |
| "dragon", | |
| "electric", | |
| "fairy", | |
| "fighting", | |
| "fire", | |
| "flying", | |
| "ghost", | |
| "grass", | |
| "ground", | |
| "ice", | |
| "normal", | |
| "poison", | |
| "psychic", | |
| "rock", | |
| "steel", | |
| "water", | |
| ], | |
| value="fire", | |
| label="Type 2", | |
| ) | |
| with gr.Row(): | |
| hp_num = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="HP", | |
| ) | |
| attack_num = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Attack", | |
| ) | |
| with gr.Row(): | |
| defense_num = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Defense", | |
| ) | |
| sp_attack_num = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Special Attack", | |
| ) | |
| with gr.Row(): | |
| sp_defense_num = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Special Defense", | |
| ) | |
| speed_num = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Speed", | |
| ) | |
| gen_btn = gr.Button("Generate") | |
| gen_btn.click( | |
| fn=generate_images, | |
| inputs=[ | |
| type1, | |
| type2, | |
| hp_num, | |
| attack_num, | |
| defense_num, | |
| sp_attack_num, | |
| sp_defense_num, | |
| speed_num, | |
| ], | |
| outputs=gallery, | |
| ) | |
| gr.HTML( | |
| """ | |
| <div class="footer"> | |
| <div style='text-align: center;'>MonsterForgeFusion by <a href='https://michaelriedl.com/' target='_blank'>Michael Riedl</a></div> | |
| </div> | |
| """ | |
| ) | |
| # Launch the interface | |
| demo.launch() | |