Spaces:
Running
on
Zero
Running
on
Zero
Rename
Browse files
app.py
CHANGED
|
@@ -129,18 +129,18 @@ class AppState:
|
|
| 129 |
|
| 130 |
|
| 131 |
def init_video_session(
|
| 132 |
-
|
| 133 |
) -> tuple[AppState, int, int, Image.Image, str]:
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
|
| 145 |
video_path: str | None = None
|
| 146 |
if isinstance(video, dict):
|
|
@@ -165,14 +165,14 @@ def init_video_session(
|
|
| 165 |
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
|
| 166 |
if isinstance(info, dict):
|
| 167 |
info["num_frames"] = len(frames)
|
| 168 |
-
|
| 169 |
-
|
| 170 |
|
| 171 |
raw_video = [np.array(frame) for frame in frames]
|
| 172 |
|
| 173 |
if active_tab == "text":
|
| 174 |
processor = TEXT_VIDEO_PROCESSOR
|
| 175 |
-
|
| 176 |
video=frames,
|
| 177 |
inference_device=DEVICE,
|
| 178 |
inference_state_device=DEVICE,
|
|
@@ -182,7 +182,7 @@ def init_video_session(
|
|
| 182 |
)
|
| 183 |
else:
|
| 184 |
processor = TRACKER_PROCESSOR
|
| 185 |
-
|
| 186 |
video=raw_video,
|
| 187 |
inference_device=DEVICE,
|
| 188 |
inference_state_device=DEVICE,
|
|
@@ -195,15 +195,15 @@ def init_video_session(
|
|
| 195 |
max_idx = len(frames) - 1
|
| 196 |
if active_tab == "text":
|
| 197 |
status = (
|
| 198 |
-
f"Loaded {len(frames)} frames @ {
|
| 199 |
f"Device: {DEVICE}, dtype: bfloat16. Ready for text prompting."
|
| 200 |
)
|
| 201 |
else:
|
| 202 |
status = (
|
| 203 |
-
f"Loaded {len(frames)} frames @ {
|
| 204 |
f"Device: {DEVICE}, dtype: bfloat16. Video session initialized."
|
| 205 |
)
|
| 206 |
-
return
|
| 207 |
|
| 208 |
|
| 209 |
def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
|
|
@@ -596,24 +596,24 @@ def _get_active_prompts_display(state: AppState) -> str:
|
|
| 596 |
return "**Active prompts:** None"
|
| 597 |
|
| 598 |
|
| 599 |
-
def propagate_masks(
|
| 600 |
-
if
|
| 601 |
-
return
|
| 602 |
|
| 603 |
-
if
|
| 604 |
-
return
|
| 605 |
|
| 606 |
-
total = max(1,
|
| 607 |
processed = 0
|
| 608 |
|
| 609 |
-
yield
|
| 610 |
|
| 611 |
last_frame_idx = 0
|
| 612 |
|
| 613 |
with torch.no_grad():
|
| 614 |
-
if
|
| 615 |
-
if
|
| 616 |
-
yield
|
| 617 |
return
|
| 618 |
|
| 619 |
model = TEXT_VIDEO_MODEL
|
|
@@ -621,7 +621,7 @@ def propagate_masks(GLOBAL_STATE: AppState) -> Iterator[tuple[AppState, str, dic
|
|
| 621 |
|
| 622 |
# Collect all unique prompts from existing frame annotations
|
| 623 |
text_prompt_to_obj_ids = {}
|
| 624 |
-
for frame_idx, frame_texts in
|
| 625 |
for obj_id, text_prompt in frame_texts.items():
|
| 626 |
if text_prompt not in text_prompt_to_obj_ids:
|
| 627 |
text_prompt_to_obj_ids[text_prompt] = []
|
|
@@ -629,8 +629,8 @@ def propagate_masks(GLOBAL_STATE: AppState) -> Iterator[tuple[AppState, str, dic
|
|
| 629 |
text_prompt_to_obj_ids[text_prompt].append(obj_id)
|
| 630 |
|
| 631 |
# Also check if there are prompts already in the inference session
|
| 632 |
-
if hasattr(
|
| 633 |
-
for prompt_text in
|
| 634 |
if prompt_text not in text_prompt_to_obj_ids:
|
| 635 |
text_prompt_to_obj_ids[prompt_text] = []
|
| 636 |
|
|
@@ -638,31 +638,29 @@ def propagate_masks(GLOBAL_STATE: AppState) -> Iterator[tuple[AppState, str, dic
|
|
| 638 |
text_prompt_to_obj_ids[text_prompt].sort()
|
| 639 |
|
| 640 |
if not text_prompt_to_obj_ids:
|
| 641 |
-
yield
|
| 642 |
return
|
| 643 |
|
| 644 |
# Add all prompts to the inference session (processor handles deduplication)
|
| 645 |
for text_prompt in text_prompt_to_obj_ids:
|
| 646 |
-
|
| 647 |
-
inference_session=
|
| 648 |
text=text_prompt,
|
| 649 |
)
|
| 650 |
|
| 651 |
-
earliest_frame = (
|
| 652 |
-
min(GLOBAL_STATE.text_prompts_by_frame_obj.keys()) if GLOBAL_STATE.text_prompts_by_frame_obj else 0
|
| 653 |
-
)
|
| 654 |
|
| 655 |
-
frames_to_track =
|
| 656 |
|
| 657 |
outputs_per_frame = {}
|
| 658 |
|
| 659 |
for model_outputs in model.propagate_in_video_iterator(
|
| 660 |
-
inference_session=
|
| 661 |
start_frame_idx=earliest_frame,
|
| 662 |
max_frame_num_to_track=frames_to_track,
|
| 663 |
):
|
| 664 |
processed_outputs = processor.postprocess_outputs(
|
| 665 |
-
|
| 666 |
model_outputs,
|
| 667 |
)
|
| 668 |
frame_idx = model_outputs.frame_idx
|
|
@@ -673,8 +671,8 @@ def propagate_masks(GLOBAL_STATE: AppState) -> Iterator[tuple[AppState, str, dic
|
|
| 673 |
scores = processed_outputs["scores"]
|
| 674 |
prompt_to_obj_ids = processed_outputs.get("prompt_to_obj_ids", {})
|
| 675 |
|
| 676 |
-
masks_for_frame =
|
| 677 |
-
frame_texts =
|
| 678 |
|
| 679 |
num_objects = len(object_ids)
|
| 680 |
if num_objects > 0:
|
|
@@ -701,137 +699,131 @@ def propagate_masks(GLOBAL_STATE: AppState) -> Iterator[tuple[AppState, str, dic
|
|
| 701 |
# Store prompt and assign color
|
| 702 |
if found_prompt:
|
| 703 |
frame_texts[current_obj_id] = found_prompt.strip()
|
| 704 |
-
_ensure_color_for_obj(
|
| 705 |
|
| 706 |
-
|
| 707 |
last_frame_idx = frame_idx
|
| 708 |
processed += 1
|
| 709 |
if processed % 30 == 0 or processed == total:
|
| 710 |
-
yield
|
| 711 |
else:
|
| 712 |
-
if
|
| 713 |
-
yield
|
| 714 |
return
|
| 715 |
|
| 716 |
model = TRACKER_MODEL
|
| 717 |
processor = TRACKER_PROCESSOR
|
| 718 |
|
| 719 |
-
for sam2_video_output in model.propagate_in_video_iterator(
|
| 720 |
-
inference_session=GLOBAL_STATE.inference_session
|
| 721 |
-
):
|
| 722 |
video_res_masks = processor.post_process_masks(
|
| 723 |
[sam2_video_output.pred_masks],
|
| 724 |
-
original_sizes=[
|
| 725 |
-
[GLOBAL_STATE.inference_session.video_height, GLOBAL_STATE.inference_session.video_width]
|
| 726 |
-
],
|
| 727 |
)[0]
|
| 728 |
|
| 729 |
frame_idx = sam2_video_output.frame_idx
|
| 730 |
-
for i, out_obj_id in enumerate(
|
| 731 |
-
_ensure_color_for_obj(
|
| 732 |
mask_2d = video_res_masks[i].cpu().numpy()
|
| 733 |
-
masks_for_frame =
|
| 734 |
masks_for_frame[int(out_obj_id)] = mask_2d
|
| 735 |
-
|
| 736 |
|
| 737 |
last_frame_idx = frame_idx
|
| 738 |
processed += 1
|
| 739 |
if processed % 30 == 0 or processed == total:
|
| 740 |
-
yield
|
| 741 |
|
| 742 |
text = f"Propagated masks across {processed} frames."
|
| 743 |
-
yield
|
| 744 |
|
| 745 |
|
| 746 |
-
def reset_prompts(
|
| 747 |
"""Reset prompts and all outputs, but keep processed frames and cached vision features."""
|
| 748 |
-
if
|
| 749 |
-
active_prompts = _get_active_prompts_display(
|
| 750 |
-
return
|
| 751 |
|
| 752 |
-
if
|
| 753 |
-
active_prompts = _get_active_prompts_display(
|
| 754 |
-
return
|
| 755 |
|
| 756 |
# Reset inference session tracking data but keep cache and processed frames
|
| 757 |
-
if hasattr(
|
| 758 |
-
|
| 759 |
|
| 760 |
# Manually clear prompts (reset_tracking_data doesn't clear prompts themselves)
|
| 761 |
-
if hasattr(
|
| 762 |
-
|
| 763 |
-
if hasattr(
|
| 764 |
-
|
| 765 |
-
if hasattr(
|
| 766 |
-
|
| 767 |
-
if hasattr(
|
| 768 |
-
|
| 769 |
-
if hasattr(
|
| 770 |
-
|
| 771 |
|
| 772 |
# Reset detection-tracking fusion state
|
| 773 |
-
if hasattr(
|
| 774 |
-
|
| 775 |
-
if hasattr(
|
| 776 |
-
|
| 777 |
-
if hasattr(
|
| 778 |
-
|
| 779 |
-
if hasattr(
|
| 780 |
-
|
| 781 |
-
if hasattr(
|
| 782 |
-
|
| 783 |
-
if hasattr(
|
| 784 |
-
|
| 785 |
-
if hasattr(
|
| 786 |
-
|
| 787 |
-
if hasattr(
|
| 788 |
-
|
| 789 |
-
if hasattr(
|
| 790 |
-
|
| 791 |
-
if hasattr(
|
| 792 |
-
|
| 793 |
-
if hasattr(
|
| 794 |
-
|
| 795 |
|
| 796 |
# Clear all app state outputs
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
|
| 803 |
# Update display
|
| 804 |
-
current_idx = int(getattr(
|
| 805 |
-
current_idx = max(0, min(current_idx,
|
| 806 |
-
preview_img = update_frame_display(
|
| 807 |
-
active_prompts = _get_active_prompts_display(
|
| 808 |
status = "Prompts and outputs reset. Processed frames and cached vision features preserved."
|
| 809 |
|
| 810 |
-
return
|
| 811 |
|
| 812 |
|
| 813 |
-
def reset_session(
|
| 814 |
-
if not
|
| 815 |
-
return
|
| 816 |
|
| 817 |
-
if
|
| 818 |
-
if
|
| 819 |
processor = TEXT_VIDEO_PROCESSOR
|
| 820 |
-
|
| 821 |
-
video=
|
| 822 |
inference_device=DEVICE,
|
| 823 |
processing_device="cpu",
|
| 824 |
video_storage_device="cpu",
|
| 825 |
dtype=DTYPE,
|
| 826 |
)
|
| 827 |
-
elif
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
GLOBAL_STATE.inference_session.reset_inference_session()
|
| 831 |
-
elif GLOBAL_STATE.video_frames:
|
| 832 |
processor = TRACKER_PROCESSOR
|
| 833 |
-
raw_video = [np.array(frame) for frame in
|
| 834 |
-
|
| 835 |
video=raw_video,
|
| 836 |
inference_device=DEVICE,
|
| 837 |
video_storage_device="cpu",
|
|
@@ -839,44 +831,44 @@ def reset_session(GLOBAL_STATE: AppState) -> tuple[AppState, Image.Image, int, i
|
|
| 839 |
dtype=DTYPE,
|
| 840 |
)
|
| 841 |
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
|
| 853 |
gc.collect()
|
| 854 |
|
| 855 |
-
current_idx = int(getattr(
|
| 856 |
-
current_idx = max(0, min(current_idx,
|
| 857 |
-
preview_img = update_frame_display(
|
| 858 |
-
slider_minmax = gr.update(minimum=0, maximum=max(
|
| 859 |
slider_value = gr.update(value=current_idx)
|
| 860 |
status = "Session reset. Prompts cleared; video preserved."
|
| 861 |
-
active_prompts = _get_active_prompts_display(
|
| 862 |
-
return
|
| 863 |
|
| 864 |
|
| 865 |
-
def _on_video_change_pointbox(
|
| 866 |
-
|
| 867 |
return (
|
| 868 |
-
|
| 869 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
| 870 |
first_frame,
|
| 871 |
status,
|
| 872 |
)
|
| 873 |
|
| 874 |
|
| 875 |
-
def _on_video_change_text(
|
| 876 |
-
|
| 877 |
-
active_prompts = _get_active_prompts_display(
|
| 878 |
return (
|
| 879 |
-
|
| 880 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
| 881 |
first_frame,
|
| 882 |
status,
|
|
@@ -885,7 +877,7 @@ def _on_video_change_text(GLOBAL_STATE: AppState, video: str | dict) -> tuple[Ap
|
|
| 885 |
|
| 886 |
|
| 887 |
with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")) as demo:
|
| 888 |
-
|
| 889 |
|
| 890 |
gr.Markdown(
|
| 891 |
"""
|
|
@@ -953,9 +945,9 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
|
|
| 953 |
with gr.Row():
|
| 954 |
gr.Examples(
|
| 955 |
examples=examples_list_text,
|
| 956 |
-
inputs=[
|
| 957 |
fn=_on_video_change_text,
|
| 958 |
-
outputs=[
|
| 959 |
label="Examples",
|
| 960 |
cache_examples=False,
|
| 961 |
examples_per_page=5,
|
|
@@ -1016,9 +1008,9 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
|
|
| 1016 |
with gr.Row():
|
| 1017 |
gr.Examples(
|
| 1018 |
examples=examples_list_pointbox,
|
| 1019 |
-
inputs=[
|
| 1020 |
fn=_on_video_change_pointbox,
|
| 1021 |
-
outputs=[
|
| 1022 |
label="Examples",
|
| 1023 |
cache_examples=False,
|
| 1024 |
examples_per_page=5,
|
|
@@ -1026,8 +1018,8 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
|
|
| 1026 |
|
| 1027 |
video_in_pointbox.change(
|
| 1028 |
_on_video_change_pointbox,
|
| 1029 |
-
inputs=[
|
| 1030 |
-
outputs=[
|
| 1031 |
show_progress=True,
|
| 1032 |
)
|
| 1033 |
|
|
@@ -1038,14 +1030,14 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
|
|
| 1038 |
|
| 1039 |
frame_slider_pointbox.change(
|
| 1040 |
_sync_frame_idx_pointbox,
|
| 1041 |
-
inputs=[
|
| 1042 |
outputs=preview_pointbox,
|
| 1043 |
)
|
| 1044 |
|
| 1045 |
video_in_text.change(
|
| 1046 |
_on_video_change_text,
|
| 1047 |
-
inputs=[
|
| 1048 |
-
outputs=[
|
| 1049 |
show_progress=True,
|
| 1050 |
)
|
| 1051 |
|
|
@@ -1056,7 +1048,7 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
|
|
| 1056 |
|
| 1057 |
frame_slider_text.change(
|
| 1058 |
_sync_frame_idx_text,
|
| 1059 |
-
inputs=[
|
| 1060 |
outputs=preview_text,
|
| 1061 |
)
|
| 1062 |
|
|
@@ -1065,14 +1057,14 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
|
|
| 1065 |
s.current_obj_id = int(oid)
|
| 1066 |
return gr.update()
|
| 1067 |
|
| 1068 |
-
obj_id_inp.change(_sync_obj_id, inputs=[
|
| 1069 |
|
| 1070 |
def _sync_label(s: AppState, lab: str):
|
| 1071 |
if s is not None and lab is not None:
|
| 1072 |
s.current_label = str(lab)
|
| 1073 |
return gr.update()
|
| 1074 |
|
| 1075 |
-
label_radio.change(_sync_label, inputs=[
|
| 1076 |
|
| 1077 |
def _sync_prompt_type(s: AppState, val: str):
|
| 1078 |
if s is not None and val is not None:
|
|
@@ -1087,13 +1079,13 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
|
|
| 1087 |
|
| 1088 |
prompt_type.change(
|
| 1089 |
_sync_prompt_type,
|
| 1090 |
-
inputs=[
|
| 1091 |
outputs=[label_radio, clear_old_chk],
|
| 1092 |
)
|
| 1093 |
|
| 1094 |
preview_pointbox.select(
|
| 1095 |
on_image_click,
|
| 1096 |
-
[preview_pointbox,
|
| 1097 |
preview_pointbox,
|
| 1098 |
)
|
| 1099 |
|
|
@@ -1103,14 +1095,14 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
|
|
| 1103 |
|
| 1104 |
text_apply_btn.click(
|
| 1105 |
_on_text_apply,
|
| 1106 |
-
inputs=[
|
| 1107 |
outputs=[preview_text, text_status, active_prompts_display],
|
| 1108 |
)
|
| 1109 |
|
| 1110 |
reset_prompts_btn.click(
|
| 1111 |
reset_prompts,
|
| 1112 |
-
inputs=[
|
| 1113 |
-
outputs=[
|
| 1114 |
)
|
| 1115 |
|
| 1116 |
def _render_video(s: AppState):
|
|
@@ -1139,32 +1131,32 @@ with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose"
|
|
| 1139 |
print(f"Failed to render video with cv2: {e}")
|
| 1140 |
raise gr.Error(f"Failed to render video: {e}")
|
| 1141 |
|
| 1142 |
-
render_btn_pointbox.click(_render_video, inputs=[
|
| 1143 |
-
render_btn_text.click(_render_video, inputs=[
|
| 1144 |
|
| 1145 |
propagate_btn_pointbox.click(
|
| 1146 |
propagate_masks,
|
| 1147 |
-
inputs=[
|
| 1148 |
-
outputs=[
|
| 1149 |
)
|
| 1150 |
|
| 1151 |
propagate_btn_text.click(
|
| 1152 |
propagate_masks,
|
| 1153 |
-
inputs=[
|
| 1154 |
-
outputs=[
|
| 1155 |
)
|
| 1156 |
|
| 1157 |
reset_btn_pointbox.click(
|
| 1158 |
reset_session,
|
| 1159 |
-
inputs=
|
| 1160 |
-
outputs=[
|
| 1161 |
)
|
| 1162 |
|
| 1163 |
reset_btn_text.click(
|
| 1164 |
reset_session,
|
| 1165 |
-
inputs=
|
| 1166 |
outputs=[
|
| 1167 |
-
|
| 1168 |
preview_text,
|
| 1169 |
frame_slider_text,
|
| 1170 |
frame_slider_text,
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
def init_video_session(
|
| 132 |
+
state: AppState, video: str | dict, active_tab: str = "point_box"
|
| 133 |
) -> tuple[AppState, int, int, Image.Image, str]:
|
| 134 |
+
state.video_frames = []
|
| 135 |
+
state.masks_by_frame = {}
|
| 136 |
+
state.color_by_obj = {}
|
| 137 |
+
state.color_by_prompt = {}
|
| 138 |
+
state.text_prompts_by_frame_obj = {}
|
| 139 |
+
state.clicks_by_frame_obj = {}
|
| 140 |
+
state.boxes_by_frame_obj = {}
|
| 141 |
+
state.composited_frames = {}
|
| 142 |
+
state.inference_session = None
|
| 143 |
+
state.active_tab = active_tab
|
| 144 |
|
| 145 |
video_path: str | None = None
|
| 146 |
if isinstance(video, dict):
|
|
|
|
| 165 |
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
|
| 166 |
if isinstance(info, dict):
|
| 167 |
info["num_frames"] = len(frames)
|
| 168 |
+
state.video_frames = frames
|
| 169 |
+
state.video_fps = float(fps_in) if fps_in else None
|
| 170 |
|
| 171 |
raw_video = [np.array(frame) for frame in frames]
|
| 172 |
|
| 173 |
if active_tab == "text":
|
| 174 |
processor = TEXT_VIDEO_PROCESSOR
|
| 175 |
+
state.inference_session = processor.init_video_session(
|
| 176 |
video=frames,
|
| 177 |
inference_device=DEVICE,
|
| 178 |
inference_state_device=DEVICE,
|
|
|
|
| 182 |
)
|
| 183 |
else:
|
| 184 |
processor = TRACKER_PROCESSOR
|
| 185 |
+
state.inference_session = processor.init_video_session(
|
| 186 |
video=raw_video,
|
| 187 |
inference_device=DEVICE,
|
| 188 |
inference_state_device=DEVICE,
|
|
|
|
| 195 |
max_idx = len(frames) - 1
|
| 196 |
if active_tab == "text":
|
| 197 |
status = (
|
| 198 |
+
f"Loaded {len(frames)} frames @ {state.video_fps or 'unknown'} fps{trimmed_note}. "
|
| 199 |
f"Device: {DEVICE}, dtype: bfloat16. Ready for text prompting."
|
| 200 |
)
|
| 201 |
else:
|
| 202 |
status = (
|
| 203 |
+
f"Loaded {len(frames)} frames @ {state.video_fps or 'unknown'} fps{trimmed_note}. "
|
| 204 |
f"Device: {DEVICE}, dtype: bfloat16. Video session initialized."
|
| 205 |
)
|
| 206 |
+
return state, 0, max_idx, first_frame, status
|
| 207 |
|
| 208 |
|
| 209 |
def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
|
|
|
|
| 596 |
return "**Active prompts:** None"
|
| 597 |
|
| 598 |
|
| 599 |
+
def propagate_masks(state: AppState) -> Iterator[tuple[AppState, str, dict]]:
|
| 600 |
+
if state is None:
|
| 601 |
+
return state, "Load a video first.", gr.update()
|
| 602 |
|
| 603 |
+
if state.active_tab != "text" and state.inference_session is None:
|
| 604 |
+
return state, "Load a video first.", gr.update()
|
| 605 |
|
| 606 |
+
total = max(1, state.num_frames)
|
| 607 |
processed = 0
|
| 608 |
|
| 609 |
+
yield state, f"Propagating masks: {processed}/{total}", gr.update()
|
| 610 |
|
| 611 |
last_frame_idx = 0
|
| 612 |
|
| 613 |
with torch.no_grad():
|
| 614 |
+
if state.active_tab == "text":
|
| 615 |
+
if state.inference_session is None:
|
| 616 |
+
yield state, "Text video model not loaded.", gr.update()
|
| 617 |
return
|
| 618 |
|
| 619 |
model = TEXT_VIDEO_MODEL
|
|
|
|
| 621 |
|
| 622 |
# Collect all unique prompts from existing frame annotations
|
| 623 |
text_prompt_to_obj_ids = {}
|
| 624 |
+
for frame_idx, frame_texts in state.text_prompts_by_frame_obj.items():
|
| 625 |
for obj_id, text_prompt in frame_texts.items():
|
| 626 |
if text_prompt not in text_prompt_to_obj_ids:
|
| 627 |
text_prompt_to_obj_ids[text_prompt] = []
|
|
|
|
| 629 |
text_prompt_to_obj_ids[text_prompt].append(obj_id)
|
| 630 |
|
| 631 |
# Also check if there are prompts already in the inference session
|
| 632 |
+
if hasattr(state.inference_session, "prompts") and state.inference_session.prompts:
|
| 633 |
+
for prompt_text in state.inference_session.prompts.values():
|
| 634 |
if prompt_text not in text_prompt_to_obj_ids:
|
| 635 |
text_prompt_to_obj_ids[prompt_text] = []
|
| 636 |
|
|
|
|
| 638 |
text_prompt_to_obj_ids[text_prompt].sort()
|
| 639 |
|
| 640 |
if not text_prompt_to_obj_ids:
|
| 641 |
+
yield state, "No text prompts found. Please add a text prompt first.", gr.update()
|
| 642 |
return
|
| 643 |
|
| 644 |
# Add all prompts to the inference session (processor handles deduplication)
|
| 645 |
for text_prompt in text_prompt_to_obj_ids:
|
| 646 |
+
state.inference_session = processor.add_text_prompt(
|
| 647 |
+
inference_session=state.inference_session,
|
| 648 |
text=text_prompt,
|
| 649 |
)
|
| 650 |
|
| 651 |
+
earliest_frame = min(state.text_prompts_by_frame_obj.keys()) if state.text_prompts_by_frame_obj else 0
|
|
|
|
|
|
|
| 652 |
|
| 653 |
+
frames_to_track = state.num_frames - earliest_frame
|
| 654 |
|
| 655 |
outputs_per_frame = {}
|
| 656 |
|
| 657 |
for model_outputs in model.propagate_in_video_iterator(
|
| 658 |
+
inference_session=state.inference_session,
|
| 659 |
start_frame_idx=earliest_frame,
|
| 660 |
max_frame_num_to_track=frames_to_track,
|
| 661 |
):
|
| 662 |
processed_outputs = processor.postprocess_outputs(
|
| 663 |
+
state.inference_session,
|
| 664 |
model_outputs,
|
| 665 |
)
|
| 666 |
frame_idx = model_outputs.frame_idx
|
|
|
|
| 671 |
scores = processed_outputs["scores"]
|
| 672 |
prompt_to_obj_ids = processed_outputs.get("prompt_to_obj_ids", {})
|
| 673 |
|
| 674 |
+
masks_for_frame = state.masks_by_frame.setdefault(frame_idx, {})
|
| 675 |
+
frame_texts = state.text_prompts_by_frame_obj.setdefault(frame_idx, {})
|
| 676 |
|
| 677 |
num_objects = len(object_ids)
|
| 678 |
if num_objects > 0:
|
|
|
|
| 699 |
# Store prompt and assign color
|
| 700 |
if found_prompt:
|
| 701 |
frame_texts[current_obj_id] = found_prompt.strip()
|
| 702 |
+
_ensure_color_for_obj(state, current_obj_id)
|
| 703 |
|
| 704 |
+
state.composited_frames.pop(frame_idx, None)
|
| 705 |
last_frame_idx = frame_idx
|
| 706 |
processed += 1
|
| 707 |
if processed % 30 == 0 or processed == total:
|
| 708 |
+
yield state, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
|
| 709 |
else:
|
| 710 |
+
if state.inference_session is None:
|
| 711 |
+
yield state, "Tracker model not loaded.", gr.update()
|
| 712 |
return
|
| 713 |
|
| 714 |
model = TRACKER_MODEL
|
| 715 |
processor = TRACKER_PROCESSOR
|
| 716 |
|
| 717 |
+
for sam2_video_output in model.propagate_in_video_iterator(inference_session=state.inference_session):
|
|
|
|
|
|
|
| 718 |
video_res_masks = processor.post_process_masks(
|
| 719 |
[sam2_video_output.pred_masks],
|
| 720 |
+
original_sizes=[[state.inference_session.video_height, state.inference_session.video_width]],
|
|
|
|
|
|
|
| 721 |
)[0]
|
| 722 |
|
| 723 |
frame_idx = sam2_video_output.frame_idx
|
| 724 |
+
for i, out_obj_id in enumerate(state.inference_session.obj_ids):
|
| 725 |
+
_ensure_color_for_obj(state, int(out_obj_id))
|
| 726 |
mask_2d = video_res_masks[i].cpu().numpy()
|
| 727 |
+
masks_for_frame = state.masks_by_frame.setdefault(frame_idx, {})
|
| 728 |
masks_for_frame[int(out_obj_id)] = mask_2d
|
| 729 |
+
state.composited_frames.pop(frame_idx, None)
|
| 730 |
|
| 731 |
last_frame_idx = frame_idx
|
| 732 |
processed += 1
|
| 733 |
if processed % 30 == 0 or processed == total:
|
| 734 |
+
yield state, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
|
| 735 |
|
| 736 |
text = f"Propagated masks across {processed} frames."
|
| 737 |
+
yield state, text, gr.update(value=last_frame_idx)
|
| 738 |
|
| 739 |
|
| 740 |
+
def reset_prompts(state: AppState) -> tuple[AppState, Image.Image, str, str]:
|
| 741 |
"""Reset prompts and all outputs, but keep processed frames and cached vision features."""
|
| 742 |
+
if state is None or state.inference_session is None:
|
| 743 |
+
active_prompts = _get_active_prompts_display(state)
|
| 744 |
+
return state, None, "No active session to reset.", active_prompts
|
| 745 |
|
| 746 |
+
if state.active_tab != "text":
|
| 747 |
+
active_prompts = _get_active_prompts_display(state)
|
| 748 |
+
return state, None, "Reset prompts is only available for text prompting mode.", active_prompts
|
| 749 |
|
| 750 |
# Reset inference session tracking data but keep cache and processed frames
|
| 751 |
+
if hasattr(state.inference_session, "reset_tracking_data"):
|
| 752 |
+
state.inference_session.reset_tracking_data()
|
| 753 |
|
| 754 |
# Manually clear prompts (reset_tracking_data doesn't clear prompts themselves)
|
| 755 |
+
if hasattr(state.inference_session, "prompts"):
|
| 756 |
+
state.inference_session.prompts.clear()
|
| 757 |
+
if hasattr(state.inference_session, "prompt_input_ids"):
|
| 758 |
+
state.inference_session.prompt_input_ids.clear()
|
| 759 |
+
if hasattr(state.inference_session, "prompt_embeddings"):
|
| 760 |
+
state.inference_session.prompt_embeddings.clear()
|
| 761 |
+
if hasattr(state.inference_session, "prompt_attention_masks"):
|
| 762 |
+
state.inference_session.prompt_attention_masks.clear()
|
| 763 |
+
if hasattr(state.inference_session, "obj_id_to_prompt_id"):
|
| 764 |
+
state.inference_session.obj_id_to_prompt_id.clear()
|
| 765 |
|
| 766 |
# Reset detection-tracking fusion state
|
| 767 |
+
if hasattr(state.inference_session, "obj_id_to_score"):
|
| 768 |
+
state.inference_session.obj_id_to_score.clear()
|
| 769 |
+
if hasattr(state.inference_session, "obj_id_to_tracker_score_frame_wise"):
|
| 770 |
+
state.inference_session.obj_id_to_tracker_score_frame_wise.clear()
|
| 771 |
+
if hasattr(state.inference_session, "obj_id_to_last_occluded"):
|
| 772 |
+
state.inference_session.obj_id_to_last_occluded.clear()
|
| 773 |
+
if hasattr(state.inference_session, "max_obj_id"):
|
| 774 |
+
state.inference_session.max_obj_id = -1
|
| 775 |
+
if hasattr(state.inference_session, "obj_first_frame_idx"):
|
| 776 |
+
state.inference_session.obj_first_frame_idx.clear()
|
| 777 |
+
if hasattr(state.inference_session, "unmatched_frame_inds"):
|
| 778 |
+
state.inference_session.unmatched_frame_inds.clear()
|
| 779 |
+
if hasattr(state.inference_session, "overlap_pair_to_frame_inds"):
|
| 780 |
+
state.inference_session.overlap_pair_to_frame_inds.clear()
|
| 781 |
+
if hasattr(state.inference_session, "trk_keep_alive"):
|
| 782 |
+
state.inference_session.trk_keep_alive.clear()
|
| 783 |
+
if hasattr(state.inference_session, "removed_obj_ids"):
|
| 784 |
+
state.inference_session.removed_obj_ids.clear()
|
| 785 |
+
if hasattr(state.inference_session, "suppressed_obj_ids"):
|
| 786 |
+
state.inference_session.suppressed_obj_ids.clear()
|
| 787 |
+
if hasattr(state.inference_session, "hotstart_removed_obj_ids"):
|
| 788 |
+
state.inference_session.hotstart_removed_obj_ids.clear()
|
| 789 |
|
| 790 |
# Clear all app state outputs
|
| 791 |
+
state.masks_by_frame.clear()
|
| 792 |
+
state.text_prompts_by_frame_obj.clear()
|
| 793 |
+
state.composited_frames.clear()
|
| 794 |
+
state.color_by_obj.clear()
|
| 795 |
+
state.color_by_prompt.clear()
|
| 796 |
|
| 797 |
# Update display
|
| 798 |
+
current_idx = int(getattr(state, "current_frame_idx", 0))
|
| 799 |
+
current_idx = max(0, min(current_idx, state.num_frames - 1))
|
| 800 |
+
preview_img = update_frame_display(state, current_idx)
|
| 801 |
+
active_prompts = _get_active_prompts_display(state)
|
| 802 |
status = "Prompts and outputs reset. Processed frames and cached vision features preserved."
|
| 803 |
|
| 804 |
+
return state, preview_img, status, active_prompts
|
| 805 |
|
| 806 |
|
| 807 |
+
def reset_session(state: AppState) -> tuple[AppState, Image.Image, int, int, str, str]:
|
| 808 |
+
if not state.video_frames:
|
| 809 |
+
return state, None, 0, 0, "Session reset. Load a new video.", "**Active prompts:** None"
|
| 810 |
|
| 811 |
+
if state.active_tab == "text":
|
| 812 |
+
if state.video_frames:
|
| 813 |
processor = TEXT_VIDEO_PROCESSOR
|
| 814 |
+
state.inference_session = processor.init_video_session(
|
| 815 |
+
video=state.video_frames,
|
| 816 |
inference_device=DEVICE,
|
| 817 |
processing_device="cpu",
|
| 818 |
video_storage_device="cpu",
|
| 819 |
dtype=DTYPE,
|
| 820 |
)
|
| 821 |
+
elif state.inference_session is not None and hasattr(state.inference_session, "reset_inference_session"):
|
| 822 |
+
state.inference_session.reset_inference_session()
|
| 823 |
+
elif state.video_frames:
|
|
|
|
|
|
|
| 824 |
processor = TRACKER_PROCESSOR
|
| 825 |
+
raw_video = [np.array(frame) for frame in state.video_frames]
|
| 826 |
+
state.inference_session = processor.init_video_session(
|
| 827 |
video=raw_video,
|
| 828 |
inference_device=DEVICE,
|
| 829 |
video_storage_device="cpu",
|
|
|
|
| 831 |
dtype=DTYPE,
|
| 832 |
)
|
| 833 |
|
| 834 |
+
state.masks_by_frame.clear()
|
| 835 |
+
state.clicks_by_frame_obj.clear()
|
| 836 |
+
state.boxes_by_frame_obj.clear()
|
| 837 |
+
state.text_prompts_by_frame_obj.clear()
|
| 838 |
+
state.composited_frames.clear()
|
| 839 |
+
state.color_by_obj.clear()
|
| 840 |
+
state.color_by_prompt.clear()
|
| 841 |
+
state.pending_box_start = None
|
| 842 |
+
state.pending_box_start_frame_idx = None
|
| 843 |
+
state.pending_box_start_obj_id = None
|
| 844 |
|
| 845 |
gc.collect()
|
| 846 |
|
| 847 |
+
current_idx = int(getattr(state, "current_frame_idx", 0))
|
| 848 |
+
current_idx = max(0, min(current_idx, state.num_frames - 1))
|
| 849 |
+
preview_img = update_frame_display(state, current_idx)
|
| 850 |
+
slider_minmax = gr.update(minimum=0, maximum=max(state.num_frames - 1, 0), interactive=True)
|
| 851 |
slider_value = gr.update(value=current_idx)
|
| 852 |
status = "Session reset. Prompts cleared; video preserved."
|
| 853 |
+
active_prompts = _get_active_prompts_display(state)
|
| 854 |
+
return state, preview_img, slider_minmax, slider_value, status, active_prompts
|
| 855 |
|
| 856 |
|
| 857 |
+
def _on_video_change_pointbox(state: AppState, video: str | dict) -> tuple[AppState, dict, Image.Image, str]:
|
| 858 |
+
state, min_idx, max_idx, first_frame, status = init_video_session(state, video, "point_box")
|
| 859 |
return (
|
| 860 |
+
state,
|
| 861 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
| 862 |
first_frame,
|
| 863 |
status,
|
| 864 |
)
|
| 865 |
|
| 866 |
|
| 867 |
+
def _on_video_change_text(state: AppState, video: str | dict) -> tuple[AppState, dict, Image.Image, str, str]:
|
| 868 |
+
state, min_idx, max_idx, first_frame, status = init_video_session(state, video, "text")
|
| 869 |
+
active_prompts = _get_active_prompts_display(state)
|
| 870 |
return (
|
| 871 |
+
state,
|
| 872 |
gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
|
| 873 |
first_frame,
|
| 874 |
status,
|
|
|
|
| 877 |
|
| 878 |
|
| 879 |
with gr.Blocks(title="SAM3", theme=Soft(primary_hue="blue", secondary_hue="rose", neutral_hue="slate")) as demo:
|
| 880 |
+
app_state = gr.State(AppState())
|
| 881 |
|
| 882 |
gr.Markdown(
|
| 883 |
"""
|
|
|
|
| 945 |
with gr.Row():
|
| 946 |
gr.Examples(
|
| 947 |
examples=examples_list_text,
|
| 948 |
+
inputs=[app_state, video_in_text],
|
| 949 |
fn=_on_video_change_text,
|
| 950 |
+
outputs=[app_state, frame_slider_text, preview_text, load_status_text, active_prompts_display],
|
| 951 |
label="Examples",
|
| 952 |
cache_examples=False,
|
| 953 |
examples_per_page=5,
|
|
|
|
| 1008 |
with gr.Row():
|
| 1009 |
gr.Examples(
|
| 1010 |
examples=examples_list_pointbox,
|
| 1011 |
+
inputs=[app_state, video_in_pointbox],
|
| 1012 |
fn=_on_video_change_pointbox,
|
| 1013 |
+
outputs=[app_state, frame_slider_pointbox, preview_pointbox, load_status_pointbox],
|
| 1014 |
label="Examples",
|
| 1015 |
cache_examples=False,
|
| 1016 |
examples_per_page=5,
|
|
|
|
| 1018 |
|
| 1019 |
video_in_pointbox.change(
|
| 1020 |
_on_video_change_pointbox,
|
| 1021 |
+
inputs=[app_state, video_in_pointbox],
|
| 1022 |
+
outputs=[app_state, frame_slider_pointbox, preview_pointbox, load_status_pointbox],
|
| 1023 |
show_progress=True,
|
| 1024 |
)
|
| 1025 |
|
|
|
|
| 1030 |
|
| 1031 |
frame_slider_pointbox.change(
|
| 1032 |
_sync_frame_idx_pointbox,
|
| 1033 |
+
inputs=[app_state, frame_slider_pointbox],
|
| 1034 |
outputs=preview_pointbox,
|
| 1035 |
)
|
| 1036 |
|
| 1037 |
video_in_text.change(
|
| 1038 |
_on_video_change_text,
|
| 1039 |
+
inputs=[app_state, video_in_text],
|
| 1040 |
+
outputs=[app_state, frame_slider_text, preview_text, load_status_text, active_prompts_display],
|
| 1041 |
show_progress=True,
|
| 1042 |
)
|
| 1043 |
|
|
|
|
| 1048 |
|
| 1049 |
frame_slider_text.change(
|
| 1050 |
_sync_frame_idx_text,
|
| 1051 |
+
inputs=[app_state, frame_slider_text],
|
| 1052 |
outputs=preview_text,
|
| 1053 |
)
|
| 1054 |
|
|
|
|
| 1057 |
s.current_obj_id = int(oid)
|
| 1058 |
return gr.update()
|
| 1059 |
|
| 1060 |
+
obj_id_inp.change(_sync_obj_id, inputs=[app_state, obj_id_inp], outputs=[])
|
| 1061 |
|
| 1062 |
def _sync_label(s: AppState, lab: str):
|
| 1063 |
if s is not None and lab is not None:
|
| 1064 |
s.current_label = str(lab)
|
| 1065 |
return gr.update()
|
| 1066 |
|
| 1067 |
+
label_radio.change(_sync_label, inputs=[app_state, label_radio], outputs=[])
|
| 1068 |
|
| 1069 |
def _sync_prompt_type(s: AppState, val: str):
|
| 1070 |
if s is not None and val is not None:
|
|
|
|
| 1079 |
|
| 1080 |
prompt_type.change(
|
| 1081 |
_sync_prompt_type,
|
| 1082 |
+
inputs=[app_state, prompt_type],
|
| 1083 |
outputs=[label_radio, clear_old_chk],
|
| 1084 |
)
|
| 1085 |
|
| 1086 |
preview_pointbox.select(
|
| 1087 |
on_image_click,
|
| 1088 |
+
[preview_pointbox, app_state, frame_slider_pointbox, obj_id_inp, label_radio, clear_old_chk],
|
| 1089 |
preview_pointbox,
|
| 1090 |
)
|
| 1091 |
|
|
|
|
| 1095 |
|
| 1096 |
text_apply_btn.click(
|
| 1097 |
_on_text_apply,
|
| 1098 |
+
inputs=[app_state, frame_slider_text, text_prompt_input],
|
| 1099 |
outputs=[preview_text, text_status, active_prompts_display],
|
| 1100 |
)
|
| 1101 |
|
| 1102 |
reset_prompts_btn.click(
|
| 1103 |
reset_prompts,
|
| 1104 |
+
inputs=[app_state],
|
| 1105 |
+
outputs=[app_state, preview_text, text_status, active_prompts_display],
|
| 1106 |
)
|
| 1107 |
|
| 1108 |
def _render_video(s: AppState):
|
|
|
|
| 1131 |
print(f"Failed to render video with cv2: {e}")
|
| 1132 |
raise gr.Error(f"Failed to render video: {e}")
|
| 1133 |
|
| 1134 |
+
render_btn_pointbox.click(_render_video, inputs=[app_state], outputs=[playback_video_pointbox])
|
| 1135 |
+
render_btn_text.click(_render_video, inputs=[app_state], outputs=[playback_video_text])
|
| 1136 |
|
| 1137 |
propagate_btn_pointbox.click(
|
| 1138 |
propagate_masks,
|
| 1139 |
+
inputs=[app_state],
|
| 1140 |
+
outputs=[app_state, propagate_status_pointbox, frame_slider_pointbox],
|
| 1141 |
)
|
| 1142 |
|
| 1143 |
propagate_btn_text.click(
|
| 1144 |
propagate_masks,
|
| 1145 |
+
inputs=[app_state],
|
| 1146 |
+
outputs=[app_state, propagate_status_text, frame_slider_text],
|
| 1147 |
)
|
| 1148 |
|
| 1149 |
reset_btn_pointbox.click(
|
| 1150 |
reset_session,
|
| 1151 |
+
inputs=app_state,
|
| 1152 |
+
outputs=[app_state, preview_pointbox, frame_slider_pointbox, frame_slider_pointbox, load_status_pointbox],
|
| 1153 |
)
|
| 1154 |
|
| 1155 |
reset_btn_text.click(
|
| 1156 |
reset_session,
|
| 1157 |
+
inputs=app_state,
|
| 1158 |
outputs=[
|
| 1159 |
+
app_state,
|
| 1160 |
preview_text,
|
| 1161 |
frame_slider_text,
|
| 1162 |
frame_slider_text,
|