Spaces:
Runtime error
Runtime error
| import jax | |
| import jax.numpy as jnp | |
| from flax import jax_utils | |
| from flax.training.common_utils import shard | |
| from PIL import Image | |
| from argparse import Namespace | |
| import gradio as gr | |
| from diffusers import ( | |
| FlaxControlNetModel, | |
| FlaxStableDiffusionControlNetPipeline, | |
| ) | |
| args = Namespace( | |
| pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", | |
| revision="non-ema", | |
| from_pt=True, | |
| controlnet_model_name_or_path="Vincent-luo/controlnet-hands", | |
| controlnet_revision=None, | |
| controlnet_from_pt=False, | |
| ) | |
| weight_dtype = jnp.float32 | |
| controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( | |
| args.controlnet_model_name_or_path, | |
| revision=args.controlnet_revision, | |
| from_pt=args.controlnet_from_pt, | |
| dtype=jnp.float32, | |
| ) | |
| pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| # tokenizer=tokenizer, | |
| controlnet=controlnet, | |
| safety_checker=None, | |
| dtype=weight_dtype, | |
| revision=args.revision, | |
| from_pt=args.from_pt, | |
| ) | |
| pipeline_params["controlnet"] = controlnet_params | |
| pipeline_params = jax_utils.replicate(pipeline_params) | |
| rng = jax.random.PRNGKey(0) | |
| num_samples = jax.device_count() | |
| prng_seed = jax.random.split(rng, jax.device_count()) | |
| def infer(prompt, negative_prompt, image): | |
| prompts = num_samples * [prompt] | |
| prompt_ids = pipeline.prepare_text_inputs(prompts) | |
| prompt_ids = shard(prompt_ids) | |
| validation_image = Image.fromarray(image).convert("RGB") | |
| processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) | |
| processed_image = shard(processed_image) | |
| negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples) | |
| negative_prompt_ids = shard(negative_prompt_ids) | |
| images = pipeline( | |
| prompt_ids=prompt_ids, | |
| image=processed_image, | |
| params=pipeline_params, | |
| prng_seed=prng_seed, | |
| num_inference_steps=50, | |
| neg_prompt_ids=negative_prompt_ids, | |
| jit=True, | |
| ).images | |
| images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) | |
| return images[0] | |
| with gr.Blocks(theme='gradio/soft') as demo: | |
| gr.Markdown("## Stable Diffusion with Hand Control") | |
| gr.Markdown("In this app, you can find different ControlNets with different filters. ") | |
| with gr.Column(): | |
| prompt_input = gr.Textbox(label="Prompt") | |
| negative_prompt = gr.Textbox(label="Negative Prompt") | |
| input_image = gr.Image(label="Input Image") | |
| output_image = gr.Image(label="Output Image") | |
| submit_btn = gr.Button(value = "Submit") | |
| inputs = [prompt_input, negative_prompt, input_image] | |
| submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) | |
| demo.launch() |