Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import torch | |
| import nltk | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| from diffusers import DDIMScheduler | |
| from diffusers.models.attention_processor import AttnProcessor2_0 | |
| from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline | |
| from injection_utils import register_attention_editor_diffusers | |
| from bounded_attention import BoundedAttention | |
| from pytorch_lightning import seed_everything | |
| REMOTE_MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0" | |
| LOCAL_MODEL_PATH = "./model" | |
| RESOLUTION = 256 | |
| MIN_SIZE = 0.01 | |
| WHITE = 255 | |
| COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"] | |
| PROMPT1 = "a ginger kitten and a gray puppy in a yard" | |
| SUBJECT_SUB_PROMPTS1 = "ginger kitten;gray puppy" | |
| SUBJECT_TOKEN_INDICES1 = "2,3;6,7" | |
| FILTER_TOKEN_INDICES1 = "1,4,5,8,9" | |
| NUM_TOKENS1 = "10" | |
| PROMPT2 = "3 D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest" | |
| PROMPT3 = "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship" | |
| PROMPT4 = "a realistic photo of a highway with a semi trailer and a concrete mixer and a helicopter" | |
| PROMPT5 = "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool" | |
| EXAMPLE_BOXES = { | |
| PROMPT1: [ | |
| [0.15, 0.2, 0.45, 0.9], | |
| [0.55, 0.25, 0.85, 0.95], | |
| ], | |
| PROMPT2 : [ | |
| [0.35, 0.4, 0.65, 0.9], | |
| [0, 0.6, 0.3, 0.9], | |
| [0.7, 0.55, 1, 0.85] | |
| ], | |
| PROMPT3: [ | |
| [0.4, 0.45, 0.6, 0.95], | |
| [0.2, 0.3, 0.4, 0.85], | |
| [0.6, 0.3, 0.8, 0.85], | |
| [0.1, 0, 0.9, 0.3] | |
| ], | |
| PROMPT4: [ | |
| [0.05, 0.5, 0.45, 0.85], | |
| [0.55, 0.6, 0.95, 0.85], | |
| [0.3, 0.2, 0.7, 0.45], | |
| ], | |
| PROMPT5: [ | |
| [0, 0.5, 0.2, 0.8], | |
| [0.2, 0.2, 0.4, 0.5], | |
| [0.4, 0.5, 0.6, 0.8], | |
| [0.6, 0.2, 0.8, 0.5], | |
| [0.8, 0.5, 1, 0.8] | |
| ], | |
| } | |
| CSS = """ | |
| #paper-info a { | |
| color:#008AD7; | |
| text-decoration: none; | |
| } | |
| #paper-info a:hover { | |
| cursor: pointer; | |
| text-decoration: none; | |
| } | |
| .tooltip { | |
| color: #555; | |
| position: relative; | |
| display: inline-block; | |
| cursor: pointer; | |
| } | |
| .tooltip .tooltiptext { | |
| visibility: hidden; | |
| width: 400px; | |
| background-color: #555; | |
| color: #fff; | |
| text-align: center; | |
| padding: 5px; | |
| border-radius: 5px; | |
| position: absolute; | |
| z-index: 1; /* Set z-index to 1 */ | |
| left: 10px; | |
| top: 100%; | |
| opacity: 0; | |
| transition: opacity 0.3s; | |
| } | |
| .tooltip:hover .tooltiptext { | |
| visibility: visible; | |
| opacity: 1; | |
| z-index: 9999; /* Set a high z-index value when hovering */ | |
| } | |
| """ | |
| DESCRIPTION = """ | |
| <p style="text-align: center; font-weight: bold;"> | |
| <span style="font-size: 28px">Bounded Attention</span> | |
| <br> | |
| <span style="font-size: 18px" id="paper-info"> | |
| [<a href="https://omer11a.github.io/bounded-attention/" target="_blank">Project Page</a>] | |
| [<a href="https://arxiv.org/abs/2403.16990" target="_blank">Paper</a>] | |
| [<a href="https://github.com/omer11a/bounded-attention" target="_blank">GitHub</a>] | |
| </span> | |
| </p> | |
| """ | |
| COPY_LINK = """ | |
| <a href="https://huggingface.co/spaces/omer11a/bounded-attention?duplicate=true"> | |
| <img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"> | |
| </a> | |
| Duplicate this space to generate more samples without waiting in queue. | |
| <br> | |
| To get better results, increase the number of guidance steps to 15. | |
| """ | |
| ADVANCED_OPTION_DESCRIPTION = """ | |
| <div class="tooltip" >Number of guidance steps ⓘ | |
| <span class="tooltiptext">The number of timesteps in which to perform guidance. Recommended value is 15, but increasing this will also increases the runtime.</span> | |
| </div> | |
| <div class="tooltip">Batch size ⓘ | |
| <span class="tooltiptext">The number of images to generate.</span> | |
| </div> | |
| <div class="tooltip">Initial step size ⓘ | |
| <span class="tooltiptext">The initial step size of the linear step size scheduler when performing guidance.</span> | |
| </div> | |
| <div class="tooltip">Final step size ⓘ | |
| <span class="tooltiptext">The final step size of the linear step size scheduler when performing guidance.</span> | |
| </div> | |
| <div class="tooltip">First refinement step ⓘ | |
| <span class="tooltiptext">The timestep from which subject mask refinement is performed.</span> | |
| </div> | |
| <div class="tooltip">Number of self-attention clusters per subject ⓘ | |
| <span class="tooltiptext">The number of clusters computed when clustering the self-attention maps (#clusters = #subject x #clusters_per_subject). Changing this value might improve semantics (adherence to the prompt), especially when the subjects exceed their bounding boxes.</span> | |
| </div> | |
| <div class="tooltip">Cross-attention loss scale factor ⓘ | |
| <span class="tooltiptext">The scale factor of the cross-attention loss term. Increasing it will improve semantic control (adherence to the prompt), but may reduce image quality.</span> | |
| </div> | |
| <div class="tooltip">Self-attention loss scale factor ⓘ | |
| <span class="tooltiptext">The scale factor of the self-attention loss term. Increasing it will improve layout control (adherence to the bounding boxes), but may reduce image quality.</span> | |
| </div> | |
| <div class="tooltip" >Number of Gradient Descent iterations per timestep ⓘ | |
| <span class="tooltiptext">The number of Gradient Descent iterations for each timestep when performing guidance.</span> | |
| </div> | |
| <div class="tooltip" >Loss Threshold ⓘ | |
| <span class="tooltiptext">If the loss is below the threshold, Gradient Descent stops for that timestep. </span> | |
| </div> | |
| <div class="tooltip">Classifier-free guidance scale ⓘ | |
| <span class="tooltiptext">The scale factor of classifier-free guidance.</span> | |
| </div> | |
| """ | |
| FOOTNOTE = """ | |
| <p>The source code of this demo is based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GLIGEN demo</a>.</p> | |
| """ | |
| def inference( | |
| boxes, | |
| prompts, | |
| subject_sub_prompts, | |
| subject_token_indices, | |
| filter_token_indices, | |
| num_tokens, | |
| init_step_size, | |
| final_step_size, | |
| first_refinement_step, | |
| num_clusters_per_subject, | |
| cross_loss_scale, | |
| self_loss_scale, | |
| classifier_free_guidance_scale, | |
| num_iterations, | |
| loss_threshold, | |
| num_guidance_steps, | |
| seed, | |
| ): | |
| if not torch.cuda.is_available(): | |
| raise gr.Error("cuda is not available") | |
| device = torch.device("cuda") | |
| scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) | |
| model = StableDiffusionXLPipeline.from_pretrained(LOCAL_MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16, device_map="auto") | |
| model.to(device) | |
| model.unet.set_attn_processor(AttnProcessor2_0()) | |
| model.enable_sequential_cpu_offload() | |
| seed_everything(seed) | |
| start_code = torch.randn([len(prompts), 4, 128, 128], device=device) | |
| eos_token_index = None if num_tokens is None else num_tokens + 1 | |
| editor = BoundedAttention( | |
| boxes, | |
| prompts, | |
| list(range(70, 82)), | |
| list(range(70, 82)), | |
| subject_sub_prompts=subject_sub_prompts, | |
| subject_token_indices=subject_token_indices, | |
| filter_token_indices=filter_token_indices, | |
| eos_token_index=eos_token_index, | |
| cross_loss_coef=cross_loss_scale, | |
| self_loss_coef=self_loss_scale, | |
| max_guidance_iter=num_guidance_steps, | |
| max_guidance_iter_per_step=num_iterations, | |
| start_step_size=init_step_size, | |
| end_step_size=final_step_size, | |
| loss_stopping_value=loss_threshold, | |
| min_clustering_step=first_refinement_step, | |
| num_clusters_per_box=num_clusters_per_subject, | |
| max_resolution=32, | |
| ) | |
| register_attention_editor_diffusers(model, editor) | |
| return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images | |
| def generate( | |
| prompt, | |
| subject_sub_prompts, | |
| subject_token_indices, | |
| filter_token_indices, | |
| num_tokens, | |
| init_step_size, | |
| final_step_size, | |
| first_refinement_step, | |
| num_clusters_per_subject, | |
| cross_loss_scale, | |
| self_loss_scale, | |
| classifier_free_guidance_scale, | |
| batch_size, | |
| num_iterations, | |
| loss_threshold, | |
| num_guidance_steps, | |
| seed, | |
| boxes, | |
| ): | |
| num_subjects = 0 | |
| subject_sub_prompts = convert_sub_prompts(subject_sub_prompts) | |
| subject_token_indices = convert_token_indices(subject_token_indices, nested=True) | |
| if subject_sub_prompts is not None: | |
| num_subjects = len(subject_sub_prompts) | |
| if subject_token_indices is not None: | |
| num_subjects = len(subject_token_indices) | |
| if len(boxes) != num_subjects: | |
| raise gr.Error(""" | |
| The number of boxes should be equal to the number of subjects. | |
| Number of boxes drawn: {}, number of subjects: {}. | |
| """.format(len(boxes), num_subjects)) | |
| filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None | |
| num_tokens = int(num_tokens) if len(num_tokens.strip()) > 0 else None | |
| prompts = [prompt.strip(".").strip(",").strip()] * batch_size | |
| images = inference( | |
| boxes, prompts, subject_sub_prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size, | |
| final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale, | |
| classifier_free_guidance_scale, num_iterations, loss_threshold, num_guidance_steps, seed) | |
| return images | |
| def convert_sub_prompts(sub_prompts): | |
| sub_prompts = sub_prompts.strip() | |
| if len(sub_prompts) == 0: | |
| return None | |
| return [sub_prompt.strip() for sub_prompt in sub_prompts.split(";")] | |
| def convert_token_indices(token_indices, nested=False): | |
| token_indices = token_indices.strip() | |
| if len(token_indices) == 0: | |
| return None | |
| if nested: | |
| return [convert_token_indices(indices, nested=False) for indices in token_indices.split(";")] | |
| return [int(index.strip()) for index in token_indices.split(",") if len(index.strip()) > 0] | |
| def draw(sketchpad): | |
| boxes = [] | |
| for i, layer in enumerate(sketchpad["layers"]): | |
| non_zeros = layer.nonzero() | |
| x1 = x2 = y1 = y2 = 0 | |
| if len(non_zeros[0]) > 0: | |
| x1x2 = non_zeros[1] / layer.shape[1] | |
| y1y2 = non_zeros[0] / layer.shape[0] | |
| x1 = x1x2.min() | |
| x2 = x1x2.max() | |
| y1 = y1y2.min() | |
| y2 = y1y2.max() | |
| if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE): | |
| raise gr.Error(f"Box in layer {i} is too small") | |
| boxes.append((x1, y1, x2, y2)) | |
| print(f"Drawn boxes: {boxes}") | |
| layout_image = draw_boxes(boxes) | |
| return [boxes, layout_image] | |
| def draw_boxes(boxes, is_sketch=False): | |
| if len(boxes) == 0: | |
| return None | |
| boxes = np.array(boxes) * RESOLUTION | |
| image = Image.new("RGB", (RESOLUTION, RESOLUTION), (WHITE, WHITE, WHITE)) | |
| drawing = ImageDraw.Draw(image) | |
| for i, box in enumerate(boxes.astype(int).tolist()): | |
| color = "black" if is_sketch else COLORS[i % len(COLORS)] | |
| drawing.rectangle(box, outline=color, width=4) | |
| return image | |
| def clear(batch_size): | |
| return [[], None, None, None] | |
| def build_example_layout(prompt, *args): | |
| boxes = EXAMPLE_BOXES[prompt] | |
| print(f"Loaded boxes: {boxes}") | |
| composite = draw_boxes(boxes, is_sketch=True) | |
| sketchpad = {"background": None, "layers": [], "composite": composite} | |
| layout_image = draw_boxes(boxes) | |
| return boxes, sketchpad, layout_image | |
| def main(): | |
| nltk.download("averaged_perceptron_tagger") | |
| model = StableDiffusionXLPipeline.from_pretrained(REMOTE_MODEL_PATH) | |
| model.save_pretrained(LOCAL_MODEL_PATH) | |
| del model | |
| with gr.Blocks( | |
| css=CSS, | |
| title="Bounded Attention demo", | |
| ) as demo: | |
| gr.HTML(DESCRIPTION) | |
| gr.HTML(COPY_LINK) | |
| with gr.Column(): | |
| gr.HTML("Scroll down to see examples of the required input format.") | |
| prompt = gr.Textbox( | |
| label="Text prompt", | |
| placeholder=PROMPT1, | |
| ) | |
| subject_sub_prompts = gr.Textbox( | |
| label="Sub-prompts for each subject (separate with semicolons)", | |
| placeholder=SUBJECT_SUB_PROMPTS1, | |
| ) | |
| with gr.Accordion("Precise inputs", open=False): | |
| subject_token_indices = gr.Textbox( | |
| label="Optional: The token indices of each subject (separate indices for the same subject with commas, and for different subjects with semicolons)", | |
| placeholder=SUBJECT_TOKEN_INDICES1, | |
| ) | |
| filter_token_indices = gr.Textbox( | |
| label="Optional: The token indices to filter, i.e. conjunctions, numbers, postional relations, etc. (if left empty, this will be automatically inferred)", | |
| placeholder=FILTER_TOKEN_INDICES1, | |
| ) | |
| num_tokens = gr.Textbox( | |
| label="Optional: The number of tokens in the prompt (We use this to verify your input, as sometimes rare words are split into more than one token)", | |
| placeholder=NUM_TOKENS1, | |
| ) | |
| with gr.Row(): | |
| sketchpad = gr.Sketchpad(label="Sketch Pad (draw each bounding box in a different layer)") | |
| layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False) | |
| with gr.Row(): | |
| generate_layout_button = gr.Button(value="Generate layout") | |
| generate_image_button = gr.Button(value="Generate image") | |
| clear_button = gr.Button(value="Clear") | |
| with gr.Row(): | |
| out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False) | |
| with gr.Accordion("Advanced Options", open=False): | |
| with gr.Column(): | |
| gr.HTML(ADVANCED_OPTION_DESCRIPTION) | |
| batch_size = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of samples (limited to one sample on current space)") | |
| num_guidance_steps = gr.Slider(minimum=5, maximum=20, step=1, value=8, label="Number of timesteps to perform guidance") | |
| init_step_size = gr.Slider(minimum=0, maximum=50, step=0.5, value=30, label="Initial step size") | |
| final_step_size = gr.Slider(minimum=0, maximum=20, step=0.5, value=15, label="Final step size") | |
| first_refinement_step = gr.Slider(minimum=0, maximum=50, step=1, value=15, label="The timestep from which to start refining the subject masks") | |
| num_clusters_per_subject = gr.Slider(minimum=0, maximum=5, step=0.5, value=3, label="Number of clusters per subject") | |
| cross_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Cross-attention loss scale factor") | |
| self_loss_scale = gr.Slider(minimum=0, maximum=2, step=0.1, value=1, label="Self-attention loss scale factor") | |
| num_iterations = gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Number of Gradient Descent iterations") | |
| loss_threshold = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Loss threshold") | |
| classifier_free_guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Classifier-free guidance Scale") | |
| seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random Seed") | |
| boxes = gr.State([]) | |
| clear_button.click( | |
| clear, | |
| inputs=[batch_size], | |
| outputs=[boxes, sketchpad, layout_image, out_images], | |
| queue=False, | |
| ) | |
| generate_layout_button.click( | |
| draw, | |
| inputs=[sketchpad], | |
| outputs=[boxes, layout_image], | |
| queue=False, | |
| ) | |
| generate_image_button.click( | |
| fn=generate, | |
| inputs=[ | |
| prompt, subject_sub_prompts, subject_token_indices, filter_token_indices, num_tokens, | |
| init_step_size, final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale, | |
| classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps, | |
| seed, | |
| boxes, | |
| ], | |
| outputs=[out_images], | |
| queue=True, | |
| ) | |
| with gr.Column(): | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| PROMPT1, SUBJECT_SUB_PROMPTS1, SUBJECT_TOKEN_INDICES1, FILTER_TOKEN_INDICES1, NUM_TOKENS1, | |
| 15, 10, 15, 3, 1, 1, | |
| 7.5, 1, 5, 0.2, 8, | |
| 12, | |
| ], | |
| [ | |
| PROMPT2, "cute unicorn;pink hedgehog;nerdy owl", "7,8,17;11,12,17;15,16,17", "5,6,9,10,13,14,18,19", "21", | |
| 25, 18, 15, 3, 1, 1, | |
| 7.5, 1, 5, 0.2, 8, | |
| 286, | |
| ], | |
| [ | |
| PROMPT3, "astronaut;robot;green alien;spaceship", "7;10;13,14;17", "5,6,8,9,11,12,15,16", "17", | |
| 18, 12, 15, 3, 1, 1, | |
| 7.5, 1, 5, 0.2, 8, | |
| 216, | |
| ], | |
| [ | |
| PROMPT4, "semi trailer;concrete mixer;helicopter", "9,10;13,14;17", "1,4,5,7,8,11,12,15,16", "17", | |
| 25, 18, 15, 3, 1, 1, | |
| 7.5, 1, 5, 0.2, 8, | |
| 82, | |
| ], | |
| [ | |
| PROMPT5, "golden retriever;german shepherd;boston terrier;english bulldog;border collie", "2,3;6,7;10,11;14,15;18,19", "1,4,5,8,9,12,13,16,17,20,21", "22", | |
| 18, 12, 15, 3, 1, 1, | |
| 7.5, 1, 5, 0.2, 8, | |
| 152, | |
| ], | |
| ], | |
| fn=build_example_layout, | |
| inputs=[ | |
| prompt, subject_sub_prompts, subject_token_indices, filter_token_indices, num_tokens, | |
| init_step_size, final_step_size, first_refinement_step, num_clusters_per_subject, cross_loss_scale, self_loss_scale, | |
| classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps, | |
| seed, | |
| ], | |
| outputs=[boxes, sketchpad, layout_image], | |
| run_on_click=True, | |
| ) | |
| gr.HTML(FOOTNOTE) | |
| demo.launch(show_api=False, show_error=True) | |
| if __name__ == "__main__": | |
| main() | |