Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
adb1841
1
Parent(s):
03b77e4
process frames one at a time
Browse files
app.py
CHANGED
|
@@ -54,10 +54,12 @@ def try_load_video_frames(video_path_or_url: str) -> tuple[list[Image.Image], di
|
|
| 54 |
cap.release()
|
| 55 |
if fps_val and fps_val > 0:
|
| 56 |
info["fps"] = float(fps_val)
|
| 57 |
-
except Exception:
|
|
|
|
| 58 |
pass
|
| 59 |
return pil_frames, info
|
| 60 |
-
except Exception:
|
|
|
|
| 61 |
# Fallback to OpenCV
|
| 62 |
try:
|
| 63 |
import cv2 # type: ignore
|
|
@@ -180,14 +182,6 @@ def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoPr
|
|
| 180 |
if GLOBAL_STATE.model_repo_id == desired_repo:
|
| 181 |
return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
|
| 182 |
# Different repo requested: dispose current and reload
|
| 183 |
-
try:
|
| 184 |
-
del GLOBAL_STATE.model
|
| 185 |
-
except Exception:
|
| 186 |
-
pass
|
| 187 |
-
try:
|
| 188 |
-
del GLOBAL_STATE.processor
|
| 189 |
-
except Exception:
|
| 190 |
-
pass
|
| 191 |
GLOBAL_STATE.model = None
|
| 192 |
GLOBAL_STATE.processor = None
|
| 193 |
print(f"Loading model from {desired_repo}")
|
|
@@ -219,16 +213,8 @@ def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
|
|
| 219 |
GLOBAL_STATE.clicks_by_frame_obj.clear()
|
| 220 |
GLOBAL_STATE.boxes_by_frame_obj.clear()
|
| 221 |
GLOBAL_STATE.composited_frames.clear()
|
| 222 |
-
# Dispose previous session cleanly
|
| 223 |
-
try:
|
| 224 |
-
if GLOBAL_STATE.inference_session is not None:
|
| 225 |
-
GLOBAL_STATE.inference_session.reset_inference_session()
|
| 226 |
-
except Exception:
|
| 227 |
-
pass
|
| 228 |
GLOBAL_STATE.inference_session = None
|
| 229 |
-
gc.collect()
|
| 230 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 231 |
-
video=GLOBAL_STATE.video_frames,
|
| 232 |
inference_device=device,
|
| 233 |
video_storage_device="cpu",
|
| 234 |
dtype=dtype,
|
|
@@ -265,40 +251,18 @@ def init_video_session(GLOBAL_STATE: gr.State, video: str | dict) -> tuple[AppSt
|
|
| 265 |
# Enforce max duration of 8 seconds (trim if longer)
|
| 266 |
MAX_SECONDS = 8.0
|
| 267 |
trimmed_note = ""
|
| 268 |
-
fps_in =
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
max_frames_allowed = int(MAX_SECONDS * fps_in)
|
| 276 |
-
if len(frames) > max_frames_allowed:
|
| 277 |
-
frames = frames[:max_frames_allowed]
|
| 278 |
-
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
|
| 279 |
-
if isinstance(info, dict):
|
| 280 |
-
info["num_frames"] = len(frames)
|
| 281 |
-
else:
|
| 282 |
-
# Fallback when FPS unknown: assume ~30 FPS and cap to 240 frames (~8s)
|
| 283 |
-
max_frames_allowed = 240
|
| 284 |
-
if len(frames) > max_frames_allowed:
|
| 285 |
-
frames = frames[:max_frames_allowed]
|
| 286 |
-
trimmed_note = " (trimmed to 240 frames ~8s @30fps)"
|
| 287 |
-
if isinstance(info, dict):
|
| 288 |
-
info["num_frames"] = len(frames)
|
| 289 |
-
|
| 290 |
GLOBAL_STATE.video_frames = frames
|
| 291 |
# Try to capture original FPS if provided by loader
|
| 292 |
-
GLOBAL_STATE.video_fps =
|
| 293 |
-
if isinstance(info, dict) and info.get("fps"):
|
| 294 |
-
try:
|
| 295 |
-
GLOBAL_STATE.video_fps = float(info["fps"]) or None
|
| 296 |
-
except Exception:
|
| 297 |
-
GLOBAL_STATE.video_fps = None
|
| 298 |
-
|
| 299 |
# Initialize session
|
| 300 |
inference_session = processor.init_video_session(
|
| 301 |
-
video=frames,
|
| 302 |
inference_device=device,
|
| 303 |
video_storage_device="cpu",
|
| 304 |
dtype=dtype,
|
|
@@ -412,6 +376,12 @@ def on_image_click(
|
|
| 412 |
processor = state.processor
|
| 413 |
model = state.model
|
| 414 |
inference_session = state.inference_session
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
|
| 416 |
if state.current_prompt_type == "Boxes":
|
| 417 |
# Two-click box input
|
|
@@ -443,6 +413,7 @@ def on_image_click(
|
|
| 443 |
obj_ids=int(obj_id),
|
| 444 |
input_boxes=[[[x_min, y_min, x_max, y_max]]],
|
| 445 |
clear_old_inputs=True, # For boxes, always clear old inputs
|
|
|
|
| 446 |
)
|
| 447 |
|
| 448 |
frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
|
|
@@ -465,6 +436,7 @@ def on_image_click(
|
|
| 465 |
obj_ids=int(obj_id),
|
| 466 |
input_points=[[[[int(x), int(y)]]]],
|
| 467 |
input_labels=[[[int(label_int)]]],
|
|
|
|
| 468 |
clear_old_inputs=bool(clear_old),
|
| 469 |
)
|
| 470 |
|
|
@@ -477,10 +449,7 @@ def on_image_click(
|
|
| 477 |
|
| 478 |
# Forward on that frame
|
| 479 |
with torch.inference_mode():
|
| 480 |
-
outputs = model(
|
| 481 |
-
inference_session=inference_session,
|
| 482 |
-
frame_idx=int(frame_idx),
|
| 483 |
-
)
|
| 484 |
|
| 485 |
H = inference_session.video_height
|
| 486 |
W = inference_session.video_width
|
|
@@ -509,8 +478,8 @@ def on_image_click(
|
|
| 509 |
@spaces.GPU()
|
| 510 |
def propagate_masks(GLOBAL_STATE: gr.State):
|
| 511 |
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
| 512 |
-
yield "Load a video first.", gr.update()
|
| 513 |
-
return
|
| 514 |
|
| 515 |
processor = deepcopy(GLOBAL_STATE.processor)
|
| 516 |
model = deepcopy(GLOBAL_STATE.model)
|
|
@@ -524,17 +493,19 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 524 |
processed = 0
|
| 525 |
|
| 526 |
# Initial status; no slider change yet
|
| 527 |
-
yield f"Propagating masks: {processed}/{total}", gr.update()
|
| 528 |
|
| 529 |
last_frame_idx = 0
|
| 530 |
with torch.inference_mode():
|
| 531 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
H = inference_session.video_height
|
| 533 |
W = inference_session.video_width
|
| 534 |
pred_masks = sam2_video_output.pred_masks.detach().cpu()
|
| 535 |
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
|
| 536 |
-
|
| 537 |
-
frame_idx = int(sam2_video_output.frame_idx)
|
| 538 |
last_frame_idx = frame_idx
|
| 539 |
masks_for_frame: dict[int, np.ndarray] = {}
|
| 540 |
obj_ids_order = list(inference_session.obj_ids)
|
|
@@ -547,15 +518,15 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 547 |
|
| 548 |
processed += 1
|
| 549 |
# Every 15th frame (or last), move slider to current frame to update preview via slider binding
|
| 550 |
-
if processed %
|
| 551 |
-
yield f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
|
| 552 |
-
else:
|
| 553 |
-
|
| 554 |
|
| 555 |
text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
|
| 556 |
|
| 557 |
# Final status; ensure slider points to last processed frame
|
| 558 |
-
yield text, gr.update(value=last_frame_idx)
|
| 559 |
|
| 560 |
|
| 561 |
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
|
|
@@ -785,14 +756,16 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 785 |
|
| 786 |
iio.imwrite(out_path, [fr[:, :, ::-1] for fr in frames_np], plugin="pyav", fps=fps)
|
| 787 |
return out_path
|
| 788 |
-
except Exception:
|
|
|
|
| 789 |
# Fallbacks
|
| 790 |
try:
|
| 791 |
import imageio.v2 as imageio # type: ignore
|
| 792 |
|
| 793 |
imageio.mimsave(out_path, [fr[:, :, ::-1] for fr in frames_np], fps=fps)
|
| 794 |
return out_path
|
| 795 |
-
except Exception:
|
|
|
|
| 796 |
try:
|
| 797 |
import cv2 # type: ignore
|
| 798 |
|
|
@@ -803,6 +776,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 803 |
writer.release()
|
| 804 |
return out_path
|
| 805 |
except Exception as e:
|
|
|
|
| 806 |
raise gr.Error(f"Failed to render video: {e}")
|
| 807 |
|
| 808 |
render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video])
|
|
@@ -811,7 +785,7 @@ with gr.Blocks(title="SAM2 Video (Transformers) - Interactive Segmentation", the
|
|
| 811 |
propagate_btn.click(
|
| 812 |
propagate_masks,
|
| 813 |
inputs=[GLOBAL_STATE],
|
| 814 |
-
outputs=[propagate_status, frame_slider],
|
| 815 |
)
|
| 816 |
|
| 817 |
reset_btn.click(
|
|
|
|
| 54 |
cap.release()
|
| 55 |
if fps_val and fps_val > 0:
|
| 56 |
info["fps"] = float(fps_val)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"Failed to render video with cv2: {e}")
|
| 59 |
pass
|
| 60 |
return pil_frames, info
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f"Failed to load video with transformers.video_utils: {e}")
|
| 63 |
# Fallback to OpenCV
|
| 64 |
try:
|
| 65 |
import cv2 # type: ignore
|
|
|
|
| 182 |
if GLOBAL_STATE.model_repo_id == desired_repo:
|
| 183 |
return GLOBAL_STATE.model, GLOBAL_STATE.processor, GLOBAL_STATE.device, GLOBAL_STATE.dtype
|
| 184 |
# Different repo requested: dispose current and reload
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
GLOBAL_STATE.model = None
|
| 186 |
GLOBAL_STATE.processor = None
|
| 187 |
print(f"Loading model from {desired_repo}")
|
|
|
|
| 213 |
GLOBAL_STATE.clicks_by_frame_obj.clear()
|
| 214 |
GLOBAL_STATE.boxes_by_frame_obj.clear()
|
| 215 |
GLOBAL_STATE.composited_frames.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
GLOBAL_STATE.inference_session = None
|
|
|
|
| 217 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
|
|
|
| 218 |
inference_device=device,
|
| 219 |
video_storage_device="cpu",
|
| 220 |
dtype=dtype,
|
|
|
|
| 251 |
# Enforce max duration of 8 seconds (trim if longer)
|
| 252 |
MAX_SECONDS = 8.0
|
| 253 |
trimmed_note = ""
|
| 254 |
+
fps_in = info.get("fps")
|
| 255 |
+
max_frames_allowed = int(MAX_SECONDS * fps_in)
|
| 256 |
+
if len(frames) > max_frames_allowed:
|
| 257 |
+
frames = frames[:max_frames_allowed]
|
| 258 |
+
trimmed_note = f" (trimmed to {int(MAX_SECONDS)}s = {len(frames)} frames)"
|
| 259 |
+
if isinstance(info, dict):
|
| 260 |
+
info["num_frames"] = len(frames)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
GLOBAL_STATE.video_frames = frames
|
| 262 |
# Try to capture original FPS if provided by loader
|
| 263 |
+
GLOBAL_STATE.video_fps = float(fps_in)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
# Initialize session
|
| 265 |
inference_session = processor.init_video_session(
|
|
|
|
| 266 |
inference_device=device,
|
| 267 |
video_storage_device="cpu",
|
| 268 |
dtype=dtype,
|
|
|
|
| 376 |
processor = state.processor
|
| 377 |
model = state.model
|
| 378 |
inference_session = state.inference_session
|
| 379 |
+
original_size = None
|
| 380 |
+
pixel_values = None
|
| 381 |
+
if not inference_session.processed_frames or frame_idx not in inference_session.processed_frames:
|
| 382 |
+
inputs = processor(images=state.video_frames[frame_idx], device=state.device, return_tensors="pt")
|
| 383 |
+
original_size = inputs.original_sizes[0]
|
| 384 |
+
pixel_values = inputs.pixel_values[0]
|
| 385 |
|
| 386 |
if state.current_prompt_type == "Boxes":
|
| 387 |
# Two-click box input
|
|
|
|
| 413 |
obj_ids=int(obj_id),
|
| 414 |
input_boxes=[[[x_min, y_min, x_max, y_max]]],
|
| 415 |
clear_old_inputs=True, # For boxes, always clear old inputs
|
| 416 |
+
original_size=original_size,
|
| 417 |
)
|
| 418 |
|
| 419 |
frame_boxes = state.boxes_by_frame_obj.setdefault(int(frame_idx), {})
|
|
|
|
| 436 |
obj_ids=int(obj_id),
|
| 437 |
input_points=[[[[int(x), int(y)]]]],
|
| 438 |
input_labels=[[[int(label_int)]]],
|
| 439 |
+
original_size=original_size,
|
| 440 |
clear_old_inputs=bool(clear_old),
|
| 441 |
)
|
| 442 |
|
|
|
|
| 449 |
|
| 450 |
# Forward on that frame
|
| 451 |
with torch.inference_mode():
|
| 452 |
+
outputs = model(inference_session=inference_session, frame=pixel_values, frame_idx=int(frame_idx))
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
H = inference_session.video_height
|
| 455 |
W = inference_session.video_width
|
|
|
|
| 478 |
@spaces.GPU()
|
| 479 |
def propagate_masks(GLOBAL_STATE: gr.State):
|
| 480 |
if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
|
| 481 |
+
# yield GLOBAL_STATE, "Load a video first.", gr.update()
|
| 482 |
+
return GLOBAL_STATE, "Load a video first.", gr.update()
|
| 483 |
|
| 484 |
processor = deepcopy(GLOBAL_STATE.processor)
|
| 485 |
model = deepcopy(GLOBAL_STATE.model)
|
|
|
|
| 493 |
processed = 0
|
| 494 |
|
| 495 |
# Initial status; no slider change yet
|
| 496 |
+
yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update()
|
| 497 |
|
| 498 |
last_frame_idx = 0
|
| 499 |
with torch.inference_mode():
|
| 500 |
+
for frame_idx, frame in enumerate(GLOBAL_STATE.video_frames):
|
| 501 |
+
pixel_values = None
|
| 502 |
+
if not inference_session.processed_frames or frame_idx not in inference_session.processed_frames:
|
| 503 |
+
pixel_values = processor(images=frame, device="cuda", return_tensors="pt").pixel_values[0]
|
| 504 |
+
sam2_video_output = model(inference_session=inference_session, frame=pixel_values, frame_idx=frame_idx)
|
| 505 |
H = inference_session.video_height
|
| 506 |
W = inference_session.video_width
|
| 507 |
pred_masks = sam2_video_output.pred_masks.detach().cpu()
|
| 508 |
video_res_masks = processor.post_process_masks([pred_masks], original_sizes=[[H, W]])[0]
|
|
|
|
|
|
|
| 509 |
last_frame_idx = frame_idx
|
| 510 |
masks_for_frame: dict[int, np.ndarray] = {}
|
| 511 |
obj_ids_order = list(inference_session.obj_ids)
|
|
|
|
| 518 |
|
| 519 |
processed += 1
|
| 520 |
# Every 15th frame (or last), move slider to current frame to update preview via slider binding
|
| 521 |
+
if processed % 30 == 0 or processed == total:
|
| 522 |
+
yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update(value=frame_idx)
|
| 523 |
+
# else:
|
| 524 |
+
# yield GLOBAL_STATE, f"Propagating masks: {processed}/{total}", gr.update()
|
| 525 |
|
| 526 |
text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
|
| 527 |
|
| 528 |
# Final status; ensure slider points to last processed frame
|
| 529 |
+
yield GLOBAL_STATE, text, gr.update(value=last_frame_idx)
|
| 530 |
|
| 531 |
|
| 532 |
def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
|
|
|
|
| 756 |
|
| 757 |
iio.imwrite(out_path, [fr[:, :, ::-1] for fr in frames_np], plugin="pyav", fps=fps)
|
| 758 |
return out_path
|
| 759 |
+
except Exception as e:
|
| 760 |
+
print(f"Failed to render video with imageio.v3: {e}")
|
| 761 |
# Fallbacks
|
| 762 |
try:
|
| 763 |
import imageio.v2 as imageio # type: ignore
|
| 764 |
|
| 765 |
imageio.mimsave(out_path, [fr[:, :, ::-1] for fr in frames_np], fps=fps)
|
| 766 |
return out_path
|
| 767 |
+
except Exception as e:
|
| 768 |
+
print(f"Failed to render video with imageio.v2: {e}")
|
| 769 |
try:
|
| 770 |
import cv2 # type: ignore
|
| 771 |
|
|
|
|
| 776 |
writer.release()
|
| 777 |
return out_path
|
| 778 |
except Exception as e:
|
| 779 |
+
print(f"Failed to render video with cv2: {e}")
|
| 780 |
raise gr.Error(f"Failed to render video: {e}")
|
| 781 |
|
| 782 |
render_btn.click(_render_video, inputs=[GLOBAL_STATE], outputs=[playback_video])
|
|
|
|
| 785 |
propagate_btn.click(
|
| 786 |
propagate_masks,
|
| 787 |
inputs=[GLOBAL_STATE],
|
| 788 |
+
outputs=[GLOBAL_STATE, propagate_status, frame_slider],
|
| 789 |
)
|
| 790 |
|
| 791 |
reset_btn.click(
|