MidasRev commited on
Commit
fdf6ae7
·
verified ·
1 Parent(s): 42d68a6

Add auto logo mask and tune ZeroGPU

Browse files
Files changed (1) hide show
  1. web-demos/hugging_face/app.py +736 -683
web-demos/hugging_face/app.py CHANGED
@@ -1,684 +1,737 @@
1
- import sys
2
- sys.path.append("../../")
3
-
4
- import os
5
- import json
6
- import time
7
- import psutil
8
- import argparse
9
-
10
- import cv2
11
- import torch
12
- import torchvision
13
- import numpy as np
14
- import gradio as gr
15
- import spaces
16
-
17
- from tools.painter import mask_painter
18
- from track_anything import TrackingAnything
19
-
20
- from model.misc import get_device
21
- from utils.download_util import load_file_from_url, download_url_to_file
22
-
23
- # make sample videos into mp4 as git does not allow mp4 without lfs
24
- sample_videos_path = os.path.join('/home/user/app/web-demos/hugging_face/', "test_sample/")
25
- download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/14334509/281805130-e57c7016-5a6d-4d3b-9df9-b4ea6372cc87.mp4", os.path.join(sample_videos_path, "test-sample0.mp4"))
26
- download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/14334509/281828039-5def0fc9-3a22-45b7-838d-6bf78b6772c3.mp4", os.path.join(sample_videos_path, "test-sample1.mp4"))
27
- download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/76810782/281807801-69b9f70c-1e56-428d-9b1b-4870c5e533a7.mp4", os.path.join(sample_videos_path, "test-sample2.mp4"))
28
- download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/76810782/281808625-ad98f03f-99c7-4008-acf1-3d7beb48f13b.mp4", os.path.join(sample_videos_path, "test-sample3.mp4"))
29
- download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/14334509/281828066-ee09ae82-916f-4a2e-a6c7-6fc50645fd20.mp4", os.path.join(sample_videos_path, "test-sample4.mp4"))
30
-
31
-
32
- def parse_augment():
33
- parser = argparse.ArgumentParser()
34
- parser.add_argument('--device', type=str, default=None)
35
- parser.add_argument('--sam_model_type', type=str, default="vit_h")
36
- parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications")
37
- parser.add_argument('--mask_save', default=False)
38
- args = parser.parse_args()
39
-
40
- if not args.device:
41
- args.device = str(get_device())
42
-
43
- return args
44
-
45
- # convert points input to prompt state
46
- def get_prompt(click_state, click_input):
47
- inputs = json.loads(click_input)
48
- points = click_state[0]
49
- labels = click_state[1]
50
- for input in inputs:
51
- points.append(input[:2])
52
- labels.append(input[2])
53
- click_state[0] = points
54
- click_state[1] = labels
55
- prompt = {
56
- "prompt_type":["click"],
57
- "input_point":click_state[0],
58
- "input_label":click_state[1],
59
- "multimask_output":"True",
60
- }
61
- return prompt
62
-
63
- # extract frames from upload video
64
- def get_frames_from_video(video_input, video_state):
65
- """
66
- Args:
67
- video_path:str
68
- timestamp:float64
69
- Return
70
- [[0:nearest_frame], [nearest_frame:], nearest_frame]
71
- """
72
- video_path = video_input
73
- frames = []
74
- user_name = time.time()
75
- status_ok = True
76
- operation_log = [("[Must Do]", "Click image"), (": Video uploaded! Try to click the image shown in step2 to add masks.\n", None)]
77
- try:
78
- cap = cv2.VideoCapture(video_path)
79
- fps = cap.get(cv2.CAP_PROP_FPS)
80
- length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
81
-
82
- if length > 1200:
83
- operation_log = [("You uploaded a video with more than 500 frames. Stop the video extraction. Kindly lower the video frame rate to a value below 500. We highly recommend deploying the demo locally for long video processing.", "Error")]
84
- ret, frame = cap.read()
85
- if ret == True:
86
- original_h, original_w = frame.shape[:2]
87
- scale_factor = min(1, 1280/max(original_h, original_w))
88
- target_h, target_w = int(original_h*scale_factor), int(original_w*scale_factor)
89
- frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
90
- status_ok = False
91
- else:
92
- while cap.isOpened():
93
- ret, frame = cap.read()
94
- if ret == True:
95
- # resize input image
96
- original_h, original_w = frame.shape[:2]
97
- scale_factor = min(1, 1280/max(original_h, original_w))
98
- target_h, target_w = int(original_h*scale_factor), int(original_w*scale_factor)
99
- if scale_factor != 1:
100
- frame = cv2.resize(frame, (target_w, target_h))
101
- frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
102
- else:
103
- break
104
- t = len(frames)
105
- if t > 0:
106
- print(f'Inp video shape: t_{t}, s_{original_h}x{original_w} to s_{target_h}x{target_w}')
107
- else:
108
- print(f'Inp video shape: t_{t}, no input video!!!')
109
- except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
110
- status_ok = False
111
- print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
112
-
113
- # initialize video_state
114
- if frames[0].shape[0] > 720 or frames[0].shape[1] > 720:
115
- operation_log = [(f"Video uploaded! Try to click the image shown in step2 to add masks. (You uploaded a video with a size of {original_w}x{original_h}, and the length of its longest edge exceeds 720 pixels. We may resize the input video during processing.)", "Normal")]
116
-
117
- video_state = {
118
- "user_name": user_name,
119
- "video_name": os.path.split(video_path)[-1],
120
- "origin_images": frames,
121
- "painted_images": frames.copy(),
122
- "masks": [np.zeros((target_h, target_w), np.uint8)]*len(frames),
123
- "logits": [None]*len(frames),
124
- "select_frame_number": 0,
125
- "fps": fps
126
- }
127
- video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), length, (original_w, original_h))
128
- model.samcontroler.sam_controler.reset_image()
129
- model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
130
- return video_state, video_info, video_state["origin_images"][0], gr.update(visible=status_ok, maximum=len(frames), value=1), gr.update(visible=status_ok, maximum=len(frames), value=len(frames)), \
131
- gr.update(visible=status_ok), gr.update(visible=status_ok), \
132
- gr.update(visible=status_ok), gr.update(visible=status_ok),\
133
- gr.update(visible=status_ok), gr.update(visible=status_ok), \
134
- gr.update(visible=status_ok), gr.update(visible=status_ok), \
135
- gr.update(visible=status_ok), gr.update(visible=status_ok), \
136
- gr.update(visible=status_ok), gr.update(visible=status_ok, choices=[], value=[]), \
137
- gr.update(visible=True, value=operation_log), gr.update(visible=status_ok, value=operation_log)
138
-
139
- # get the select frame from gradio slider
140
- def select_template(image_selection_slider, video_state, interactive_state, mask_dropdown):
141
-
142
- # images = video_state[1]
143
- image_selection_slider -= 1
144
- video_state["select_frame_number"] = image_selection_slider
145
-
146
- # once select a new template frame, set the image in sam
147
-
148
- model.samcontroler.sam_controler.reset_image()
149
- model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
150
-
151
- operation_log = [("",""), ("Select tracking start frame {}. Try to click the image to add masks for tracking.".format(image_selection_slider),"Normal")]
152
-
153
- return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log, operation_log
154
-
155
- # set the tracking end frame
156
- def get_end_number(track_pause_number_slider, video_state, interactive_state):
157
- interactive_state["track_end_number"] = track_pause_number_slider
158
- operation_log = [("",""),("Select tracking finish frame {}.Try to click the image to add masks for tracking.".format(track_pause_number_slider),"Normal")]
159
-
160
- return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log
161
-
162
- # use sam to get the mask
163
- @spaces.GPU(duration=60)
164
- def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
165
- """
166
- Args:
167
- template_frame: PIL.Image
168
- point_prompt: flag for positive or negative button click
169
- click_state: [[points], [labels]]
170
- """
171
- if point_prompt == "Positive":
172
- coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
173
- interactive_state["positive_click_times"] += 1
174
- else:
175
- coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
176
- interactive_state["negative_click_times"] += 1
177
-
178
- # prompt for sam model
179
- model.samcontroler.sam_controler.reset_image()
180
- model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
181
- prompt = get_prompt(click_state=click_state, click_input=coordinate)
182
-
183
- mask, logit, painted_image = model.first_frame_click(
184
- image=video_state["origin_images"][video_state["select_frame_number"]],
185
- points=np.array(prompt["input_point"]),
186
- labels=np.array(prompt["input_label"]),
187
- multimask=prompt["multimask_output"],
188
- )
189
- video_state["masks"][video_state["select_frame_number"]] = mask
190
- video_state["logits"][video_state["select_frame_number"]] = logit
191
- video_state["painted_images"][video_state["select_frame_number"]] = painted_image
192
-
193
- operation_log = [("[Must Do]", "Add mask"), (": add the current displayed mask for video segmentation.\n", None),
194
- ("[Optional]", "Remove mask"), (": remove all added masks.\n", None),
195
- ("[Optional]", "Clear clicks"), (": clear current displayed mask.\n", None),
196
- ("[Optional]", "Click image"), (": Try to click the image shown in step2 if you want to generate more masks.\n", None)]
197
- return painted_image, video_state, interactive_state, operation_log, operation_log
198
-
199
- def add_multi_mask(video_state, interactive_state, mask_dropdown):
200
- try:
201
- mask = video_state["masks"][video_state["select_frame_number"]]
202
- interactive_state["multi_mask"]["masks"].append(mask)
203
- interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
204
- mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
205
- select_frame, _, _ = show_mask(video_state, interactive_state, mask_dropdown)
206
- operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")]
207
- except:
208
- operation_log = [("Please click the image in step2 to generate masks.", "Error"), ("","")]
209
- return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log, operation_log
210
-
211
- def clear_click(video_state, click_state):
212
- click_state = [[],[]]
213
- template_frame = video_state["origin_images"][video_state["select_frame_number"]]
214
- operation_log = [("",""), ("Cleared points history and refresh the image.","Normal")]
215
- return template_frame, click_state, operation_log, operation_log
216
-
217
- def remove_multi_mask(interactive_state, mask_dropdown):
218
- interactive_state["multi_mask"]["mask_names"]= []
219
- interactive_state["multi_mask"]["masks"] = []
220
-
221
- operation_log = [("",""), ("Remove all masks. Try to add new masks","Normal")]
222
- return interactive_state, gr.update(choices=[],value=[]), operation_log, operation_log
223
-
224
- def show_mask(video_state, interactive_state, mask_dropdown):
225
- mask_dropdown.sort()
226
- select_frame = video_state["origin_images"][video_state["select_frame_number"]]
227
- for i in range(len(mask_dropdown)):
228
- mask_number = int(mask_dropdown[i].split("_")[1]) - 1
229
- mask = interactive_state["multi_mask"]["masks"][mask_number]
230
- select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
231
-
232
- operation_log = [("",""), ("Added masks {}. If you want to do the inpainting with current masks, please go to step3, and click the Tracking button first and then Inpainting button.".format(mask_dropdown),"Normal")]
233
- return select_frame, operation_log, operation_log
234
-
235
- # tracking vos
236
- @spaces.GPU(duration=120)
237
- def vos_tracking_video(video_state, interactive_state, mask_dropdown):
238
- operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")]
239
- model.cutie.clear_memory()
240
- if interactive_state["track_end_number"]:
241
- following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
242
- else:
243
- following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
244
-
245
- if interactive_state["multi_mask"]["masks"]:
246
- if len(mask_dropdown) == 0:
247
- mask_dropdown = ["mask_001"]
248
- mask_dropdown.sort()
249
- template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
250
- for i in range(1,len(mask_dropdown)):
251
- mask_number = int(mask_dropdown[i].split("_")[1]) - 1
252
- template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
253
- video_state["masks"][video_state["select_frame_number"]]= template_mask
254
- else:
255
- template_mask = video_state["masks"][video_state["select_frame_number"]]
256
- fps = video_state["fps"]
257
-
258
- # operation error
259
- if len(np.unique(template_mask))==1:
260
- template_mask[0][0]=1
261
- operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
262
- # return video_output, video_state, interactive_state, operation_error
263
- masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
264
- # clear GPU memory
265
- model.cutie.clear_memory()
266
-
267
- if interactive_state["track_end_number"]:
268
- video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
269
- video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
270
- video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
271
- else:
272
- video_state["masks"][video_state["select_frame_number"]:] = masks
273
- video_state["logits"][video_state["select_frame_number"]:] = logits
274
- video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
275
-
276
- video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
277
- interactive_state["inference_times"] += 1
278
-
279
- print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"],
280
- interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
281
- interactive_state["positive_click_times"],
282
- interactive_state["negative_click_times"]))
283
-
284
- #### shanggao code for mask save
285
- if interactive_state["mask_save"]:
286
- if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
287
- os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0]))
288
- i = 0
289
- print("save mask")
290
- for mask in video_state["masks"]:
291
- np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask)
292
- i+=1
293
- # save_mask(video_state["masks"], video_state["video_name"])
294
- #### shanggao code for mask save
295
- return video_output, video_state, interactive_state, operation_log, operation_log
296
-
297
- # inpaint
298
- @spaces.GPU(duration=120)
299
- def inpaint_video(video_state, resize_ratio_number, dilate_radius_number, raft_iter_number, subvideo_length_number, neighbor_length_number, ref_stride_number, mask_dropdown):
300
- operation_log = [("",""), ("Inpainting finished!","Normal")]
301
-
302
- frames = np.asarray(video_state["origin_images"])
303
- fps = video_state["fps"]
304
- inpaint_masks = np.asarray(video_state["masks"])
305
- if len(mask_dropdown) == 0:
306
- mask_dropdown = ["mask_001"]
307
- mask_dropdown.sort()
308
- # convert mask_dropdown to mask numbers
309
- inpaint_mask_numbers = [int(mask_dropdown[i].split("_")[1]) for i in range(len(mask_dropdown))]
310
- # interate through all masks and remove the masks that are not in mask_dropdown
311
- unique_masks = np.unique(inpaint_masks)
312
- num_masks = len(unique_masks) - 1
313
- for i in range(1, num_masks + 1):
314
- if i in inpaint_mask_numbers:
315
- continue
316
- inpaint_masks[inpaint_masks==i] = 0
317
-
318
- # inpaint for videos
319
- inpainted_frames = model.baseinpainter.inpaint(frames,
320
- inpaint_masks,
321
- ratio=resize_ratio_number,
322
- dilate_radius=dilate_radius_number,
323
- raft_iter=raft_iter_number,
324
- subvideo_length=subvideo_length_number,
325
- neighbor_length=neighbor_length_number,
326
- ref_stride=ref_stride_number) # numpy array, T, H, W, 3
327
-
328
- video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
329
-
330
- return video_output, operation_log, operation_log
331
-
332
- # generate video after vos inference
333
- def generate_video_from_frames(frames, output_path, fps=30):
334
- """
335
- Generates a video from a list of frames.
336
-
337
- Args:
338
- frames (list of numpy arrays): The frames to include in the video.
339
- output_path (str): The path to save the generated video.
340
- fps (int, optional): The frame rate of the output video. Defaults to 30.
341
- """
342
- frames = torch.from_numpy(np.asarray(frames))
343
- if not os.path.exists(os.path.dirname(output_path)):
344
- os.makedirs(os.path.dirname(output_path))
345
- torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
346
- return output_path
347
-
348
- def restart():
349
- operation_log = [("",""), ("Try to upload your video and click the Get video info button to get started! (Kindly ensure that the uploaded video consists of fewer than 500 frames in total)", "Normal")]
350
- return {
351
- "user_name": "",
352
- "video_name": "",
353
- "origin_images": None,
354
- "painted_images": None,
355
- "masks": None,
356
- "inpaint_masks": None,
357
- "logits": None,
358
- "select_frame_number": 0,
359
- "fps": 30
360
- }, {
361
- "inference_times": 0,
362
- "negative_click_times" : 0,
363
- "positive_click_times": 0,
364
- "mask_save": args.mask_save,
365
- "multi_mask": {
366
- "mask_names": [],
367
- "masks": []
368
- },
369
- "track_end_number": None,
370
- }, [[],[]], None, None, None, \
371
- gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\
372
- gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
373
- gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
374
- gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", \
375
- gr.update(visible=True, value=operation_log), gr.update(visible=False, value=operation_log)
376
-
377
-
378
- # args, defined in track_anything.py
379
- args = parse_augment()
380
- pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/'
381
- sam_checkpoint_url_dict = {
382
- 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
383
- 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
384
- 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
385
- }
386
- checkpoint_fodler = os.path.join('..', '..', 'weights')
387
-
388
- sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_fodler)
389
- cutie_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'cutie-base-mega.pth'), checkpoint_fodler)
390
- propainter_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'ProPainter.pth'), checkpoint_fodler)
391
- raft_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'raft-things.pth'), checkpoint_fodler)
392
- flow_completion_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'), checkpoint_fodler)
393
-
394
- # initialize sam, cutie, propainter models
395
- model = TrackingAnything(sam_checkpoint, cutie_checkpoint, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, args)
396
-
397
-
398
- title = r"""<h1 align="center">ProPainter: Improving Propagation and Transformer for Video Inpainting</h1>"""
399
-
400
- description = r"""
401
- <center><img src='https://github.com/sczhou/ProPainter/raw/main/assets/propainter_logo1_glow.png' alt='Propainter logo' style="width:180px; margin-bottom:20px"></center>
402
- <b>Official Gradio demo</b> for <a href='https://github.com/sczhou/ProPainter' target='_blank'><b>Improving Propagation and Transformer for Video Inpainting (ICCV 2023)</b></a>.<br>
403
- 🔥 Propainter is a robust inpainting algorithm.<br>
404
- 🤗 Try to drop your video, add the masks and get the the inpainting results!<br>
405
- """
406
- article = r"""
407
- If ProPainter is helpful, please help to ⭐ the <a href='https://github.com/sczhou/ProPainter' target='_blank'>Github Repo</a>. Thanks!
408
- [![GitHub Stars](https://img.shields.io/github/stars/sczhou/ProPainter?style=social)](https://github.com/sczhou/ProPainter)
409
-
410
- ---
411
-
412
- 📝 **Citation**
413
- <br>
414
- If our work is useful for your research, please consider citing:
415
- ```bibtex
416
- @inproceedings{zhou2023propainter,
417
- title={{ProPainter}: Improving Propagation and Transformer for Video Inpainting},
418
- author={Zhou, Shangchen and Li, Chongyi and Chan, Kelvin C.K and Loy, Chen Change},
419
- booktitle={Proceedings of IEEE International Conference on Computer Vision (ICCV)},
420
- year={2023}
421
- }
422
- ```
423
-
424
- 📋 **License**
425
- <br>
426
- This project is licensed under <a rel="license" href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">S-Lab License 1.0</a>.
427
- Redistribution and use for non-commercial purposes should follow this license.
428
-
429
- 📧 **Contact**
430
- <br>
431
- If you have any questions, please feel free to reach me out at <b>shangchenzhou@gmail.com</b>.
432
- <div>
433
- 🤗 Find Me:
434
- <a href="https://twitter.com/ShangchenZhou"><img style="margin-top:0.5em; margin-bottom:0.5em" src="https://img.shields.io/twitter/follow/ShangchenZhou?label=%40ShangchenZhou&style=social" alt="Twitter Follow"></a>
435
- <a href="https://github.com/sczhou"><img style="margin-top:0.5em; margin-bottom:2em" src="https://img.shields.io/github/followers/sczhou?style=social" alt="Github Follow"></a>
436
- </div>
437
-
438
- """
439
- css = """
440
- .gradio-container {width: 85% !important}
441
- .gr-monochrome-group {border-radius: 5px !important; border: revert-layer !important; border-width: 2px !important; color: black !important;}
442
- span.svelte-s1r2yt {font-size: 17px !important; font-weight: bold !important; color: #d30f2f !important;}
443
- button {border-radius: 8px !important;}
444
- .add_button {background-color: #4CAF50 !important;}
445
- .remove_button {background-color: #f44336 !important;}
446
- .clear_button {background-color: gray !important;}
447
- .mask_button_group {gap: 10px !important;}
448
- .video {height: 300px !important;}
449
- .image {height: 300px !important;}
450
- .video .wrap.svelte-lcpz3o {display: flex !important; align-items: center !important; justify-content: center !important;}
451
- .video .wrap.svelte-lcpz3o > :first-child {height: 100% !important;}
452
- .margin_center {width: 50% !important; margin: auto !important;}
453
- .jc_center {justify-content: center !important;}
454
- """
455
-
456
- with gr.Blocks(theme=gr.themes.Monochrome(), css=css) as iface:
457
- click_state = gr.State([[],[]])
458
-
459
- interactive_state = gr.State({
460
- "inference_times": 0,
461
- "negative_click_times" : 0,
462
- "positive_click_times": 0,
463
- "mask_save": args.mask_save,
464
- "multi_mask": {
465
- "mask_names": [],
466
- "masks": []
467
- },
468
- "track_end_number": None,
469
- }
470
- )
471
-
472
- video_state = gr.State(
473
- {
474
- "user_name": "",
475
- "video_name": "",
476
- "origin_images": None,
477
- "painted_images": None,
478
- "masks": None,
479
- "inpaint_masks": None,
480
- "logits": None,
481
- "select_frame_number": 0,
482
- "fps": 30
483
- }
484
- )
485
-
486
- gr.Markdown(title)
487
- gr.Markdown(description)
488
-
489
- with gr.Group(elem_classes="gr-monochrome-group"):
490
- with gr.Row():
491
- with gr.Accordion('ProPainter Parameters (click to expand)', open=False):
492
- with gr.Row():
493
- resize_ratio_number = gr.Slider(label='Resize ratio',
494
- minimum=0.01,
495
- maximum=1.0,
496
- step=0.01,
497
- value=1.0)
498
- raft_iter_number = gr.Slider(label='Iterations for RAFT inference.',
499
- minimum=5,
500
- maximum=20,
501
- step=1,
502
- value=20,)
503
- with gr.Row():
504
- dilate_radius_number = gr.Slider(label='Mask dilation for video and flow masking.',
505
- minimum=0,
506
- maximum=10,
507
- step=1,
508
- value=8,)
509
-
510
- subvideo_length_number = gr.Slider(label='Length of sub-video for long video inference.',
511
- minimum=40,
512
- maximum=200,
513
- step=1,
514
- value=80,)
515
- with gr.Row():
516
- neighbor_length_number = gr.Slider(label='Length of local neighboring frames.',
517
- minimum=5,
518
- maximum=20,
519
- step=1,
520
- value=10,)
521
-
522
- ref_stride_number = gr.Slider(label='Stride of global reference frames.',
523
- minimum=5,
524
- maximum=20,
525
- step=1,
526
- value=10,)
527
-
528
- with gr.Column():
529
- # input video
530
- gr.Markdown("## Step1: Upload video")
531
- with gr.Row(equal_height=True):
532
- with gr.Column(scale=2):
533
- video_input = gr.Video(elem_classes="video")
534
- extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
535
- with gr.Column(scale=2):
536
- run_status = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started! (Kindly ensure that the uploaded video consists of fewer than 500 frames in total)", "Normal")],
537
- color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"})
538
- video_info = gr.Textbox(label="Video Info")
539
-
540
-
541
- # add masks
542
- step2_title = gr.Markdown("---\n## Step2: Add masks", visible=False)
543
- with gr.Row(equal_height=True):
544
- with gr.Column(scale=2):
545
- template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
546
- image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False)
547
- track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
548
- with gr.Column(scale=2, elem_classes="jc_center"):
549
- run_status2 = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started! (Kindly ensure that the uploaded video consists of fewer than 500 frames in total)", "Normal")],
550
- color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"},
551
- visible=False)
552
- with gr.Column():
553
- point_prompt = gr.Radio(
554
- choices=["Positive", "Negative"],
555
- value="Positive",
556
- label="Point prompt",
557
- interactive=True,
558
- visible=False,
559
- min_width=100,
560
- scale=1,)
561
- with gr.Row(elem_classes="mask_button_group"):
562
- Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False, elem_classes="add_button")
563
- remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False, elem_classes="remove_button")
564
- clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False, elem_classes="clear_button")
565
- mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
566
-
567
- # output video
568
- step3_title = gr.Markdown("---\n## Step3: Track masks and get the inpainting result", visible=False)
569
- with gr.Row(equal_height=True):
570
- with gr.Column(scale=2):
571
- tracking_video_output = gr.Video(visible=False, elem_classes="video")
572
- tracking_video_predict_button = gr.Button(value="1. Tracking", visible=False, elem_classes="margin_center")
573
- with gr.Column(scale=2):
574
- inpaiting_video_output = gr.Video(visible=False, elem_classes="video")
575
- inpaint_video_predict_button = gr.Button(value="2. Inpainting", visible=False, elem_classes="margin_center")
576
-
577
- # first step: get the video information
578
- extract_frames_button.click(
579
- fn=get_frames_from_video,
580
- inputs=[
581
- video_input, video_state
582
- ],
583
- outputs=[video_state, video_info, template_frame,
584
- image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame,
585
- tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button, inpaint_video_predict_button, step2_title, step3_title,mask_dropdown, run_status, run_status2]
586
- )
587
-
588
- # second step: select images from slider
589
- image_selection_slider.release(fn=select_template,
590
- inputs=[image_selection_slider, video_state, interactive_state],
591
- outputs=[template_frame, video_state, interactive_state, run_status, run_status2], api_name="select_image")
592
- track_pause_number_slider.release(fn=get_end_number,
593
- inputs=[track_pause_number_slider, video_state, interactive_state],
594
- outputs=[template_frame, interactive_state, run_status, run_status2], api_name="end_image")
595
-
596
- # click select image to get mask using sam
597
- template_frame.select(
598
- fn=sam_refine,
599
- inputs=[video_state, point_prompt, click_state, interactive_state],
600
- outputs=[template_frame, video_state, interactive_state, run_status, run_status2]
601
- )
602
-
603
- # add different mask
604
- Add_mask_button.click(
605
- fn=add_multi_mask,
606
- inputs=[video_state, interactive_state, mask_dropdown],
607
- outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status, run_status2]
608
- )
609
-
610
- remove_mask_button.click(
611
- fn=remove_multi_mask,
612
- inputs=[interactive_state, mask_dropdown],
613
- outputs=[interactive_state, mask_dropdown, run_status, run_status2]
614
- )
615
-
616
- # tracking video from select image and mask
617
- tracking_video_predict_button.click(
618
- fn=vos_tracking_video,
619
- inputs=[video_state, interactive_state, mask_dropdown],
620
- outputs=[tracking_video_output, video_state, interactive_state, run_status, run_status2]
621
- )
622
-
623
- # inpaint video from select image and mask
624
- inpaint_video_predict_button.click(
625
- fn=inpaint_video,
626
- inputs=[video_state, resize_ratio_number, dilate_radius_number, raft_iter_number, subvideo_length_number, neighbor_length_number, ref_stride_number, mask_dropdown],
627
- outputs=[inpaiting_video_output, run_status, run_status2]
628
- )
629
-
630
- # click to get mask
631
- mask_dropdown.change(
632
- fn=show_mask,
633
- inputs=[video_state, interactive_state, mask_dropdown],
634
- outputs=[template_frame, run_status, run_status2]
635
- )
636
-
637
- # clear input
638
- video_input.change(
639
- fn=restart,
640
- inputs=[],
641
- outputs=[
642
- video_state,
643
- interactive_state,
644
- click_state,
645
- tracking_video_output, inpaiting_video_output,
646
- template_frame,
647
- tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
648
- Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2
649
- ],
650
- queue=False,
651
- show_progress=False)
652
-
653
- video_input.clear(
654
- fn=restart,
655
- inputs=[],
656
- outputs=[
657
- video_state,
658
- interactive_state,
659
- click_state,
660
- tracking_video_output, inpaiting_video_output,
661
- template_frame,
662
- tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
663
- Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2
664
- ],
665
- queue=False,
666
- show_progress=False)
667
-
668
- # points clear
669
- clear_button_click.click(
670
- fn = clear_click,
671
- inputs = [video_state, click_state,],
672
- outputs = [template_frame,click_state, run_status, run_status2],
673
- )
674
-
675
- # set example
676
- gr.Markdown("## Examples")
677
- gr.Examples(
678
- examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0.mp4", "test-sample1.mp4", "test-sample2.mp4", "test-sample3.mp4", "test-sample4.mp4"]],
679
- inputs=[video_input],
680
- )
681
- gr.Markdown(article)
682
-
683
- iface.queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
  iface.launch(debug=True)
 
1
+ import sys
2
+ sys.path.append("../../")
3
+
4
+ import os
5
+ import json
6
+ import time
7
+ import psutil
8
+ import argparse
9
+
10
+ import cv2
11
+ import torch
12
+ import torchvision
13
+ import numpy as np
14
+ import gradio as gr
15
+ import spaces
16
+
17
+ from tools.painter import mask_painter
18
+ from track_anything import TrackingAnything
19
+
20
+ from model.misc import get_device
21
+ from utils.download_util import load_file_from_url, download_url_to_file
22
+
23
+ # make sample videos into mp4 as git does not allow mp4 without lfs
24
+ sample_videos_path = os.path.join('/home/user/app/web-demos/hugging_face/', "test_sample/")
25
+ download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/14334509/281805130-e57c7016-5a6d-4d3b-9df9-b4ea6372cc87.mp4", os.path.join(sample_videos_path, "test-sample0.mp4"))
26
+ download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/14334509/281828039-5def0fc9-3a22-45b7-838d-6bf78b6772c3.mp4", os.path.join(sample_videos_path, "test-sample1.mp4"))
27
+ download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/76810782/281807801-69b9f70c-1e56-428d-9b1b-4870c5e533a7.mp4", os.path.join(sample_videos_path, "test-sample2.mp4"))
28
+ download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/76810782/281808625-ad98f03f-99c7-4008-acf1-3d7beb48f13b.mp4", os.path.join(sample_videos_path, "test-sample3.mp4"))
29
+ download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/14334509/281828066-ee09ae82-916f-4a2e-a6c7-6fc50645fd20.mp4", os.path.join(sample_videos_path, "test-sample4.mp4"))
30
+
31
+
32
+ def parse_augment():
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument('--device', type=str, default=None)
35
+ parser.add_argument('--sam_model_type', type=str, default="vit_h")
36
+ parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications")
37
+ parser.add_argument('--mask_save', default=False)
38
+ args = parser.parse_args()
39
+
40
+ if not args.device:
41
+ args.device = str(get_device())
42
+
43
+ return args
44
+
45
+ # convert points input to prompt state
46
+ def get_prompt(click_state, click_input):
47
+ inputs = json.loads(click_input)
48
+ points = click_state[0]
49
+ labels = click_state[1]
50
+ for input in inputs:
51
+ points.append(input[:2])
52
+ labels.append(input[2])
53
+ click_state[0] = points
54
+ click_state[1] = labels
55
+ prompt = {
56
+ "prompt_type":["click"],
57
+ "input_point":click_state[0],
58
+ "input_label":click_state[1],
59
+ "multimask_output":"True",
60
+ }
61
+ return prompt
62
+
63
+ # extract frames from upload video
64
+ def get_frames_from_video(video_input, video_state):
65
+ """
66
+ Args:
67
+ video_path:str
68
+ timestamp:float64
69
+ Return
70
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
71
+ """
72
+ video_path = video_input
73
+ frames = []
74
+ user_name = time.time()
75
+ status_ok = True
76
+ operation_log = [("[Must Do]", "Click image"), (": Video uploaded! Try to click the image shown in step2 to add masks.\n", None)]
77
+ try:
78
+ cap = cv2.VideoCapture(video_path)
79
+ fps = cap.get(cv2.CAP_PROP_FPS)
80
+ length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
81
+
82
+ if length > 1200:
83
+ operation_log = [("You uploaded a video with more than 500 frames. Stop the video extraction. Kindly lower the video frame rate to a value below 500. We highly recommend deploying the demo locally for long video processing.", "Error")]
84
+ ret, frame = cap.read()
85
+ if ret == True:
86
+ original_h, original_w = frame.shape[:2]
87
+ scale_factor = min(1, 1280/max(original_h, original_w))
88
+ target_h, target_w = int(original_h*scale_factor), int(original_w*scale_factor)
89
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
90
+ status_ok = False
91
+ else:
92
+ while cap.isOpened():
93
+ ret, frame = cap.read()
94
+ if ret == True:
95
+ # resize input image
96
+ original_h, original_w = frame.shape[:2]
97
+ scale_factor = min(1, 1280/max(original_h, original_w))
98
+ target_h, target_w = int(original_h*scale_factor), int(original_w*scale_factor)
99
+ if scale_factor != 1:
100
+ frame = cv2.resize(frame, (target_w, target_h))
101
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
102
+ else:
103
+ break
104
+ t = len(frames)
105
+ if t > 0:
106
+ print(f'Inp video shape: t_{t}, s_{original_h}x{original_w} to s_{target_h}x{target_w}')
107
+ else:
108
+ print(f'Inp video shape: t_{t}, no input video!!!')
109
+ except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
110
+ status_ok = False
111
+ print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
112
+
113
+ # initialize video_state
114
+ if frames[0].shape[0] > 720 or frames[0].shape[1] > 720:
115
+ operation_log = [(f"Video uploaded! Try to click the image shown in step2 to add masks. (You uploaded a video with a size of {original_w}x{original_h}, and the length of its longest edge exceeds 720 pixels. We may resize the input video during processing.)", "Normal")]
116
+
117
+ video_state = {
118
+ "user_name": user_name,
119
+ "video_name": os.path.split(video_path)[-1],
120
+ "origin_images": frames,
121
+ "painted_images": frames.copy(),
122
+ "masks": [np.zeros((target_h, target_w), np.uint8)]*len(frames),
123
+ "logits": [None]*len(frames),
124
+ "select_frame_number": 0,
125
+ "fps": fps
126
+ }
127
+ video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), length, (original_w, original_h))
128
+ model.samcontroler.sam_controler.reset_image()
129
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
130
+ return video_state, video_info, video_state["origin_images"][0], gr.update(visible=status_ok, maximum=len(frames), value=1), gr.update(visible=status_ok, maximum=len(frames), value=len(frames)), \
131
+ gr.update(visible=status_ok), gr.update(visible=status_ok), \
132
+ gr.update(visible=status_ok), gr.update(visible=status_ok),\
133
+ gr.update(visible=status_ok), gr.update(visible=status_ok), \
134
+ gr.update(visible=status_ok), gr.update(visible=status_ok), \
135
+ gr.update(visible=status_ok), gr.update(visible=status_ok), \
136
+ gr.update(visible=status_ok), gr.update(visible=status_ok, choices=[], value=[]), \
137
+ gr.update(visible=True, value=operation_log), gr.update(visible=status_ok, value=operation_log)
138
+
139
+ # get the select frame from gradio slider
140
+ def select_template(image_selection_slider, video_state, interactive_state, mask_dropdown):
141
+
142
+ # images = video_state[1]
143
+ image_selection_slider -= 1
144
+ video_state["select_frame_number"] = image_selection_slider
145
+
146
+ # once select a new template frame, set the image in sam
147
+
148
+ model.samcontroler.sam_controler.reset_image()
149
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
150
+
151
+ operation_log = [("",""), ("Select tracking start frame {}. Try to click the image to add masks for tracking.".format(image_selection_slider),"Normal")]
152
+
153
+ return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log, operation_log
154
+
155
+ # set the tracking end frame
156
+ def get_end_number(track_pause_number_slider, video_state, interactive_state):
157
+ interactive_state["track_end_number"] = track_pause_number_slider
158
+ operation_log = [("",""),("Select tracking finish frame {}.Try to click the image to add masks for tracking.".format(track_pause_number_slider),"Normal")]
159
+
160
+ return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log
161
+
162
+ # use sam to get the mask
163
+ @spaces.GPU(duration=60)
164
+ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
165
+ """
166
+ Args:
167
+ template_frame: PIL.Image
168
+ point_prompt: flag for positive or negative button click
169
+ click_state: [[points], [labels]]
170
+ """
171
+ if point_prompt == "Positive":
172
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
173
+ interactive_state["positive_click_times"] += 1
174
+ else:
175
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
176
+ interactive_state["negative_click_times"] += 1
177
+
178
+ # prompt for sam model
179
+ model.samcontroler.sam_controler.reset_image()
180
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
181
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
182
+
183
+ mask, logit, painted_image = model.first_frame_click(
184
+ image=video_state["origin_images"][video_state["select_frame_number"]],
185
+ points=np.array(prompt["input_point"]),
186
+ labels=np.array(prompt["input_label"]),
187
+ multimask=prompt["multimask_output"],
188
+ )
189
+ video_state["masks"][video_state["select_frame_number"]] = mask
190
+ video_state["logits"][video_state["select_frame_number"]] = logit
191
+ video_state["painted_images"][video_state["select_frame_number"]] = painted_image
192
+
193
+ operation_log = [("[Must Do]", "Add mask"), (": add the current displayed mask for video segmentation.\n", None),
194
+ ("[Optional]", "Remove mask"), (": remove all added masks.\n", None),
195
+ ("[Optional]", "Clear clicks"), (": clear current displayed mask.\n", None),
196
+ ("[Optional]", "Click image"), (": Try to click the image shown in step2 if you want to generate more masks.\n", None)]
197
+ return painted_image, video_state, interactive_state, operation_log, operation_log
198
+
199
+ def add_multi_mask(video_state, interactive_state, mask_dropdown):
200
+ try:
201
+ mask = video_state["masks"][video_state["select_frame_number"]]
202
+ interactive_state["multi_mask"]["masks"].append(mask)
203
+ interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
204
+ mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
205
+ select_frame, _, _ = show_mask(video_state, interactive_state, mask_dropdown)
206
+ operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")]
207
+ except:
208
+ operation_log = [("Please click the image in step2 to generate masks.", "Error"), ("","")]
209
+ return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log, operation_log
210
+
211
+ def clear_click(video_state, click_state):
212
+ click_state = [[],[]]
213
+ template_frame = video_state["origin_images"][video_state["select_frame_number"]]
214
+ operation_log = [("",""), ("Cleared points history and refresh the image.","Normal")]
215
+ return template_frame, click_state, operation_log, operation_log
216
+
217
+ def remove_multi_mask(interactive_state, mask_dropdown):
218
+ interactive_state["multi_mask"]["mask_names"]= []
219
+ interactive_state["multi_mask"]["masks"] = []
220
+
221
+ operation_log = [("",""), ("Remove all masks. Try to add new masks","Normal")]
222
+ return interactive_state, gr.update(choices=[],value=[]), operation_log, operation_log
223
+
224
+ def show_mask(video_state, interactive_state, mask_dropdown):
225
+ mask_dropdown.sort()
226
+ select_frame = video_state["origin_images"][video_state["select_frame_number"]]
227
+ for i in range(len(mask_dropdown)):
228
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
229
+ mask = interactive_state["multi_mask"]["masks"][mask_number]
230
+ select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
231
+
232
+ operation_log = [("",""), ("Added masks {}. If you want to do the inpainting with current masks, please go to step3, and click the Tracking button first and then Inpainting button.".format(mask_dropdown),"Normal")]
233
+ return select_frame, operation_log, operation_log
234
+
235
+ # tracking vos
236
+ @spaces.GPU(duration=120)
237
+ def auto_mask_logo(video_state, interactive_state):
238
+ if not video_state["origin_images"]:
239
+ operation_log = [("Please upload a video first.", "Error"), ("", "")]
240
+ return None, video_state, interactive_state, gr.update(choices=[], value=[]), operation_log, operation_log
241
+
242
+ frames = video_state["origin_images"]
243
+ height, width = frames[0].shape[:2]
244
+ x0 = int(width * 100 / 1920)
245
+ x1 = int(width * 1820 / 1920)
246
+ top_y0 = int(height * 340 / 1080)
247
+ top_y1 = int(height * 515 / 1080)
248
+ bottom_y0 = int(height * 565 / 1080)
249
+ bottom_y1 = int(height * 700 / 1080)
250
+
251
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11))
252
+ masks = []
253
+ min_component_area = max(48, (height * width) // 40000)
254
+
255
+ for frame in frames:
256
+ gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
257
+ mask = np.zeros((height, width), np.uint8)
258
+ for y0, y1 in ((top_y0, top_y1), (bottom_y0, bottom_y1)):
259
+ roi = gray[y0:y1, x0:x1]
260
+ mask[y0:y1, x0:x1] = (roi > 105).astype(np.uint8)
261
+
262
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
263
+ filtered = np.zeros_like(mask)
264
+ for label in range(1, num_labels):
265
+ if stats[label, cv2.CC_STAT_AREA] >= min_component_area:
266
+ filtered[labels == label] = 1
267
+
268
+ filtered = cv2.dilate(filtered, kernel)
269
+ masks.append(filtered.astype(np.uint8))
270
+
271
+ video_state["masks"] = masks
272
+ current_index = min(video_state["select_frame_number"], len(masks) - 1)
273
+ preview = mask_painter(video_state["origin_images"][current_index], masks[current_index].astype('uint8'), mask_color=2)
274
+ video_state["painted_images"][current_index] = preview
275
+ interactive_state["multi_mask"]["mask_names"] = ["mask_001"]
276
+ interactive_state["multi_mask"]["masks"] = [masks[current_index]]
277
+ operation_log = [("", ""), ("Auto logo mask generated. Run inpainting directly or refine with manual clicks.", "Normal")]
278
+ return preview, video_state, interactive_state, gr.update(choices=["mask_001"], value=["mask_001"]), operation_log, operation_log
279
+
280
+
281
+ @spaces.GPU(duration=120)
282
+ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
283
+ operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")]
284
+ model.cutie.clear_memory()
285
+ if interactive_state["track_end_number"]:
286
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
287
+ else:
288
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
289
+
290
+ if interactive_state["multi_mask"]["masks"]:
291
+ if len(mask_dropdown) == 0:
292
+ mask_dropdown = ["mask_001"]
293
+ mask_dropdown.sort()
294
+ template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
295
+ for i in range(1,len(mask_dropdown)):
296
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
297
+ template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
298
+ video_state["masks"][video_state["select_frame_number"]]= template_mask
299
+ else:
300
+ template_mask = video_state["masks"][video_state["select_frame_number"]]
301
+ fps = video_state["fps"]
302
+
303
+ # operation error
304
+ if len(np.unique(template_mask))==1:
305
+ template_mask[0][0]=1
306
+ operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
307
+ # return video_output, video_state, interactive_state, operation_error
308
+ masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
309
+ # clear GPU memory
310
+ model.cutie.clear_memory()
311
+
312
+ if interactive_state["track_end_number"]:
313
+ video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
314
+ video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
315
+ video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
316
+ else:
317
+ video_state["masks"][video_state["select_frame_number"]:] = masks
318
+ video_state["logits"][video_state["select_frame_number"]:] = logits
319
+ video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
320
+
321
+ video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
322
+ interactive_state["inference_times"] += 1
323
+
324
+ print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"],
325
+ interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
326
+ interactive_state["positive_click_times"],
327
+ interactive_state["negative_click_times"]))
328
+
329
+ #### shanggao code for mask save
330
+ if interactive_state["mask_save"]:
331
+ if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
332
+ os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0]))
333
+ i = 0
334
+ print("save mask")
335
+ for mask in video_state["masks"]:
336
+ np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask)
337
+ i+=1
338
+ # save_mask(video_state["masks"], video_state["video_name"])
339
+ #### shanggao code for mask save
340
+ return video_output, video_state, interactive_state, operation_log, operation_log
341
+
342
+ # inpaint
343
+ @spaces.GPU(duration=120)
344
+ def inpaint_video(video_state, resize_ratio_number, dilate_radius_number, raft_iter_number, subvideo_length_number, neighbor_length_number, ref_stride_number, mask_dropdown):
345
+ operation_log = [("",""), ("Inpainting finished!","Normal")]
346
+
347
+ frames = np.asarray(video_state["origin_images"])
348
+ fps = video_state["fps"]
349
+ inpaint_masks = np.asarray(video_state["masks"])
350
+ if len(mask_dropdown) == 0:
351
+ mask_dropdown = ["mask_001"]
352
+ mask_dropdown.sort()
353
+ # convert mask_dropdown to mask numbers
354
+ inpaint_mask_numbers = [int(mask_dropdown[i].split("_")[1]) for i in range(len(mask_dropdown))]
355
+ # interate through all masks and remove the masks that are not in mask_dropdown
356
+ unique_masks = np.unique(inpaint_masks)
357
+ num_masks = len(unique_masks) - 1
358
+ for i in range(1, num_masks + 1):
359
+ if i in inpaint_mask_numbers:
360
+ continue
361
+ inpaint_masks[inpaint_masks==i] = 0
362
+
363
+ # inpaint for videos
364
+ inpainted_frames = model.baseinpainter.inpaint(frames,
365
+ inpaint_masks,
366
+ ratio=resize_ratio_number,
367
+ dilate_radius=dilate_radius_number,
368
+ raft_iter=raft_iter_number,
369
+ subvideo_length=subvideo_length_number,
370
+ neighbor_length=neighbor_length_number,
371
+ ref_stride=ref_stride_number) # numpy array, T, H, W, 3
372
+
373
+ video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
374
+
375
+ return video_output, operation_log, operation_log
376
+
377
+ # generate video after vos inference
378
+ def generate_video_from_frames(frames, output_path, fps=30):
379
+ """
380
+ Generates a video from a list of frames.
381
+
382
+ Args:
383
+ frames (list of numpy arrays): The frames to include in the video.
384
+ output_path (str): The path to save the generated video.
385
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
386
+ """
387
+ frames = torch.from_numpy(np.asarray(frames))
388
+ if not os.path.exists(os.path.dirname(output_path)):
389
+ os.makedirs(os.path.dirname(output_path))
390
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
391
+ return output_path
392
+
393
+ def restart():
394
+ operation_log = [("",""), ("Try to upload your video and click the Get video info button to get started! (Kindly ensure that the uploaded video consists of fewer than 500 frames in total)", "Normal")]
395
+ return {
396
+ "user_name": "",
397
+ "video_name": "",
398
+ "origin_images": None,
399
+ "painted_images": None,
400
+ "masks": None,
401
+ "inpaint_masks": None,
402
+ "logits": None,
403
+ "select_frame_number": 0,
404
+ "fps": 30
405
+ }, {
406
+ "inference_times": 0,
407
+ "negative_click_times" : 0,
408
+ "positive_click_times": 0,
409
+ "mask_save": args.mask_save,
410
+ "multi_mask": {
411
+ "mask_names": [],
412
+ "masks": []
413
+ },
414
+ "track_end_number": None,
415
+ }, [[],[]], None, None, None, \
416
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\
417
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
418
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
419
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", \
420
+ gr.update(visible=True, value=operation_log), gr.update(visible=False, value=operation_log)
421
+
422
+
423
+ # args, defined in track_anything.py
424
+ args = parse_augment()
425
+ pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/'
426
+ sam_checkpoint_url_dict = {
427
+ 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
428
+ 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
429
+ 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
430
+ }
431
+ checkpoint_fodler = os.path.join('..', '..', 'weights')
432
+
433
+ sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_fodler)
434
+ cutie_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'cutie-base-mega.pth'), checkpoint_fodler)
435
+ propainter_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'ProPainter.pth'), checkpoint_fodler)
436
+ raft_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'raft-things.pth'), checkpoint_fodler)
437
+ flow_completion_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'), checkpoint_fodler)
438
+
439
+ # initialize sam, cutie, propainter models
440
+ model = TrackingAnything(sam_checkpoint, cutie_checkpoint, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, args)
441
+
442
+
443
+ title = r"""<h1 align="center">ProPainter: Improving Propagation and Transformer for Video Inpainting</h1>"""
444
+
445
+ description = r"""
446
+ <center><img src='https://github.com/sczhou/ProPainter/raw/main/assets/propainter_logo1_glow.png' alt='Propainter logo' style="width:180px; margin-bottom:20px"></center>
447
+ <b>Official Gradio demo</b> for <a href='https://github.com/sczhou/ProPainter' target='_blank'><b>Improving Propagation and Transformer for Video Inpainting (ICCV 2023)</b></a>.<br>
448
+ 🔥 Propainter is a robust inpainting algorithm.<br>
449
+ 🤗 Try to drop your video, add the masks and get the the inpainting results!<br>
450
+ """
451
+ article = r"""
452
+ If ProPainter is helpful, please help to ⭐ the <a href='https://github.com/sczhou/ProPainter' target='_blank'>Github Repo</a>. Thanks!
453
+ [![GitHub Stars](https://img.shields.io/github/stars/sczhou/ProPainter?style=social)](https://github.com/sczhou/ProPainter)
454
+
455
+ ---
456
+
457
+ 📝 **Citation**
458
+ <br>
459
+ If our work is useful for your research, please consider citing:
460
+ ```bibtex
461
+ @inproceedings{zhou2023propainter,
462
+ title={{ProPainter}: Improving Propagation and Transformer for Video Inpainting},
463
+ author={Zhou, Shangchen and Li, Chongyi and Chan, Kelvin C.K and Loy, Chen Change},
464
+ booktitle={Proceedings of IEEE International Conference on Computer Vision (ICCV)},
465
+ year={2023}
466
+ }
467
+ ```
468
+
469
+ 📋 **License**
470
+ <br>
471
+ This project is licensed under <a rel="license" href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">S-Lab License 1.0</a>.
472
+ Redistribution and use for non-commercial purposes should follow this license.
473
+
474
+ 📧 **Contact**
475
+ <br>
476
+ If you have any questions, please feel free to reach me out at <b>shangchenzhou@gmail.com</b>.
477
+ <div>
478
+ 🤗 Find Me:
479
+ <a href="https://twitter.com/ShangchenZhou"><img style="margin-top:0.5em; margin-bottom:0.5em" src="https://img.shields.io/twitter/follow/ShangchenZhou?label=%40ShangchenZhou&style=social" alt="Twitter Follow"></a>
480
+ <a href="https://github.com/sczhou"><img style="margin-top:0.5em; margin-bottom:2em" src="https://img.shields.io/github/followers/sczhou?style=social" alt="Github Follow"></a>
481
+ </div>
482
+
483
+ """
484
+ css = """
485
+ .gradio-container {width: 85% !important}
486
+ .gr-monochrome-group {border-radius: 5px !important; border: revert-layer !important; border-width: 2px !important; color: black !important;}
487
+ span.svelte-s1r2yt {font-size: 17px !important; font-weight: bold !important; color: #d30f2f !important;}
488
+ button {border-radius: 8px !important;}
489
+ .add_button {background-color: #4CAF50 !important;}
490
+ .remove_button {background-color: #f44336 !important;}
491
+ .clear_button {background-color: gray !important;}
492
+ .mask_button_group {gap: 10px !important;}
493
+ .video {height: 300px !important;}
494
+ .image {height: 300px !important;}
495
+ .video .wrap.svelte-lcpz3o {display: flex !important; align-items: center !important; justify-content: center !important;}
496
+ .video .wrap.svelte-lcpz3o > :first-child {height: 100% !important;}
497
+ .margin_center {width: 50% !important; margin: auto !important;}
498
+ .jc_center {justify-content: center !important;}
499
+ """
500
+
501
+ with gr.Blocks(theme=gr.themes.Monochrome(), css=css) as iface:
502
+ click_state = gr.State([[],[]])
503
+
504
+ interactive_state = gr.State({
505
+ "inference_times": 0,
506
+ "negative_click_times" : 0,
507
+ "positive_click_times": 0,
508
+ "mask_save": args.mask_save,
509
+ "multi_mask": {
510
+ "mask_names": [],
511
+ "masks": []
512
+ },
513
+ "track_end_number": None,
514
+ }
515
+ )
516
+
517
+ video_state = gr.State(
518
+ {
519
+ "user_name": "",
520
+ "video_name": "",
521
+ "origin_images": None,
522
+ "painted_images": None,
523
+ "masks": None,
524
+ "inpaint_masks": None,
525
+ "logits": None,
526
+ "select_frame_number": 0,
527
+ "fps": 30
528
+ }
529
+ )
530
+
531
+ gr.Markdown(title)
532
+ gr.Markdown(description)
533
+
534
+ with gr.Group(elem_classes="gr-monochrome-group"):
535
+ with gr.Row():
536
+ with gr.Accordion('ProPainter Parameters (click to expand)', open=False):
537
+ with gr.Row():
538
+ resize_ratio_number = gr.Slider(label='Resize ratio',
539
+ minimum=0.01,
540
+ maximum=1.0,
541
+ step=0.01,
542
+ value=1.0)
543
+ raft_iter_number = gr.Slider(label='Iterations for RAFT inference.',
544
+ minimum=5,
545
+ maximum=20,
546
+ step=1,
547
+ value=20,)
548
+ with gr.Row():
549
+ dilate_radius_number = gr.Slider(label='Mask dilation for video and flow masking.',
550
+ minimum=0,
551
+ maximum=10,
552
+ step=1,
553
+ value=8,)
554
+
555
+ subvideo_length_number = gr.Slider(label='Length of sub-video for long video inference.',
556
+ minimum=40,
557
+ maximum=200,
558
+ step=1,
559
+ value=80,)
560
+ with gr.Row():
561
+ neighbor_length_number = gr.Slider(label='Length of local neighboring frames.',
562
+ minimum=5,
563
+ maximum=20,
564
+ step=1,
565
+ value=10,)
566
+
567
+ ref_stride_number = gr.Slider(label='Stride of global reference frames.',
568
+ minimum=5,
569
+ maximum=20,
570
+ step=1,
571
+ value=10,)
572
+
573
+ with gr.Column():
574
+ # input video
575
+ gr.Markdown("## Step1: Upload video")
576
+ with gr.Row(equal_height=True):
577
+ with gr.Column(scale=2):
578
+ video_input = gr.Video(elem_classes="video")
579
+ extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
580
+ with gr.Column(scale=2):
581
+ run_status = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started! (Kindly ensure that the uploaded video consists of fewer than 500 frames in total)", "Normal")],
582
+ color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"})
583
+ video_info = gr.Textbox(label="Video Info")
584
+
585
+
586
+ # add masks
587
+ step2_title = gr.Markdown("---\n## Step2: Add masks", visible=False)
588
+ with gr.Row(equal_height=True):
589
+ with gr.Column(scale=2):
590
+ template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
591
+ image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False)
592
+ track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
593
+ with gr.Column(scale=2, elem_classes="jc_center"):
594
+ run_status2 = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get video info button to get started! (Kindly ensure that the uploaded video consists of fewer than 500 frames in total)", "Normal")],
595
+ color_map={"Normal": "green", "Error": "red", "Clear clicks": "gray", "Add mask": "green", "Remove mask": "red"},
596
+ visible=False)
597
+ with gr.Column():
598
+ point_prompt = gr.Radio(
599
+ choices=["Positive", "Negative"],
600
+ value="Positive",
601
+ label="Point prompt",
602
+ interactive=True,
603
+ visible=False,
604
+ min_width=100,
605
+ scale=1,)
606
+ with gr.Row(elem_classes="mask_button_group"):
607
+ Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False, elem_classes="add_button")
608
+ remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False, elem_classes="remove_button")
609
+ clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False, elem_classes="clear_button")
610
+ auto_mask_button = gr.Button(value="Auto logo mask", interactive=True, visible=False)
611
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
612
+
613
+ # output video
614
+ step3_title = gr.Markdown("---\n## Step3: Track masks and get the inpainting result", visible=False)
615
+ with gr.Row(equal_height=True):
616
+ with gr.Column(scale=2):
617
+ tracking_video_output = gr.Video(visible=False, elem_classes="video")
618
+ tracking_video_predict_button = gr.Button(value="1. Tracking", visible=False, elem_classes="margin_center")
619
+ with gr.Column(scale=2):
620
+ inpaiting_video_output = gr.Video(visible=False, elem_classes="video")
621
+ inpaint_video_predict_button = gr.Button(value="2. Inpainting", visible=False, elem_classes="margin_center")
622
+
623
+ # first step: get the video information
624
+ extract_frames_button.click(
625
+ fn=get_frames_from_video,
626
+ inputs=[
627
+ video_input, video_state
628
+ ],
629
+ outputs=[video_state, video_info, template_frame,
630
+ image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame,
631
+ tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button, inpaint_video_predict_button, step2_title, step3_title,mask_dropdown, run_status, run_status2]
632
+ )
633
+
634
+ # second step: select images from slider
635
+ image_selection_slider.release(fn=select_template,
636
+ inputs=[image_selection_slider, video_state, interactive_state],
637
+ outputs=[template_frame, video_state, interactive_state, run_status, run_status2], api_name="select_image")
638
+ track_pause_number_slider.release(fn=get_end_number,
639
+ inputs=[track_pause_number_slider, video_state, interactive_state],
640
+ outputs=[template_frame, interactive_state, run_status, run_status2], api_name="end_image")
641
+
642
+ # click select image to get mask using sam
643
+ template_frame.select(
644
+ fn=sam_refine,
645
+ inputs=[video_state, point_prompt, click_state, interactive_state],
646
+ outputs=[template_frame, video_state, interactive_state, run_status, run_status2]
647
+ )
648
+
649
+ # add different mask
650
+ Add_mask_button.click(
651
+ fn=add_multi_mask,
652
+ inputs=[video_state, interactive_state, mask_dropdown],
653
+ outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status, run_status2]
654
+ )
655
+
656
+ remove_mask_button.click(
657
+ fn=remove_multi_mask,
658
+ inputs=[interactive_state, mask_dropdown],
659
+ outputs=[interactive_state, mask_dropdown, run_status, run_status2]
660
+ )
661
+ auto_mask_button.click(
662
+ fn=auto_mask_logo,
663
+ inputs=[video_state, interactive_state],
664
+ outputs=[template_frame, video_state, interactive_state, mask_dropdown, run_status, run_status2],
665
+ api_name="auto_logo_mask"
666
+ )
667
+
668
+
669
+ # tracking video from select image and mask
670
+ tracking_video_predict_button.click(
671
+ fn=vos_tracking_video,
672
+ inputs=[video_state, interactive_state, mask_dropdown],
673
+ outputs=[tracking_video_output, video_state, interactive_state, run_status, run_status2]
674
+ )
675
+
676
+ # inpaint video from select image and mask
677
+ inpaint_video_predict_button.click(
678
+ fn=inpaint_video,
679
+ inputs=[video_state, resize_ratio_number, dilate_radius_number, raft_iter_number, subvideo_length_number, neighbor_length_number, ref_stride_number, mask_dropdown],
680
+ outputs=[inpaiting_video_output, run_status, run_status2]
681
+ )
682
+
683
+ # click to get mask
684
+ mask_dropdown.change(
685
+ fn=show_mask,
686
+ inputs=[video_state, interactive_state, mask_dropdown],
687
+ outputs=[template_frame, run_status, run_status2]
688
+ )
689
+
690
+ # clear input
691
+ video_input.change(
692
+ fn=restart,
693
+ inputs=[],
694
+ outputs=[
695
+ video_state,
696
+ interactive_state,
697
+ click_state,
698
+ tracking_video_output, inpaiting_video_output,
699
+ template_frame,
700
+ tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
701
+ Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2
702
+ ],
703
+ queue=False,
704
+ show_progress=False)
705
+
706
+ video_input.clear(
707
+ fn=restart,
708
+ inputs=[],
709
+ outputs=[
710
+ video_state,
711
+ interactive_state,
712
+ click_state,
713
+ tracking_video_output, inpaiting_video_output,
714
+ template_frame,
715
+ tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
716
+ Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2
717
+ ],
718
+ queue=False,
719
+ show_progress=False)
720
+
721
+ # points clear
722
+ clear_button_click.click(
723
+ fn = clear_click,
724
+ inputs = [video_state, click_state,],
725
+ outputs = [template_frame,click_state, run_status, run_status2],
726
+ )
727
+
728
+ # set example
729
+ gr.Markdown("## Examples")
730
+ gr.Examples(
731
+ examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0.mp4", "test-sample1.mp4", "test-sample2.mp4", "test-sample3.mp4", "test-sample4.mp4"]],
732
+ inputs=[video_input],
733
+ )
734
+ gr.Markdown(article)
735
+
736
+ iface.queue()
737
  iface.launch(debug=True)