Spaces:
Runtime error
Runtime error
| # ************************************************************************* | |
| # Copyright (2023) Bytedance Inc. | |
| # | |
| # Copyright (2023) DragDiffusion Authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ************************************************************************* | |
| import cv2 | |
| import numpy as np | |
| import PIL | |
| from PIL import Image | |
| from PIL.ImageOps import exif_transpose | |
| import os | |
| import gradio as gr | |
| import datetime | |
| import pickle | |
| from copy import deepcopy | |
| LENGTH=480 # length of the square area displaying/editing images | |
| def clear_all(length=480): | |
| return gr.Image.update(value=None, height=length, width=length), \ | |
| gr.Image.update(value=None, height=length, width=length), \ | |
| [], None, None | |
| def mask_image(image, | |
| mask, | |
| color=[255,0,0], | |
| alpha=0.5): | |
| """ Overlay mask on image for visualization purpose. | |
| Args: | |
| image (H, W, 3) or (H, W): input image | |
| mask (H, W): mask to be overlaid | |
| color: the color of overlaid mask | |
| alpha: the transparency of the mask | |
| """ | |
| out = deepcopy(image) | |
| img = deepcopy(image) | |
| img[mask == 1] = color | |
| out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out) | |
| return out | |
| def store_img(img, length=512): | |
| image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. | |
| height,width,_ = image.shape | |
| image = Image.fromarray(image) | |
| image = exif_transpose(image) | |
| image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR) | |
| mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST) | |
| image = np.array(image) | |
| if mask.sum() > 0: | |
| mask = np.uint8(mask > 0) | |
| masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) | |
| else: | |
| masked_img = image.copy() | |
| # when new image is uploaded, `selected_points` should be empty | |
| return image, [], masked_img, mask | |
| # user click the image to get points, and show the points on the image | |
| def get_points(img, | |
| sel_pix, | |
| evt: gr.SelectData): | |
| # collect the selected point | |
| sel_pix.append(evt.index) | |
| # draw points | |
| points = [] | |
| for idx, point in enumerate(sel_pix): | |
| if idx % 2 == 0: | |
| # draw a red circle at the handle point | |
| cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) | |
| else: | |
| # draw a blue circle at the handle point | |
| cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) | |
| points.append(tuple(point)) | |
| # draw an arrow from handle point to target point | |
| if len(points) == 2: | |
| cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) | |
| points = [] | |
| return img if isinstance(img, np.ndarray) else np.array(img) | |
| # clear all handle/target points | |
| def undo_points(original_image, | |
| mask): | |
| if mask.sum() > 0: | |
| mask = np.uint8(mask > 0) | |
| masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3) | |
| else: | |
| masked_img = original_image.copy() | |
| return masked_img, [] | |
| def save_all(category, | |
| source_image, | |
| image_with_clicks, | |
| mask, | |
| labeler, | |
| prompt, | |
| points, | |
| root_dir='./drag_bench_data'): | |
| if not os.path.isdir(root_dir): | |
| os.mkdir(root_dir) | |
| if not os.path.isdir(os.path.join(root_dir, category)): | |
| os.mkdir(os.path.join(root_dir, category)) | |
| save_prefix = labeler + '_' + datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") | |
| save_dir = os.path.join(root_dir, category, save_prefix) | |
| if not os.path.isdir(save_dir): | |
| os.mkdir(save_dir) | |
| # save images | |
| Image.fromarray(source_image).save(os.path.join(save_dir, 'original_image.png')) | |
| Image.fromarray(image_with_clicks).save(os.path.join(save_dir, 'user_drag.png')) | |
| # save meta data | |
| meta_data = { | |
| 'prompt' : prompt, | |
| 'points' : points, | |
| 'mask' : mask, | |
| } | |
| with open(os.path.join(save_dir, 'meta_data.pkl'), 'wb') as f: | |
| pickle.dump(meta_data, f) | |
| return save_prefix + " saved!" | |
| with gr.Blocks() as demo: | |
| # UI components for editing real images | |
| with gr.Tab(label="Editing Real Image"): | |
| mask = gr.State(value=None) # store mask | |
| selected_points = gr.State([]) # store points | |
| original_image = gr.State(value=None) # store original input image | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""") | |
| canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask", | |
| show_label=True, height=LENGTH, width=LENGTH) # for mask painting | |
| with gr.Column(): | |
| gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""") | |
| input_image = gr.Image(type="numpy", label="Click Points", | |
| show_label=True, height=LENGTH, width=LENGTH) # for points clicking | |
| with gr.Row(): | |
| labeler = gr.Textbox(label="Labeler") | |
| category = gr.Dropdown(value="art_work", | |
| label="Image Category", | |
| choices=[ | |
| 'art_work', | |
| 'land_scape', | |
| 'building_city_view', | |
| 'building_countryside_view', | |
| 'animals', | |
| 'human_head', | |
| 'human_upper_body', | |
| 'human_full_body', | |
| 'interior_design', | |
| 'other_objects', | |
| ] | |
| ) | |
| prompt = gr.Textbox(label="Prompt") | |
| save_status = gr.Textbox(label="display saving status") | |
| with gr.Row(): | |
| undo_button = gr.Button("undo points") | |
| clear_all_button = gr.Button("clear all") | |
| save_button = gr.Button("save") | |
| # event definition | |
| # event for dragging user-input real image | |
| canvas.edit( | |
| store_img, | |
| [canvas], | |
| [original_image, selected_points, input_image, mask] | |
| ) | |
| input_image.select( | |
| get_points, | |
| [input_image, selected_points], | |
| [input_image], | |
| ) | |
| undo_button.click( | |
| undo_points, | |
| [original_image, mask], | |
| [input_image, selected_points] | |
| ) | |
| clear_all_button.click( | |
| clear_all, | |
| [gr.Number(value=LENGTH, visible=False, precision=0)], | |
| [canvas, | |
| input_image, | |
| selected_points, | |
| original_image, | |
| mask] | |
| ) | |
| save_button.click( | |
| save_all, | |
| [category, | |
| original_image, | |
| input_image, | |
| mask, | |
| labeler, | |
| prompt, | |
| selected_points,], | |
| [save_status] | |
| ) | |
| demo.queue().launch(share=True, debug=True) | |