| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | import os |
| |
|
| | os.chdir("DragGan") |
| |
|
| | |
| |
|
| | import os.path as osp |
| | from argparse import ArgumentParser |
| | from functools import partial |
| | from pathlib import Path |
| | import time |
| |
|
| | import psutil |
| |
|
| | import gradio as gr |
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| |
|
| | import dnnlib |
| | from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image, |
| | get_latest_points_pair, get_valid_mask, |
| | on_change_single_global_state) |
| | from viz.renderer import Renderer, add_watermark_np |
| |
|
| |
|
| | |
| | from huggingface_hub import snapshot_download |
| |
|
| | model_dir = Path('./checkpoints') |
| | os.mkdir(model_dir) |
| | snapshot_download('DragGan/DragGan-Models', |
| | repo_type='model', local_dir=model_dir) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | ''' |
| | parser = ArgumentParser() |
| | parser.add_argument('--share', action='store_true') |
| | parser.add_argument('--cache-dir', type=str, default='./checkpoints') |
| | args = parser.parse_args() |
| | ''' |
| |
|
| | |
| | cache_dir = './checkpoints' |
| |
|
| | device = 'cuda' |
| | |
| | IS_SPACE = "DragGan/DragGan" in os.environ.get('SPACE_ID', '') |
| | TIMEOUT = 80 |
| |
|
| |
|
| | def reverse_point_pairs(points): |
| | new_points = [] |
| | for p in points: |
| | new_points.append([p[1], p[0]]) |
| | return new_points |
| |
|
| |
|
| | def clear_state(global_state, target=None): |
| | """Clear target history state from global_state |
| | If target is not defined, points and mask will be both removed. |
| | 1. set global_state['points'] as empty dict |
| | 2. set global_state['mask'] as full-one mask. |
| | """ |
| | if target is None: |
| | target = ['point', 'mask'] |
| | if not isinstance(target, list): |
| | target = [target] |
| | if 'point' in target: |
| | global_state['points'] = dict() |
| | print('Clear Points State!') |
| | if 'mask' in target: |
| | image_raw = global_state["images"]["image_raw"] |
| | global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]), |
| | dtype=np.uint8) |
| | print('Clear mask State!') |
| |
|
| | return global_state |
| |
|
| |
|
| | def init_images(global_state): |
| | """This function is called only ones with Gradio App is started. |
| | 0. pre-process global_state, unpack value from global_state of need |
| | 1. Re-init renderer |
| | 2. run `renderer._render_drag_impl` with `is_drag=False` to generate |
| | new image |
| | 3. Assign images to global state and re-generate mask |
| | """ |
| |
|
| | if isinstance(global_state, gr.State): |
| | state = global_state.value |
| | else: |
| | state = global_state |
| |
|
| | state['renderer'].init_network( |
| | state['generator_params'], |
| | valid_checkpoints_dict[state['pretrained_weight']], |
| | state['params']['seed'], |
| | None, |
| | state['params']['latent_space'] == 'w+', |
| | 'const', |
| | state['params']['trunc_psi'], |
| | state['params']['trunc_cutoff'], |
| | None, |
| | state['params']['lr'] |
| | ) |
| |
|
| | state['renderer']._render_drag_impl(state['generator_params'], |
| | is_drag=False, |
| | to_pil=True) |
| |
|
| | init_image = state['generator_params'].image |
| | state['images']['image_orig'] = init_image |
| | state['images']['image_raw'] = init_image |
| | state['images']['image_show'] = Image.fromarray( |
| | add_watermark_np(np.array(init_image))) |
| | state['mask'] = np.ones((init_image.size[1], init_image.size[0]), |
| | dtype=np.uint8) |
| | return global_state |
| |
|
| |
|
| | def update_image_draw(image, points, mask, show_mask, global_state=None): |
| |
|
| | image_draw = draw_points_on_image(image, points) |
| | if show_mask and mask is not None and not (mask == 0).all() and not ( |
| | mask == 1).all(): |
| | image_draw = draw_mask_on_image(image_draw, mask) |
| |
|
| | image_draw = Image.fromarray(add_watermark_np(np.array(image_draw))) |
| | if global_state is not None: |
| | global_state['images']['image_show'] = image_draw |
| | return image_draw |
| |
|
| |
|
| | def preprocess_mask_info(global_state, image): |
| | """Function to handle mask information. |
| | 1. last_mask is None: Do not need to change mask, return mask |
| | 2. last_mask is not None: |
| | 2.1 global_state is remove_mask: |
| | 2.2 global_state is add_mask: |
| | """ |
| | if isinstance(image, dict): |
| | last_mask = get_valid_mask(image['mask']) |
| | else: |
| | last_mask = None |
| | mask = global_state['mask'] |
| |
|
| | |
| | if (mask == 1).all(): |
| | mask = last_mask |
| |
|
| | |
| | editing_mode = global_state['editing_state'] |
| |
|
| | if last_mask is None: |
| | return global_state |
| |
|
| | if editing_mode == 'remove_mask': |
| | updated_mask = np.clip(mask - last_mask, 0, 1) |
| | print(f'Last editing_state is {editing_mode}, do remove.') |
| | elif editing_mode == 'add_mask': |
| | updated_mask = np.clip(mask + last_mask, 0, 1) |
| | print(f'Last editing_state is {editing_mode}, do add.') |
| | else: |
| | updated_mask = mask |
| | print(f'Last editing_state is {editing_mode}, ' |
| | 'do nothing to mask.') |
| |
|
| | global_state['mask'] = updated_mask |
| | |
| | return global_state |
| |
|
| |
|
| | def print_memory_usage(): |
| | |
| | print(f"System memory usage: {psutil.virtual_memory().percent}%") |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | print(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9} GB") |
| | print( |
| | f"Max GPU memory usage: {torch.cuda.max_memory_allocated() / 1e9} GB") |
| | device_properties = torch.cuda.get_device_properties(device) |
| | available_memory = device_properties.total_memory - \ |
| | torch.cuda.max_memory_allocated() |
| | print(f"Available GPU memory: {available_memory / 1e9} GB") |
| | else: |
| | print("No GPU available") |
| |
|
| |
|
| | |
| | allowed_checkpoints = [] |
| | if IS_SPACE: |
| | ''' |
| | allowed_checkpoints = ["stylegan_human_v2_512.pkl", |
| | "stylegan2_dogs_1024_pytorch.pkl", "stylegan3-t-ffhq-1024x1024.pkl"] |
| | ''' |
| | |
| | |
| | allowed_checkpoints = ["cat_512_stylegan2.pkl"] |
| |
|
| | valid_checkpoints_dict = { |
| | f.name.split('.')[0]: str(f) |
| | for f in Path(cache_dir).glob('*.pkl') |
| | if f.name in allowed_checkpoints or not IS_SPACE |
| | } |
| | print('Valid checkpoint file:') |
| | print(valid_checkpoints_dict) |
| |
|
| | |
| | |
| | |
| | init_pkl = "cat_512_stylegan2" |
| |
|
| | with gr.Blocks() as app: |
| | gr.Markdown(""" |
| | # DragGAN - Drag Your GAN |
| | ## Interactive Point-based Manipulation on the Generative Image Manifold |
| | ### Unofficial Gradio Demo |
| | |
| | **Due to high demand, only one model can be run at a time, or you can duplicate the space and run your own copy.** |
| | |
| | <a href="https://huggingface.co/spaces/radames/DragGan?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank"> |
| | <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> for no queue on your own hardware.</p> |
| | |
| | * Official Repo: [XingangPan](https://github.com/XingangPan/DragGAN) |
| | * Gradio Demo by: [LeoXing1996](https://github.com/LeoXing1996) © [OpenMMLab MMagic](https://github.com/open-mmlab/mmagic) |
| | """) |
| |
|
| | |
| | global_state = gr.State({ |
| | "images": { |
| | |
| | |
| | |
| | }, |
| | "temporal_params": { |
| | |
| | }, |
| | 'mask': |
| | None, |
| | 'last_mask': None, |
| | 'show_mask': True, |
| | "generator_params": dnnlib.EasyDict(), |
| | "params": { |
| | "seed": int(np.random.randint(0, 2**32 - 1)), |
| | "motion_lambda": 20, |
| | "r1_in_pixels": 3, |
| | "r2_in_pixels": 12, |
| | "magnitude_direction_in_pixels": 1.0, |
| | "latent_space": "w+", |
| | "trunc_psi": 0.7, |
| | "trunc_cutoff": None, |
| | "lr": 0.001, |
| | }, |
| | "device": device, |
| | "draw_interval": 1, |
| | "renderer": Renderer(disable_timing=True), |
| | "points": {}, |
| | "curr_point": None, |
| | "curr_type_point": "start", |
| | 'editing_state': 'add_points', |
| | 'pretrained_weight': init_pkl |
| | }) |
| |
|
| | |
| | global_state = init_images(global_state) |
| | with gr.Row(): |
| |
|
| | with gr.Row(): |
| |
|
| | |
| | with gr.Column(scale=3): |
| |
|
| | |
| | with gr.Row(): |
| |
|
| | with gr.Column(scale=1, min_width=10): |
| | gr.Markdown(value='Pickle', show_label=False) |
| |
|
| | with gr.Column(scale=4, min_width=10): |
| | form_pretrained_dropdown = gr.Dropdown( |
| | choices=list(valid_checkpoints_dict.keys()), |
| | label="Pretrained Model", |
| | value=init_pkl, |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1, min_width=10): |
| | gr.Markdown(value='Latent', show_label=False) |
| |
|
| | with gr.Column(scale=4, min_width=10): |
| | form_seed_number = gr.Slider( |
| | mininium=0, |
| | maximum=2**32-1, |
| | step=1, |
| | value=global_state.value['params']['seed'], |
| | interactive=True, |
| | |
| | label="Seed", |
| | ) |
| | form_lr_number = gr.Number( |
| | value=global_state.value["params"]["lr"], |
| | interactive=True, |
| | label="Step Size") |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=2, min_width=10): |
| | form_reset_image = gr.Button("Reset Image") |
| | with gr.Column(scale=3, min_width=10): |
| | form_latent_space = gr.Radio( |
| | ['w', 'w+'], |
| | value=global_state.value['params'] |
| | ['latent_space'], |
| | interactive=True, |
| | label='Latent space to optimize', |
| | show_label=False, |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1, min_width=10): |
| | gr.Markdown(value='Drag', show_label=False) |
| | with gr.Column(scale=4, min_width=10): |
| | with gr.Row(): |
| | with gr.Column(scale=1, min_width=10): |
| | enable_add_points = gr.Button('Add Points') |
| | with gr.Column(scale=1, min_width=10): |
| | undo_points = gr.Button('Reset Points') |
| | with gr.Row(): |
| | with gr.Column(scale=1, min_width=10): |
| | form_start_btn = gr.Button("Start") |
| | with gr.Column(scale=1, min_width=10): |
| | form_stop_btn = gr.Button("Stop") |
| |
|
| | form_steps_number = gr.Number(value=0, |
| | label="Steps", |
| | interactive=False) |
| |
|
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1, min_width=10): |
| | gr.Markdown(value='Mask', show_label=False) |
| | with gr.Column(scale=4, min_width=10): |
| | enable_add_mask = gr.Button('Edit Flexible Area') |
| | with gr.Row(): |
| | with gr.Column(scale=1, min_width=10): |
| | form_reset_mask_btn = gr.Button("Reset mask") |
| | with gr.Column(scale=1, min_width=10): |
| | show_mask = gr.Checkbox( |
| | label='Show Mask', |
| | value=global_state.value['show_mask'], |
| | show_label=False) |
| |
|
| | with gr.Row(): |
| | form_lambda_number = gr.Number( |
| | value=global_state.value["params"] |
| | ["motion_lambda"], |
| | interactive=True, |
| | label="Lambda", |
| | ) |
| |
|
| | form_draw_interval_number = gr.Number( |
| | value=global_state.value["draw_interval"], |
| | label="Draw Interval (steps)", |
| | interactive=True, |
| | visible=False) |
| |
|
| | |
| | with gr.Column(scale=8): |
| | form_image = ImageMask( |
| | value=global_state.value['images']['image_show'], |
| | brush_radius=20).style( |
| | width=768, |
| | height=768) |
| | gr.Markdown(""" |
| | ## Quick Start |
| | |
| | 1. Select desired `Pretrained Model` and adjust `Seed` to generate an |
| | initial image. |
| | 2. Click on image to add control points. |
| | 3. Click `Start` and enjoy it! |
| | |
| | ## Advance Usage |
| | |
| | 1. Change `Step Size` to adjust learning rate in drag optimization. |
| | 2. Select `w` or `w+` to change latent space to optimize: |
| | * Optimize on `w` space may cause greater influence to the image. |
| | * Optimize on `w+` space may work slower than `w`, but usually achieve |
| | better results. |
| | * Note that changing the latent space will reset the image, points and |
| | mask (this has the same effect as `Reset Image` button). |
| | 3. Click `Edit Flexible Area` to create a mask and constrain the |
| | unmasked region to remain unchanged. |
| | |
| | |
| | """) |
| | gr.HTML(""" |
| | <style> |
| | .container { |
| | position: absolute; |
| | height: 50px; |
| | text-align: center; |
| | line-height: 50px; |
| | width: 100%; |
| | } |
| | </style> |
| | <div class="container"> |
| | Gradio demo supported by |
| | <img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;"> |
| | <a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a> |
| | </div> |
| | """) |
| | |
| |
|
| | def on_change_pretrained_dropdown(pretrained_value, global_state): |
| | """Function to handle model change. |
| | 1. Set pretrained value to global_state |
| | 2. Re-init images and clear all states |
| | """ |
| |
|
| | global_state['pretrained_weight'] = pretrained_value |
| | init_images(global_state) |
| | clear_state(global_state) |
| |
|
| | return global_state, global_state["images"]['image_show'] |
| |
|
| | form_pretrained_dropdown.change( |
| | on_change_pretrained_dropdown, |
| | inputs=[form_pretrained_dropdown, global_state], |
| | outputs=[global_state, form_image], |
| | queue=True, |
| | ) |
| |
|
| | def on_click_reset_image(global_state): |
| | """Reset image to the original one and clear all states |
| | 1. Re-init images |
| | 2. Clear all states |
| | """ |
| |
|
| | init_images(global_state) |
| | clear_state(global_state) |
| |
|
| | return global_state, global_state['images']['image_show'] |
| |
|
| | form_reset_image.click( |
| | on_click_reset_image, |
| | inputs=[global_state], |
| | outputs=[global_state, form_image], |
| | queue=False, |
| | ) |
| |
|
| | |
| | def on_change_update_image_seed(seed, global_state): |
| | """Function to handle generation seed change. |
| | 1. Set seed to global_state |
| | 2. Re-init images and clear all states |
| | """ |
| |
|
| | global_state["params"]["seed"] = int(seed) |
| | init_images(global_state) |
| | clear_state(global_state) |
| |
|
| | return global_state, global_state['images']['image_show'] |
| |
|
| | form_seed_number.change( |
| | on_change_update_image_seed, |
| | inputs=[form_seed_number, global_state], |
| | outputs=[global_state, form_image], |
| | ) |
| |
|
| | def on_click_latent_space(latent_space, global_state): |
| | """Function to reset latent space to optimize. |
| | NOTE: this function we reset the image and all controls |
| | 1. Set latent-space to global_state |
| | 2. Re-init images and clear all state |
| | """ |
| |
|
| | global_state['params']['latent_space'] = latent_space |
| | init_images(global_state) |
| | clear_state(global_state) |
| |
|
| | return global_state, global_state['images']['image_show'] |
| |
|
| | form_latent_space.change(on_click_latent_space, |
| | inputs=[form_latent_space, global_state], |
| | outputs=[global_state, form_image]) |
| |
|
| | |
| | form_lambda_number.change( |
| | partial(on_change_single_global_state, ["params", "motion_lambda"]), |
| | inputs=[form_lambda_number, global_state], |
| | outputs=[global_state], |
| | ) |
| |
|
| | def on_change_lr(lr, global_state): |
| | if lr == 0: |
| | print('lr is 0, do nothing.') |
| | return global_state |
| | else: |
| | global_state["params"]["lr"] = lr |
| | renderer = global_state['renderer'] |
| | renderer.update_lr(lr) |
| | print('New optimizer: ') |
| | print(renderer.w_optim) |
| | return global_state |
| |
|
| | form_lr_number.change( |
| | on_change_lr, |
| | inputs=[form_lr_number, global_state], |
| | outputs=[global_state], |
| | queue=False, |
| | ) |
| |
|
| | def on_click_start(global_state, image): |
| | p_in_pixels = [] |
| | t_in_pixels = [] |
| | valid_points = [] |
| |
|
| | |
| | global_state = preprocess_mask_info(global_state, image) |
| |
|
| | |
| | if len(global_state["points"]) == 0: |
| | |
| | image_raw = global_state['images']['image_raw'] |
| | update_image_draw( |
| | image_raw, |
| | global_state['points'], |
| | global_state['mask'], |
| | global_state['show_mask'], |
| | global_state, |
| | ) |
| |
|
| | yield ( |
| | global_state, |
| | 0, |
| | global_state['images']['image_show'], |
| | |
| | gr.Button.update(interactive=True), |
| | gr.Button.update(interactive=True), |
| | gr.Button.update(interactive=True), |
| | gr.Button.update(interactive=True), |
| | gr.Button.update(interactive=True), |
| | |
| | gr.Radio.update(interactive=True), |
| | gr.Button.update(interactive=True), |
| | |
| | gr.Button.update(interactive=False), |
| |
|
| | |
| | gr.Dropdown.update(interactive=True), |
| | gr.Number.update(interactive=True), |
| | gr.Number.update(interactive=True), |
| | gr.Button.update(interactive=True), |
| | gr.Button.update(interactive=True), |
| | gr.Checkbox.update(interactive=True), |
| | |
| | gr.Number.update(interactive=True), |
| | ) |
| | else: |
| |
|
| | |
| | for key_point, point in global_state["points"].items(): |
| | try: |
| | p_start = point.get("start_temp", point["start"]) |
| | p_end = point["target"] |
| |
|
| | if p_start is None or p_end is None: |
| | continue |
| |
|
| | except KeyError: |
| | continue |
| |
|
| | p_in_pixels.append(p_start) |
| | t_in_pixels.append(p_end) |
| | valid_points.append(key_point) |
| |
|
| | mask = torch.tensor(global_state['mask']).float() |
| | drag_mask = 1 - mask |
| |
|
| | renderer: Renderer = global_state["renderer"] |
| | global_state['temporal_params']['stop'] = False |
| | global_state['editing_state'] = 'running' |
| |
|
| | |
| | p_to_opt = reverse_point_pairs(p_in_pixels) |
| | t_to_opt = reverse_point_pairs(t_in_pixels) |
| | print('Running with:') |
| | print(f' Source: {p_in_pixels}') |
| | print(f' Target: {t_in_pixels}') |
| | step_idx = 0 |
| | last_time = time.time() |
| | while True: |
| | print_memory_usage() |
| | |
| | print(f'Running time: {time.time() - last_time}') |
| | if IS_SPACE and time.time() - last_time > TIMEOUT: |
| | print('Timeout break!') |
| | break |
| | if global_state["temporal_params"]["stop"] or global_state['generator_params']["stop"]: |
| | break |
| |
|
| | |
| | renderer._render_drag_impl( |
| | global_state['generator_params'], |
| | p_to_opt, |
| | t_to_opt, |
| | drag_mask, |
| | global_state['params']['motion_lambda'], |
| | reg=0, |
| | feature_idx=5, |
| | r1=global_state['params']['r1_in_pixels'], |
| | r2=global_state['params']['r2_in_pixels'], |
| | |
| | |
| | trunc_psi=global_state['params']['trunc_psi'], |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | is_drag=True, |
| | to_pil=True) |
| |
|
| | if step_idx % global_state['draw_interval'] == 0: |
| | print('Current Source:') |
| | for key_point, p_i, t_i in zip(valid_points, p_to_opt, |
| | t_to_opt): |
| | global_state["points"][key_point]["start_temp"] = [ |
| | p_i[1], |
| | p_i[0], |
| | ] |
| | global_state["points"][key_point]["target"] = [ |
| | t_i[1], |
| | t_i[0], |
| | ] |
| | start_temp = global_state["points"][key_point][ |
| | "start_temp"] |
| | print(f' {start_temp}') |
| |
|
| | image_result = global_state['generator_params']['image'] |
| | image_draw = update_image_draw( |
| | image_result, |
| | global_state['points'], |
| | global_state['mask'], |
| | global_state['show_mask'], |
| | global_state, |
| | ) |
| | global_state['images']['image_raw'] = image_result |
| |
|
| | yield ( |
| | global_state, |
| | step_idx, |
| | global_state['images']['image_show'], |
| | |
| | gr.Button.update(interactive=False), |
| | gr.Button.update(interactive=False), |
| | gr.Button.update(interactive=False), |
| | gr.Button.update(interactive=False), |
| | gr.Button.update(interactive=False), |
| | |
| | gr.Radio.update(interactive=False), |
| | gr.Button.update(interactive=False), |
| | |
| | gr.Button.update(interactive=True), |
| |
|
| | |
| | gr.Dropdown.update(interactive=False), |
| | gr.Number.update(interactive=False), |
| | gr.Number.update(interactive=False), |
| | gr.Button.update(interactive=False), |
| | gr.Button.update(interactive=False), |
| | gr.Checkbox.update(interactive=False), |
| | |
| | gr.Number.update(interactive=False), |
| | ) |
| |
|
| | |
| | step_idx += 1 |
| |
|
| | image_result = global_state['generator_params']['image'] |
| | global_state['images']['image_raw'] = image_result |
| | image_draw = update_image_draw(image_result, |
| | global_state['points'], |
| | global_state['mask'], |
| | global_state['show_mask'], |
| | global_state) |
| |
|
| | |
| | |
| |
|
| | global_state['editing_state'] = 'add_points' |
| |
|
| | yield ( |
| | global_state, |
| | 0, |
| | global_state['images']['image_show'], |
| | |
| | gr.Button.update(interactive=True), |
| | gr.Button.update(interactive=True), |
| | gr.Button.update(interactive=True), |
| | gr.Button.update(interactive=True), |
| | gr.Button.update(interactive=True), |
| | |
| | gr.Radio.update(interactive=True), |
| | gr.Button.update(interactive=True), |
| | |
| | gr.Button.update(interactive=False), |
| |
|
| | |
| | gr.Dropdown.update(interactive=True), |
| | gr.Number.update(interactive=True), |
| | gr.Number.update(interactive=True), |
| | gr.Checkbox.update(interactive=True), |
| | gr.Number.update(interactive=True), |
| | ) |
| |
|
| | form_start_btn.click( |
| | on_click_start, |
| | inputs=[global_state, form_image], |
| | outputs=[ |
| | global_state, |
| | form_steps_number, |
| | form_image, |
| | |
| | |
| | form_reset_image, |
| | enable_add_points, |
| | enable_add_mask, |
| | undo_points, |
| | form_reset_mask_btn, |
| | form_latent_space, |
| | form_start_btn, |
| | form_stop_btn, |
| | |
| | |
| | form_pretrained_dropdown, |
| | form_seed_number, |
| | form_lr_number, |
| | show_mask, |
| | form_lambda_number, |
| | ], |
| | ) |
| |
|
| | def on_click_stop(global_state): |
| | """Function to handle stop button is clicked. |
| | 1. send a stop signal by set global_state["temporal_params"]["stop"] as True |
| | 2. Disable Stop button |
| | """ |
| | global_state["temporal_params"]["stop"] = True |
| |
|
| | return global_state, gr.Button.update(interactive=False) |
| |
|
| | form_stop_btn.click(on_click_stop, |
| | inputs=[global_state], |
| | outputs=[global_state, form_stop_btn], |
| | queue=False) |
| |
|
| | form_draw_interval_number.change( |
| | partial( |
| | on_change_single_global_state, |
| | "draw_interval", |
| | map_transform=lambda x: int(x), |
| | ), |
| | inputs=[form_draw_interval_number, global_state], |
| | outputs=[global_state], |
| | queue=False, |
| | ) |
| |
|
| | def on_click_remove_point(global_state): |
| | choice = global_state["curr_point"] |
| | del global_state["points"][choice] |
| |
|
| | choices = list(global_state["points"].keys()) |
| |
|
| | if len(choices) > 0: |
| | global_state["curr_point"] = choices[0] |
| |
|
| | return ( |
| | gr.Dropdown.update(choices=choices, value=choices[0]), |
| | global_state, |
| | ) |
| |
|
| | |
| | def on_click_reset_mask(global_state): |
| | global_state['mask'] = np.ones( |
| | ( |
| | global_state["images"]["image_raw"].size[1], |
| | global_state["images"]["image_raw"].size[0], |
| | ), |
| | dtype=np.uint8, |
| | ) |
| | image_draw = update_image_draw(global_state['images']['image_raw'], |
| | global_state['points'], |
| | global_state['mask'], |
| | global_state['show_mask'], global_state) |
| | return global_state, image_draw |
| |
|
| | form_reset_mask_btn.click( |
| | on_click_reset_mask, |
| | inputs=[global_state], |
| | outputs=[global_state, form_image], |
| | ) |
| |
|
| | |
| | def on_click_enable_draw(global_state, image): |
| | """Function to start add mask mode. |
| | 1. Preprocess mask info from last state |
| | 2. Change editing state to add_mask |
| | 3. Set curr image with points and mask |
| | """ |
| | global_state = preprocess_mask_info(global_state, image) |
| | global_state['editing_state'] = 'add_mask' |
| | image_raw = global_state['images']['image_raw'] |
| | image_draw = update_image_draw(image_raw, global_state['points'], |
| | global_state['mask'], True, |
| | global_state) |
| | return (global_state, |
| | gr.Image.update(value=image_draw, interactive=True)) |
| |
|
| | def on_click_remove_draw(global_state, image): |
| | """Function to start remove mask mode. |
| | 1. Preprocess mask info from last state |
| | 2. Change editing state to remove_mask |
| | 3. Set curr image with points and mask |
| | """ |
| | global_state = preprocess_mask_info(global_state, image) |
| | global_state['edinting_state'] = 'remove_mask' |
| | image_raw = global_state['images']['image_raw'] |
| | image_draw = update_image_draw(image_raw, global_state['points'], |
| | global_state['mask'], True, |
| | global_state) |
| | return (global_state, |
| | gr.Image.update(value=image_draw, interactive=True)) |
| |
|
| | enable_add_mask.click(on_click_enable_draw, |
| | inputs=[global_state, form_image], |
| | outputs=[ |
| | global_state, |
| | form_image, |
| | ], |
| | queue=False) |
| |
|
| | def on_click_add_point(global_state, image: dict): |
| | """Function switch from add mask mode to add points mode. |
| | 1. Updaste mask buffer if need |
| | 2. Change global_state['editing_state'] to 'add_points' |
| | 3. Set current image with mask |
| | """ |
| |
|
| | global_state = preprocess_mask_info(global_state, image) |
| | global_state['editing_state'] = 'add_points' |
| | mask = global_state['mask'] |
| | image_raw = global_state['images']['image_raw'] |
| | image_draw = update_image_draw(image_raw, global_state['points'], mask, |
| | global_state['show_mask'], global_state) |
| |
|
| | return (global_state, |
| | gr.Image.update(value=image_draw, interactive=False)) |
| |
|
| | enable_add_points.click(on_click_add_point, |
| | inputs=[global_state, form_image], |
| | outputs=[global_state, form_image], |
| | queue=False) |
| |
|
| | def on_click_image(global_state, evt: gr.SelectData): |
| | """This function only support click for point selection |
| | """ |
| | xy = evt.index |
| | if global_state['editing_state'] != 'add_points': |
| | print(f'In {global_state["editing_state"]} state. ' |
| | 'Do not add points.') |
| |
|
| | return global_state, global_state['images']['image_show'] |
| |
|
| | points = global_state["points"] |
| |
|
| | point_idx = get_latest_points_pair(points) |
| | if point_idx is None: |
| | points[0] = {'start': xy, 'target': None} |
| | print(f'Click Image - Start - {xy}') |
| | elif points[point_idx].get('target', None) is None: |
| | points[point_idx]['target'] = xy |
| | print(f'Click Image - Target - {xy}') |
| | else: |
| | points[point_idx + 1] = {'start': xy, 'target': None} |
| | print(f'Click Image - Start - {xy}') |
| |
|
| | image_raw = global_state['images']['image_raw'] |
| | image_draw = update_image_draw( |
| | image_raw, |
| | global_state['points'], |
| | global_state['mask'], |
| | global_state['show_mask'], |
| | global_state, |
| | ) |
| |
|
| | return global_state, image_draw |
| |
|
| | form_image.select( |
| | on_click_image, |
| | inputs=[global_state], |
| | outputs=[global_state, form_image], |
| | queue=False, |
| | ) |
| |
|
| | def on_click_clear_points(global_state): |
| | """Function to handle clear all control points |
| | 1. clear global_state['points'] (clear_state) |
| | 2. re-init network |
| | 2. re-draw image |
| | """ |
| | clear_state(global_state, target='point') |
| |
|
| | renderer: Renderer = global_state["renderer"] |
| | renderer.feat_refs = None |
| |
|
| | image_raw = global_state['images']['image_raw'] |
| | image_draw = update_image_draw(image_raw, {}, global_state['mask'], |
| | global_state['show_mask'], global_state) |
| | return global_state, image_draw |
| |
|
| | undo_points.click(on_click_clear_points, |
| | inputs=[global_state], |
| | outputs=[global_state, form_image], |
| | queue=False) |
| |
|
| | def on_click_show_mask(global_state, show_mask): |
| | """Function to control whether show mask on image.""" |
| | global_state['show_mask'] = show_mask |
| |
|
| | image_raw = global_state['images']['image_raw'] |
| | image_draw = update_image_draw( |
| | image_raw, |
| | global_state['points'], |
| | global_state['mask'], |
| | global_state['show_mask'], |
| | global_state, |
| | ) |
| | return global_state, image_draw |
| |
|
| | show_mask.change( |
| | on_click_show_mask, |
| | inputs=[global_state, show_mask], |
| | outputs=[global_state, form_image], |
| | queue=False, |
| | ) |
| |
|
| | |
| | gr.close_all() |
| | app.queue(concurrency_count=1, max_size=200, api_open=False) |
| | |
| | |
| | app.launch(share = True) |
| |
|