Spaces:
Runtime error
Runtime error
Commit
·
5da584d
1
Parent(s):
94dd0a9
Upload 65 files
Browse files- app.py +11 -39
- requirements.txt +5 -1
app.py
CHANGED
|
@@ -13,13 +13,7 @@ import requests
|
|
| 13 |
import json
|
| 14 |
import torchvision
|
| 15 |
import torch
|
| 16 |
-
from tools.interact_tools import SamControler
|
| 17 |
-
from tracker.base_tracker import BaseTracker
|
| 18 |
from tools.painter import mask_painter
|
| 19 |
-
try:
|
| 20 |
-
from mmcv.cnn import ConvModule
|
| 21 |
-
except:
|
| 22 |
-
os.system("mim install mmcv")
|
| 23 |
|
| 24 |
# download checkpoints
|
| 25 |
def download_checkpoint(url, folder, filename):
|
|
@@ -206,7 +200,6 @@ def show_mask(video_state, interactive_state, mask_dropdown):
|
|
| 206 |
|
| 207 |
# tracking vos
|
| 208 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
| 209 |
-
|
| 210 |
model.xmem.clear_memory()
|
| 211 |
if interactive_state["track_end_number"]:
|
| 212 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
|
@@ -226,8 +219,6 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
| 226 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
| 227 |
fps = video_state["fps"]
|
| 228 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
| 229 |
-
# clear GPU memory
|
| 230 |
-
model.xmem.clear_memory()
|
| 231 |
|
| 232 |
if interactive_state["track_end_number"]:
|
| 233 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
|
@@ -267,7 +258,6 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
| 267 |
|
| 268 |
# inpaint
|
| 269 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
| 270 |
-
|
| 271 |
frames = np.asarray(video_state["origin_images"])
|
| 272 |
fps = video_state["fps"]
|
| 273 |
inpaint_masks = np.asarray(video_state["masks"])
|
|
@@ -314,44 +304,27 @@ def generate_video_from_frames(frames, output_path, fps=30):
|
|
| 314 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
| 315 |
return output_path
|
| 316 |
|
| 317 |
-
|
| 318 |
-
# args, defined in track_anything.py
|
| 319 |
-
args = parse_augment()
|
| 320 |
-
|
| 321 |
# check and download checkpoints if needed
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
'vit_l': "sam_vit_l_0b3195.pth",
|
| 325 |
-
"vit_b": "sam_vit_b_01ec64.pth"
|
| 326 |
-
}
|
| 327 |
-
SAM_checkpoint_url_dict = {
|
| 328 |
-
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
| 329 |
-
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
| 330 |
-
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
| 331 |
-
}
|
| 332 |
-
sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
|
| 333 |
-
sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
|
| 334 |
xmem_checkpoint = "XMem-s012.pth"
|
| 335 |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
| 336 |
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
| 337 |
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
| 338 |
|
| 339 |
-
|
| 340 |
folder ="./checkpoints"
|
| 341 |
-
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder,
|
| 342 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
| 343 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
| 344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
# initialize sam, xmem, e2fgvi models
|
| 347 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
| 348 |
|
| 349 |
-
|
| 350 |
-
title = """<p><h1 align="center">Track-Anything</h1></p>
|
| 351 |
-
"""
|
| 352 |
-
description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. I To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">https://github.com/gaomingqi/Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
| 353 |
-
|
| 354 |
-
|
| 355 |
with gr.Blocks() as iface:
|
| 356 |
"""
|
| 357 |
state for
|
|
@@ -383,8 +356,7 @@ with gr.Blocks() as iface:
|
|
| 383 |
"fps": 30
|
| 384 |
}
|
| 385 |
)
|
| 386 |
-
|
| 387 |
-
gr.Markdown(description)
|
| 388 |
with gr.Row():
|
| 389 |
|
| 390 |
# for user video input
|
|
@@ -393,7 +365,7 @@ with gr.Blocks() as iface:
|
|
| 393 |
video_input = gr.Video(autosize=True)
|
| 394 |
with gr.Column():
|
| 395 |
video_info = gr.Textbox()
|
| 396 |
-
|
| 397 |
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
| 398 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
| 399 |
|
|
@@ -562,7 +534,7 @@ with gr.Blocks() as iface:
|
|
| 562 |
# cache_examples=True,
|
| 563 |
)
|
| 564 |
iface.queue(concurrency_count=1)
|
| 565 |
-
iface.launch(debug=True, enable_queue=True)
|
| 566 |
|
| 567 |
|
| 568 |
|
|
|
|
| 13 |
import json
|
| 14 |
import torchvision
|
| 15 |
import torch
|
|
|
|
|
|
|
| 16 |
from tools.painter import mask_painter
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# download checkpoints
|
| 19 |
def download_checkpoint(url, folder, filename):
|
|
|
|
| 200 |
|
| 201 |
# tracking vos
|
| 202 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
|
|
| 203 |
model.xmem.clear_memory()
|
| 204 |
if interactive_state["track_end_number"]:
|
| 205 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
|
|
|
| 219 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
| 220 |
fps = video_state["fps"]
|
| 221 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
|
|
|
|
|
|
| 222 |
|
| 223 |
if interactive_state["track_end_number"]:
|
| 224 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
|
|
|
| 258 |
|
| 259 |
# inpaint
|
| 260 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
|
|
|
| 261 |
frames = np.asarray(video_state["origin_images"])
|
| 262 |
fps = video_state["fps"]
|
| 263 |
inpaint_masks = np.asarray(video_state["masks"])
|
|
|
|
| 304 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
| 305 |
return output_path
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
# check and download checkpoints if needed
|
| 308 |
+
SAM_checkpoint = "sam_vit_h_4b8939.pth"
|
| 309 |
+
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
xmem_checkpoint = "XMem-s012.pth"
|
| 311 |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
| 312 |
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
| 313 |
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
| 314 |
|
|
|
|
| 315 |
folder ="./checkpoints"
|
| 316 |
+
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
|
| 317 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
| 318 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
| 319 |
+
# args, defined in track_anything.py
|
| 320 |
+
args = parse_augment()
|
| 321 |
+
# args.port = 12315
|
| 322 |
+
# args.device = "cuda:2"
|
| 323 |
+
# args.mask_save = True
|
| 324 |
|
| 325 |
# initialize sam, xmem, e2fgvi models
|
| 326 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
| 327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
with gr.Blocks() as iface:
|
| 329 |
"""
|
| 330 |
state for
|
|
|
|
| 356 |
"fps": 30
|
| 357 |
}
|
| 358 |
)
|
| 359 |
+
|
|
|
|
| 360 |
with gr.Row():
|
| 361 |
|
| 362 |
# for user video input
|
|
|
|
| 365 |
video_input = gr.Video(autosize=True)
|
| 366 |
with gr.Column():
|
| 367 |
video_info = gr.Textbox()
|
| 368 |
+
video_info = gr.Textbox(value="If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
|
| 369 |
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
| 370 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
| 371 |
|
|
|
|
| 534 |
# cache_examples=True,
|
| 535 |
)
|
| 536 |
iface.queue(concurrency_count=1)
|
| 537 |
+
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
|
| 538 |
|
| 539 |
|
| 540 |
|
requirements.txt
CHANGED
|
@@ -10,6 +10,10 @@ gradio==3.25.0
|
|
| 10 |
opencv-python
|
| 11 |
pycocotools
|
| 12 |
matplotlib
|
|
|
|
|
|
|
|
|
|
| 13 |
pyyaml
|
| 14 |
av
|
| 15 |
-
|
|
|
|
|
|
| 10 |
opencv-python
|
| 11 |
pycocotools
|
| 12 |
matplotlib
|
| 13 |
+
onnxruntime
|
| 14 |
+
onnx
|
| 15 |
+
metaseg==0.6.1
|
| 16 |
pyyaml
|
| 17 |
av
|
| 18 |
+
mmcv-full
|
| 19 |
+
mmengine
|