pizb commited on
Commit
0b67fec
·
1 Parent(s): d04036a

sam2 update

Browse files
Files changed (4) hide show
  1. app.py +2 -5
  2. sam2.1_hiera_large.pt +3 -0
  3. sam2_wrapper.py +17 -5
  4. videomama_wrapper.py +18 -7
app.py CHANGED
@@ -3,10 +3,6 @@ VideoMaMa Gradio Demo
3
  Interactive video matting with SAM2 mask tracking
4
  """
5
 
6
- import sys
7
- sys.path.append("../")
8
- sys.path.append("../../")
9
-
10
  import os
11
  import json
12
  import time
@@ -379,7 +375,8 @@ button {border-radius: 8px !important;}
379
  """
380
 
381
  # Build Gradio interface
382
- with gr.Blocks(css=custom_css, title="VideoMaMa Demo") as demo:
 
383
  gr.HTML('<div class="title-text">VideoMaMa Interactive Demo</div>')
384
  gr.Markdown(
385
  '<div class="description-text">🎬 Upload a video → 🖱️ Click to mark object → ✅ Generate masks → 🎨 Run VideoMaMa</div>'
 
3
  Interactive video matting with SAM2 mask tracking
4
  """
5
 
 
 
 
 
6
  import os
7
  import json
8
  import time
 
375
  """
376
 
377
  # Build Gradio interface
378
+ with gr.Blocks(title="VideoMaMa Demo") as demo:
379
+ gr.HTML(f"<style>{custom_css}</style>")
380
  gr.HTML('<div class="title-text">VideoMaMa Interactive Demo</div>')
381
  gr.Markdown(
382
  '<div class="description-text">🎬 Upload a video → 🖱️ Click to mark object → ✅ Generate masks → 🎨 Run VideoMaMa</div>'
sam2.1_hiera_large.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2647878d5dfa5098f2f8649825738a9345572bae2d4350a2468587ece47dd318
3
+ size 898083611
sam2_wrapper.py CHANGED
@@ -3,9 +3,7 @@ SAM2 Wrapper for Video Mask Tracking
3
  Handles mask generation and propagation through video
4
  """
5
 
6
- import sys
7
- sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2")
8
-
9
  import cv2
10
  import numpy as np
11
  import torch
@@ -163,8 +161,22 @@ def load_sam2_tracker(device="cuda"):
163
  Returns:
164
  SAM2VideoTracker instance
165
  """
166
- checkpoint_path = "/home/cvlab19/project/samuel/CVPR/sam2/checkpoints/sam2.1_hiera_large.pt"
167
- config_file = "configs/sam2.1/sam2.1_hiera_l.yaml"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  print(f"Loading SAM2 from {checkpoint_path}...")
170
  tracker = SAM2VideoTracker(checkpoint_path, config_file, device)
 
3
  Handles mask generation and propagation through video
4
  """
5
 
6
+ import os
 
 
7
  import cv2
8
  import numpy as np
9
  import torch
 
161
  Returns:
162
  SAM2VideoTracker instance
163
  """
164
+ # Use relative paths that work on Hugging Face Space
165
+ # The checkpoint file should be in the root directory or checkpoints/
166
+ checkpoint_path = "sam2.1_hiera_large.pt"
167
+ config_file = "sam2_hiera_l.yaml"
168
+
169
+ # Check if checkpoint exists
170
+ if not os.path.exists(checkpoint_path):
171
+ # Try alternative path
172
+ alt_checkpoint_path = os.path.join("checkpoints", "sam2.1_hiera_large.pt")
173
+ if os.path.exists(alt_checkpoint_path):
174
+ checkpoint_path = alt_checkpoint_path
175
+ else:
176
+ raise FileNotFoundError(
177
+ f"SAM2 checkpoint not found at {checkpoint_path} or {alt_checkpoint_path}. "
178
+ "Please run download_checkpoints.sh first or ensure sam2.1_hiera_large.pt is in the root directory."
179
+ )
180
 
181
  print(f"Loading SAM2 from {checkpoint_path}...")
182
  tracker = SAM2VideoTracker(checkpoint_path, config_file, device)
videomama_wrapper.py CHANGED
@@ -3,10 +3,7 @@ VideoMaMa Inference Wrapper
3
  Handles video matting with mask conditioning
4
  """
5
 
6
- import sys
7
- sys.path.append("../")
8
- sys.path.append("../../")
9
-
10
  import torch
11
  import numpy as np
12
  from PIL import Image
@@ -70,9 +67,23 @@ def load_videomama_pipeline(device="cuda"):
70
  Returns:
71
  VideoInferencePipeline instance
72
  """
73
- # Local paths for testing
74
- base_model_path = "/home/cvlab19/project/samuel/data/CVPR/pretrained_models/stable-video-diffusion-img2vid-xt"
75
- unet_checkpoint_path = "/home/cvlab19/project/samuel/data/CVPR/pretrained_models/videomama"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  print(f"Loading VideoMaMa pipeline from {unet_checkpoint_path}...")
78
 
 
3
  Handles video matting with mask conditioning
4
  """
5
 
6
+ import os
 
 
 
7
  import torch
8
  import numpy as np
9
  from PIL import Image
 
67
  Returns:
68
  VideoInferencePipeline instance
69
  """
70
+ # Use relative paths for Hugging Face Space
71
+ # Checkpoints should be downloaded via download_checkpoints.sh
72
+ base_model_path = os.path.join("checkpoints", "stable-video-diffusion-img2vid-xt")
73
+ unet_checkpoint_path = os.path.join("checkpoints", "videomama")
74
+
75
+ # Check if checkpoints exist
76
+ if not os.path.exists(base_model_path):
77
+ raise FileNotFoundError(
78
+ f"SVD base model not found at {base_model_path}. "
79
+ "Please run download_checkpoints.sh first."
80
+ )
81
+
82
+ if not os.path.exists(unet_checkpoint_path):
83
+ raise FileNotFoundError(
84
+ f"VideoMaMa checkpoint not found at {unet_checkpoint_path}. "
85
+ "Please run download_checkpoints.sh first."
86
+ )
87
 
88
  print(f"Loading VideoMaMa pipeline from {unet_checkpoint_path}...")
89