Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import jax | |
| import numpy as np | |
| import jax.numpy as jnp | |
| from flax.jax_utils import replicate | |
| from flax.training.common_utils import shard | |
| from PIL import Image | |
| from diffusers import FlaxStableDiffusionPipeline | |
| def create_key(seed=0): | |
| return jax.random.PRNGKey(seed) | |
| pipe, params = FlaxStableDiffusionPipeline.from_pretrained( | |
| "MuhammadHanif/stable-diffusion-v1-5-high-res", | |
| dtype=jnp.bfloat16, | |
| use_memory_efficient_attention=True | |
| ) | |
| def infer(prompts, negative_prompts, width=1088, height=1088, inference_steps=30, seed=0): | |
| num_samples = 1 #jax.device_count() | |
| rng = create_key(int(seed)) | |
| rng = jax.random.split(rng, jax.device_count()) | |
| prompt_ids = pipe.prepare_inputs([prompts] * num_samples) | |
| negative_prompt_ids = pipe.prepare_inputs([negative_prompts] * num_samples) | |
| p_params = replicate(params) | |
| prompt_ids = shard(prompt_ids) | |
| negative_prompt_ids = shard(negative_prompt_ids) | |
| output = pipe( | |
| prompt_ids=prompt_ids, | |
| params=p_params, | |
| height=height, | |
| width=width, | |
| prng_seed=rng, | |
| num_inference_steps=inference_steps, | |
| neg_prompt_ids=negative_prompt_ids, | |
| jit=True, | |
| ).images | |
| output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) | |
| return output_images[0] | |
| prompt_input = gr.inputs.Textbox( | |
| label="Prompt", | |
| placeholder="a highly detailed mansion in the autumn by studio ghibli, makoto shinkai" | |
| ) | |
| neg_prompt_input = gr.inputs.Textbox( | |
| label="Negative Prompt", | |
| placeholder="" | |
| ) | |
| width_slider = gr.inputs.Slider( | |
| minimum=512, maximum=2048, default=1088, step=64, label="width" | |
| ) | |
| height_slider = gr.inputs.Slider( | |
| minimum=512, maximum=2048, default=1088, step=64, label="height" | |
| ) | |
| inf_steps_input = gr.inputs.Slider( | |
| minimum=1, maximum=100, default=30, step=1, label="Inference Steps" | |
| ) | |
| seed_input = gr.inputs.Number(default=0, label="Seed") | |
| app = gr.Interface( | |
| fn=infer, | |
| inputs=[prompt_input, neg_prompt_input, width_slider, height_slider, inf_steps_input, seed_input], | |
| outputs="image", | |
| title="Stable Diffusion High Resolution", | |
| description=( | |
| "Based on stable diffusion 1.5 and fine-tuned on 576x576 up to 1088x1088 images, " | |
| "Stable Diffusion High Resolution is compartible with another SD1.5 model and mergeable with other SD1.5 model, " | |
| "giving other model to generate high resolution images without using upscaler." | |
| ), | |
| # examples=[["a highly detailed mansion in the autumn by studio ghibli, makoto shinkai","", 1088, 1088, 30, 0]], | |
| ) | |
| app.launch() |