gpu support
Browse files- app.py +17 -3
- requirements.txt +3 -0
app.py
CHANGED
|
@@ -3,6 +3,9 @@ VideoMaMa Gradio Demo
|
|
| 3 |
Interactive video matting with SAM2 mask tracking
|
| 4 |
"""
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
import os
|
| 7 |
import json
|
| 8 |
import time
|
|
@@ -47,9 +50,12 @@ POINT_ALPHA = 0.9
|
|
| 47 |
POINT_RADIUS = 15
|
| 48 |
|
| 49 |
def initialize_models():
|
| 50 |
-
"""Initialize SAM2 and VideoMaMa models"""
|
| 51 |
global sam2_tracker, videomama_pipeline
|
| 52 |
|
|
|
|
|
|
|
|
|
|
| 53 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 54 |
print(f"Using device: {device}")
|
| 55 |
|
|
@@ -149,6 +155,7 @@ def load_video(video_input, video_state):
|
|
| 149 |
gr.update(visible=True), gr.update(visible=False)
|
| 150 |
|
| 151 |
|
|
|
|
| 152 |
def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
|
| 153 |
"""
|
| 154 |
Add click and update mask on first frame
|
|
@@ -159,6 +166,9 @@ def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
|
|
| 159 |
click_state: [[points], [labels]]
|
| 160 |
evt: Gradio SelectData event with click coordinates
|
| 161 |
"""
|
|
|
|
|
|
|
|
|
|
| 162 |
if video_state is None or "frames" not in video_state:
|
| 163 |
return None, video_state, click_state
|
| 164 |
|
|
@@ -264,10 +274,14 @@ def propagate_masks(video_state, click_state):
|
|
| 264 |
return video_state, status_msg, gr.update(visible=True)
|
| 265 |
|
| 266 |
|
|
|
|
| 267 |
def run_videomama_with_sam2(video_state, click_state):
|
| 268 |
"""
|
| 269 |
Run SAM2 propagation and VideoMaMa inference together
|
| 270 |
"""
|
|
|
|
|
|
|
|
|
|
| 271 |
if video_state is None or "frames" not in video_state:
|
| 272 |
return video_state, None, None, None, "⚠️ No video loaded"
|
| 273 |
|
|
@@ -482,8 +496,8 @@ if __name__ == "__main__":
|
|
| 482 |
print("VideoMaMa Interactive Demo")
|
| 483 |
print("=" * 60)
|
| 484 |
|
| 485 |
-
#
|
| 486 |
-
initialize_models()
|
| 487 |
|
| 488 |
# Launch demo
|
| 489 |
demo.queue()
|
|
|
|
| 3 |
Interactive video matting with SAM2 mask tracking
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
# CRITICAL: Import spaces FIRST before any CUDA-related imports
|
| 7 |
+
import spaces
|
| 8 |
+
|
| 9 |
import os
|
| 10 |
import json
|
| 11 |
import time
|
|
|
|
| 50 |
POINT_RADIUS = 15
|
| 51 |
|
| 52 |
def initialize_models():
|
| 53 |
+
"""Initialize SAM2 and VideoMaMa models (lazy loading)"""
|
| 54 |
global sam2_tracker, videomama_pipeline
|
| 55 |
|
| 56 |
+
if sam2_tracker is not None and videomama_pipeline is not None:
|
| 57 |
+
return # Already initialized
|
| 58 |
+
|
| 59 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 60 |
print(f"Using device: {device}")
|
| 61 |
|
|
|
|
| 155 |
gr.update(visible=True), gr.update(visible=False)
|
| 156 |
|
| 157 |
|
| 158 |
+
@spaces.GPU
|
| 159 |
def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
|
| 160 |
"""
|
| 161 |
Add click and update mask on first frame
|
|
|
|
| 166 |
click_state: [[points], [labels]]
|
| 167 |
evt: Gradio SelectData event with click coordinates
|
| 168 |
"""
|
| 169 |
+
# Lazy load models on first use
|
| 170 |
+
initialize_models()
|
| 171 |
+
|
| 172 |
if video_state is None or "frames" not in video_state:
|
| 173 |
return None, video_state, click_state
|
| 174 |
|
|
|
|
| 274 |
return video_state, status_msg, gr.update(visible=True)
|
| 275 |
|
| 276 |
|
| 277 |
+
@spaces.GPU(duration=120)
|
| 278 |
def run_videomama_with_sam2(video_state, click_state):
|
| 279 |
"""
|
| 280 |
Run SAM2 propagation and VideoMaMa inference together
|
| 281 |
"""
|
| 282 |
+
# Lazy load models on first use
|
| 283 |
+
initialize_models()
|
| 284 |
+
|
| 285 |
if video_state is None or "frames" not in video_state:
|
| 286 |
return video_state, None, None, None, "⚠️ No video loaded"
|
| 287 |
|
|
|
|
| 496 |
print("VideoMaMa Interactive Demo")
|
| 497 |
print("=" * 60)
|
| 498 |
|
| 499 |
+
# Models will be initialized on first use (lazy loading for ZeroGPU)
|
| 500 |
+
# initialize_models()
|
| 501 |
|
| 502 |
# Launch demo
|
| 503 |
demo.queue()
|
requirements.txt
CHANGED
|
@@ -1,5 +1,8 @@
|
|
| 1 |
# Hugging Face Space Requirements for VideoMaMa Demo
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
# Core frameworks
|
| 4 |
torch>=2.0.0
|
| 5 |
torchvision>=0.15.0
|
|
|
|
| 1 |
# Hugging Face Space Requirements for VideoMaMa Demo
|
| 2 |
|
| 3 |
+
# CRITICAL: Hugging Face ZeroGPU support
|
| 4 |
+
spaces
|
| 5 |
+
|
| 6 |
# Core frameworks
|
| 7 |
torch>=2.0.0
|
| 8 |
torchvision>=0.15.0
|