| import gradio as gr |
| import sys |
| import torch |
|
|
| from PIL import Image |
| import numpy as np |
| from io import BytesIO |
| import os |
|
|
| from diffusers.utils import load_image |
| from diffusers import ControlNetModel |
| import numpy as np |
| import torch |
| from diffusers.image_processor import VaeImageProcessor |
| from PIL import Image |
| from pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained( |
| "Salesforce/blipdiffusion-controlnet" |
| ) |
| controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint") |
|
|
| blip_diffusion_pipe.controlnet = controlnet |
| blip_diffusion_pipe.to(device) |
|
|
| def make_inpaint_condition(image, image_mask): |
| image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 |
| image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 |
| assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" |
| image[image_mask > 0.5] = -1 |
| image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) |
| image = torch.from_numpy(image) |
| return image |
|
|
| css=''' |
| .container {max-width: 1150px;margin: auto;padding-top: 1.5rem} |
| .image_upload{min-height:500px} |
| .image_upload [data-testid="image"], .image_upload [data-testid="image"] > div{min-height: 500px} |
| .image_upload [data-testid="target"], .image_upload [data-testid="target"] > div{min-height: 500px} |
| .image_upload .touch-none{display: flex} |
| #output_image{min-height:500px;max-height=500px;} |
| ''' |
|
|
|
|
| def create_demo(): |
| |
| HEIGHT, WIDTH=512,512 |
| with gr.Blocks(theme=gr.themes.Default(font=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace","monospace"], |
| primary_hue="lime", |
| secondary_hue="emerald", |
| neutral_hue="slate", |
| ), css=css) as demo: |
| gr.Markdown('# BLIP-Diffusion') |
| with gr.Accordion('Instructions', open=False): |
| gr.Markdown('1. Upload src image and draw mask') |
| gr.Markdown('2. Upload tgt image') |
| gr.Markdown('3. Input name of tgt object and description') |
| gr.Markdown('4. Click `Generate` when it is ready!') |
|
|
| with gr.Group(): |
| with gr.Box(): |
| with gr.Column(): |
| with gr.Row() as main_blocks: |
| |
| with gr.Column() as step_1: |
| gr.Markdown('### Source Input and Add Mask') |
| image = gr.Image(source='upload', |
| shape=[HEIGHT,WIDTH], |
| type='pil', |
| elem_classes="image_upload", |
| label='Source Image', |
| tool='sketch', |
| brush_radius=60).style(height=500) |
| src_input=image |
| text_prompt = gr.Textbox(label='Prompt') |
| run_button = gr.Button(label='Generate', value='Generate', variant="primary") |
| |
| with gr.Column() as step_2: |
| gr.Markdown('### Target Input') |
| target = gr.Image(source='upload', |
| shape=[HEIGHT,WIDTH], |
| type='pil', |
| elem_classes="image_upload", |
| label='Target Image' |
| ).style(height=500) |
| tgt_input=target |
| style_subject = gr.Textbox(label='Target Object') |
| |
| with gr.Row() as output_blocks: |
| with gr.Column() as output_step: |
| gr.Markdown('### Output') |
| output_image = gr.Gallery( |
| label="Generated images", |
| show_label=False, |
| elem_id="output_image", |
| ).style(height=500,containter=True) |
|
|
| with gr.Accordion('Advanced options', open=False): |
| num_inference_steps = gr.Slider(label='Steps', |
| minimum=1, |
| maximum=100, |
| value=50, |
| step=1) |
| guidance_scale = gr.Slider(label='Text Guidance Scale', |
| minimum=0.1, |
| maximum=30.0, |
| value=7.5, |
| step=0.1) |
| seed = gr.Slider(label='Seed', |
| minimum=-1, |
| maximum=2147483647, |
| step=1, |
| randomize=True) |
| |
| |
| inputs = [ |
| src_input, |
| tgt_input, |
| text_prompt, |
| style_subject, |
| num_inference_steps, |
| guidance_scale, |
| seed, |
| ] |
|
|
| def generate(src_input, |
| tgt_input, |
| text_prompt, |
| style_subject, |
| num_inference_steps, |
| guidance_scale, |
| seed, |
| ): |
| if src_input is None or tgt_input is None: |
| gr.Error("You must upload an image first.") |
| return {output_image : None,} |
| |
| tgt_subject = style_subject |
| generator = torch.Generator(device="cpu").manual_seed(seed) |
| init_image = src_input['image'] |
| cldm_cond_image = src_input['mask'] |
| control_image = make_inpaint_condition(init_image, cldm_cond_image) |
| style_image = tgt_input |
|
|
| negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" |
|
|
| output = blip_diffusion_pipe( |
| text_prompt, |
| style_image, |
| control_image, |
| style_subject, |
| tgt_subject, |
| generator=generator, |
| image=init_image, |
| mask_image=cldm_cond_image, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_inference_steps, |
| neg_prompt=negative_prompt, |
| height=HEIGHT, |
| width=WIDTH, |
| ).images |
| return {output_image : output,} |
|
|
| run_button.click(fn=generate, inputs=inputs, outputs=[output_image]) |
| return demo |
|
|
| if __name__ == '__main__': |
| demo = create_demo() |
| demo.queue().launch() |
|
|
|
|
|
|