SuperPauly commited on
Commit
49feeef
Β·
verified Β·
1 Parent(s): 2222d57

Update visualizer_drag_gradio.py

Browse files
Files changed (1) hide show
  1. visualizer_drag_gradio.py +326 -313
visualizer_drag_gradio.py CHANGED
@@ -15,28 +15,31 @@ import torch
15
  from PIL import Image
16
 
17
  import dnnlib
18
- from gradio_utils import (ImageMask, draw_mask_on_image, draw_points_on_image,
19
- get_latest_points_pair, get_valid_mask,
20
- on_change_single_global_state)
 
 
 
 
 
21
  from viz.renderer import Renderer, add_watermark_np
22
 
23
-
24
  # download models from Hugging Face hub
25
  from huggingface_hub import snapshot_download
26
 
27
- model_dir = Path('./checkpoints')
28
- snapshot_download('DragGan/DragGan-Models',
29
- repo_type='model', local_dir=model_dir)
30
 
31
  parser = ArgumentParser()
32
- parser.add_argument('--share', action='store_true')
33
- parser.add_argument('--cache-dir', type=str, default='./checkpoints')
34
  args = parser.parse_args()
35
 
36
  cache_dir = args.cache_dir
37
 
38
- device = 'cuda'
39
- IS_SPACE = "DragGan/DragGan" in os.environ.get('SPACE_ID', '')
40
  TIMEOUT = 80
41
 
42
 
@@ -54,23 +57,24 @@ def clear_state(global_state, target=None):
54
  2. set global_state['mask'] as full-one mask.
55
  """
56
  if target is None:
57
- target = ['point', 'mask']
58
  if not isinstance(target, list):
59
  target = [target]
60
- if 'point' in target:
61
- global_state['points'] = dict()
62
- print('Clear Points State!')
63
- if 'mask' in target:
64
  image_raw = global_state["images"]["image_raw"]
65
- global_state['mask'] = np.ones((image_raw.size[1], image_raw.size[0]),
66
- dtype=np.uint8)
67
- print('Clear mask State!')
 
68
 
69
  return global_state
70
 
71
 
72
  def init_images(global_state):
73
- """This function is called only ones with Gradio App is started.
74
  0. pre-process global_state, unpack value from global_state of need
75
  1. Re-init renderer
76
  2. run `renderer._render_drag_impl` with `is_drag=False` to generate
@@ -83,43 +87,48 @@ def init_images(global_state):
83
  else:
84
  state = global_state
85
 
86
- state['renderer'].init_network(
87
- state['generator_params'], # res
88
- valid_checkpoints_dict[state['pretrained_weight']], # pkl
89
- state['params']['seed'], # w0_seed,
90
  None, # w_load
91
- state['params']['latent_space'] == 'w+', # w_plus
92
- 'const',
93
- state['params']['trunc_psi'], # trunc_psi,
94
- state['params']['trunc_cutoff'], # trunc_cutoff,
95
  None, # input_transform
96
- state['params']['lr'] # lr,
 
 
 
 
97
  )
98
 
99
- state['renderer']._render_drag_impl(state['generator_params'],
100
- is_drag=False,
101
- to_pil=True)
102
-
103
- init_image = state['generator_params'].image
104
- state['images']['image_orig'] = init_image
105
- state['images']['image_raw'] = init_image
106
- state['images']['image_show'] = Image.fromarray(
107
- add_watermark_np(np.array(init_image)))
108
- state['mask'] = np.ones((init_image.size[1], init_image.size[0]),
109
- dtype=np.uint8)
110
  return global_state
111
 
112
 
113
  def update_image_draw(image, points, mask, show_mask, global_state=None):
114
-
115
  image_draw = draw_points_on_image(image, points)
116
- if show_mask and mask is not None and not (mask == 0).all() and not (
117
- mask == 1).all():
 
 
 
 
118
  image_draw = draw_mask_on_image(image_draw, mask)
119
 
120
  image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
121
  if global_state is not None:
122
- global_state['images']['image_show'] = image_draw
123
  return image_draw
124
 
125
 
@@ -131,34 +140,31 @@ def preprocess_mask_info(global_state, image):
131
  2.2 global_state is add_mask:
132
  """
133
  if isinstance(image, dict):
134
- last_mask = get_valid_mask(image['mask'])
135
  else:
136
  last_mask = None
137
- mask = global_state['mask']
138
 
139
  # mask in global state is a placeholder with all 1.
140
  if (mask == 1).all():
141
  mask = last_mask
142
 
143
- # last_mask = global_state['last_mask']
144
- editing_mode = global_state['editing_state']
145
 
146
  if last_mask is None:
147
  return global_state
148
 
149
- if editing_mode == 'remove_mask':
150
  updated_mask = np.clip(mask - last_mask, 0, 1)
151
- print(f'Last editing_state is {editing_mode}, do remove.')
152
- elif editing_mode == 'add_mask':
153
  updated_mask = np.clip(mask + last_mask, 0, 1)
154
- print(f'Last editing_state is {editing_mode}, do add.')
155
  else:
156
  updated_mask = mask
157
- print(f'Last editing_state is {editing_mode}, '
158
- 'do nothing to mask.')
159
 
160
- global_state['mask'] = updated_mask
161
- # global_state['last_mask'] = None # clear buffer
162
  return global_state
163
 
164
 
@@ -171,10 +177,12 @@ def print_memory_usage():
171
  device = torch.device("cuda")
172
  print(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9} GB")
173
  print(
174
- f"Max GPU memory usage: {torch.cuda.max_memory_allocated() / 1e9} GB")
 
175
  device_properties = torch.cuda.get_device_properties(device)
176
- available_memory = device_properties.total_memory - \
177
- torch.cuda.max_memory_allocated()
 
178
  print(f"Available GPU memory: {available_memory / 1e9} GB")
179
  else:
180
  print("No GPU available")
@@ -183,21 +191,24 @@ def print_memory_usage():
183
  # filter large models running on SPACES
184
  allowed_checkpoints = [] # all checkpoints
185
  if IS_SPACE:
186
- allowed_checkpoints = ["stylegan_human_v2_512.pkl",
187
- "stylegan2_dogs_1024_pytorch.pkl"]
 
 
188
 
189
  valid_checkpoints_dict = {
190
- f.name.split('.')[0]: str(f)
191
- for f in Path(cache_dir).glob('*.pkl')
192
  if f.name in allowed_checkpoints or not IS_SPACE
193
  }
194
- print('Valid checkpoint file:')
195
  print(valid_checkpoints_dict)
196
 
197
- init_pkl = 'stylegan_human_v2_512'
198
 
199
  with gr.Blocks() as app:
200
- gr.Markdown("""
 
201
  # DragGAN - Drag Your GAN
202
  ## Interactive Point-based Manipulation on the Generative Image Manifold
203
  ### Unofficial Gradio Demo
@@ -209,58 +220,49 @@ with gr.Blocks() as app:
209
 
210
  * Official Repo: [XingangPan](https://github.com/XingangPan/DragGAN)
211
  * Gradio Demo by: [LeoXing1996](https://github.com/LeoXing1996) Β© [OpenMMLab MMagic](https://github.com/open-mmlab/mmagic)
212
- """)
213
-
214
- # renderer = Renderer()
215
- global_state = gr.State({
216
- "images": {
217
- # image_orig: the original image, change with seed/model is changed
218
- # image_raw: image with mask and points, change durning optimization
219
- # image_show: image showed on screen
220
- },
221
- "temporal_params": {
222
- # stop
223
- },
224
- 'mask':
225
- None, # mask for visualization, 1 for editing and 0 for unchange
226
- 'last_mask': None, # last edited mask
227
- 'show_mask': True, # add button
228
- "generator_params": dnnlib.EasyDict(),
229
- "params": {
230
- "seed": int(np.random.randint(0, 2**32 - 1)),
231
- "motion_lambda": 20,
232
- "r1_in_pixels": 3,
233
- "r2_in_pixels": 12,
234
- "magnitude_direction_in_pixels": 1.0,
235
- "latent_space": "w+",
236
- "trunc_psi": 0.7,
237
- "trunc_cutoff": None,
238
- "lr": 0.001,
239
- },
240
- "device": device,
241
- "draw_interval": 1,
242
- "renderer": Renderer(disable_timing=True),
243
- "points": {},
244
- "curr_point": None,
245
- "curr_type_point": "start",
246
- 'editing_state': 'add_points',
247
- 'pretrained_weight': init_pkl
248
- })
249
 
250
  # init image
251
  global_state = init_images(global_state)
252
  with gr.Row():
253
-
254
  with gr.Row():
255
-
256
  # Left --> tools
257
  with gr.Column(scale=3):
258
-
259
  # Pickle
260
  with gr.Row():
261
-
262
  with gr.Column(scale=1, min_width=10):
263
- gr.Markdown(value='Pickle', show_label=False)
264
 
265
  with gr.Column(scale=4, min_width=10):
266
  form_pretrained_dropdown = gr.Dropdown(
@@ -272,14 +274,14 @@ with gr.Blocks() as app:
272
  # Latent
273
  with gr.Row():
274
  with gr.Column(scale=1, min_width=10):
275
- gr.Markdown(value='Latent', show_label=False)
276
 
277
  with gr.Column(scale=4, min_width=10):
278
  form_seed_number = gr.Slider(
279
- mininium=0,
280
- maximum=2**32-1,
281
  step=1,
282
- value=global_state.value['params']['seed'],
283
  interactive=True,
284
  # randomize=True,
285
  label="Seed",
@@ -287,60 +289,66 @@ with gr.Blocks() as app:
287
  form_lr_number = gr.Number(
288
  value=global_state.value["params"]["lr"],
289
  interactive=True,
290
- label="Step Size")
 
291
 
292
  with gr.Row():
293
  with gr.Column(scale=2, min_width=10):
294
  form_reset_image = gr.Button("Reset Image")
295
  with gr.Column(scale=3, min_width=10):
296
  form_latent_space = gr.Radio(
297
- ['w', 'w+'],
298
- value=global_state.value['params']
299
- ['latent_space'],
 
300
  interactive=True,
301
- label='Latent space to optimize',
302
  show_label=False,
303
  )
304
 
305
  # Drag
306
  with gr.Row():
307
  with gr.Column(scale=1, min_width=10):
308
- gr.Markdown(value='Drag', show_label=False)
309
  with gr.Column(scale=4, min_width=10):
310
  with gr.Row():
311
  with gr.Column(scale=1, min_width=10):
312
- enable_add_points = gr.Button('Add Points')
313
  with gr.Column(scale=1, min_width=10):
314
- undo_points = gr.Button('Reset Points')
315
  with gr.Row():
316
  with gr.Column(scale=1, min_width=10):
317
  form_start_btn = gr.Button("Start")
318
  with gr.Column(scale=1, min_width=10):
319
  form_stop_btn = gr.Button("Stop")
320
 
321
- form_steps_number = gr.Number(value=0,
322
- label="Steps",
323
- interactive=False)
 
 
324
 
325
  # Mask
326
  with gr.Row():
327
  with gr.Column(scale=1, min_width=10):
328
- gr.Markdown(value='Mask', show_label=False)
329
  with gr.Column(scale=4, min_width=10):
330
- enable_add_mask = gr.Button('Edit Flexible Area')
331
  with gr.Row():
332
  with gr.Column(scale=1, min_width=10):
333
  form_reset_mask_btn = gr.Button("Reset mask")
334
  with gr.Column(scale=1, min_width=10):
335
  show_mask = gr.Checkbox(
336
- label='Show Mask',
337
- value=global_state.value['show_mask'],
338
- show_label=False)
 
339
 
340
  with gr.Row():
341
  form_lambda_number = gr.Number(
342
- value=global_state.value["params"]
343
- ["motion_lambda"],
 
344
  interactive=True,
345
  label="Lambda",
346
  )
@@ -349,16 +357,21 @@ with gr.Blocks() as app:
349
  value=global_state.value["draw_interval"],
350
  label="Draw Interval (steps)",
351
  interactive=True,
352
- visible=False)
 
353
 
354
  # Right --> Image
355
  with gr.Column(scale=8):
356
  form_image = ImageMask(
357
- value=global_state.value['images']['image_show'],
358
- brush_radius=20).style(
359
- width=768,
360
- height=768) # NOTE: hard image size code here.
361
- gr.Markdown("""
 
 
 
 
362
  ## Quick Start
363
 
364
  1. Select desired `Pretrained Model` and adjust `Seed` to generate an
@@ -377,10 +390,10 @@ with gr.Blocks() as app:
377
  mask (this has the same effect as `Reset Image` button).
378
  3. Click `Edit Flexible Area` to create a mask and constrain the
379
  unmasked region to remain unchanged.
380
-
381
-
382
- """)
383
- gr.HTML("""
384
  <style>
385
  .container {
386
  position: absolute;
@@ -395,8 +408,8 @@ with gr.Blocks() as app:
395
  <img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
396
  <a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
397
  </div>
398
- """)
399
- # Network & latents tab listeners
400
 
401
  def on_change_pretrained_dropdown(pretrained_value, global_state):
402
  """Function to handle model change.
@@ -404,11 +417,11 @@ with gr.Blocks() as app:
404
  2. Re-init images and clear all states
405
  """
406
 
407
- global_state['pretrained_weight'] = pretrained_value
408
  init_images(global_state)
409
  clear_state(global_state)
410
 
411
- return global_state, global_state["images"]['image_show']
412
 
413
  form_pretrained_dropdown.change(
414
  on_change_pretrained_dropdown,
@@ -426,7 +439,7 @@ with gr.Blocks() as app:
426
  init_images(global_state)
427
  clear_state(global_state)
428
 
429
- return global_state, global_state['images']['image_show']
430
 
431
  form_reset_image.click(
432
  on_click_reset_image,
@@ -446,7 +459,7 @@ with gr.Blocks() as app:
446
  init_images(global_state)
447
  clear_state(global_state)
448
 
449
- return global_state, global_state['images']['image_show']
450
 
451
  form_seed_number.change(
452
  on_change_update_image_seed,
@@ -461,15 +474,17 @@ with gr.Blocks() as app:
461
  2. Re-init images and clear all state
462
  """
463
 
464
- global_state['params']['latent_space'] = latent_space
465
  init_images(global_state)
466
  clear_state(global_state)
467
 
468
- return global_state, global_state['images']['image_show']
469
 
470
- form_latent_space.change(on_click_latent_space,
471
- inputs=[form_latent_space, global_state],
472
- outputs=[global_state, form_image])
 
 
473
 
474
  # ==== Params
475
  form_lambda_number.change(
@@ -480,13 +495,13 @@ with gr.Blocks() as app:
480
 
481
  def on_change_lr(lr, global_state):
482
  if lr == 0:
483
- print('lr is 0, do nothing.')
484
  return global_state
485
  else:
486
  global_state["params"]["lr"] = lr
487
- renderer = global_state['renderer']
488
  renderer.update_lr(lr)
489
- print('New optimizer: ')
490
  print(renderer.w_optim)
491
  return global_state
492
 
@@ -507,45 +522,36 @@ with gr.Blocks() as app:
507
 
508
  # Prepare the points for the inference
509
  if len(global_state["points"]) == 0:
510
- # yield on_click_start_wo_points(global_state, image)
511
- image_raw = global_state['images']['image_raw']
512
  update_image_draw(
513
  image_raw,
514
- global_state['points'],
515
- global_state['mask'],
516
- global_state['show_mask'],
517
  global_state,
518
  )
519
 
520
  yield (
521
  global_state,
522
  0,
523
- global_state['images']['image_show'],
524
- # gr.File.update(visible=False),
525
  gr.Button.update(interactive=True),
526
  gr.Button.update(interactive=True),
527
  gr.Button.update(interactive=True),
528
  gr.Button.update(interactive=True),
529
  gr.Button.update(interactive=True),
530
- # latent space
531
  gr.Radio.update(interactive=True),
532
  gr.Button.update(interactive=True),
533
- # NOTE: disable stop button
534
  gr.Button.update(interactive=False),
535
-
536
- # update other comps
537
  gr.Dropdown.update(interactive=True),
538
  gr.Number.update(interactive=True),
539
  gr.Number.update(interactive=True),
540
  gr.Button.update(interactive=True),
541
  gr.Button.update(interactive=True),
542
  gr.Checkbox.update(interactive=True),
543
- # gr.Number.update(interactive=True),
544
  gr.Number.update(interactive=True),
545
  )
546
  else:
547
-
548
- # Transform the points into torch tensors
549
  for key_point, point in global_state["points"].items():
550
  try:
551
  p_start = point.get("start_temp", point["start"])
@@ -561,59 +567,51 @@ with gr.Blocks() as app:
561
  t_in_pixels.append(p_end)
562
  valid_points.append(key_point)
563
 
564
- mask = torch.tensor(global_state['mask']).float()
565
  drag_mask = 1 - mask
566
 
567
  renderer: Renderer = global_state["renderer"]
568
- global_state['temporal_params']['stop'] = False
569
- global_state['editing_state'] = 'running'
570
 
571
- # reverse points order
572
  p_to_opt = reverse_point_pairs(p_in_pixels)
573
  t_to_opt = reverse_point_pairs(t_in_pixels)
574
- print('Running with:')
575
- print(f' Source: {p_in_pixels}')
576
- print(f' Target: {t_in_pixels}')
577
  step_idx = 0
578
  last_time = time.time()
579
  while True:
580
  print_memory_usage()
581
- # add a TIMEOUT break
582
- print(f'Running time: {time.time() - last_time}')
583
  if IS_SPACE and time.time() - last_time > TIMEOUT:
584
- print('Timeout break!')
585
  break
586
- if global_state["temporal_params"]["stop"] or global_state['generator_params']["stop"]:
 
 
587
  break
588
 
589
- # do drage here!
590
  renderer._render_drag_impl(
591
- global_state['generator_params'],
592
- p_to_opt, # point
593
- t_to_opt, # target
594
- drag_mask, # mask,
595
- global_state['params']['motion_lambda'], # lambda_mask
596
  reg=0,
597
- feature_idx=5, # NOTE: do not support change for now
598
- r1=global_state['params']['r1_in_pixels'], # r1
599
- r2=global_state['params']['r2_in_pixels'], # r2
600
- # random_seed = 0,
601
- # noise_mode = 'const',
602
- trunc_psi=global_state['params']['trunc_psi'],
603
- # force_fp32 = False,
604
- # layer_name = None,
605
- # sel_channels = 3,
606
- # base_channel = 0,
607
- # img_scale_db = 0,
608
- # img_normalize = False,
609
- # untransform = False,
610
  is_drag=True,
611
- to_pil=True)
 
612
 
613
- if step_idx % global_state['draw_interval'] == 0:
614
- print('Current Source:')
615
- for key_point, p_i, t_i in zip(valid_points, p_to_opt,
616
- t_to_opt):
 
617
  global_state["points"][key_point]["start_temp"] = [
618
  p_i[1],
619
  p_i[0],
@@ -623,79 +621,67 @@ with gr.Blocks() as app:
623
  t_i[0],
624
  ]
625
  start_temp = global_state["points"][key_point][
626
- "start_temp"]
627
- print(f' {start_temp}')
 
628
 
629
- image_result = global_state['generator_params']['image']
630
  image_draw = update_image_draw(
631
  image_result,
632
- global_state['points'],
633
- global_state['mask'],
634
- global_state['show_mask'],
635
  global_state,
636
  )
637
- global_state['images']['image_raw'] = image_result
638
 
639
  yield (
640
  global_state,
641
  step_idx,
642
- global_state['images']['image_show'],
643
- # gr.File.update(visible=False),
644
  gr.Button.update(interactive=False),
645
  gr.Button.update(interactive=False),
646
  gr.Button.update(interactive=False),
647
  gr.Button.update(interactive=False),
648
  gr.Button.update(interactive=False),
649
- # latent space
650
  gr.Radio.update(interactive=False),
651
  gr.Button.update(interactive=False),
652
- # enable stop button in loop
653
  gr.Button.update(interactive=True),
654
-
655
- # update other comps
656
  gr.Dropdown.update(interactive=False),
657
  gr.Number.update(interactive=False),
658
  gr.Number.update(interactive=False),
659
  gr.Button.update(interactive=False),
660
  gr.Button.update(interactive=False),
661
  gr.Checkbox.update(interactive=False),
662
- # gr.Number.update(interactive=False),
663
  gr.Number.update(interactive=False),
664
  )
665
 
666
- # increate step
667
  step_idx += 1
668
 
669
- image_result = global_state['generator_params']['image']
670
- global_state['images']['image_raw'] = image_result
671
- image_draw = update_image_draw(image_result,
672
- global_state['points'],
673
- global_state['mask'],
674
- global_state['show_mask'],
675
- global_state)
676
-
677
- # fp = NamedTemporaryFile(suffix=".png", delete=False)
678
- # image_result.save(fp, "PNG")
679
 
680
- global_state['editing_state'] = 'add_points'
681
 
682
  yield (
683
  global_state,
684
- 0, # reset step to 0 after stop.
685
- global_state['images']['image_show'],
686
- # gr.File.update(visible=True, value=fp.name),
687
  gr.Button.update(interactive=True),
688
  gr.Button.update(interactive=True),
689
  gr.Button.update(interactive=True),
690
  gr.Button.update(interactive=True),
691
  gr.Button.update(interactive=True),
692
- # latent space
693
  gr.Radio.update(interactive=True),
694
  gr.Button.update(interactive=True),
695
- # NOTE: disable stop button with loop finish
696
  gr.Button.update(interactive=False),
697
-
698
- # update other comps
699
  gr.Dropdown.update(interactive=True),
700
  gr.Number.update(interactive=True),
701
  gr.Number.update(interactive=True),
@@ -710,8 +696,6 @@ with gr.Blocks() as app:
710
  global_state,
711
  form_steps_number,
712
  form_image,
713
- # form_download_result_file,
714
- # >>> buttons
715
  form_reset_image,
716
  enable_add_points,
717
  enable_add_mask,
@@ -720,8 +704,6 @@ with gr.Blocks() as app:
720
  form_latent_space,
721
  form_start_btn,
722
  form_stop_btn,
723
- # <<< buttonm
724
- # >>> inputs comps
725
  form_pretrained_dropdown,
726
  form_seed_number,
727
  form_lr_number,
@@ -739,10 +721,12 @@ with gr.Blocks() as app:
739
 
740
  return global_state, gr.Button.update(interactive=False)
741
 
742
- form_stop_btn.click(on_click_stop,
743
- inputs=[global_state],
744
- outputs=[global_state, form_stop_btn],
745
- queue=False)
 
 
746
 
747
  form_draw_interval_number.change(
748
  partial(
@@ -771,17 +755,20 @@ with gr.Blocks() as app:
771
 
772
  # Mask
773
  def on_click_reset_mask(global_state):
774
- global_state['mask'] = np.ones(
775
  (
776
  global_state["images"]["image_raw"].size[1],
777
  global_state["images"]["image_raw"].size[0],
778
  ),
779
  dtype=np.uint8,
780
  )
781
- image_draw = update_image_draw(global_state['images']['image_raw'],
782
- global_state['points'],
783
- global_state['mask'],
784
- global_state['show_mask'], global_state)
 
 
 
785
  return global_state, image_draw
786
 
787
  form_reset_mask_btn.click(
@@ -798,13 +785,16 @@ with gr.Blocks() as app:
798
  3. Set curr image with points and mask
799
  """
800
  global_state = preprocess_mask_info(global_state, image)
801
- global_state['editing_state'] = 'add_mask'
802
- image_raw = global_state['images']['image_raw']
803
- image_draw = update_image_draw(image_raw, global_state['points'],
804
- global_state['mask'], True,
805
- global_state)
806
- return (global_state,
807
- gr.Image.update(value=image_draw, interactive=True))
 
 
 
808
 
809
  def on_click_remove_draw(global_state, image):
810
  """Function to start remove mask mode.
@@ -813,73 +803,89 @@ with gr.Blocks() as app:
813
  3. Set curr image with points and mask
814
  """
815
  global_state = preprocess_mask_info(global_state, image)
816
- global_state['edinting_state'] = 'remove_mask'
817
- image_raw = global_state['images']['image_raw']
818
- image_draw = update_image_draw(image_raw, global_state['points'],
819
- global_state['mask'], True,
820
- global_state)
821
- return (global_state,
822
- gr.Image.update(value=image_draw, interactive=True))
823
-
824
- enable_add_mask.click(on_click_enable_draw,
825
- inputs=[global_state, form_image],
826
- outputs=[
827
- global_state,
828
- form_image,
829
- ],
830
- queue=False)
 
 
 
 
 
831
 
832
  def on_click_add_point(global_state, image: dict):
833
  """Function switch from add mask mode to add points mode.
834
- 1. Updaste mask buffer if need
835
  2. Change global_state['editing_state'] to 'add_points'
836
  3. Set current image with mask
837
  """
838
 
839
  global_state = preprocess_mask_info(global_state, image)
840
- global_state['editing_state'] = 'add_points'
841
- mask = global_state['mask']
842
- image_raw = global_state['images']['image_raw']
843
- image_draw = update_image_draw(image_raw, global_state['points'], mask,
844
- global_state['show_mask'], global_state)
 
 
 
 
 
845
 
846
- return (global_state,
847
- gr.Image.update(value=image_draw, interactive=False))
 
 
848
 
849
- enable_add_points.click(on_click_add_point,
850
- inputs=[global_state, form_image],
851
- outputs=[global_state, form_image],
852
- queue=False)
 
 
853
 
854
  def on_click_image(global_state, evt: gr.SelectData):
855
  """This function only support click for point selection
856
  """
857
  xy = evt.index
858
- if global_state['editing_state'] != 'add_points':
859
- print(f'In {global_state["editing_state"]} state. '
860
- 'Do not add points.')
 
 
861
 
862
- return global_state, global_state['images']['image_show']
863
 
864
  points = global_state["points"]
865
 
866
  point_idx = get_latest_points_pair(points)
867
  if point_idx is None:
868
- points[0] = {'start': xy, 'target': None}
869
- print(f'Click Image - Start - {xy}')
870
- elif points[point_idx].get('target', None) is None:
871
- points[point_idx]['target'] = xy
872
- print(f'Click Image - Target - {xy}')
873
  else:
874
- points[point_idx + 1] = {'start': xy, 'target': None}
875
- print(f'Click Image - Start - {xy}')
876
 
877
- image_raw = global_state['images']['image_raw']
878
  image_draw = update_image_draw(
879
  image_raw,
880
- global_state['points'],
881
- global_state['mask'],
882
- global_state['show_mask'],
883
  global_state,
884
  )
885
 
@@ -898,31 +904,38 @@ with gr.Blocks() as app:
898
  2. re-init network
899
  2. re-draw image
900
  """
901
- clear_state(global_state, target='point')
902
 
903
  renderer: Renderer = global_state["renderer"]
904
  renderer.feat_refs = None
905
 
906
- image_raw = global_state['images']['image_raw']
907
- image_draw = update_image_draw(image_raw, {}, global_state['mask'],
908
- global_state['show_mask'], global_state)
 
 
 
 
 
909
  return global_state, image_draw
910
 
911
- undo_points.click(on_click_clear_points,
912
- inputs=[global_state],
913
- outputs=[global_state, form_image],
914
- queue=False)
 
 
915
 
916
- def on_click_show_mask(global_state, show_mask):
917
  """Function to control whether show mask on image."""
918
- global_state['show_mask'] = show_mask
919
 
920
- image_raw = global_state['images']['image_raw']
921
  image_draw = update_image_draw(
922
  image_raw,
923
- global_state['points'],
924
- global_state['mask'],
925
- global_state['show_mask'],
926
  global_state,
927
  )
928
  return global_state, image_draw
@@ -937,4 +950,4 @@ with gr.Blocks() as app:
937
  print("SHAReD: Start app", parser.parse_args())
938
  gr.close_all()
939
  app.queue(concurrency_count=1, max_size=200, api_open=False)
940
- app.launch(share=args.share, show_api=False)
 
15
  from PIL import Image
16
 
17
  import dnnlib
18
+ from gradio_utils import (
19
+ ImageMask,
20
+ draw_mask_on_image,
21
+ draw_points_on_image,
22
+ get_latest_points_pair,
23
+ get_valid_mask,
24
+ on_change_single_global_state,
25
+ )
26
  from viz.renderer import Renderer, add_watermark_np
27
 
 
28
  # download models from Hugging Face hub
29
  from huggingface_hub import snapshot_download
30
 
31
+ model_dir = Path("./checkpoints")
32
+ snapshot_download("DragGan/DragGan-Models", repo_type="model", local_dir=model_dir)
 
33
 
34
  parser = ArgumentParser()
35
+ parser.add_argument("--share", action="store_true")
36
+ parser.add_argument("--cache-dir", type=str, default="./checkpoints")
37
  args = parser.parse_args()
38
 
39
  cache_dir = args.cache_dir
40
 
41
+ IS_SPACE = "DragGan/DragGan" in os.environ.get("SPACE_ID", "")
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
  TIMEOUT = 80
44
 
45
 
 
57
  2. set global_state['mask'] as full-one mask.
58
  """
59
  if target is None:
60
+ target = ["point", "mask"]
61
  if not isinstance(target, list):
62
  target = [target]
63
+ if "point" in target:
64
+ global_state["points"] = dict()
65
+ print("Clear Points State!")
66
+ if "mask" in target:
67
  image_raw = global_state["images"]["image_raw"]
68
+ global_state["mask"] = np.ones(
69
+ (image_raw.size[1], image_raw.size[0]), dtype=np.uint8
70
+ )
71
+ print("Clear mask State!")
72
 
73
  return global_state
74
 
75
 
76
  def init_images(global_state):
77
+ """This function is called only once when Gradio App is started.
78
  0. pre-process global_state, unpack value from global_state of need
79
  1. Re-init renderer
80
  2. run `renderer._render_drag_impl` with `is_drag=False` to generate
 
87
  else:
88
  state = global_state
89
 
90
+ state["renderer"].init_network(
91
+ state["generator_params"], # res
92
+ valid_checkpoints_dict[state["pretrained_weight"]], # pkl
93
+ state["params"]["seed"], # w0_seed,
94
  None, # w_load
95
+ state["params"]["latent_space"] == "w+", # w_plus
96
+ "const",
97
+ state["params"]["trunc_psi"], # trunc_psi,
98
+ state["params"]["trunc_cutoff"], # trunc_cutoff,
99
  None, # input_transform
100
+ state["params"]["lr"], # lr,
101
+ )
102
+
103
+ state["renderer"]._render_drag_impl(
104
+ state["generator_params"], is_drag=False, to_pil=True
105
  )
106
 
107
+ init_image = state["generator_params"].image
108
+ state["images"]["image_orig"] = init_image
109
+ state["images"]["image_raw"] = init_image
110
+ state["images"]["image_show"] = Image.fromarray(
111
+ add_watermark_np(np.array(init_image))
112
+ )
113
+ state["mask"] = np.ones(
114
+ (init_image.size[1], init_image.size[0]), dtype=np.uint8
115
+ )
 
 
116
  return global_state
117
 
118
 
119
  def update_image_draw(image, points, mask, show_mask, global_state=None):
 
120
  image_draw = draw_points_on_image(image, points)
121
+ if (
122
+ show_mask
123
+ and mask is not None
124
+ and not (mask == 0).all()
125
+ and not (mask == 1).all()
126
+ ):
127
  image_draw = draw_mask_on_image(image_draw, mask)
128
 
129
  image_draw = Image.fromarray(add_watermark_np(np.array(image_draw)))
130
  if global_state is not None:
131
+ global_state["images"]["image_show"] = image_draw
132
  return image_draw
133
 
134
 
 
140
  2.2 global_state is add_mask:
141
  """
142
  if isinstance(image, dict):
143
+ last_mask = get_valid_mask(image["mask"])
144
  else:
145
  last_mask = None
146
+ mask = global_state["mask"]
147
 
148
  # mask in global state is a placeholder with all 1.
149
  if (mask == 1).all():
150
  mask = last_mask
151
 
152
+ editing_mode = global_state["editing_state"]
 
153
 
154
  if last_mask is None:
155
  return global_state
156
 
157
+ if editing_mode == "remove_mask":
158
  updated_mask = np.clip(mask - last_mask, 0, 1)
159
+ print(f"Last editing_state is {editing_mode}, do remove.")
160
+ elif editing_mode == "add_mask":
161
  updated_mask = np.clip(mask + last_mask, 0, 1)
162
+ print(f"Last editing_state is {editing_mode}, do add.")
163
  else:
164
  updated_mask = mask
165
+ print(f"Last editing_state is {editing_mode}, do nothing to mask.")
 
166
 
167
+ global_state["mask"] = updated_mask
 
168
  return global_state
169
 
170
 
 
177
  device = torch.device("cuda")
178
  print(f"GPU memory usage: {torch.cuda.memory_allocated() / 1e9} GB")
179
  print(
180
+ f"Max GPU memory usage: {torch.cuda.max_memory_allocated() / 1e9} GB"
181
+ )
182
  device_properties = torch.cuda.get_device_properties(device)
183
+ available_memory = (
184
+ device_properties.total_memory - torch.cuda.max_memory_allocated()
185
+ )
186
  print(f"Available GPU memory: {available_memory / 1e9} GB")
187
  else:
188
  print("No GPU available")
 
191
  # filter large models running on SPACES
192
  allowed_checkpoints = [] # all checkpoints
193
  if IS_SPACE:
194
+ allowed_checkpoints = [
195
+ "stylegan_human_v2_512.pkl",
196
+ "stylegan2_dogs_1024_pytorch.pkl",
197
+ ]
198
 
199
  valid_checkpoints_dict = {
200
+ f.name.split(".")[0]: str(f)
201
+ for f in Path(cache_dir).glob("*.pkl")
202
  if f.name in allowed_checkpoints or not IS_SPACE
203
  }
204
+ print("Valid checkpoint file:")
205
  print(valid_checkpoints_dict)
206
 
207
+ init_pkl = "stylegan_human_v2_512"
208
 
209
  with gr.Blocks() as app:
210
+ gr.Markdown(
211
+ """
212
  # DragGAN - Drag Your GAN
213
  ## Interactive Point-based Manipulation on the Generative Image Manifold
214
  ### Unofficial Gradio Demo
 
220
 
221
  * Official Repo: [XingangPan](https://github.com/XingangPan/DragGAN)
222
  * Gradio Demo by: [LeoXing1996](https://github.com/LeoXing1996) Β© [OpenMMLab MMagic](https://github.com/open-mmlab/mmagic)
223
+ """
224
+ )
225
+
226
+ global_state = gr.State(
227
+ {
228
+ "images": {},
229
+ "temporal_params": {},
230
+ "mask": None, # mask for visualization, 1 for editing and 0 for unchange
231
+ "last_mask": None, # last edited mask
232
+ "show_mask": True, # add button
233
+ "generator_params": dnnlib.EasyDict(),
234
+ "params": {
235
+ "seed": int(np.random.randint(0, 2**32 - 1)),
236
+ "motion_lambda": 20,
237
+ "r1_in_pixels": 3,
238
+ "r2_in_pixels": 12,
239
+ "magnitude_direction_in_pixels": 1.0,
240
+ "latent_space": "w+",
241
+ "trunc_psi": 0.7,
242
+ "trunc_cutoff": None,
243
+ "lr": 0.001,
244
+ },
245
+ "device": device,
246
+ "draw_interval": 1,
247
+ "renderer": Renderer(disable_timing=True),
248
+ "points": {},
249
+ "curr_point": None,
250
+ "curr_type_point": "start",
251
+ "editing_state": "add_points",
252
+ "pretrained_weight": init_pkl,
253
+ }
254
+ )
 
 
 
 
 
255
 
256
  # init image
257
  global_state = init_images(global_state)
258
  with gr.Row():
 
259
  with gr.Row():
 
260
  # Left --> tools
261
  with gr.Column(scale=3):
 
262
  # Pickle
263
  with gr.Row():
 
264
  with gr.Column(scale=1, min_width=10):
265
+ gr.Markdown(value="Pickle", show_label=False)
266
 
267
  with gr.Column(scale=4, min_width=10):
268
  form_pretrained_dropdown = gr.Dropdown(
 
274
  # Latent
275
  with gr.Row():
276
  with gr.Column(scale=1, min_width=10):
277
+ gr.Markdown(value="Latent", show_label=False)
278
 
279
  with gr.Column(scale=4, min_width=10):
280
  form_seed_number = gr.Slider(
281
+ minimum=0,
282
+ maximum=2**32 - 1,
283
  step=1,
284
+ value=global_state.value["params"]["seed"],
285
  interactive=True,
286
  # randomize=True,
287
  label="Seed",
 
289
  form_lr_number = gr.Number(
290
  value=global_state.value["params"]["lr"],
291
  interactive=True,
292
+ label="Step Size",
293
+ )
294
 
295
  with gr.Row():
296
  with gr.Column(scale=2, min_width=10):
297
  form_reset_image = gr.Button("Reset Image")
298
  with gr.Column(scale=3, min_width=10):
299
  form_latent_space = gr.Radio(
300
+ ["w", "w+"],
301
+ value=global_state.value["params"][
302
+ "latent_space"
303
+ ],
304
  interactive=True,
305
+ label="Latent space to optimize",
306
  show_label=False,
307
  )
308
 
309
  # Drag
310
  with gr.Row():
311
  with gr.Column(scale=1, min_width=10):
312
+ gr.Markdown(value="Drag", show_label=False)
313
  with gr.Column(scale=4, min_width=10):
314
  with gr.Row():
315
  with gr.Column(scale=1, min_width=10):
316
+ enable_add_points = gr.Button("Add Points")
317
  with gr.Column(scale=1, min_width=10):
318
+ undo_points = gr.Button("Reset Points")
319
  with gr.Row():
320
  with gr.Column(scale=1, min_width=10):
321
  form_start_btn = gr.Button("Start")
322
  with gr.Column(scale=1, min_width=10):
323
  form_stop_btn = gr.Button("Stop")
324
 
325
+ form_steps_number = gr.Number(
326
+ value=0,
327
+ label="Steps",
328
+ interactive=False,
329
+ )
330
 
331
  # Mask
332
  with gr.Row():
333
  with gr.Column(scale=1, min_width=10):
334
+ gr.Markdown(value="Mask", show_label=False)
335
  with gr.Column(scale=4, min_width=10):
336
+ enable_add_mask = gr.Button("Edit Flexible Area")
337
  with gr.Row():
338
  with gr.Column(scale=1, min_width=10):
339
  form_reset_mask_btn = gr.Button("Reset mask")
340
  with gr.Column(scale=1, min_width=10):
341
  show_mask = gr.Checkbox(
342
+ label="Show Mask",
343
+ value=global_state.value["show_mask"],
344
+ show_label=False,
345
+ )
346
 
347
  with gr.Row():
348
  form_lambda_number = gr.Number(
349
+ value=global_state.value["params"][
350
+ "motion_lambda"
351
+ ],
352
  interactive=True,
353
  label="Lambda",
354
  )
 
357
  value=global_state.value["draw_interval"],
358
  label="Draw Interval (steps)",
359
  interactive=True,
360
+ visible=False,
361
+ )
362
 
363
  # Right --> Image
364
  with gr.Column(scale=8):
365
  form_image = ImageMask(
366
+ value=global_state.value["images"]["image_show"],
367
+ brush_radius=20,
368
+ ).style(
369
+ width=768,
370
+ height=768,
371
+ ) # NOTE: hard image size code here.
372
+
373
+ gr.Markdown(
374
+ """
375
  ## Quick Start
376
 
377
  1. Select desired `Pretrained Model` and adjust `Seed` to generate an
 
390
  mask (this has the same effect as `Reset Image` button).
391
  3. Click `Edit Flexible Area` to create a mask and constrain the
392
  unmasked region to remain unchanged.
393
+ """
394
+ )
395
+ gr.HTML(
396
+ """
397
  <style>
398
  .container {
399
  position: absolute;
 
408
  <img src="https://avatars.githubusercontent.com/u/10245193?s=200&v=4" height="20" width="20" style="display:inline;">
409
  <a href="https://github.com/open-mmlab/mmagic">OpenMMLab MMagic</a>
410
  </div>
411
+ """
412
+ )
413
 
414
  def on_change_pretrained_dropdown(pretrained_value, global_state):
415
  """Function to handle model change.
 
417
  2. Re-init images and clear all states
418
  """
419
 
420
+ global_state["pretrained_weight"] = pretrained_value
421
  init_images(global_state)
422
  clear_state(global_state)
423
 
424
+ return global_state, global_state["images"]["image_show"]
425
 
426
  form_pretrained_dropdown.change(
427
  on_change_pretrained_dropdown,
 
439
  init_images(global_state)
440
  clear_state(global_state)
441
 
442
+ return global_state, global_state["images"]["image_show"]
443
 
444
  form_reset_image.click(
445
  on_click_reset_image,
 
459
  init_images(global_state)
460
  clear_state(global_state)
461
 
462
+ return global_state, global_state["images"]["image_show"]
463
 
464
  form_seed_number.change(
465
  on_change_update_image_seed,
 
474
  2. Re-init images and clear all state
475
  """
476
 
477
+ global_state["params"]["latent_space"] = latent_space
478
  init_images(global_state)
479
  clear_state(global_state)
480
 
481
+ return global_state, global_state["images"]["image_show"]
482
 
483
+ form_latent_space.change(
484
+ on_click_latent_space,
485
+ inputs=[form_latent_space, global_state],
486
+ outputs=[global_state, form_image],
487
+ )
488
 
489
  # ==== Params
490
  form_lambda_number.change(
 
495
 
496
  def on_change_lr(lr, global_state):
497
  if lr == 0:
498
+ print("lr is 0, do nothing.")
499
  return global_state
500
  else:
501
  global_state["params"]["lr"] = lr
502
+ renderer = global_state["renderer"]
503
  renderer.update_lr(lr)
504
+ print("New optimizer: ")
505
  print(renderer.w_optim)
506
  return global_state
507
 
 
522
 
523
  # Prepare the points for the inference
524
  if len(global_state["points"]) == 0:
525
+ image_raw = global_state["images"]["image_raw"]
 
526
  update_image_draw(
527
  image_raw,
528
+ global_state["points"],
529
+ global_state["mask"],
530
+ global_state["show_mask"],
531
  global_state,
532
  )
533
 
534
  yield (
535
  global_state,
536
  0,
537
+ global_state["images"]["image_show"],
 
538
  gr.Button.update(interactive=True),
539
  gr.Button.update(interactive=True),
540
  gr.Button.update(interactive=True),
541
  gr.Button.update(interactive=True),
542
  gr.Button.update(interactive=True),
 
543
  gr.Radio.update(interactive=True),
544
  gr.Button.update(interactive=True),
 
545
  gr.Button.update(interactive=False),
 
 
546
  gr.Dropdown.update(interactive=True),
547
  gr.Number.update(interactive=True),
548
  gr.Number.update(interactive=True),
549
  gr.Button.update(interactive=True),
550
  gr.Button.update(interactive=True),
551
  gr.Checkbox.update(interactive=True),
 
552
  gr.Number.update(interactive=True),
553
  )
554
  else:
 
 
555
  for key_point, point in global_state["points"].items():
556
  try:
557
  p_start = point.get("start_temp", point["start"])
 
567
  t_in_pixels.append(p_end)
568
  valid_points.append(key_point)
569
 
570
+ mask = torch.tensor(global_state["mask"]).float()
571
  drag_mask = 1 - mask
572
 
573
  renderer: Renderer = global_state["renderer"]
574
+ global_state["temporal_params"]["stop"] = False
575
+ global_state["editing_state"] = "running"
576
 
 
577
  p_to_opt = reverse_point_pairs(p_in_pixels)
578
  t_to_opt = reverse_point_pairs(t_in_pixels)
579
+ print("Running with:")
580
+ print(f" Source: {p_in_pixels}")
581
+ print(f" Target: {t_in_pixels}")
582
  step_idx = 0
583
  last_time = time.time()
584
  while True:
585
  print_memory_usage()
586
+ print(f"Running time: {time.time() - last_time}")
 
587
  if IS_SPACE and time.time() - last_time > TIMEOUT:
588
+ print("Timeout break!")
589
  break
590
+ if global_state["temporal_params"]["stop"] or global_state[
591
+ "generator_params"
592
+ ]["stop"]:
593
  break
594
 
 
595
  renderer._render_drag_impl(
596
+ global_state["generator_params"],
597
+ p_to_opt,
598
+ t_to_opt,
599
+ drag_mask,
600
+ global_state["params"]["motion_lambda"],
601
  reg=0,
602
+ feature_idx=5,
603
+ r1=global_state["params"]["r1_in_pixels"],
604
+ r2=global_state["params"]["r2_in_pixels"],
605
+ trunc_psi=global_state["params"]["trunc_psi"],
 
 
 
 
 
 
 
 
 
606
  is_drag=True,
607
+ to_pil=True,
608
+ )
609
 
610
+ if step_idx % global_state["draw_interval"] == 0:
611
+ print("Current Source:")
612
+ for key_point, p_i, t_i in zip(
613
+ valid_points, p_to_opt, t_to_opt
614
+ ):
615
  global_state["points"][key_point]["start_temp"] = [
616
  p_i[1],
617
  p_i[0],
 
621
  t_i[0],
622
  ]
623
  start_temp = global_state["points"][key_point][
624
+ "start_temp"
625
+ ]
626
+ print(f" {start_temp}")
627
 
628
+ image_result = global_state["generator_params"]["image"]
629
  image_draw = update_image_draw(
630
  image_result,
631
+ global_state["points"],
632
+ global_state["mask"],
633
+ global_state["show_mask"],
634
  global_state,
635
  )
636
+ global_state["images"]["image_raw"] = image_result
637
 
638
  yield (
639
  global_state,
640
  step_idx,
641
+ global_state["images"]["image_show"],
 
642
  gr.Button.update(interactive=False),
643
  gr.Button.update(interactive=False),
644
  gr.Button.update(interactive=False),
645
  gr.Button.update(interactive=False),
646
  gr.Button.update(interactive=False),
 
647
  gr.Radio.update(interactive=False),
648
  gr.Button.update(interactive=False),
 
649
  gr.Button.update(interactive=True),
 
 
650
  gr.Dropdown.update(interactive=False),
651
  gr.Number.update(interactive=False),
652
  gr.Number.update(interactive=False),
653
  gr.Button.update(interactive=False),
654
  gr.Button.update(interactive=False),
655
  gr.Checkbox.update(interactive=False),
 
656
  gr.Number.update(interactive=False),
657
  )
658
 
 
659
  step_idx += 1
660
 
661
+ image_result = global_state["generator_params"]["image"]
662
+ global_state["images"]["image_raw"] = image_result
663
+ image_draw = update_image_draw(
664
+ image_result,
665
+ global_state["points"],
666
+ global_state["mask"],
667
+ global_state["show_mask"],
668
+ global_state,
669
+ )
 
670
 
671
+ global_state["editing_state"] = "add_points"
672
 
673
  yield (
674
  global_state,
675
+ 0,
676
+ global_state["images"]["image_show"],
 
677
  gr.Button.update(interactive=True),
678
  gr.Button.update(interactive=True),
679
  gr.Button.update(interactive=True),
680
  gr.Button.update(interactive=True),
681
  gr.Button.update(interactive=True),
 
682
  gr.Radio.update(interactive=True),
683
  gr.Button.update(interactive=True),
 
684
  gr.Button.update(interactive=False),
 
 
685
  gr.Dropdown.update(interactive=True),
686
  gr.Number.update(interactive=True),
687
  gr.Number.update(interactive=True),
 
696
  global_state,
697
  form_steps_number,
698
  form_image,
 
 
699
  form_reset_image,
700
  enable_add_points,
701
  enable_add_mask,
 
704
  form_latent_space,
705
  form_start_btn,
706
  form_stop_btn,
 
 
707
  form_pretrained_dropdown,
708
  form_seed_number,
709
  form_lr_number,
 
721
 
722
  return global_state, gr.Button.update(interactive=False)
723
 
724
+ form_stop_btn.click(
725
+ on_click_stop,
726
+ inputs=[global_state],
727
+ outputs=[global_state, form_stop_btn],
728
+ queue=False,
729
+ )
730
 
731
  form_draw_interval_number.change(
732
  partial(
 
755
 
756
  # Mask
757
  def on_click_reset_mask(global_state):
758
+ global_state["mask"] = np.ones(
759
  (
760
  global_state["images"]["image_raw"].size[1],
761
  global_state["images"]["image_raw"].size[0],
762
  ),
763
  dtype=np.uint8,
764
  )
765
+ image_draw = update_image_draw(
766
+ global_state["images"]["image_raw"],
767
+ global_state["points"],
768
+ global_state["mask"],
769
+ global_state["show_mask"],
770
+ global_state,
771
+ )
772
  return global_state, image_draw
773
 
774
  form_reset_mask_btn.click(
 
785
  3. Set curr image with points and mask
786
  """
787
  global_state = preprocess_mask_info(global_state, image)
788
+ global_state["editing_state"] = "add_mask"
789
+ image_raw = global_state["images"]["image_raw"]
790
+ image_draw = update_image_draw(
791
+ image_raw,
792
+ global_state["points"],
793
+ global_state["mask"],
794
+ True,
795
+ global_state,
796
+ )
797
+ return (global_state, gr.Image.update(value=image_draw, interactive=True))
798
 
799
  def on_click_remove_draw(global_state, image):
800
  """Function to start remove mask mode.
 
803
  3. Set curr image with points and mask
804
  """
805
  global_state = preprocess_mask_info(global_state, image)
806
+ global_state["editing_state"] = "remove_mask"
807
+ image_raw = global_state["images"]["image_raw"]
808
+ image_draw = update_image_draw(
809
+ image_raw,
810
+ global_state["points"],
811
+ global_state["mask"],
812
+ True,
813
+ global_state,
814
+ )
815
+ return (global_state, gr.Image.update(value=image_draw, interactive=True))
816
+
817
+ enable_add_mask.click(
818
+ on_click_enable_draw,
819
+ inputs=[global_state, form_image],
820
+ outputs=[
821
+ global_state,
822
+ form_image,
823
+ ],
824
+ queue=False,
825
+ )
826
 
827
  def on_click_add_point(global_state, image: dict):
828
  """Function switch from add mask mode to add points mode.
829
+ 1. Update mask buffer if need
830
  2. Change global_state['editing_state'] to 'add_points'
831
  3. Set current image with mask
832
  """
833
 
834
  global_state = preprocess_mask_info(global_state, image)
835
+ global_state["editing_state"] = "add_points"
836
+ mask = global_state["mask"]
837
+ image_raw = global_state["images"]["image_raw"]
838
+ image_draw = update_image_draw(
839
+ image_raw,
840
+ global_state["points"],
841
+ mask,
842
+ global_state["show_mask"],
843
+ global_state,
844
+ )
845
 
846
+ return (
847
+ global_state,
848
+ gr.Image.update(value=image_draw, interactive=False),
849
+ )
850
 
851
+ enable_add_points.click(
852
+ on_click_add_point,
853
+ inputs=[global_state, form_image],
854
+ outputs=[global_state, form_image],
855
+ queue=False,
856
+ )
857
 
858
  def on_click_image(global_state, evt: gr.SelectData):
859
  """This function only support click for point selection
860
  """
861
  xy = evt.index
862
+ if global_state["editing_state"] != "add_points":
863
+ print(
864
+ f'In {global_state["editing_state"]} state. '
865
+ "Do not add points."
866
+ )
867
 
868
+ return global_state, global_state["images"]["image_show"]
869
 
870
  points = global_state["points"]
871
 
872
  point_idx = get_latest_points_pair(points)
873
  if point_idx is None:
874
+ points[0] = {"start": xy, "target": None}
875
+ print(f"Click Image - Start - {xy}")
876
+ elif points[point_idx].get("target", None) is None:
877
+ points[point_idx]["target"] = xy
878
+ print(f"Click Image - Target - {xy}")
879
  else:
880
+ points[point_idx + 1] = {"start": xy, "target": None}
881
+ print(f"Click Image - Start - {xy}")
882
 
883
+ image_raw = global_state["images"]["image_raw"]
884
  image_draw = update_image_draw(
885
  image_raw,
886
+ global_state["points"],
887
+ global_state["mask"],
888
+ global_state["show_mask"],
889
  global_state,
890
  )
891
 
 
904
  2. re-init network
905
  2. re-draw image
906
  """
907
+ clear_state(global_state, target="point")
908
 
909
  renderer: Renderer = global_state["renderer"]
910
  renderer.feat_refs = None
911
 
912
+ image_raw = global_state["images"]["image_raw"]
913
+ image_draw = update_image_draw(
914
+ image_raw,
915
+ {},
916
+ global_state["mask"],
917
+ global_state["show_mask"],
918
+ global_state,
919
+ )
920
  return global_state, image_draw
921
 
922
+ undo_points.click(
923
+ on_click_clear_points,
924
+ inputs=[global_state],
925
+ outputs=[global_state, form_image],
926
+ queue=False,
927
+ )
928
 
929
+ def on_click_show_mask(global_state, show_mask_val):
930
  """Function to control whether show mask on image."""
931
+ global_state["show_mask"] = show_mask_val
932
 
933
+ image_raw = global_state["images"]["image_raw"]
934
  image_draw = update_image_draw(
935
  image_raw,
936
+ global_state["points"],
937
+ global_state["mask"],
938
+ global_state["show_mask"],
939
  global_state,
940
  )
941
  return global_state, image_draw
 
950
  print("SHAReD: Start app", parser.parse_args())
951
  gr.close_all()
952
  app.queue(concurrency_count=1, max_size=200, api_open=False)
953
+ app.launch(share=args.share, show_api=False)