Spaces:
Runtime error
Runtime error
added colormap options
Browse files- app.py +45 -54
- utils/improc.py +1 -1
app.py
CHANGED
|
@@ -517,15 +517,16 @@ def choose_rate8(video_preview, video_fps, tracks, visibs):
|
|
| 517 |
# def choose_rate16(video_preview, video_fps, tracks, visibs):
|
| 518 |
# return choose_rate(16, video_preview, video_fps, tracks, visibs)
|
| 519 |
|
| 520 |
-
def
|
| 521 |
print('rate', rate)
|
|
|
|
| 522 |
print('video_preview', video_preview.shape)
|
| 523 |
T, H, W,_ = video_preview.shape
|
| 524 |
tracks_ = tracks.reshape(H,W,T,2)[::rate,::rate].reshape(-1,T,2)
|
| 525 |
visibs_ = visibs.reshape(H,W,T)[::rate,::rate].reshape(-1,T)
|
| 526 |
-
return paint_video(video_preview, video_fps, tracks_, visibs_, rate=rate)
|
| 527 |
# return video_preview_array[int(frame_num)]
|
| 528 |
-
|
| 529 |
def preprocess_video_input(video_path):
|
| 530 |
video_arr = mediapy.read_video(video_path)
|
| 531 |
video_fps = video_arr.metadata.fps
|
|
@@ -553,27 +554,15 @@ def preprocess_video_input(video_path):
|
|
| 553 |
preview_video = np.array(preview_video)
|
| 554 |
input_video = np.array(input_video)
|
| 555 |
|
| 556 |
-
interactive = True
|
| 557 |
-
|
| 558 |
return (
|
| 559 |
video_arr, # Original video
|
| 560 |
preview_video, # Original preview video, resized for faster processing
|
| 561 |
preview_video.copy(), # Copy of preview video for visualization
|
| 562 |
input_video, # Resized video input for model
|
| 563 |
-
# None, # video_feature, # Extracted feature
|
| 564 |
video_fps, # Set the video FPS
|
| 565 |
-
# gr.update(open=True), # open/close the video input drawer
|
| 566 |
-
# tracking_mode, # Set the tracking mode
|
| 567 |
preview_video[0], # Set the preview frame to the first frame
|
| 568 |
-
gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=
|
| 569 |
-
|
| 570 |
-
[[] for _ in range(num_frames)], # Set query_points_color to empty
|
| 571 |
-
[[] for _ in range(num_frames)],
|
| 572 |
-
0, # Set query count to 0
|
| 573 |
-
gr.update(interactive=interactive), # Make the buttons interactive
|
| 574 |
-
gr.update(interactive=interactive),
|
| 575 |
-
gr.update(interactive=interactive),
|
| 576 |
-
gr.update(interactive=True),
|
| 577 |
# gr.update(interactive=True),
|
| 578 |
# gr.update(interactive=True),
|
| 579 |
# gr.update(interactive=True),
|
|
@@ -581,22 +570,30 @@ def preprocess_video_input(video_path):
|
|
| 581 |
)
|
| 582 |
|
| 583 |
|
| 584 |
-
def paint_video(video_preview, video_fps, tracks, visibs, rate=1):
|
| 585 |
print('video_preview', video_preview.shape)
|
|
|
|
| 586 |
T, H, W, _ = video_preview.shape
|
| 587 |
query_count = tracks.shape[0]
|
| 588 |
-
cmap
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
painted_video = paint_point_track_gpu_scatter(video_preview,tracks,visibs,colors,rate=rate)#=max(rate//2,1))
|
| 601 |
# save video
|
| 602 |
video_file_name = uuid.uuid4().hex + ".mp4"
|
|
@@ -630,9 +627,6 @@ def track(
|
|
| 630 |
video_input,
|
| 631 |
video_fps,
|
| 632 |
query_frame,
|
| 633 |
-
query_points,
|
| 634 |
-
query_points_color,
|
| 635 |
-
query_count,
|
| 636 |
):
|
| 637 |
# tracking_mode = 'selected'
|
| 638 |
# if query_count == 0:
|
|
@@ -788,7 +782,7 @@ def track(
|
|
| 788 |
# print('sc', sc)
|
| 789 |
# tracks = tracks * sc
|
| 790 |
|
| 791 |
-
return paint_video(video_preview, video_fps, tracks, visibs), tracks, visibs, gr.update(interactive=True,
|
| 792 |
# gr.update(interactive=True),
|
| 793 |
# gr.update(interactive=True),
|
| 794 |
# gr.update(interactive=True),
|
|
@@ -863,11 +857,6 @@ with gr.Blocks() as demo:
|
|
| 863 |
video_input = gr.State()
|
| 864 |
video_fps = gr.State(24)
|
| 865 |
|
| 866 |
-
query_points = gr.State([])
|
| 867 |
-
query_points_color = gr.State([])
|
| 868 |
-
is_tracked_query = gr.State([])
|
| 869 |
-
query_count = gr.State(0)
|
| 870 |
-
|
| 871 |
# rate = gr.State([])
|
| 872 |
tracks = gr.State([])
|
| 873 |
visibs = gr.State([])
|
|
@@ -875,14 +864,13 @@ with gr.Blocks() as demo:
|
|
| 875 |
gr.Markdown("# ⚡ AllTracker: Efficient Dense Point Tracking at High Resolution")
|
| 876 |
gr.Markdown("<div style='text-align: left;'> \
|
| 877 |
<p>Welcome to <a href='https://alltracker.github.io/' target='_blank'>AllTracker</a>! This demo runs our model to perform all-pixel tracking in a video of your choice.</p> \
|
| 878 |
-
<p>To get started, simply upload
|
| 879 |
<p>After picking a video, click \"Submit\" to load the frames into the app, and optionally choose a query frame (using the slider), and then click \"Track\".</p> \
|
| 880 |
<p>For full info on how this works, check out our <a href='https://github.com/aharley/alltracker/' target='_blank'>GitHub repo</a>, or <a href='https://arxiv.org/abs/2506.07310' target='_blank'>paper</a>.</p> \
|
| 881 |
<p>Initial code for this Gradio app came from LocoTrack and CoTracker -- big thanks to those authors!</p> \
|
| 882 |
</div>"
|
| 883 |
)
|
| 884 |
|
| 885 |
-
|
| 886 |
gr.Markdown("## Step 1: Select a video, and click \"Submit\".")
|
| 887 |
with gr.Row():
|
| 888 |
with gr.Column():
|
|
@@ -891,7 +879,6 @@ with gr.Blocks() as demo:
|
|
| 891 |
with gr.Row():
|
| 892 |
submit = gr.Button("Submit")
|
| 893 |
with gr.Column():
|
| 894 |
-
# with gr.Accordion("Sample videos", open=True) as video_in_drawer:
|
| 895 |
with gr.Row():
|
| 896 |
butterfly = os.path.join(os.path.dirname(__file__), "videos", "butterfly_800.mp4")
|
| 897 |
monkey = os.path.join(os.path.dirname(__file__), "videos", "monkey_800.mp4")
|
|
@@ -951,6 +938,9 @@ with gr.Blocks() as demo:
|
|
| 951 |
# rate_slider = gr.Slider(
|
| 952 |
# minimum=1, maximum=16, value=1, step=1, label="Choose subsampling rate", interactive=False)
|
| 953 |
rate_radio = gr.Radio([1, 2, 4, 8, 16], value=1, label="Choose visualization subsampling", interactive=False)
|
|
|
|
|
|
|
|
|
|
| 954 |
|
| 955 |
with gr.Row():
|
| 956 |
output_video = gr.Video(
|
|
@@ -971,13 +961,8 @@ with gr.Blocks() as demo:
|
|
| 971 |
video_queried_preview,
|
| 972 |
video_input,
|
| 973 |
video_fps,
|
| 974 |
-
# video_in_drawer,
|
| 975 |
current_frame,
|
| 976 |
query_frame_slider,
|
| 977 |
-
query_points,
|
| 978 |
-
query_points_color,
|
| 979 |
-
is_tracked_query,
|
| 980 |
-
query_count,
|
| 981 |
# undo,
|
| 982 |
# clear_frame,
|
| 983 |
# clear_all,
|
|
@@ -1081,15 +1066,13 @@ with gr.Blocks() as demo:
|
|
| 1081 |
video_input,
|
| 1082 |
video_fps,
|
| 1083 |
query_frame_slider,
|
| 1084 |
-
query_points,
|
| 1085 |
-
query_points_color,
|
| 1086 |
-
query_count,
|
| 1087 |
],
|
| 1088 |
outputs = [
|
| 1089 |
output_video,
|
| 1090 |
tracks,
|
| 1091 |
visibs,
|
| 1092 |
rate_radio,
|
|
|
|
| 1093 |
# rate1_button,
|
| 1094 |
# rate2_button,
|
| 1095 |
# rate4_button,
|
|
@@ -1108,8 +1091,16 @@ with gr.Blocks() as demo:
|
|
| 1108 |
# queue = False
|
| 1109 |
# )
|
| 1110 |
rate_radio.change(
|
| 1111 |
-
fn =
|
| 1112 |
-
inputs = [rate_radio, video_preview, video_fps, tracks, visibs],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1113 |
outputs = [
|
| 1114 |
output_video,
|
| 1115 |
],
|
|
@@ -1153,5 +1144,5 @@ with gr.Blocks() as demo:
|
|
| 1153 |
|
| 1154 |
|
| 1155 |
# demo.launch(show_api=False, show_error=True, debug=False, share=False)
|
| 1156 |
-
|
| 1157 |
-
demo.launch(show_api=False, show_error=True, debug=False, share=False)
|
|
|
|
| 517 |
# def choose_rate16(video_preview, video_fps, tracks, visibs):
|
| 518 |
# return choose_rate(16, video_preview, video_fps, tracks, visibs)
|
| 519 |
|
| 520 |
+
def update_vis(rate, cmap, video_preview, query_frame, video_fps, tracks, visibs):
|
| 521 |
print('rate', rate)
|
| 522 |
+
print('cmap', cmap)
|
| 523 |
print('video_preview', video_preview.shape)
|
| 524 |
T, H, W,_ = video_preview.shape
|
| 525 |
tracks_ = tracks.reshape(H,W,T,2)[::rate,::rate].reshape(-1,T,2)
|
| 526 |
visibs_ = visibs.reshape(H,W,T)[::rate,::rate].reshape(-1,T)
|
| 527 |
+
return paint_video(video_preview, query_frame, video_fps, tracks_, visibs_, rate=rate, cmap=cmap)
|
| 528 |
# return video_preview_array[int(frame_num)]
|
| 529 |
+
|
| 530 |
def preprocess_video_input(video_path):
|
| 531 |
video_arr = mediapy.read_video(video_path)
|
| 532 |
video_fps = video_arr.metadata.fps
|
|
|
|
| 554 |
preview_video = np.array(preview_video)
|
| 555 |
input_video = np.array(input_video)
|
| 556 |
|
|
|
|
|
|
|
| 557 |
return (
|
| 558 |
video_arr, # Original video
|
| 559 |
preview_video, # Original preview video, resized for faster processing
|
| 560 |
preview_video.copy(), # Copy of preview video for visualization
|
| 561 |
input_video, # Resized video input for model
|
|
|
|
| 562 |
video_fps, # Set the video FPS
|
|
|
|
|
|
|
| 563 |
preview_video[0], # Set the preview frame to the first frame
|
| 564 |
+
gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=True), # Set slider interactive
|
| 565 |
+
gr.update(interactive=True), # make track button interactive
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
# gr.update(interactive=True),
|
| 567 |
# gr.update(interactive=True),
|
| 568 |
# gr.update(interactive=True),
|
|
|
|
| 570 |
)
|
| 571 |
|
| 572 |
|
| 573 |
+
def paint_video(video_preview, query_frame, video_fps, tracks, visibs, rate=1, cmap="gist_rainbow"):
|
| 574 |
print('video_preview', video_preview.shape)
|
| 575 |
+
print('tracks', tracks.shape)
|
| 576 |
T, H, W, _ = video_preview.shape
|
| 577 |
query_count = tracks.shape[0]
|
| 578 |
+
print('cmap', cmap)
|
| 579 |
+
|
| 580 |
+
if cmap=="bremm":
|
| 581 |
+
xy0 = tracks[:,query_frame] # N,2
|
| 582 |
+
colors = utils.improc.get_2d_colors(xy0, H, W)
|
| 583 |
+
else:
|
| 584 |
+
cmap_ = matplotlib.colormaps.get_cmap(cmap)
|
| 585 |
+
query_points_color = [[]]
|
| 586 |
+
for i in range(query_count):
|
| 587 |
+
# Choose the color for the point from matplotlib colormap
|
| 588 |
+
color = cmap_(i / float(query_count))
|
| 589 |
+
color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255))
|
| 590 |
+
query_points_color[0].append(color)
|
| 591 |
+
# make color array
|
| 592 |
+
colors = []
|
| 593 |
+
for frame_colors in query_points_color:
|
| 594 |
+
colors.extend(frame_colors)
|
| 595 |
+
colors = np.array(colors)
|
| 596 |
+
|
| 597 |
painted_video = paint_point_track_gpu_scatter(video_preview,tracks,visibs,colors,rate=rate)#=max(rate//2,1))
|
| 598 |
# save video
|
| 599 |
video_file_name = uuid.uuid4().hex + ".mp4"
|
|
|
|
| 627 |
video_input,
|
| 628 |
video_fps,
|
| 629 |
query_frame,
|
|
|
|
|
|
|
|
|
|
| 630 |
):
|
| 631 |
# tracking_mode = 'selected'
|
| 632 |
# if query_count == 0:
|
|
|
|
| 782 |
# print('sc', sc)
|
| 783 |
# tracks = tracks * sc
|
| 784 |
|
| 785 |
+
return paint_video(video_preview, query_frame, video_fps, tracks, visibs), tracks, visibs, gr.update(interactive=True), gr.update(interactive=True)
|
| 786 |
# gr.update(interactive=True),
|
| 787 |
# gr.update(interactive=True),
|
| 788 |
# gr.update(interactive=True),
|
|
|
|
| 857 |
video_input = gr.State()
|
| 858 |
video_fps = gr.State(24)
|
| 859 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 860 |
# rate = gr.State([])
|
| 861 |
tracks = gr.State([])
|
| 862 |
visibs = gr.State([])
|
|
|
|
| 864 |
gr.Markdown("# ⚡ AllTracker: Efficient Dense Point Tracking at High Resolution")
|
| 865 |
gr.Markdown("<div style='text-align: left;'> \
|
| 866 |
<p>Welcome to <a href='https://alltracker.github.io/' target='_blank'>AllTracker</a>! This demo runs our model to perform all-pixel tracking in a video of your choice.</p> \
|
| 867 |
+
<p>To get started, simply upload an mp4, or select one of the example videos. The shorter the video, the faster the processing. We recommend submitting videos under 20 seconds long.</p> \
|
| 868 |
<p>After picking a video, click \"Submit\" to load the frames into the app, and optionally choose a query frame (using the slider), and then click \"Track\".</p> \
|
| 869 |
<p>For full info on how this works, check out our <a href='https://github.com/aharley/alltracker/' target='_blank'>GitHub repo</a>, or <a href='https://arxiv.org/abs/2506.07310' target='_blank'>paper</a>.</p> \
|
| 870 |
<p>Initial code for this Gradio app came from LocoTrack and CoTracker -- big thanks to those authors!</p> \
|
| 871 |
</div>"
|
| 872 |
)
|
| 873 |
|
|
|
|
| 874 |
gr.Markdown("## Step 1: Select a video, and click \"Submit\".")
|
| 875 |
with gr.Row():
|
| 876 |
with gr.Column():
|
|
|
|
| 879 |
with gr.Row():
|
| 880 |
submit = gr.Button("Submit")
|
| 881 |
with gr.Column():
|
|
|
|
| 882 |
with gr.Row():
|
| 883 |
butterfly = os.path.join(os.path.dirname(__file__), "videos", "butterfly_800.mp4")
|
| 884 |
monkey = os.path.join(os.path.dirname(__file__), "videos", "monkey_800.mp4")
|
|
|
|
| 938 |
# rate_slider = gr.Slider(
|
| 939 |
# minimum=1, maximum=16, value=1, step=1, label="Choose subsampling rate", interactive=False)
|
| 940 |
rate_radio = gr.Radio([1, 2, 4, 8, 16], value=1, label="Choose visualization subsampling", interactive=False)
|
| 941 |
+
|
| 942 |
+
with gr.Row():
|
| 943 |
+
cmap_radio = gr.Radio(["gist_rainbow", "rainbow", "jet", "turbo", "bremm"], value="gist_rainbow", label="Choose colormap", interactive=False)
|
| 944 |
|
| 945 |
with gr.Row():
|
| 946 |
output_video = gr.Video(
|
|
|
|
| 961 |
video_queried_preview,
|
| 962 |
video_input,
|
| 963 |
video_fps,
|
|
|
|
| 964 |
current_frame,
|
| 965 |
query_frame_slider,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 966 |
# undo,
|
| 967 |
# clear_frame,
|
| 968 |
# clear_all,
|
|
|
|
| 1066 |
video_input,
|
| 1067 |
video_fps,
|
| 1068 |
query_frame_slider,
|
|
|
|
|
|
|
|
|
|
| 1069 |
],
|
| 1070 |
outputs = [
|
| 1071 |
output_video,
|
| 1072 |
tracks,
|
| 1073 |
visibs,
|
| 1074 |
rate_radio,
|
| 1075 |
+
cmap_radio,
|
| 1076 |
# rate1_button,
|
| 1077 |
# rate2_button,
|
| 1078 |
# rate4_button,
|
|
|
|
| 1091 |
# queue = False
|
| 1092 |
# )
|
| 1093 |
rate_radio.change(
|
| 1094 |
+
fn = update_vis,
|
| 1095 |
+
inputs = [rate_radio, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
|
| 1096 |
+
outputs = [
|
| 1097 |
+
output_video,
|
| 1098 |
+
],
|
| 1099 |
+
queue = False
|
| 1100 |
+
)
|
| 1101 |
+
cmap_radio.change(
|
| 1102 |
+
fn = update_vis,
|
| 1103 |
+
inputs = [rate_radio, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
|
| 1104 |
outputs = [
|
| 1105 |
output_video,
|
| 1106 |
],
|
|
|
|
| 1144 |
|
| 1145 |
|
| 1146 |
# demo.launch(show_api=False, show_error=True, debug=False, share=False)
|
| 1147 |
+
demo.launch(show_api=False, show_error=True, debug=False, share=True)
|
| 1148 |
+
# demo.launch(show_api=False, show_error=True, debug=False, share=False)
|
utils/improc.py
CHANGED
|
@@ -58,7 +58,7 @@ def flow2color(flow, clip=0.0):
|
|
| 58 |
flow = (flow*255.0).type(torch.ByteTensor)
|
| 59 |
return flow
|
| 60 |
|
| 61 |
-
COLORMAP_FILE = "./
|
| 62 |
class ColorMap2d:
|
| 63 |
def __init__(self, filename=None):
|
| 64 |
self._colormap_file = filename or COLORMAP_FILE
|
|
|
|
| 58 |
flow = (flow*255.0).type(torch.ByteTensor)
|
| 59 |
return flow
|
| 60 |
|
| 61 |
+
COLORMAP_FILE = "./bremm.png"
|
| 62 |
class ColorMap2d:
|
| 63 |
def __init__(self, filename=None):
|
| 64 |
self._colormap_file = filename or COLORMAP_FILE
|