| | import gradio as gr |
| | import spaces |
| | import torch |
| |
|
| | from diffusers import AutoencoderKL, ControlNetUnionModel, DiffusionPipeline, TCDScheduler |
| |
|
| |
|
| | def callback_cfg_cutoff(pipeline, step_index, timestep, callback_kwargs): |
| | if step_index == int(pipeline.num_timesteps * 0.2): |
| | prompt_embeds = callback_kwargs["prompt_embeds"] |
| | prompt_embeds = prompt_embeds[-1:] |
| |
|
| | add_text_embeds = callback_kwargs["add_text_embeds"] |
| | add_text_embeds = add_text_embeds[-1:] |
| |
|
| | add_time_ids = callback_kwargs["add_time_ids"] |
| | add_time_ids = add_time_ids[-1:] |
| |
|
| | control_image = callback_kwargs["control_image"] |
| | control_image[0] = control_image[0][-1:] |
| |
|
| | control_type = callback_kwargs["control_type"] |
| | control_type = control_type[-1:] |
| |
|
| | pipeline._guidance_scale = 0.0 |
| | callback_kwargs["prompt_embeds"] = prompt_embeds |
| | callback_kwargs["add_text_embeds"] = add_text_embeds |
| | callback_kwargs["add_time_ids"] = add_time_ids |
| | callback_kwargs["control_image"] = control_image |
| | callback_kwargs["control_type"] = control_type |
| |
|
| | return callback_kwargs |
| |
|
| |
|
| | MODELS = { |
| | "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning", |
| | } |
| |
|
| | controlnet_model = ControlNetUnionModel.from_pretrained( |
| | "OzzyGT/controlnet-union-promax-sdxl-1.0", variant="fp16", torch_dtype=torch.float16 |
| | ) |
| | controlnet_model.to(device="cuda", dtype=torch.float16) |
| | vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda") |
| |
|
| | pipe = DiffusionPipeline.from_pretrained( |
| | "SG161222/RealVisXL_V5.0_Lightning", |
| | torch_dtype=torch.float16, |
| | vae=vae, |
| | controlnet=controlnet_model, |
| | custom_pipeline="OzzyGT/custom_sdxl_cnet_union", |
| | ).to("cuda") |
| |
|
| | pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config) |
| |
|
| |
|
| | @spaces.GPU(duration=24) |
| | def fill_image(prompt, negative_prompt, image, model_selection, paste_back): |
| | ( |
| | prompt_embeds, |
| | negative_prompt_embeds, |
| | pooled_prompt_embeds, |
| | negative_pooled_prompt_embeds, |
| | ) = pipe.encode_prompt(prompt, device="cuda", negative_prompt=negative_prompt) |
| |
|
| | source = image["background"] |
| | mask = image["layers"][0] |
| |
|
| | alpha_channel = mask.split()[3] |
| | binary_mask = alpha_channel.point(lambda p: p > 0 and 255) |
| | cnet_image = source.copy() |
| | cnet_image.paste(0, (0, 0), binary_mask) |
| |
|
| | image = pipe( |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | pooled_prompt_embeds=pooled_prompt_embeds, |
| | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
| | control_image=[cnet_image], |
| | controlnet_conditioning_scale=[1.0], |
| | control_mode=[7], |
| | num_inference_steps=8, |
| | guidance_scale=1.5, |
| | callback_on_step_end=callback_cfg_cutoff, |
| | callback_on_step_end_tensor_inputs=[ |
| | "prompt_embeds", |
| | "add_text_embeds", |
| | "add_time_ids", |
| | "control_image", |
| | "control_type", |
| | ], |
| | ).images[0] |
| |
|
| | if paste_back: |
| | image = image.convert("RGBA") |
| | cnet_image.paste(image, (0, 0), binary_mask) |
| | else: |
| | cnet_image = image |
| |
|
| | yield source, cnet_image |
| |
|
| |
|
| | def clear_result(): |
| | return gr.update(value=None) |
| |
|
| |
|
| | title = """<h2 align="center">Diffusers Fast Inpaint</h2> |
| | <div align="center">Draw the mask over the subject you want to erase or change and write what you want to inpaint it with.</div> |
| | """ |
| |
|
| | with gr.Blocks() as demo: |
| | gr.HTML(title) |
| | with gr.Row(): |
| | with gr.Column(): |
| | prompt = gr.Textbox( |
| | label="Prompt", |
| | lines=1, |
| | ) |
| | with gr.Column(): |
| | with gr.Row(): |
| | negative_prompt = gr.Textbox( |
| | label="Negative Prompt", |
| | lines=1, |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | run_button = gr.Button("Generate") |
| |
|
| | with gr.Column(): |
| | paste_back = gr.Checkbox(True, label="Paste back original") |
| |
|
| | with gr.Row(): |
| | input_image = gr.ImageMask( |
| | type="pil", |
| | label="Input Image", |
| | crop_size=(1024, 1024), |
| | canvas_size=(1024, 1024), |
| | layers=False, |
| | height=512, |
| | ) |
| |
|
| | result = gr.ImageSlider( |
| | interactive=False, |
| | label="Generated Image", |
| | ) |
| |
|
| | use_as_input_button = gr.Button("Use as Input Image", visible=False) |
| |
|
| | model_selection = gr.Dropdown(choices=list(MODELS.keys()), value="RealVisXL V5.0 Lightning", label="Model") |
| |
|
| | def use_output_as_input(output_image): |
| | return gr.update(value=output_image[1]) |
| |
|
| | use_as_input_button.click(fn=use_output_as_input, inputs=[result], outputs=[input_image]) |
| |
|
| | run_button.click( |
| | fn=clear_result, |
| | inputs=None, |
| | outputs=result, |
| | ).then( |
| | fn=lambda: gr.update(visible=False), |
| | inputs=None, |
| | outputs=use_as_input_button, |
| | ).then( |
| | fn=fill_image, |
| | inputs=[prompt, negative_prompt, input_image, model_selection, paste_back], |
| | outputs=result, |
| | ).then( |
| | fn=lambda: gr.update(visible=True), |
| | inputs=None, |
| | outputs=use_as_input_button, |
| | ) |
| |
|
| | prompt.submit( |
| | fn=clear_result, |
| | inputs=None, |
| | outputs=result, |
| | ).then( |
| | fn=lambda: gr.update(visible=False), |
| | inputs=None, |
| | outputs=use_as_input_button, |
| | ).then( |
| | fn=fill_image, |
| | inputs=[prompt, negative_prompt, input_image, model_selection, paste_back], |
| | outputs=result, |
| | ).then( |
| | fn=lambda: gr.update(visible=True), |
| | inputs=None, |
| | outputs=use_as_input_button, |
| | ) |
| |
|
| |
|
| | demo.queue(max_size=12).launch(share=False) |
| |
|