Spaces:
Build error
Build error
| from pathlib import Path | |
| import cv2 | |
| import sys | |
| import gradio as gr | |
| import os | |
| import numpy as np | |
| from gradio_utils import * | |
| from transformers import CLIPTokenizer | |
| def image_mod(image): | |
| return image.rotate(45) | |
| sys.path.insert(1, os.path.join(sys.path[0], '..')) | |
| NUM_POINTS = 3 | |
| NUM_FRAMES = 16 | |
| LARGE_BOX_SIZE = 176 | |
| data = {} | |
| tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32') | |
| def get_token_number(prompt, word): | |
| all_tokens = tokenizer(prompt).input_ids | |
| word_tokens = tokenizer(word).input_ids | |
| print(all_tokens, word_tokens, word) | |
| return all_tokens.index(word_tokens[1]) # Word_tokens start with cls | |
| def overlay_mask(img, mask): | |
| mask_resized = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| # Create a 3-channel version of the mask | |
| mask_3ch = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR) | |
| # Set the opacity level | |
| opacity = 0.25 # Adjust as needed | |
| alpha_channel = np.ones_like(mask_resized) * 255 # Start with a fully opaque alpha channel | |
| # Set black pixels to be completely transparent | |
| alpha_channel[mask_resized < 5] = 0 | |
| # Set the opacity level for non-black pixels | |
| opacity = 0.3 # Adjust this value as needed (0.0 to 1.0) | |
| alpha_channel[mask_resized != 0] = int(255 * opacity) | |
| # Create a 4-channel image (BGR + Alpha) | |
| b, g, r = cv2.split(img) | |
| rgba = [b, g, r, alpha_channel] | |
| result = cv2.merge(rgba, 4) | |
| # Overlay the mask on the image | |
| overlay = cv2.addWeighted(mask_3ch, opacity, img, 1 - opacity, 0) | |
| return overlay | |
| def fetch_proper_img(prompt, word, frame_num, diffusion_step, layer_num=0): | |
| frame_num = frame_num - 1 | |
| if layer_num is None: | |
| layer_num = 0 | |
| else: | |
| layer_num = 100 if layer_num == 3 else layer_num | |
| video_file_name = f"./data/videos/{prompt.replace(' ', '_')}/video/frame_{frame_num:04d}.png" | |
| img = cv2.imread(video_file_name) | |
| if word is None: | |
| overlaid_image = img | |
| else: | |
| mask_file_name = f'./data/final_masks/attention_probs_{prompt}/frame_{frame_num}_layer_{layer_num}_diffusionstep_{diffusion_step}_token_{get_token_number(prompt, word)}.png' | |
| mask = cv2.imread(mask_file_name, cv2.IMREAD_GRAYSCALE) | |
| overlaid_image = overlay_mask(img, mask) | |
| print(mask_file_name) | |
| return img, overlaid_image | |
| def fetch_proper_img_and_change_prompt(prompt, word, frame_num, diffusion_step, layer_num=0): | |
| radio = change_text_prompt(prompt) | |
| video_1, video_2 = fetch_proper_img(prompt, word, frame_num, diffusion_step, layer_num) | |
| return [video_1, video_2, radio] | |
| css = """ | |
| .word-btn { | |
| width: fit-content; | |
| padding: 3px; | |
| } | |
| .word-btns-container { | |
| flex-direction: row; | |
| } | |
| """ | |
| registry = { | |
| 'spider': 'mask_1', | |
| 'descending': 'mask_2', | |
| } | |
| data_path = Path( | |
| 'data' | |
| ) | |
| available_prompts = ['a dog and a cat sitting','A fish swimming in the water', 'A spider descending from its web', 'An astronaut riding a horse'] | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Row(): | |
| video_1 = gr.Image(label="Image", ) | |
| # video_1 = gr.Image(label="Image", width=256, height=256) | |
| video_2 = gr.Image(label="Image with Attention Mask", ) | |
| # video_2 = gr.Image(label="Image with Attention Mask", width=256, height=256) | |
| def change_text_prompt(text): | |
| return gr.Radio(text.strip().split(' '), value=None, label='Choose a word to visualize its attention mask.') | |
| text = 'a dog and a cat sitting' | |
| gr.Markdown(""" | |
| ## Visualizing Attention Masks | |
| * Select a prompt from the drop down | |
| * Click on "Get words" to get the words in the prompt | |
| * Select a radio button from the words to visualize the attention mask | |
| * Play around with the index of diffusion steps, layers to visualize different masks | |
| * Brighter mask corresponds to larger values of attention. | |
| """) | |
| with gr.Group("Video Selection"): | |
| txt_1 = gr.Dropdown(choices=available_prompts, label="Video Prompt", value=available_prompts[0]) | |
| submit_btn = gr.Button('Get words') | |
| with gr.Group('Word Selection'): | |
| radio = gr.Radio(text.split(' '), value=None, label='Choose a word to visualize its attention mask.') | |
| range_slider = gr.Slider(1, 16, 1, step=2, label='Frame of the generated video to visualize the attention mask.') | |
| diffusion_slider = gr.Slider(0, 35, 0, step=5, label='Index of diffusion steps.') | |
| layer_num_slider = gr.Slider(0, 6, 0, step=1, label='Layer number for attention mask.') | |
| radio.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2]) | |
| range_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2]) | |
| diffusion_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2]) | |
| layer_num_slider.change(fetch_proper_img, inputs=[txt_1, radio, range_slider, diffusion_slider, layer_num_slider], outputs=[video_1, video_2]) | |
| submit_btn.click(change_text_prompt, inputs=[txt_1], outputs=[radio]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name='0.0.0.0') | |