| |
| |
| from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler |
| from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed |
| from flax.training.common_utils import shard |
| from flax.jax_utils import replicate |
| from diffusers.utils import load_image |
| import jax.numpy as jnp |
| import jax |
| import cv2 |
| from PIL import Image |
| import numpy as np |
| import gradio as gr |
|
|
| def create_key(seed=0): |
| return jax.random.PRNGKey(seed) |
|
|
| def load_controlnet(controlnet_version): |
| controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( |
| "Baptlem/baptlem-controlnet", |
| subfolder=controlnet_version, |
| from_flax=True, |
| dtype=jnp.float32, |
| ) |
| return controlnet, controlnet_params |
|
|
|
|
| def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"): |
| controlnet, controlnet_params = load_controlnet(controlnet_version) |
|
|
| scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained( |
| sb_path, |
| subfolder="scheduler" |
| ) |
| |
| pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( |
| sb_path, |
| controlnet=controlnet, |
| revision="flax", |
| dtype=jnp.bfloat16 |
| ) |
| |
| pipe.scheduler = scheduler |
| params["controlnet"] = controlnet_params |
| params["scheduler"] = scheduler_params |
| return pipe, params |
|
|
| |
|
|
| controlnet_path = "Baptlem/baptlem-controlnet" |
| controlnet_version = "coyo-500k" |
|
|
| |
| low_threshold = 100 |
| high_threshold = 200 |
|
|
|
|
|
|
| |
| |
| |
| print("Loaded models...") |
| def pipe_inference( |
| image, |
| prompt, |
| is_canny=False, |
| num_samples=4, |
| resolution=128, |
| num_inference_steps=50, |
| guidance_scale=7.5, |
| model="coyo-500k", |
| seed=0, |
| negative_prompt="", |
| ): |
| print("Loading pipe") |
| pipe, params = load_sb_pipe(model) |
| |
| if not isinstance(image, np.ndarray): |
| image = np.array(image) |
|
|
| processed_image = resize_image(image, resolution) |
| |
| if not is_canny: |
| resized_image, processed_image = preprocess_canny(processed_image, resolution) |
|
|
| rng = create_key(seed) |
| rng = jax.random.split(rng, jax.device_count()) |
|
|
| prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples) |
| negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples) |
| processed_image = pipe.prepare_image_inputs([processed_image] * num_samples) |
| |
| p_params = replicate(params) |
| prompt_ids = shard(prompt_ids) |
| negative_prompt_ids = shard(negative_prompt_ids) |
| processed_image = shard(processed_image) |
| print("Inference...") |
| output = pipe( |
| prompt_ids=prompt_ids, |
| image=processed_image, |
| params=p_params, |
| prng_seed=rng, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| neg_prompt_ids=negative_prompt_ids, |
| jit=True, |
| ).images |
| print("Finished inference...") |
| |
| |
| |
| |
| |
| |
| |
|
|
| all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) |
| return all_outputs |
|
|
| def resize_image(image, resolution): |
| if not isinstance(image, np.ndarray): |
| image = np.array(image) |
| h, w = image.shape[:2] |
| ratio = w/h |
| if ratio > 1 : |
| resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST) |
| elif ratio < 1 : |
| resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST) |
| else: |
| resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST) |
| |
| return Image.fromarray(resized_image) |
| |
| |
| def preprocess_canny(image, resolution=128): |
| if not isinstance(image, np.ndarray): |
| image = np.array(image) |
| |
| processed_image = cv2.Canny(image, low_threshold, high_threshold) |
| processed_image = processed_image[:, :, None] |
| processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2) |
|
|
| resized_image = Image.fromarray(image) |
| processed_image = Image.fromarray(processed_image) |
| return resized_image, processed_image |
|
|
|
|
| def create_demo(process, max_images=12, default_num_images=4): |
| with gr.Blocks() as demo: |
| with gr.Row(): |
| gr.Markdown('## Control Stable Diffusion with Canny Edge Maps') |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image(source='upload', type='numpy') |
| prompt = gr.Textbox(label='Prompt') |
| run_button = gr.Button(label='Run') |
| with gr.Accordion('Advanced options', open=False): |
| is_canny = gr.Checkbox( |
| label='Is canny', value=False) |
| num_samples = gr.Slider(label='Images', |
| minimum=1, |
| maximum=max_images, |
| value=default_num_images, |
| step=1) |
| """ |
| canny_low_threshold = gr.Slider( |
| label='Canny low threshold', |
| minimum=1, |
| maximum=255, |
| value=100, |
| step=1) |
| canny_high_threshold = gr.Slider( |
| label='Canny high threshold', |
| minimum=1, |
| maximum=255, |
| value=200, |
| step=1) |
| """ |
| resolution = gr.Slider(label='Resolution', |
| minimum=128, |
| maximum=128, |
| value=128, |
| step=1) |
| num_steps = gr.Slider(label='Steps', |
| minimum=1, |
| maximum=100, |
| value=20, |
| step=1) |
| guidance_scale = gr.Slider(label='Guidance Scale', |
| minimum=0.1, |
| maximum=30.0, |
| value=7.5, |
| step=0.1) |
| model = gr.Dropdown(choices=["coyo-500k", "bridge-2M", "coyo2M-bridge3M"], |
| value="coyo-500k", |
| label="Model used for inference", |
| info="Find every models at https://huggingface.co/Baptlem/baptlem-controlnet") |
| seed = gr.Slider(label='Seed', |
| minimum=-1, |
| maximum=2147483647, |
| step=1, |
| randomize=True) |
| n_prompt = gr.Textbox( |
| label='Negative Prompt', |
| value= |
| 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' |
| ) |
| with gr.Column(): |
| result = gr.Gallery(label='Output', |
| show_label=False, |
| elem_id='gallery').style(grid=2, |
| height='auto') |
| inputs = [ |
| input_image, |
| prompt, |
| is_canny, |
| num_samples, |
| resolution, |
| |
| |
| num_steps, |
| guidance_scale, |
| model, |
| seed, |
| n_prompt, |
| ] |
| prompt.submit(fn=process, inputs=inputs, outputs=result) |
| run_button.click(fn=process, |
| inputs=inputs, |
| outputs=result, |
| api_name='canny') |
| |
| return demo |
|
|
| if __name__ == '__main__': |
|
|
| pipe_inference |
| demo = create_demo(pipe_inference) |
| demo.queue().launch() |
| |
| |