sam2 update
Browse files- app.py +2 -5
- sam2.1_hiera_large.pt +3 -0
- sam2_wrapper.py +17 -5
- videomama_wrapper.py +18 -7
app.py
CHANGED
|
@@ -3,10 +3,6 @@ VideoMaMa Gradio Demo
|
|
| 3 |
Interactive video matting with SAM2 mask tracking
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
import sys
|
| 7 |
-
sys.path.append("../")
|
| 8 |
-
sys.path.append("../../")
|
| 9 |
-
|
| 10 |
import os
|
| 11 |
import json
|
| 12 |
import time
|
|
@@ -379,7 +375,8 @@ button {border-radius: 8px !important;}
|
|
| 379 |
"""
|
| 380 |
|
| 381 |
# Build Gradio interface
|
| 382 |
-
with gr.Blocks(
|
|
|
|
| 383 |
gr.HTML('<div class="title-text">VideoMaMa Interactive Demo</div>')
|
| 384 |
gr.Markdown(
|
| 385 |
'<div class="description-text">🎬 Upload a video → 🖱️ Click to mark object → ✅ Generate masks → 🎨 Run VideoMaMa</div>'
|
|
|
|
| 3 |
Interactive video matting with SAM2 mask tracking
|
| 4 |
"""
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import os
|
| 7 |
import json
|
| 8 |
import time
|
|
|
|
| 375 |
"""
|
| 376 |
|
| 377 |
# Build Gradio interface
|
| 378 |
+
with gr.Blocks(title="VideoMaMa Demo") as demo:
|
| 379 |
+
gr.HTML(f"<style>{custom_css}</style>")
|
| 380 |
gr.HTML('<div class="title-text">VideoMaMa Interactive Demo</div>')
|
| 381 |
gr.Markdown(
|
| 382 |
'<div class="description-text">🎬 Upload a video → 🖱️ Click to mark object → ✅ Generate masks → 🎨 Run VideoMaMa</div>'
|
sam2.1_hiera_large.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2647878d5dfa5098f2f8649825738a9345572bae2d4350a2468587ece47dd318
|
| 3 |
+
size 898083611
|
sam2_wrapper.py
CHANGED
|
@@ -3,9 +3,7 @@ SAM2 Wrapper for Video Mask Tracking
|
|
| 3 |
Handles mask generation and propagation through video
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
import
|
| 7 |
-
sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2")
|
| 8 |
-
|
| 9 |
import cv2
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
|
@@ -163,8 +161,22 @@ def load_sam2_tracker(device="cuda"):
|
|
| 163 |
Returns:
|
| 164 |
SAM2VideoTracker instance
|
| 165 |
"""
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
print(f"Loading SAM2 from {checkpoint_path}...")
|
| 170 |
tracker = SAM2VideoTracker(checkpoint_path, config_file, device)
|
|
|
|
| 3 |
Handles mask generation and propagation through video
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
import os
|
|
|
|
|
|
|
| 7 |
import cv2
|
| 8 |
import numpy as np
|
| 9 |
import torch
|
|
|
|
| 161 |
Returns:
|
| 162 |
SAM2VideoTracker instance
|
| 163 |
"""
|
| 164 |
+
# Use relative paths that work on Hugging Face Space
|
| 165 |
+
# The checkpoint file should be in the root directory or checkpoints/
|
| 166 |
+
checkpoint_path = "sam2.1_hiera_large.pt"
|
| 167 |
+
config_file = "sam2_hiera_l.yaml"
|
| 168 |
+
|
| 169 |
+
# Check if checkpoint exists
|
| 170 |
+
if not os.path.exists(checkpoint_path):
|
| 171 |
+
# Try alternative path
|
| 172 |
+
alt_checkpoint_path = os.path.join("checkpoints", "sam2.1_hiera_large.pt")
|
| 173 |
+
if os.path.exists(alt_checkpoint_path):
|
| 174 |
+
checkpoint_path = alt_checkpoint_path
|
| 175 |
+
else:
|
| 176 |
+
raise FileNotFoundError(
|
| 177 |
+
f"SAM2 checkpoint not found at {checkpoint_path} or {alt_checkpoint_path}. "
|
| 178 |
+
"Please run download_checkpoints.sh first or ensure sam2.1_hiera_large.pt is in the root directory."
|
| 179 |
+
)
|
| 180 |
|
| 181 |
print(f"Loading SAM2 from {checkpoint_path}...")
|
| 182 |
tracker = SAM2VideoTracker(checkpoint_path, config_file, device)
|
videomama_wrapper.py
CHANGED
|
@@ -3,10 +3,7 @@ VideoMaMa Inference Wrapper
|
|
| 3 |
Handles video matting with mask conditioning
|
| 4 |
"""
|
| 5 |
|
| 6 |
-
import
|
| 7 |
-
sys.path.append("../")
|
| 8 |
-
sys.path.append("../../")
|
| 9 |
-
|
| 10 |
import torch
|
| 11 |
import numpy as np
|
| 12 |
from PIL import Image
|
|
@@ -70,9 +67,23 @@ def load_videomama_pipeline(device="cuda"):
|
|
| 70 |
Returns:
|
| 71 |
VideoInferencePipeline instance
|
| 72 |
"""
|
| 73 |
-
#
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
print(f"Loading VideoMaMa pipeline from {unet_checkpoint_path}...")
|
| 78 |
|
|
|
|
| 3 |
Handles video matting with mask conditioning
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
import os
|
|
|
|
|
|
|
|
|
|
| 7 |
import torch
|
| 8 |
import numpy as np
|
| 9 |
from PIL import Image
|
|
|
|
| 67 |
Returns:
|
| 68 |
VideoInferencePipeline instance
|
| 69 |
"""
|
| 70 |
+
# Use relative paths for Hugging Face Space
|
| 71 |
+
# Checkpoints should be downloaded via download_checkpoints.sh
|
| 72 |
+
base_model_path = os.path.join("checkpoints", "stable-video-diffusion-img2vid-xt")
|
| 73 |
+
unet_checkpoint_path = os.path.join("checkpoints", "videomama")
|
| 74 |
+
|
| 75 |
+
# Check if checkpoints exist
|
| 76 |
+
if not os.path.exists(base_model_path):
|
| 77 |
+
raise FileNotFoundError(
|
| 78 |
+
f"SVD base model not found at {base_model_path}. "
|
| 79 |
+
"Please run download_checkpoints.sh first."
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
if not os.path.exists(unet_checkpoint_path):
|
| 83 |
+
raise FileNotFoundError(
|
| 84 |
+
f"VideoMaMa checkpoint not found at {unet_checkpoint_path}. "
|
| 85 |
+
"Please run download_checkpoints.sh first."
|
| 86 |
+
)
|
| 87 |
|
| 88 |
print(f"Loading VideoMaMa pipeline from {unet_checkpoint_path}...")
|
| 89 |
|