Spaces:
Runtime error
Runtime error
added overlay option; fixed bugs
Browse files- app.py +48 -22
- utils/improc.py +1 -0
app.py
CHANGED
|
@@ -20,6 +20,7 @@ import random
|
|
| 20 |
from typing import List, Optional, Sequence, Tuple
|
| 21 |
import spaces
|
| 22 |
import numpy as np
|
|
|
|
| 23 |
import utils.basic
|
| 24 |
import utils.improc
|
| 25 |
|
|
@@ -105,12 +106,16 @@ def paint_point_track_gpu_scatter(
|
|
| 105 |
visibles: np.ndarray,
|
| 106 |
colormap: Optional[List[Tuple[int, int, int]]] = None,
|
| 107 |
rate: int = 1,
|
|
|
|
| 108 |
# sharpness: float = 0.1,
|
| 109 |
) -> np.ndarray:
|
| 110 |
print('starting vis')
|
| 111 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 112 |
frames_t = torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device) # [T,C,H,W]
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
| 114 |
point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2]
|
| 115 |
visibles_t = torch.from_numpy(visibles).to(device) # [P,T]
|
| 116 |
T, C, H, W = frames_t.shape
|
|
@@ -517,14 +522,14 @@ def choose_rate8(video_preview, video_fps, tracks, visibs):
|
|
| 517 |
# def choose_rate16(video_preview, video_fps, tracks, visibs):
|
| 518 |
# return choose_rate(16, video_preview, video_fps, tracks, visibs)
|
| 519 |
|
| 520 |
-
def update_vis(rate, cmap, video_preview, query_frame, video_fps, tracks, visibs):
|
| 521 |
print('rate', rate)
|
| 522 |
print('cmap', cmap)
|
| 523 |
print('video_preview', video_preview.shape)
|
| 524 |
T, H, W,_ = video_preview.shape
|
| 525 |
tracks_ = tracks.reshape(H,W,T,2)[::rate,::rate].reshape(-1,T,2)
|
| 526 |
visibs_ = visibs.reshape(H,W,T)[::rate,::rate].reshape(-1,T)
|
| 527 |
-
return paint_video(video_preview, query_frame, video_fps, tracks_, visibs_, rate=rate, cmap=cmap)
|
| 528 |
# return video_preview_array[int(frame_num)]
|
| 529 |
|
| 530 |
def preprocess_video_input(video_path):
|
|
@@ -570,15 +575,19 @@ def preprocess_video_input(video_path):
|
|
| 570 |
)
|
| 571 |
|
| 572 |
|
| 573 |
-
def paint_video(video_preview, query_frame, video_fps, tracks, visibs, rate=1, cmap="gist_rainbow"):
|
| 574 |
print('video_preview', video_preview.shape)
|
| 575 |
print('tracks', tracks.shape)
|
| 576 |
T, H, W, _ = video_preview.shape
|
| 577 |
query_count = tracks.shape[0]
|
| 578 |
print('cmap', cmap)
|
| 579 |
-
|
| 580 |
if cmap=="bremm":
|
|
|
|
| 581 |
xy0 = tracks[:,query_frame] # N,2
|
|
|
|
|
|
|
|
|
|
| 582 |
colors = utils.improc.get_2d_colors(xy0, H, W)
|
| 583 |
else:
|
| 584 |
cmap_ = matplotlib.colormaps.get_cmap(cmap)
|
|
@@ -594,7 +603,7 @@ def paint_video(video_preview, query_frame, video_fps, tracks, visibs, rate=1, c
|
|
| 594 |
colors.extend(frame_colors)
|
| 595 |
colors = np.array(colors)
|
| 596 |
|
| 597 |
-
painted_video = paint_point_track_gpu_scatter(video_preview,tracks,visibs,colors,rate=rate)#=max(rate//2,1))
|
| 598 |
# save video
|
| 599 |
video_file_name = uuid.uuid4().hex + ".mp4"
|
| 600 |
video_path = os.path.join(os.path.dirname(__file__), "tmp")
|
|
@@ -609,7 +618,7 @@ def paint_video(video_preview, query_frame, video_fps, tracks, visibs, rate=1, c
|
|
| 609 |
im = PIL.Image.fromarray(painted_video[ti])
|
| 610 |
# im.save(temp_out_f, "PNG", subsampling=0, quality=80)
|
| 611 |
im.save(temp_out_f)
|
| 612 |
-
print('saved', temp_out_f)
|
| 613 |
# os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.png" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path))
|
| 614 |
os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.jpg" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path))
|
| 615 |
print('saved', video_file_path)
|
|
@@ -617,16 +626,19 @@ def paint_video(video_preview, query_frame, video_fps, tracks, visibs, rate=1, c
|
|
| 617 |
# temp_out_f = '%s/%03d.png' % (video_path, ti)
|
| 618 |
temp_out_f = '%s/%03d.jpg' % (video_path, ti)
|
| 619 |
os.remove(temp_out_f)
|
| 620 |
-
print('deleted', temp_out_f)
|
| 621 |
return video_file_path
|
| 622 |
|
| 623 |
|
| 624 |
@spaces.GPU
|
| 625 |
def track(
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
|
|
|
|
|
|
|
|
|
| 630 |
):
|
| 631 |
# tracking_mode = 'selected'
|
| 632 |
# if query_count == 0:
|
|
@@ -774,7 +786,8 @@ def track(
|
|
| 774 |
tracks = traj_maps_e.permute(0,3,4,1,2).reshape(-1,T,2).numpy()
|
| 775 |
visibs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
|
| 776 |
confs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
|
| 777 |
-
visibs = (visibs * confs) > 0.
|
|
|
|
| 778 |
# visibs = (confs) > 0.1 # N,T
|
| 779 |
|
| 780 |
|
|
@@ -782,7 +795,7 @@ def track(
|
|
| 782 |
# print('sc', sc)
|
| 783 |
# tracks = tracks * sc
|
| 784 |
|
| 785 |
-
return
|
| 786 |
# gr.update(interactive=True),
|
| 787 |
# gr.update(interactive=True),
|
| 788 |
# gr.update(interactive=True),
|
|
@@ -863,7 +876,7 @@ with gr.Blocks() as demo:
|
|
| 863 |
|
| 864 |
gr.Markdown("# ⚡ AllTracker: Efficient Dense Point Tracking at High Resolution")
|
| 865 |
gr.Markdown("<div style='text-align: left;'> \
|
| 866 |
-
<p>
|
| 867 |
<p>To get started, simply upload an mp4, or select one of the example videos. The shorter the video, the faster the processing. We recommend submitting videos under 20 seconds long.</p> \
|
| 868 |
<p>After picking a video, click \"Submit\" to load the frames into the app, and optionally choose a query frame (using the slider), and then click \"Track\".</p> \
|
| 869 |
<p>For full info on how this works, check out our <a href='https://github.com/aharley/alltracker/' target='_blank'>GitHub repo</a>, or <a href='https://arxiv.org/abs/2506.07310' target='_blank'>paper</a>.</p> \
|
|
@@ -909,7 +922,7 @@ with gr.Blocks() as demo:
|
|
| 909 |
with gr.Column():
|
| 910 |
with gr.Row():
|
| 911 |
query_frame_slider = gr.Slider(
|
| 912 |
-
minimum=0, maximum=100, value=0, step=1, label="
|
| 913 |
# with gr.Row():
|
| 914 |
# undo = gr.Button("Undo", interactive=False)
|
| 915 |
# clear_frame = gr.Button("Clear Frame", interactive=False)
|
|
@@ -937,11 +950,12 @@ with gr.Blocks() as demo:
|
|
| 937 |
with gr.Row():
|
| 938 |
# rate_slider = gr.Slider(
|
| 939 |
# minimum=1, maximum=16, value=1, step=1, label="Choose subsampling rate", interactive=False)
|
| 940 |
-
rate_radio = gr.Radio([1, 2, 4, 8, 16], value=1, label="
|
| 941 |
-
|
| 942 |
with gr.Row():
|
| 943 |
-
cmap_radio = gr.Radio(["gist_rainbow", "rainbow", "jet", "turbo", "bremm"], value="gist_rainbow", label="
|
| 944 |
-
|
|
|
|
|
|
|
| 945 |
with gr.Row():
|
| 946 |
output_video = gr.Video(
|
| 947 |
label="Output video",
|
|
@@ -1066,12 +1080,16 @@ with gr.Blocks() as demo:
|
|
| 1066 |
video_input,
|
| 1067 |
video_fps,
|
| 1068 |
query_frame_slider,
|
|
|
|
|
|
|
|
|
|
| 1069 |
],
|
| 1070 |
outputs = [
|
| 1071 |
output_video,
|
| 1072 |
tracks,
|
| 1073 |
visibs,
|
| 1074 |
rate_radio,
|
|
|
|
| 1075 |
cmap_radio,
|
| 1076 |
# rate1_button,
|
| 1077 |
# rate2_button,
|
|
@@ -1092,7 +1110,7 @@ with gr.Blocks() as demo:
|
|
| 1092 |
# )
|
| 1093 |
rate_radio.change(
|
| 1094 |
fn = update_vis,
|
| 1095 |
-
inputs = [rate_radio, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
|
| 1096 |
outputs = [
|
| 1097 |
output_video,
|
| 1098 |
],
|
|
@@ -1100,7 +1118,15 @@ with gr.Blocks() as demo:
|
|
| 1100 |
)
|
| 1101 |
cmap_radio.change(
|
| 1102 |
fn = update_vis,
|
| 1103 |
-
inputs = [rate_radio, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1104 |
outputs = [
|
| 1105 |
output_video,
|
| 1106 |
],
|
|
|
|
| 20 |
from typing import List, Optional, Sequence, Tuple
|
| 21 |
import spaces
|
| 22 |
import numpy as np
|
| 23 |
+
import utils.py
|
| 24 |
import utils.basic
|
| 25 |
import utils.improc
|
| 26 |
|
|
|
|
| 106 |
visibles: np.ndarray,
|
| 107 |
colormap: Optional[List[Tuple[int, int, int]]] = None,
|
| 108 |
rate: int = 1,
|
| 109 |
+
show_bkg=True,
|
| 110 |
# sharpness: float = 0.1,
|
| 111 |
) -> np.ndarray:
|
| 112 |
print('starting vis')
|
| 113 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 114 |
frames_t = torch.from_numpy(frames).float().permute(0, 3, 1, 2).to(device) # [T,C,H,W]
|
| 115 |
+
if show_bkg:
|
| 116 |
+
frames_t = frames_t * 0.5 # darken, to see the point tracks better
|
| 117 |
+
else:
|
| 118 |
+
frames_t = frames_t * 0.0 # black out
|
| 119 |
point_tracks_t = torch.from_numpy(point_tracks).to(device) # [P,T,2]
|
| 120 |
visibles_t = torch.from_numpy(visibles).to(device) # [P,T]
|
| 121 |
T, C, H, W = frames_t.shape
|
|
|
|
| 522 |
# def choose_rate16(video_preview, video_fps, tracks, visibs):
|
| 523 |
# return choose_rate(16, video_preview, video_fps, tracks, visibs)
|
| 524 |
|
| 525 |
+
def update_vis(rate, show_bkg, cmap, video_preview, query_frame, video_fps, tracks, visibs):
|
| 526 |
print('rate', rate)
|
| 527 |
print('cmap', cmap)
|
| 528 |
print('video_preview', video_preview.shape)
|
| 529 |
T, H, W,_ = video_preview.shape
|
| 530 |
tracks_ = tracks.reshape(H,W,T,2)[::rate,::rate].reshape(-1,T,2)
|
| 531 |
visibs_ = visibs.reshape(H,W,T)[::rate,::rate].reshape(-1,T)
|
| 532 |
+
return paint_video(video_preview, query_frame, video_fps, tracks_, visibs_, rate=rate, show_bkg=show_bkg, cmap=cmap)
|
| 533 |
# return video_preview_array[int(frame_num)]
|
| 534 |
|
| 535 |
def preprocess_video_input(video_path):
|
|
|
|
| 575 |
)
|
| 576 |
|
| 577 |
|
| 578 |
+
def paint_video(video_preview, query_frame, video_fps, tracks, visibs, rate=1, show_bkg=True, cmap="gist_rainbow"):
|
| 579 |
print('video_preview', video_preview.shape)
|
| 580 |
print('tracks', tracks.shape)
|
| 581 |
T, H, W, _ = video_preview.shape
|
| 582 |
query_count = tracks.shape[0]
|
| 583 |
print('cmap', cmap)
|
| 584 |
+
print('query_frame', query_frame)
|
| 585 |
if cmap=="bremm":
|
| 586 |
+
# xy0 = tracks
|
| 587 |
xy0 = tracks[:,query_frame] # N,2
|
| 588 |
+
# print('xyQ', xy0[:10])
|
| 589 |
+
# print('xy0', tracks[:10,0])
|
| 590 |
+
# print('xy1', tracks[:10,1])
|
| 591 |
colors = utils.improc.get_2d_colors(xy0, H, W)
|
| 592 |
else:
|
| 593 |
cmap_ = matplotlib.colormaps.get_cmap(cmap)
|
|
|
|
| 603 |
colors.extend(frame_colors)
|
| 604 |
colors = np.array(colors)
|
| 605 |
|
| 606 |
+
painted_video = paint_point_track_gpu_scatter(video_preview,tracks,visibs,colors,rate=rate,show_bkg=show_bkg)#=max(rate//2,1))
|
| 607 |
# save video
|
| 608 |
video_file_name = uuid.uuid4().hex + ".mp4"
|
| 609 |
video_path = os.path.join(os.path.dirname(__file__), "tmp")
|
|
|
|
| 618 |
im = PIL.Image.fromarray(painted_video[ti])
|
| 619 |
# im.save(temp_out_f, "PNG", subsampling=0, quality=80)
|
| 620 |
im.save(temp_out_f)
|
| 621 |
+
# print('saved', temp_out_f)
|
| 622 |
# os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.png" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path))
|
| 623 |
os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "%s/*.jpg" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (video_fps, video_path, video_file_path))
|
| 624 |
print('saved', video_file_path)
|
|
|
|
| 626 |
# temp_out_f = '%s/%03d.png' % (video_path, ti)
|
| 627 |
temp_out_f = '%s/%03d.jpg' % (video_path, ti)
|
| 628 |
os.remove(temp_out_f)
|
| 629 |
+
# print('deleted', temp_out_f)
|
| 630 |
return video_file_path
|
| 631 |
|
| 632 |
|
| 633 |
@spaces.GPU
|
| 634 |
def track(
|
| 635 |
+
video_preview,
|
| 636 |
+
video_input,
|
| 637 |
+
video_fps,
|
| 638 |
+
query_frame,
|
| 639 |
+
rate,
|
| 640 |
+
show_bkg,
|
| 641 |
+
cmap,
|
| 642 |
):
|
| 643 |
# tracking_mode = 'selected'
|
| 644 |
# if query_count == 0:
|
|
|
|
| 786 |
tracks = traj_maps_e.permute(0,3,4,1,2).reshape(-1,T,2).numpy()
|
| 787 |
visibs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
|
| 788 |
confs = visconf_maps_e.permute(0,3,4,1,2).reshape(-1,T,2)[:,:,0].numpy()
|
| 789 |
+
# visibs = (visibs * confs) > 0.2 # N,T
|
| 790 |
+
visibs = (confs) > 0.1 # N,T
|
| 791 |
# visibs = (confs) > 0.1 # N,T
|
| 792 |
|
| 793 |
|
|
|
|
| 795 |
# print('sc', sc)
|
| 796 |
# tracks = tracks * sc
|
| 797 |
|
| 798 |
+
return update_vis(rate, show_bkg, cmap, video_preview, query_frame, video_fps, tracks, visibs), tracks, visibs, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
|
| 799 |
# gr.update(interactive=True),
|
| 800 |
# gr.update(interactive=True),
|
| 801 |
# gr.update(interactive=True),
|
|
|
|
| 876 |
|
| 877 |
gr.Markdown("# ⚡ AllTracker: Efficient Dense Point Tracking at High Resolution")
|
| 878 |
gr.Markdown("<div style='text-align: left;'> \
|
| 879 |
+
<p>This demo runs <a href='https://alltracker.github.io/' target='_blank'>AllTracker</a> to perform all-pixel tracking in a video of your choice.</p> \
|
| 880 |
<p>To get started, simply upload an mp4, or select one of the example videos. The shorter the video, the faster the processing. We recommend submitting videos under 20 seconds long.</p> \
|
| 881 |
<p>After picking a video, click \"Submit\" to load the frames into the app, and optionally choose a query frame (using the slider), and then click \"Track\".</p> \
|
| 882 |
<p>For full info on how this works, check out our <a href='https://github.com/aharley/alltracker/' target='_blank'>GitHub repo</a>, or <a href='https://arxiv.org/abs/2506.07310' target='_blank'>paper</a>.</p> \
|
|
|
|
| 922 |
with gr.Column():
|
| 923 |
with gr.Row():
|
| 924 |
query_frame_slider = gr.Slider(
|
| 925 |
+
minimum=0, maximum=100, value=0, step=1, label="Query frame", interactive=False)
|
| 926 |
# with gr.Row():
|
| 927 |
# undo = gr.Button("Undo", interactive=False)
|
| 928 |
# clear_frame = gr.Button("Clear Frame", interactive=False)
|
|
|
|
| 950 |
with gr.Row():
|
| 951 |
# rate_slider = gr.Slider(
|
| 952 |
# minimum=1, maximum=16, value=1, step=1, label="Choose subsampling rate", interactive=False)
|
| 953 |
+
rate_radio = gr.Radio([1, 2, 4, 8, 16], value=1, label="Subsampling rate", interactive=False)
|
|
|
|
| 954 |
with gr.Row():
|
| 955 |
+
cmap_radio = gr.Radio(["gist_rainbow", "rainbow", "jet", "turbo", "bremm"], value="gist_rainbow", label="Colormap", interactive=False)
|
| 956 |
+
with gr.Row():
|
| 957 |
+
bkg_check = gr.Checkbox(value=True, label="Overlay tracks on video", interactive=False)
|
| 958 |
+
|
| 959 |
with gr.Row():
|
| 960 |
output_video = gr.Video(
|
| 961 |
label="Output video",
|
|
|
|
| 1080 |
video_input,
|
| 1081 |
video_fps,
|
| 1082 |
query_frame_slider,
|
| 1083 |
+
rate_radio,
|
| 1084 |
+
bkg_check,
|
| 1085 |
+
cmap_radio,
|
| 1086 |
],
|
| 1087 |
outputs = [
|
| 1088 |
output_video,
|
| 1089 |
tracks,
|
| 1090 |
visibs,
|
| 1091 |
rate_radio,
|
| 1092 |
+
bkg_check,
|
| 1093 |
cmap_radio,
|
| 1094 |
# rate1_button,
|
| 1095 |
# rate2_button,
|
|
|
|
| 1110 |
# )
|
| 1111 |
rate_radio.change(
|
| 1112 |
fn = update_vis,
|
| 1113 |
+
inputs = [rate_radio, bkg_check, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
|
| 1114 |
outputs = [
|
| 1115 |
output_video,
|
| 1116 |
],
|
|
|
|
| 1118 |
)
|
| 1119 |
cmap_radio.change(
|
| 1120 |
fn = update_vis,
|
| 1121 |
+
inputs = [rate_radio, bkg_check, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
|
| 1122 |
+
outputs = [
|
| 1123 |
+
output_video,
|
| 1124 |
+
],
|
| 1125 |
+
queue = False
|
| 1126 |
+
)
|
| 1127 |
+
bkg_check.change(
|
| 1128 |
+
fn = update_vis,
|
| 1129 |
+
inputs = [rate_radio, bkg_check, cmap_radio, video_preview, query_frame_slider, video_fps, tracks, visibs],
|
| 1130 |
outputs = [
|
| 1131 |
output_video,
|
| 1132 |
],
|
utils/improc.py
CHANGED
|
@@ -81,6 +81,7 @@ class ColorMap2d:
|
|
| 81 |
|
| 82 |
def get_2d_colors(xys, H, W):
|
| 83 |
N,D = xys.shape
|
|
|
|
| 84 |
assert(D==2)
|
| 85 |
bremm = ColorMap2d()
|
| 86 |
xys[:,0] /= float(W-1)
|
|
|
|
| 81 |
|
| 82 |
def get_2d_colors(xys, H, W):
|
| 83 |
N,D = xys.shape
|
| 84 |
+
xys = xys.copy()
|
| 85 |
assert(D==2)
|
| 86 |
bremm = ColorMap2d()
|
| 87 |
xys[:,0] /= float(W-1)
|