pizb commited on
Commit
d33e75e
·
1 Parent(s): dea893d

initial update

Browse files
.gitignore ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ env/
26
+ ENV/
27
+ .venv
28
+
29
+ # IDE
30
+ .vscode/
31
+ .idea/
32
+ *.swp
33
+ *.swo
34
+ *~
35
+
36
+ # Gradio
37
+ flagged/
38
+
39
+ # Temporary files
40
+ *.tmp
41
+ temp/
42
+ temp_*/
43
+ *.log
44
+
45
+ # Model checkpoints (download separately)
46
+ checkpoints/*.pt
47
+ checkpoints/*.pth
48
+ checkpoints/*.safetensors
49
+ checkpoints/*.bin
50
+
51
+ # Videos
52
+ samples/*.mp4
53
+ samples/*.avi
54
+ samples/*.mov
55
+ *.mp4
56
+ *.avi
57
+ *.mov
58
+
59
+ # OS
60
+ .DS_Store
61
+ Thumbs.db
62
+ *.bak
63
+
64
+ # Jupyter
65
+ .ipynb_checkpoints/
.hf_gitignore ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+
11
+ # Virtual environments
12
+ venv/
13
+ env/
14
+ ENV/
15
+
16
+ # IDE
17
+ .vscode/
18
+ .idea/
19
+ *.swp
20
+ *.swo
21
+ *~
22
+
23
+ # OS
24
+ .DS_Store
25
+ Thumbs.db
26
+
27
+ # Model checkpoints (will be downloaded)
28
+ checkpoints/
29
+ *.pt
30
+ *.pth
31
+ *.safetensors
32
+ *.bin
33
+
34
+ # Outputs
35
+ outputs/
36
+ output_*.mp4
37
+ masks_*.mp4
38
+ greenscreen_*.mp4
39
+
40
+ # Temporary files
41
+ *.tmp
42
+ tmp/
43
+ temp/
44
+
45
+ # Logs
46
+ *.log
47
+ logs/
app.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VideoMaMa Gradio Demo
3
+ Interactive video matting with SAM2 mask tracking
4
+ """
5
+
6
+ import sys
7
+ sys.path.append("../")
8
+ sys.path.append("../../")
9
+
10
+ import os
11
+ import json
12
+ import time
13
+ import cv2
14
+ import torch
15
+ import numpy as np
16
+ import gradio as gr
17
+ from PIL import Image
18
+ from pathlib import Path
19
+
20
+ from sam2_wrapper import load_sam2_tracker
21
+ from videomama_wrapper import load_videomama_pipeline, videomama
22
+ from tools.painter import mask_painter, point_painter
23
+
24
+ import warnings
25
+ warnings.filterwarnings("ignore")
26
+
27
+ # Global models
28
+ sam2_tracker = None
29
+ videomama_pipeline = None
30
+
31
+ # Constants
32
+ MASK_COLOR = 3
33
+ MASK_ALPHA = 0.7
34
+ CONTOUR_COLOR = 1
35
+ CONTOUR_WIDTH = 5
36
+ POINT_COLOR_POS = 8 # Positive points - orange
37
+ POINT_COLOR_NEG = 1 # Negative points - red
38
+ POINT_ALPHA = 0.9
39
+ POINT_RADIUS = 15
40
+
41
+ def initialize_models():
42
+ """Initialize SAM2 and VideoMaMa models"""
43
+ global sam2_tracker, videomama_pipeline
44
+
45
+ device = "cuda" if torch.cuda.is_available() else "cpu"
46
+ print(f"Using device: {device}")
47
+
48
+ # Load SAM2
49
+ sam2_tracker = load_sam2_tracker(device=device)
50
+
51
+ # Load VideoMaMa
52
+ videomama_pipeline = load_videomama_pipeline(device=device)
53
+
54
+ print("All models initialized successfully!")
55
+
56
+
57
+ def extract_frames_from_video(video_path, max_frames=50):
58
+ """
59
+ Extract frames from video file
60
+
61
+ Args:
62
+ video_path: Path to video file
63
+ max_frames: Maximum number of frames to extract
64
+
65
+ Returns:
66
+ frames: List of numpy arrays (H,W,3), uint8 RGB
67
+ fps: Original FPS of video
68
+ """
69
+ cap = cv2.VideoCapture(video_path)
70
+ fps = cap.get(cv2.CAP_PROP_FPS)
71
+
72
+ frames = []
73
+ while cap.isOpened() and len(frames) < max_frames:
74
+ ret, frame = cap.read()
75
+ if not ret:
76
+ break
77
+ # Convert BGR to RGB
78
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
79
+ frames.append(frame_rgb)
80
+
81
+ cap.release()
82
+ print(f"Extracted {len(frames)} frames from video (FPS: {fps})")
83
+
84
+ return frames, fps
85
+
86
+
87
+ def get_prompt(click_state, click_input):
88
+ """
89
+ Convert click input to prompt format
90
+
91
+ Args:
92
+ click_state: [[points], [labels]]
93
+ click_input: JSON string "[[x, y, label]]"
94
+
95
+ Returns:
96
+ Updated click_state
97
+ """
98
+ inputs = json.loads(click_input)
99
+ points = click_state[0]
100
+ labels = click_state[1]
101
+
102
+ for input_item in inputs:
103
+ points.append(input_item[:2])
104
+ labels.append(input_item[2])
105
+
106
+ click_state[0] = points
107
+ click_state[1] = labels
108
+
109
+ return click_state
110
+
111
+
112
+ def load_video(video_input, video_state):
113
+ """
114
+ Load video and extract first frame for mask generation
115
+ """
116
+ if video_input is None:
117
+ return video_state, None, \
118
+ gr.update(visible=False), gr.update(visible=False), \
119
+ gr.update(visible=False), gr.update(visible=False)
120
+
121
+ # Extract frames
122
+ frames, fps = extract_frames_from_video(video_input, max_frames=50)
123
+
124
+ if len(frames) == 0:
125
+ return video_state, None, \
126
+ gr.update(visible=False), gr.update(visible=False), \
127
+ gr.update(visible=False), gr.update(visible=False)
128
+
129
+ # Initialize video state
130
+ video_state = {
131
+ "frames": frames,
132
+ "fps": fps,
133
+ "first_frame_mask": None,
134
+ "masks": None,
135
+ }
136
+
137
+ first_frame_pil = Image.fromarray(frames[0])
138
+
139
+ return video_state, first_frame_pil, \
140
+ gr.update(visible=True), gr.update(visible=True), \
141
+ gr.update(visible=True), gr.update(visible=False)
142
+
143
+
144
+ def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
145
+ """
146
+ Add click and update mask on first frame
147
+
148
+ Args:
149
+ video_state: Dictionary with video data
150
+ point_prompt: "Positive" or "Negative"
151
+ click_state: [[points], [labels]]
152
+ evt: Gradio SelectData event with click coordinates
153
+ """
154
+ if video_state is None or "frames" not in video_state:
155
+ return None, video_state, click_state
156
+
157
+ # Add new click
158
+ x, y = evt.index[0], evt.index[1]
159
+ label = 1 if point_prompt == "Positive" else 0
160
+
161
+ click_state[0].append([x, y])
162
+ click_state[1].append(label)
163
+
164
+ print(f"Added {point_prompt} click at ({x}, {y}). Total clicks: {len(click_state[0])}")
165
+
166
+ # Generate mask with SAM2
167
+ first_frame = video_state["frames"][0]
168
+ mask = sam2_tracker.get_first_frame_mask(
169
+ frame=first_frame,
170
+ points=click_state[0],
171
+ labels=click_state[1]
172
+ )
173
+
174
+ # Store mask in video state
175
+ video_state["first_frame_mask"] = mask
176
+
177
+ # Visualize mask and points
178
+ painted_image = mask_painter(
179
+ first_frame.copy(),
180
+ mask,
181
+ MASK_COLOR,
182
+ MASK_ALPHA,
183
+ CONTOUR_COLOR,
184
+ CONTOUR_WIDTH
185
+ )
186
+
187
+ # Paint positive points
188
+ positive_points = np.array([click_state[0][i] for i in range(len(click_state[0]))
189
+ if click_state[1][i] == 1])
190
+ if len(positive_points) > 0:
191
+ painted_image = point_painter(
192
+ painted_image,
193
+ positive_points,
194
+ POINT_COLOR_POS,
195
+ POINT_ALPHA,
196
+ POINT_RADIUS,
197
+ CONTOUR_COLOR,
198
+ CONTOUR_WIDTH
199
+ )
200
+
201
+ # Paint negative points
202
+ negative_points = np.array([click_state[0][i] for i in range(len(click_state[0]))
203
+ if click_state[1][i] == 0])
204
+ if len(negative_points) > 0:
205
+ painted_image = point_painter(
206
+ painted_image,
207
+ negative_points,
208
+ POINT_COLOR_NEG,
209
+ POINT_ALPHA,
210
+ POINT_RADIUS,
211
+ CONTOUR_COLOR,
212
+ CONTOUR_WIDTH
213
+ )
214
+
215
+ painted_pil = Image.fromarray(painted_image)
216
+
217
+ return painted_pil, video_state, click_state
218
+
219
+
220
+ def clear_clicks(video_state, click_state):
221
+ """Clear all clicks and reset to original first frame"""
222
+ click_state = [[], []]
223
+
224
+ if video_state is not None and "frames" in video_state:
225
+ first_frame = video_state["frames"][0]
226
+ video_state["first_frame_mask"] = None
227
+ return Image.fromarray(first_frame), video_state, click_state
228
+
229
+ return None, video_state, click_state
230
+
231
+
232
+ def propagate_masks(video_state, click_state):
233
+ """
234
+ Propagate first frame mask through entire video using SAM2
235
+ """
236
+ if video_state is None or "frames" not in video_state:
237
+ return video_state, "No video loaded", gr.update(visible=False)
238
+
239
+ if len(click_state[0]) == 0:
240
+ return video_state, "⚠️ Please add at least one point first", gr.update(visible=False)
241
+
242
+ frames = video_state["frames"]
243
+
244
+ # Track through video
245
+ print(f"Tracking object through {len(frames)} frames...")
246
+ masks = sam2_tracker.track_video(
247
+ frames=frames,
248
+ points=click_state[0],
249
+ labels=click_state[1]
250
+ )
251
+
252
+ video_state["masks"] = masks
253
+
254
+ status_msg = f"✓ Generated {len(masks)} masks. Ready to run VideoMaMa!"
255
+
256
+ return video_state, status_msg, gr.update(visible=True)
257
+
258
+
259
+ def run_videomama_with_sam2(video_state, click_state):
260
+ """
261
+ Run SAM2 propagation and VideoMaMa inference together
262
+ """
263
+ if video_state is None or "frames" not in video_state:
264
+ return video_state, None, None, None, "⚠️ No video loaded"
265
+
266
+ if len(click_state[0]) == 0:
267
+ return video_state, None, None, None, "⚠️ Please add at least one point first"
268
+
269
+ frames = video_state["frames"]
270
+
271
+ # Step 1: Track through video with SAM2
272
+ print(f"🎯 Tracking object through {len(frames)} frames with SAM2...")
273
+ masks = sam2_tracker.track_video(
274
+ frames=frames,
275
+ points=click_state[0],
276
+ labels=click_state[1]
277
+ )
278
+
279
+ video_state["masks"] = masks
280
+ print(f"✓ Generated {len(masks)} masks")
281
+
282
+ # Step 2: Run VideoMaMa
283
+ print(f"🎨 Running VideoMaMa on {len(frames)} frames...")
284
+ output_frames = videomama(videomama_pipeline, frames, masks)
285
+
286
+ # Save output videos
287
+ output_dir = Path("outputs")
288
+ output_dir.mkdir(exist_ok=True)
289
+
290
+ timestamp = int(time.time())
291
+ output_video_path = output_dir / f"output_{timestamp}.mp4"
292
+ mask_video_path = output_dir / f"masks_{timestamp}.mp4"
293
+ greenscreen_path = output_dir / f"greenscreen_{timestamp}.mp4"
294
+
295
+ # Save matting result
296
+ save_video(output_frames, output_video_path, video_state["fps"])
297
+
298
+ # Save mask video (for visualization)
299
+ mask_frames_rgb = [np.stack([m, m, m], axis=-1) for m in masks]
300
+ save_video(mask_frames_rgb, mask_video_path, video_state["fps"])
301
+
302
+ # Create greenscreen composite: RGB * VideoMaMa_alpha + green * (1 - VideoMaMa_alpha)
303
+ # VideoMaMa output_frames already contain the alpha matte result
304
+ greenscreen_frames = []
305
+ for orig_frame, output_frame in zip(frames, output_frames):
306
+ # Extract alpha matte from VideoMaMa output
307
+ # VideoMaMa outputs matted foreground, we use its intensity as alpha
308
+ gray = cv2.cvtColor(output_frame, cv2.COLOR_RGB2GRAY)
309
+ alpha = np.clip(gray.astype(np.float32) / 255.0, 0, 1)
310
+ alpha_3ch = np.stack([alpha, alpha, alpha], axis=-1)
311
+
312
+ # Create green background
313
+ green_bg = np.zeros_like(orig_frame)
314
+ green_bg[:, :] = [156, 251, 165] # Green screen color
315
+
316
+ # Composite: original_RGB * alpha + green * (1 - alpha)
317
+ composite = (orig_frame.astype(np.float32) * alpha_3ch +
318
+ green_bg.astype(np.float32) * (1 - alpha_3ch)).astype(np.uint8)
319
+ greenscreen_frames.append(composite)
320
+
321
+ save_video(greenscreen_frames, greenscreen_path, video_state["fps"])
322
+
323
+ status_msg = f"✓ Complete! Generated {len(output_frames)} frames."
324
+
325
+ return video_state, str(output_video_path), str(mask_video_path), str(greenscreen_path), status_msg
326
+
327
+
328
+ def save_video(frames, output_path, fps):
329
+ """Save frames as video file"""
330
+ if len(frames) == 0:
331
+ return
332
+
333
+ height, width = frames[0].shape[:2]
334
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
335
+ out = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
336
+
337
+ for frame in frames:
338
+ if len(frame.shape) == 2: # Grayscale
339
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
340
+ else: # RGB
341
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
342
+ out.write(frame)
343
+
344
+ out.release()
345
+ print(f"Saved video to {output_path}")
346
+
347
+
348
+ def restart():
349
+ """Reset all states"""
350
+ return None, [[], []], None, \
351
+ gr.update(visible=False), gr.update(visible=False), \
352
+ gr.update(visible=False), None, None, None, ""
353
+
354
+
355
+ # CSS styling
356
+ custom_css = """
357
+ .gradio-container {width: 90% !important; margin: 0 auto;}
358
+ .title-text {text-align: center; font-size: 48px; font-weight: bold;
359
+ background: linear-gradient(to right, #8b5cf6, #10b981);
360
+ -webkit-background-clip: text; -webkit-text-fill-color: transparent;}
361
+ .description-text {text-align: center; font-size: 18px; margin: 20px 0;}
362
+ button {border-radius: 8px !important;}
363
+ .green_button {background-color: #10b981 !important; color: white !important;}
364
+ .red_button {background-color: #ef4444 !important; color: white !important;}
365
+ .run_matting_button {
366
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 50%, #f093fb 100%) !important;
367
+ color: white !important;
368
+ font-weight: bold !important;
369
+ font-size: 18px !important;
370
+ padding: 20px !important;
371
+ box-shadow: 0 4px 15px 0 rgba(102, 126, 234, 0.75) !important;
372
+ border: none !important;
373
+ }
374
+ .run_matting_button:hover {
375
+ background: linear-gradient(135deg, #764ba2 0%, #667eea 50%, #f093fb 100%) !important;
376
+ box-shadow: 0 6px 20px 0 rgba(102, 126, 234, 0.9) !important;
377
+ transform: translateY(-2px) !important;
378
+ }
379
+ """
380
+
381
+ # Build Gradio interface
382
+ with gr.Blocks(css=custom_css, title="VideoMaMa Demo") as demo:
383
+ gr.HTML('<div class="title-text">VideoMaMa Interactive Demo</div>')
384
+ gr.Markdown(
385
+ '<div class="description-text">🎬 Upload a video → 🖱️ Click to mark object → ✅ Generate masks → 🎨 Run VideoMaMa</div>'
386
+ )
387
+
388
+ # State variables
389
+ video_state = gr.State(None)
390
+ click_state = gr.State([[], []]) # [[points], [labels]]
391
+
392
+ with gr.Row():
393
+ with gr.Column(scale=1):
394
+ gr.Markdown("### Step 1: Upload Video")
395
+ video_input = gr.Video(label="Input Video")
396
+ load_button = gr.Button("📁 Load Video", variant="primary")
397
+
398
+ gr.Markdown("### Step 2: Mark Object")
399
+ point_prompt = gr.Radio(
400
+ choices=["Positive", "Negative"],
401
+ value="Positive",
402
+ label="Click Type",
403
+ info="Positive: object, Negative: background",
404
+ visible=False
405
+ )
406
+ clear_button = gr.Button("🗑️ Clear Clicks", visible=False)
407
+
408
+ with gr.Column(scale=1):
409
+ gr.Markdown("### First Frame (Click to Add Points)")
410
+ first_frame_display = gr.Image(
411
+ label="First Frame",
412
+ type="pil",
413
+ interactive=True
414
+ )
415
+ run_button = gr.Button("🚀 Run Matting", visible=False, elem_classes="run_matting_button", size="lg")
416
+
417
+ status_text = gr.Textbox(label="Status", value="", interactive=False, visible=False)
418
+
419
+ gr.Markdown("### Outputs")
420
+ with gr.Row():
421
+ with gr.Column():
422
+ output_video = gr.Video(label="Matting Result", autoplay=True)
423
+ with gr.Column():
424
+ greenscreen_video = gr.Video(label="Greenscreen Composite", autoplay=True)
425
+ with gr.Column():
426
+ mask_video = gr.Video(label="Mask Track", autoplay=True)
427
+
428
+ # Event handlers
429
+ load_button.click(
430
+ fn=load_video,
431
+ inputs=[video_input, video_state],
432
+ outputs=[video_state, first_frame_display,
433
+ point_prompt, clear_button, run_button, status_text]
434
+ )
435
+
436
+ first_frame_display.select(
437
+ fn=sam_refine,
438
+ inputs=[video_state, point_prompt, click_state],
439
+ outputs=[first_frame_display, video_state, click_state]
440
+ )
441
+
442
+ clear_button.click(
443
+ fn=clear_clicks,
444
+ inputs=[video_state, click_state],
445
+ outputs=[first_frame_display, video_state, click_state]
446
+ )
447
+
448
+ run_button.click(
449
+ fn=run_videomama_with_sam2,
450
+ inputs=[video_state, click_state],
451
+ outputs=[video_state, output_video, mask_video, greenscreen_video, status_text]
452
+ )
453
+
454
+ video_input.change(
455
+ fn=restart,
456
+ inputs=[],
457
+ outputs=[video_state, click_state, first_frame_display,
458
+ point_prompt, clear_button, run_button,
459
+ output_video, mask_video, greenscreen_video, status_text]
460
+ )
461
+
462
+ # Examples
463
+ gr.Markdown("---\n### 📦 Example Videos")
464
+ example_dir = Path("samples")
465
+ if example_dir.exists():
466
+ examples = [str(p) for p in sorted(example_dir.glob("*.mp4"))]
467
+ if examples:
468
+ gr.Examples(examples=examples, inputs=[video_input])
469
+
470
+
471
+ if __name__ == "__main__":
472
+ print("=" * 60)
473
+ print("VideoMaMa Interactive Demo")
474
+ print("=" * 60)
475
+
476
+ # Initialize models
477
+ initialize_models()
478
+
479
+ # Launch demo
480
+ demo.queue()
481
+ demo.launch(
482
+ server_name="127.0.0.1",
483
+ server_port=7860,
484
+ share=True
485
+ )
download_checkpoints.sh ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Download model checkpoints for VideoMaMa demo
3
+
4
+ set -e
5
+
6
+ echo "🔽 Downloading model checkpoints for VideoMaMa demo..."
7
+ echo ""
8
+
9
+ # Create checkpoints directory
10
+ echo "Creating checkpoints directory..."
11
+ mkdir -p checkpoints
12
+ echo "✓ Directory created"
13
+ echo ""
14
+
15
+ # Download SAM2 checkpoint
16
+ echo "Downloading SAM2 checkpoint..."
17
+ echo "URL: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
18
+ echo "This may take a few minutes (file size: ~900MB)..."
19
+
20
+ if command -v wget &> /dev/null; then
21
+ wget https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt \
22
+ -O checkpoints/sam2_hiera_large.pt
23
+ elif command -v curl &> /dev/null; then
24
+ curl -L https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt \
25
+ -o checkpoints/sam2_hiera_large.pt
26
+ else
27
+ echo "❌ Error: Neither wget nor curl is available. Please install one of them."
28
+ exit 1
29
+ fi
30
+
31
+ echo "✓ SAM2 checkpoint downloaded successfully"
32
+ echo ""
33
+
34
+ # Check if VideoMaMa checkpoint exists
35
+ echo "Checking VideoMaMa checkpoint..."
36
+ if [ -d "checkpoints/videomama_unet" ]; then
37
+ if [ -f "checkpoints/videomama_unet/config.json" ] && \
38
+ { [ -f "checkpoints/videomama_unet/diffusion_pytorch_model.safetensors" ] || \
39
+ [ -f "checkpoints/videomama_unet/diffusion_pytorch_model.bin" ]; }; then
40
+ echo "✓ VideoMaMa checkpoint already exists"
41
+ else
42
+ echo "⚠️ VideoMaMa checkpoint directory exists but is incomplete"
43
+ echo " Please add the following files to checkpoints/videomama_unet/:"
44
+ echo " - config.json"
45
+ echo " - diffusion_pytorch_model.safetensors (or .bin)"
46
+ fi
47
+ else
48
+ echo "⚠️ VideoMaMa checkpoint not found"
49
+ echo ""
50
+ echo "📝 Manual step required:"
51
+ echo " 1. Create directory: checkpoints/videomama_unet/"
52
+ echo " 2. Copy your trained VideoMaMa checkpoint files:"
53
+ echo " - config.json"
54
+ echo " - diffusion_pytorch_model.safetensors (or .bin)"
55
+ echo ""
56
+ echo " Example:"
57
+ echo " mkdir -p checkpoints/videomama_unet"
58
+ echo " cp /path/to/your/checkpoint/* checkpoints/videomama_unet/"
59
+ fi
60
+
61
+ echo ""
62
+ echo "="*70
63
+ echo "✨ Checkpoint download complete!"
64
+ echo "="*70
65
+ echo ""
66
+ echo "Next steps:"
67
+ echo "1. Verify checkpoints are in place:"
68
+ echo " python test_setup.py"
69
+ echo ""
70
+ echo "2. (Optional) Add sample videos:"
71
+ echo " mkdir -p samples"
72
+ echo " cp your_sample.mp4 samples/"
73
+ echo ""
74
+ echo "3. Test locally:"
75
+ echo " python app.py"
76
+ echo ""
77
+ echo "4. Deploy to Hugging Face Space"
78
+ echo ""
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Space Requirements for VideoMaMa Demo
2
+
3
+ # Core frameworks
4
+ torch>=2.0.0
5
+ torchvision>=0.15.0
6
+ diffusers>=0.24.0
7
+ transformers>=4.30.0
8
+
9
+ # Gradio for UI
10
+ gradio==4.31.0
11
+
12
+ # Image and video processing
13
+ opencv-python>=4.8.0
14
+ opencv-contrib-python>=4.8.0
15
+ Pillow>=10.0.0
16
+ numpy>=1.24.0
17
+ scipy>=1.10.0
18
+
19
+ # SAM2 dependencies
20
+ segment-anything-2 @ git+https://github.com/facebookresearch/segment-anything-2.git
21
+
22
+ # Additional utilities
23
+ accelerate>=0.20.0
24
+ einops>=0.6.0
25
+ tqdm>=4.65.0
26
+ safetensors>=0.3.0
27
+
28
+ # For video export
29
+ imageio>=2.31.0
30
+ imageio-ffmpeg>=0.4.9
31
+ pydantic==2.10.6
sam2_hiera_l.yaml ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Configuration for SAM2
2
+ # This file should be placed alongside the SAM2 checkpoint
3
+
4
+ # SAM 2 Hiera Large Configuration
5
+ model:
6
+ _target_: sam2.modeling.sam2_base.SAM2Base
7
+ image_encoder:
8
+ _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
9
+ trunk:
10
+ _target_: sam2.modeling.backbones.hieradet.Hiera
11
+ embed_dim: 144
12
+ num_heads: 2
13
+ stages: [2, 6, 36, 4]
14
+ global_att_blocks: [23, 33, 43]
15
+ window_pos_embed_bkg_spatial_size: [7, 7]
16
+ window_spec: [8, 4, 16, 8]
17
+ neck:
18
+ _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19
+ position_encoding:
20
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21
+ num_pos_feats: 256
22
+ normalize: true
23
+ scale: null
24
+ temperature: 10000
25
+ d_model: 256
26
+ backbone_channel_list: [1152, 576, 288, 144]
27
+ fpn_top_down_levels: [2, 3]
28
+ fpn_interp_model: nearest
29
+
30
+ memory_attention:
31
+ _target_: sam2.modeling.memory_attention.MemoryAttention
32
+ d_model: 256
33
+ pos_enc_at_input: true
34
+ layer:
35
+ _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36
+ activation: relu
37
+ dim_feedforward: 2048
38
+ dropout: 0.1
39
+ pos_enc_at_attn: false
40
+ self_attention:
41
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
42
+ rope_theta: 10000.0
43
+ feat_sizes: [32, 32]
44
+ embedding_dim: 256
45
+ num_heads: 1
46
+ downsample_rate: 1
47
+ dropout: 0.1
48
+ d_model: 256
49
+ pos_enc_at_cross_attn_keys: true
50
+ pos_enc_at_cross_attn_queries: false
51
+ cross_attention:
52
+ _target_: sam2.modeling.sam.transformer.RoPEAttention
53
+ rope_theta: 10000.0
54
+ feat_sizes: [32, 32]
55
+ rope_k_repeat: True
56
+ embedding_dim: 256
57
+ num_heads: 1
58
+ downsample_rate: 1
59
+ dropout: 0.1
60
+ kv_in_dim: 64
61
+ num_layers: 4
62
+
63
+ memory_encoder:
64
+ _target_: sam2.modeling.memory_encoder.MemoryEncoder
65
+ out_dim: 64
66
+ position_encoding:
67
+ _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68
+ num_pos_feats: 64
69
+ normalize: true
70
+ scale: null
71
+ temperature: 10000
72
+ mask_downsampler:
73
+ _target_: sam2.modeling.memory_encoder.MaskDownSampler
74
+ kernel_size: 3
75
+ stride: 2
76
+ padding: 1
77
+ fuser:
78
+ _target_: sam2.modeling.memory_encoder.Fuser
79
+ layer:
80
+ _target_: sam2.modeling.memory_encoder.CXBlock
81
+ dim: 256
82
+ kernel_size: 7
83
+ padding: 3
84
+ layer_scale_init_value: 1e-6
85
+ use_dwconv: True
86
+ num_layers: 2
87
+
88
+ num_maskmem: 7
89
+ image_size: 1024
90
+ sigmoid_scale_for_mem_enc: 20.0
91
+ sigmoid_bias_for_mem_enc: -10.0
92
+ use_mask_input_as_output_without_sam: true
93
+ directly_add_no_mem_embed: true
94
+ use_high_res_features_in_sam: true
95
+ multimask_output_in_sam: true
96
+ multimask_min_pt_num: 0
97
+ multimask_max_pt_num: 1
98
+ multimask_output_for_tracking: true
99
+ use_multimask_token_for_obj_ptr: true
100
+ iou_prediction_use_sigmoid: True
101
+ memory_temporal_stride_for_eval: 1
102
+ non_overlap_masks_for_mem_enc: true
103
+ use_obj_ptrs_in_encoder: true
104
+ max_obj_ptrs_in_encoder: 16
105
+ add_tpos_enc_to_obj_ptrs: false
106
+ proj_tpos_enc_in_obj_ptrs: false
107
+ use_signed_tpos_enc_to_obj_ptrs: false
108
+ only_obj_ptrs_in_the_past_for_eval: true
109
+ pred_obj_scores: true
110
+ pred_obj_scores_mlp: true
111
+ fixed_no_obj_ptr: true
112
+ soft_no_obj_ptr: false
113
+ use_mlp_for_obj_ptr_proj: true
114
+ no_obj_embed_spatial: true
115
+
116
+ sam_mask_decoder_extra_args:
117
+ dynamic_multimask_via_stability: true
118
+ dynamic_multimask_stability_delta: 0.05
119
+ dynamic_multimask_stability_thresh: 0.98
120
+ pred_obj_scores: true
121
+ pred_obj_scores_mlp: true
122
+ use_multimask_token_for_obj_ptr: true
123
+
124
+ compile_image_encoder: False
sam2_wrapper.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM2 Wrapper for Video Mask Tracking
3
+ Handles mask generation and propagation through video
4
+ """
5
+
6
+ import sys
7
+ sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2")
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ from PIL import Image
13
+ from pathlib import Path
14
+ from typing import List, Tuple
15
+ import tempfile
16
+ import shutil
17
+
18
+ from sam2.build_sam import build_sam2_video_predictor
19
+
20
+
21
+ class SAM2VideoTracker:
22
+ def __init__(self, checkpoint_path, config_file, device="cuda"):
23
+ """
24
+ Initialize SAM2 video tracker
25
+
26
+ Args:
27
+ checkpoint_path: Path to SAM2 checkpoint
28
+ config_file: Path to SAM2 config file
29
+ device: Device to run on
30
+ """
31
+ self.device = device
32
+ self.predictor = build_sam2_video_predictor(
33
+ config_file=config_file,
34
+ ckpt_path=checkpoint_path,
35
+ device=device
36
+ )
37
+ print(f"SAM2 video tracker initialized on {device}")
38
+
39
+ def track_video(self, frames: List[np.ndarray], points: List[List[int]],
40
+ labels: List[int]) -> List[np.ndarray]:
41
+ """
42
+ Track object through video using SAM2
43
+
44
+ Args:
45
+ frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
46
+ points: List of [x, y] coordinates for prompts
47
+ labels: List of labels (1 for positive, 0 for negative)
48
+
49
+ Returns:
50
+ masks: List of numpy arrays, [(H,W)]*n, uint8 binary masks
51
+ """
52
+ # Create temporary directory for frames
53
+ temp_dir = Path(tempfile.mkdtemp())
54
+ frames_dir = temp_dir / "frames"
55
+ frames_dir.mkdir(exist_ok=True)
56
+
57
+ try:
58
+ # Save frames to temp directory
59
+ print(f"Saving {len(frames)} frames to temporary directory...")
60
+ for i, frame in enumerate(frames):
61
+ frame_path = frames_dir / f"{i:05d}.jpg"
62
+ Image.fromarray(frame).save(frame_path, quality=95)
63
+
64
+ # Initialize SAM2 video predictor
65
+ print("Initializing SAM2 inference state...")
66
+ inference_state = self.predictor.init_state(video_path=str(frames_dir))
67
+
68
+ # Add prompts on first frame
69
+ points_array = np.array(points, dtype=np.float32)
70
+ labels_array = np.array(labels, dtype=np.int32)
71
+
72
+ print(f"Adding {len(points)} point prompts on first frame...")
73
+ _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
74
+ inference_state=inference_state,
75
+ frame_idx=0,
76
+ obj_id=1,
77
+ points=points_array,
78
+ labels=labels_array,
79
+ )
80
+
81
+ # Propagate through video
82
+ print("Propagating masks through video...")
83
+ masks = []
84
+ for frame_idx, object_ids, mask_logits in self.predictor.propagate_in_video(inference_state):
85
+ # Get mask for object ID 1
86
+ # object_ids can be a tensor or a list
87
+ obj_ids_list = object_ids.tolist() if hasattr(object_ids, 'tolist') else object_ids
88
+
89
+ if 1 in obj_ids_list:
90
+ mask_idx = obj_ids_list.index(1)
91
+ mask = (mask_logits[mask_idx] > 0.0).cpu().numpy()
92
+ mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
93
+ masks.append(mask_uint8)
94
+ else:
95
+ # No mask for this frame, use empty mask
96
+ h, w = frames[0].shape[:2]
97
+ masks.append(np.zeros((h, w), dtype=np.uint8))
98
+
99
+ print(f"Generated {len(masks)} masks")
100
+ return masks
101
+
102
+ finally:
103
+ # Clean up temporary directory
104
+ shutil.rmtree(temp_dir, ignore_errors=True)
105
+
106
+ def get_first_frame_mask(self, frame: np.ndarray, points: List[List[int]],
107
+ labels: List[int]) -> np.ndarray:
108
+ """
109
+ Get mask for first frame only (for preview)
110
+
111
+ Args:
112
+ frame: np.ndarray, (H, W, 3), uint8 RGB frame
113
+ points: List of [x, y] coordinates
114
+ labels: List of labels (1 for positive, 0 for negative)
115
+
116
+ Returns:
117
+ mask: np.ndarray, (H, W), uint8 binary mask
118
+ """
119
+ # Create temporary directory
120
+ temp_dir = Path(tempfile.mkdtemp())
121
+ frames_dir = temp_dir / "frames"
122
+ frames_dir.mkdir(exist_ok=True)
123
+
124
+ try:
125
+ # Save single frame
126
+ frame_path = frames_dir / "00000.jpg"
127
+ Image.fromarray(frame).save(frame_path, quality=95)
128
+
129
+ # Initialize SAM2
130
+ inference_state = self.predictor.init_state(video_path=str(frames_dir))
131
+
132
+ # Add prompts
133
+ points_array = np.array(points, dtype=np.float32)
134
+ labels_array = np.array(labels, dtype=np.int32)
135
+
136
+ _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
137
+ inference_state=inference_state,
138
+ frame_idx=0,
139
+ obj_id=1,
140
+ points=points_array,
141
+ labels=labels_array,
142
+ )
143
+
144
+ # Get mask
145
+ if len(out_mask_logits) > 0:
146
+ mask = (out_mask_logits[0] > 0.0).cpu().numpy()
147
+ mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
148
+ return mask_uint8
149
+ else:
150
+ return np.zeros(frame.shape[:2], dtype=np.uint8)
151
+
152
+ finally:
153
+ shutil.rmtree(temp_dir, ignore_errors=True)
154
+
155
+
156
+ def load_sam2_tracker(device="cuda"):
157
+ """
158
+ Load SAM2 video tracker with pretrained weights
159
+
160
+ Args:
161
+ device: Device to run on
162
+
163
+ Returns:
164
+ SAM2VideoTracker instance
165
+ """
166
+ checkpoint_path = "/home/cvlab19/project/samuel/CVPR/sam2/checkpoints/sam2.1_hiera_large.pt"
167
+ config_file = "configs/sam2.1/sam2.1_hiera_l.yaml"
168
+
169
+ print(f"Loading SAM2 from {checkpoint_path}...")
170
+ tracker = SAM2VideoTracker(checkpoint_path, config_file, device)
171
+
172
+ return tracker
sam2_wrapper_hf.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM2 Wrapper for Video Mask Tracking - Hugging Face Space Version
3
+ Handles mask generation and propagation through video
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ from pathlib import Path
9
+
10
+ # Add SAM2 to path if installed
11
+ try:
12
+ import sam2
13
+ except ImportError:
14
+ # Try to add from common locations
15
+ possible_paths = [
16
+ "/home/cvlab19/project/samuel/CVPR/sam2",
17
+ "./sam2"
18
+ ]
19
+ for path in possible_paths:
20
+ if os.path.exists(path):
21
+ sys.path.append(path)
22
+ break
23
+
24
+ import cv2
25
+ import numpy as np
26
+ import torch
27
+ from PIL import Image
28
+ from typing import List, Tuple
29
+ import tempfile
30
+ import shutil
31
+
32
+ from sam2.build_sam import build_sam2_video_predictor
33
+
34
+
35
+ class SAM2VideoTracker:
36
+ def __init__(self, checkpoint_path, config_file, device="cuda"):
37
+ """
38
+ Initialize SAM2 video tracker
39
+
40
+ Args:
41
+ checkpoint_path: Path to SAM2 checkpoint
42
+ config_file: Path to SAM2 config file
43
+ device: Device to run on
44
+ """
45
+ self.device = device
46
+ self.predictor = build_sam2_video_predictor(
47
+ config_file=config_file,
48
+ ckpt_path=checkpoint_path,
49
+ device=device
50
+ )
51
+ print(f"SAM2 video tracker initialized on {device}")
52
+
53
+ def track_video(self, frames: List[np.ndarray], points: List[List[int]],
54
+ labels: List[int]) -> List[np.ndarray]:
55
+ """
56
+ Track object through video using SAM2
57
+
58
+ Args:
59
+ frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
60
+ points: List of [x, y] coordinates for prompts
61
+ labels: List of labels (1 for positive, 0 for negative)
62
+
63
+ Returns:
64
+ masks: List of numpy arrays, [(H,W)]*n, uint8 binary masks
65
+ """
66
+ # Create temporary directory for frames
67
+ temp_dir = Path(tempfile.mkdtemp())
68
+ frames_dir = temp_dir / "frames"
69
+ frames_dir.mkdir(exist_ok=True)
70
+
71
+ try:
72
+ # Save frames to temp directory
73
+ print(f"Saving {len(frames)} frames to temporary directory...")
74
+ for i, frame in enumerate(frames):
75
+ frame_path = frames_dir / f"{i:05d}.jpg"
76
+ Image.fromarray(frame).save(frame_path, quality=95)
77
+
78
+ # Initialize SAM2 video predictor
79
+ print("Initializing SAM2 inference state...")
80
+ inference_state = self.predictor.init_state(video_path=str(frames_dir))
81
+
82
+ # Add prompts on first frame
83
+ points_array = np.array(points, dtype=np.float32)
84
+ labels_array = np.array(labels, dtype=np.int32)
85
+
86
+ print(f"Adding {len(points)} point prompts on first frame...")
87
+ _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
88
+ inference_state=inference_state,
89
+ frame_idx=0,
90
+ obj_id=1,
91
+ points=points_array,
92
+ labels=labels_array,
93
+ )
94
+
95
+ # Propagate through video
96
+ print("Propagating masks through video...")
97
+ masks = []
98
+ for frame_idx, object_ids, mask_logits in self.predictor.propagate_in_video(inference_state):
99
+ # Get mask for object ID 1
100
+ obj_ids_list = object_ids.tolist() if hasattr(object_ids, 'tolist') else object_ids
101
+
102
+ if 1 in obj_ids_list:
103
+ mask_idx = obj_ids_list.index(1)
104
+ mask = (mask_logits[mask_idx] > 0.0).cpu().numpy()
105
+ mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
106
+ masks.append(mask_uint8)
107
+ else:
108
+ # No mask for this frame, use empty mask
109
+ h, w = frames[0].shape[:2]
110
+ masks.append(np.zeros((h, w), dtype=np.uint8))
111
+
112
+ print(f"Generated {len(masks)} masks")
113
+ return masks
114
+
115
+ finally:
116
+ # Clean up temporary directory
117
+ shutil.rmtree(temp_dir, ignore_errors=True)
118
+
119
+ def get_first_frame_mask(self, frame: np.ndarray, points: List[List[int]],
120
+ labels: List[int]) -> np.ndarray:
121
+ """
122
+ Get mask for first frame only (for preview)
123
+
124
+ Args:
125
+ frame: np.ndarray, (H, W, 3), uint8 RGB frame
126
+ points: List of [x, y] coordinates
127
+ labels: List of labels (1 for positive, 0 for negative)
128
+
129
+ Returns:
130
+ mask: np.ndarray, (H, W), uint8 binary mask
131
+ """
132
+ # Create temporary directory
133
+ temp_dir = Path(tempfile.mkdtemp())
134
+ frames_dir = temp_dir / "frames"
135
+ frames_dir.mkdir(exist_ok=True)
136
+
137
+ try:
138
+ # Save single frame
139
+ frame_path = frames_dir / "00000.jpg"
140
+ Image.fromarray(frame).save(frame_path, quality=95)
141
+
142
+ # Initialize SAM2
143
+ inference_state = self.predictor.init_state(video_path=str(frames_dir))
144
+
145
+ # Add prompts
146
+ points_array = np.array(points, dtype=np.float32)
147
+ labels_array = np.array(labels, dtype=np.int32)
148
+
149
+ _, out_obj_ids, out_mask_logits = self.predictor.add_new_points(
150
+ inference_state=inference_state,
151
+ frame_idx=0,
152
+ obj_id=1,
153
+ points=points_array,
154
+ labels=labels_array,
155
+ )
156
+
157
+ # Get mask
158
+ if len(out_mask_logits) > 0:
159
+ mask = (out_mask_logits[0] > 0.0).cpu().numpy()
160
+ mask_uint8 = (mask.squeeze() * 255).astype(np.uint8)
161
+ return mask_uint8
162
+ else:
163
+ return np.zeros(frame.shape[:2], dtype=np.uint8)
164
+
165
+ finally:
166
+ shutil.rmtree(temp_dir, ignore_errors=True)
167
+
168
+
169
+ def load_sam2_tracker(checkpoint_path=None, device="cuda"):
170
+ """
171
+ Load SAM2 video tracker with pretrained weights
172
+
173
+ Args:
174
+ checkpoint_path: Path to SAM2 checkpoint (if None, uses default location)
175
+ device: Device to run on
176
+
177
+ Returns:
178
+ SAM2VideoTracker instance
179
+ """
180
+ # Use provided path or default
181
+ if checkpoint_path is None:
182
+ checkpoint_path = "checkpoints/sam2.1_hiera_large.pt"
183
+
184
+ # Config file should be in the SAM2 repo
185
+ config_file = "configs/sam2.1/sam2.1_hiera_l.yaml"
186
+
187
+ # Check if we need to use the local yaml file
188
+ if not os.path.exists(config_file):
189
+ config_file = "sam2_hiera_l.yaml"
190
+
191
+ print(f"Loading SAM2 from {checkpoint_path}...")
192
+ print(f"Using config: {config_file}")
193
+
194
+ tracker = SAM2VideoTracker(checkpoint_path, config_file, device)
195
+
196
+ return tracker
tools/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Tools module
tools/base_segmenter.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM2 Base Segmenter
3
+ Adapted from MatAnyone demo
4
+ """
5
+
6
+ import sys
7
+ sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2")
8
+
9
+ import torch
10
+ import numpy as np
11
+ from sam2.build_sam import build_sam2_video_predictor
12
+
13
+
14
+ class BaseSegmenter:
15
+ def __init__(self, SAM_checkpoint, model_type, device):
16
+ """
17
+ Initialize SAM2 segmenter
18
+
19
+ Args:
20
+ SAM_checkpoint: Path to SAM2 checkpoint
21
+ model_type: SAM2 model config file
22
+ device: Device to run on
23
+ """
24
+ self.device = device
25
+ self.model_type = model_type
26
+
27
+ # Build SAM2 video predictor
28
+ self.sam_predictor = build_sam2_video_predictor(
29
+ config_file=model_type,
30
+ ckpt_path=SAM_checkpoint,
31
+ device=device
32
+ )
33
+
34
+ self.orignal_image = None
35
+ self.inference_state = None
36
+
37
+ def set_image(self, image: np.ndarray):
38
+ """Set the current image for segmentation"""
39
+ self.orignal_image = image
40
+
41
+ def reset_image(self):
42
+ """Reset the current image"""
43
+ self.orignal_image = None
44
+ self.inference_state = None
45
+
46
+ def predict(self, prompts, prompt_type, multimask=True):
47
+ """
48
+ Predict mask from prompts
49
+
50
+ Args:
51
+ prompts: Dictionary with point_coords, point_labels, mask_input
52
+ prompt_type: 'point' or 'both'
53
+ multimask: Whether to return multiple masks
54
+
55
+ Returns:
56
+ masks, scores, logits
57
+ """
58
+ # For SAM2, we need to handle prompts differently
59
+ # This is simplified - actual implementation will use video predictor
60
+
61
+ # Placeholder - actual SAM2 prediction would go here
62
+ # For now, return dummy values
63
+ h, w = self.orignal_image.shape[:2]
64
+ dummy_mask = np.zeros((h, w), dtype=bool)
65
+ dummy_score = np.array([1.0])
66
+ dummy_logit = np.zeros((h, w), dtype=np.float32)
67
+
68
+ return np.array([dummy_mask]), dummy_score, np.array([dummy_logit])
tools/interact_tools.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM2 Interaction Tools
3
+ Handles SAM2 mask generation with user clicks
4
+ """
5
+
6
+ import sys
7
+ sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2")
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ from .base_segmenter import BaseSegmenter
12
+ from .painter import mask_painter, point_painter
13
+
14
+
15
+ mask_color = 3
16
+ mask_alpha = 0.7
17
+ contour_color = 1
18
+ contour_width = 5
19
+ point_color_ne = 8 # positive points
20
+ point_color_ps = 50 # negative points
21
+ point_alpha = 0.9
22
+ point_radius = 15
23
+
24
+
25
+ class SamControler:
26
+ def __init__(self, SAM_checkpoint, model_type, device):
27
+ """
28
+ Initialize SAM controller
29
+
30
+ Args:
31
+ SAM_checkpoint: Path to SAM2 checkpoint
32
+ model_type: SAM2 model config file
33
+ device: Device to run on
34
+ """
35
+ self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
36
+ self.device = device
37
+
38
+ def first_frame_click(self, image: np.ndarray, points: np.ndarray,
39
+ labels: np.ndarray, multimask=True, mask_color=3):
40
+ """
41
+ Generate mask from clicks on first frame
42
+
43
+ Args:
44
+ image: np.ndarray, (H, W, 3), RGB image
45
+ points: np.ndarray, (N, 2), [x, y] coordinates
46
+ labels: np.ndarray, (N,), 1 for positive, 0 for negative
47
+ multimask: bool, whether to generate multiple masks
48
+ mask_color: int, color ID for mask overlay
49
+
50
+ Returns:
51
+ mask: np.ndarray, (H, W), binary mask
52
+ logit: np.ndarray, (H, W), mask logits
53
+ painted_image: PIL.Image, visualization with mask and points
54
+ """
55
+ # Check if we have positive clicks
56
+ neg_flag = labels[-1]
57
+
58
+ if neg_flag == 1: # Has positive click
59
+ # First pass with points only
60
+ prompts = {
61
+ 'point_coords': points,
62
+ 'point_labels': labels,
63
+ }
64
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
65
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
66
+
67
+ # Refine with mask input
68
+ prompts = {
69
+ 'point_coords': points,
70
+ 'point_labels': labels,
71
+ 'mask_input': logit[None, :, :]
72
+ }
73
+ masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
74
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
75
+ else: # Only positive clicks
76
+ prompts = {
77
+ 'point_coords': points,
78
+ 'point_labels': labels,
79
+ }
80
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
81
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
82
+
83
+ # Paint mask on image
84
+ painted_image = mask_painter(
85
+ image,
86
+ mask.astype('uint8'),
87
+ mask_color,
88
+ mask_alpha,
89
+ contour_color,
90
+ contour_width
91
+ )
92
+
93
+ # Paint positive points (label > 0)
94
+ positive_points = np.squeeze(points[np.argwhere(labels > 0)], axis=1)
95
+ if len(positive_points) > 0:
96
+ painted_image = point_painter(
97
+ painted_image,
98
+ positive_points,
99
+ point_color_ne,
100
+ point_alpha,
101
+ point_radius,
102
+ contour_color,
103
+ contour_width
104
+ )
105
+
106
+ # Paint negative points (label < 1)
107
+ negative_points = np.squeeze(points[np.argwhere(labels < 1)], axis=1)
108
+ if len(negative_points) > 0:
109
+ painted_image = point_painter(
110
+ painted_image,
111
+ negative_points,
112
+ point_color_ps,
113
+ point_alpha,
114
+ point_radius,
115
+ contour_color,
116
+ contour_width
117
+ )
118
+
119
+ painted_image = Image.fromarray(painted_image)
120
+
121
+ return mask, logit, painted_image
tools/painter.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mask and point painting utilities
3
+ Adapted from MatAnyone demo
4
+ """
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+
11
+ def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7,
12
+ contour_color=1, contour_width=5):
13
+ """
14
+ Paint mask on image with transparency
15
+
16
+ Args:
17
+ input_image: np.ndarray, (H, W, 3)
18
+ input_mask: np.ndarray, (H, W), binary mask
19
+ mask_color: int, color ID for mask
20
+ mask_alpha: float, transparency
21
+ contour_color: int, color ID for contour
22
+ contour_width: int, width of contour
23
+
24
+ Returns:
25
+ painted_image: np.ndarray, (H, W, 3)
26
+ """
27
+ assert input_image.shape[:2] == input_mask.shape, "Image and mask must have same dimensions"
28
+
29
+ # Color palette
30
+ palette = np.array([
31
+ [0, 0, 0], # 0: black
32
+ [255, 0, 0], # 1: red
33
+ [0, 255, 0], # 2: green
34
+ [0, 0, 255], # 3: blue
35
+ [255, 255, 0], # 4: yellow
36
+ [255, 0, 255], # 5: magenta
37
+ [0, 255, 255], # 6: cyan
38
+ [128, 128, 128], # 7: gray
39
+ [255, 165, 0], # 8: orange
40
+ [128, 0, 128], # 9: purple
41
+ ])
42
+
43
+ mask_color_rgb = palette[mask_color % len(palette)]
44
+ contour_color_rgb = palette[contour_color % len(palette)]
45
+
46
+ # Create colored mask
47
+ painted_image = input_image.copy()
48
+ colored_mask = np.zeros_like(input_image)
49
+ colored_mask[input_mask > 0] = mask_color_rgb
50
+
51
+ # Blend with alpha
52
+ mask_region = input_mask > 0
53
+ painted_image[mask_region] = (
54
+ painted_image[mask_region] * (1 - mask_alpha) +
55
+ colored_mask[mask_region] * mask_alpha
56
+ ).astype(np.uint8)
57
+
58
+ # Draw contour
59
+ if contour_width > 0:
60
+ contours, _ = cv2.findContours(
61
+ input_mask.astype(np.uint8),
62
+ cv2.RETR_EXTERNAL,
63
+ cv2.CHAIN_APPROX_SIMPLE
64
+ )
65
+ cv2.drawContours(
66
+ painted_image,
67
+ contours,
68
+ -1,
69
+ contour_color_rgb.tolist(),
70
+ contour_width
71
+ )
72
+
73
+ return painted_image
74
+
75
+
76
+ def point_painter(input_image, input_points, point_color=8, point_alpha=0.9,
77
+ point_radius=15, contour_color=2, contour_width=3):
78
+ """
79
+ Paint points on image
80
+
81
+ Args:
82
+ input_image: np.ndarray, (H, W, 3)
83
+ input_points: np.ndarray, (N, 2), [x, y] coordinates
84
+ point_color: int, color ID for points
85
+ point_alpha: float, transparency
86
+ point_radius: int, radius of point circles
87
+ contour_color: int, color ID for contour
88
+ contour_width: int, width of contour
89
+
90
+ Returns:
91
+ painted_image: np.ndarray, (H, W, 3)
92
+ """
93
+ if len(input_points) == 0:
94
+ return input_image
95
+
96
+ palette = np.array([
97
+ [0, 0, 0], # 0: black
98
+ [255, 0, 0], # 1: red
99
+ [0, 255, 0], # 2: green
100
+ [0, 0, 255], # 3: blue
101
+ [255, 255, 0], # 4: yellow
102
+ [255, 0, 255], # 5: magenta
103
+ [0, 255, 255], # 6: cyan
104
+ [128, 128, 128], # 7: gray
105
+ [255, 165, 0], # 8: orange
106
+ [128, 0, 128], # 9: purple
107
+ ])
108
+
109
+ point_color_rgb = palette[point_color % len(palette)]
110
+ contour_color_rgb = palette[contour_color % len(palette)]
111
+
112
+ painted_image = input_image.copy()
113
+
114
+ for point in input_points:
115
+ x, y = int(point[0]), int(point[1])
116
+
117
+ # Draw filled circle with alpha blending
118
+ overlay = painted_image.copy()
119
+ cv2.circle(overlay, (x, y), point_radius, point_color_rgb.tolist(), -1)
120
+ cv2.addWeighted(overlay, point_alpha, painted_image, 1 - point_alpha, 0, painted_image)
121
+
122
+ # Draw contour
123
+ if contour_width > 0:
124
+ cv2.circle(painted_image, (x, y), point_radius, contour_color_rgb.tolist(), contour_width)
125
+
126
+ return painted_image
videomama_wrapper.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VideoMaMa Inference Wrapper
3
+ Handles video matting with mask conditioning
4
+ """
5
+
6
+ import sys
7
+ sys.path.append("../")
8
+ sys.path.append("../../")
9
+
10
+ import torch
11
+ import numpy as np
12
+ from PIL import Image
13
+ from pathlib import Path
14
+ from typing import List
15
+ import tqdm
16
+
17
+ from pipeline_svd_mask import VideoInferencePipeline
18
+
19
+
20
+ def videomama(pipeline, frames_np, mask_frames_np):
21
+ """
22
+ Run VideoMaMa inference on video frames with mask conditioning
23
+
24
+ Args:
25
+ pipeline: VideoInferencePipeline instance
26
+ frames_np: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
27
+ mask_frames_np: List of numpy arrays, [(H,W)]*n, uint8 grayscale masks
28
+
29
+ Returns:
30
+ output_frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB outputs
31
+ """
32
+ # Convert numpy arrays to PIL Images
33
+ frames_pil = [Image.fromarray(f) for f in frames_np]
34
+ mask_frames_pil = [Image.fromarray(m, mode='L') for m in mask_frames_np]
35
+
36
+ # Resize to model input size
37
+ target_width, target_height = 1024, 576
38
+ frames_resized = [f.resize((target_width, target_height), Image.Resampling.BILINEAR)
39
+ for f in frames_pil]
40
+ masks_resized = [m.resize((target_width, target_height), Image.Resampling.BILINEAR)
41
+ for m in mask_frames_pil]
42
+
43
+ # Run inference
44
+ print(f"Running VideoMaMa inference on {len(frames_resized)} frames...")
45
+ output_frames_pil = pipeline.run(
46
+ cond_frames=frames_resized,
47
+ mask_frames=masks_resized,
48
+ seed=42,
49
+ mask_cond_mode="vae"
50
+ )
51
+
52
+ # Resize back to original resolution
53
+ original_size = frames_pil[0].size
54
+ output_frames_resized = [f.resize(original_size, Image.Resampling.BILINEAR)
55
+ for f in output_frames_pil]
56
+
57
+ # Convert back to numpy arrays
58
+ output_frames_np = [np.array(f) for f in output_frames_resized]
59
+
60
+ return output_frames_np
61
+
62
+
63
+ def load_videomama_pipeline(device="cuda"):
64
+ """
65
+ Load VideoMaMa pipeline with pretrained weights
66
+
67
+ Args:
68
+ device: Device to run on
69
+
70
+ Returns:
71
+ VideoInferencePipeline instance
72
+ """
73
+ # Local paths for testing
74
+ base_model_path = "/home/cvlab19/project/samuel/data/CVPR/pretrained_models/stable-video-diffusion-img2vid-xt"
75
+ unet_checkpoint_path = "/home/cvlab19/project/samuel/data/CVPR/pretrained_models/videomama"
76
+
77
+ print(f"Loading VideoMaMa pipeline from {unet_checkpoint_path}...")
78
+
79
+ pipeline = VideoInferencePipeline(
80
+ base_model_path=base_model_path,
81
+ unet_checkpoint_path=unet_checkpoint_path,
82
+ weight_dtype=torch.float16,
83
+ device=device
84
+ )
85
+
86
+ print("VideoMaMa pipeline loaded successfully!")
87
+
88
+ return pipeline
videomama_wrapper_hf.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VideoMaMa Inference Wrapper - Hugging Face Space Version
3
+ Handles video matting with mask conditioning
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ from pathlib import Path
9
+
10
+ # Add parent directories to path for imports
11
+ sys.path.append(str(Path(__file__).parent))
12
+ sys.path.append(str(Path(__file__).parent.parent))
13
+
14
+ import torch
15
+ import numpy as np
16
+ from PIL import Image
17
+ from typing import List
18
+
19
+ from pipeline_svd_mask import VideoInferencePipeline
20
+
21
+
22
+ def videomama(pipeline, frames_np, mask_frames_np):
23
+ """
24
+ Run VideoMaMa inference on video frames with mask conditioning
25
+
26
+ Args:
27
+ pipeline: VideoInferencePipeline instance
28
+ frames_np: List of numpy arrays, [(H,W,3)]*n, uint8 RGB frames
29
+ mask_frames_np: List of numpy arrays, [(H,W)]*n, uint8 grayscale masks
30
+
31
+ Returns:
32
+ output_frames: List of numpy arrays, [(H,W,3)]*n, uint8 RGB outputs
33
+ """
34
+ # Convert numpy arrays to PIL Images
35
+ frames_pil = [Image.fromarray(f) for f in frames_np]
36
+ mask_frames_pil = [Image.fromarray(m, mode='L') for m in mask_frames_np]
37
+
38
+ # Resize to model input size
39
+ target_width, target_height = 1024, 576
40
+ frames_resized = [f.resize((target_width, target_height), Image.Resampling.BILINEAR)
41
+ for f in frames_pil]
42
+ masks_resized = [m.resize((target_width, target_height), Image.Resampling.BILINEAR)
43
+ for m in mask_frames_pil]
44
+
45
+ # Run inference
46
+ print(f"Running VideoMaMa inference on {len(frames_resized)} frames...")
47
+ output_frames_pil = pipeline.run(
48
+ cond_frames=frames_resized,
49
+ mask_frames=masks_resized,
50
+ seed=42,
51
+ mask_cond_mode="vae"
52
+ )
53
+
54
+ # Resize back to original resolution
55
+ original_size = frames_pil[0].size
56
+ output_frames_resized = [f.resize(original_size, Image.Resampling.BILINEAR)
57
+ for f in output_frames_pil]
58
+
59
+ # Convert back to numpy arrays
60
+ output_frames_np = [np.array(f) for f in output_frames_resized]
61
+
62
+ return output_frames_np
63
+
64
+
65
+ def load_videomama_pipeline(base_model_path=None, unet_checkpoint_path=None, device="cuda"):
66
+ """
67
+ Load VideoMaMa pipeline with pretrained weights
68
+
69
+ Args:
70
+ base_model_path: Path to SVD base model (if None, uses default)
71
+ unet_checkpoint_path: Path to VideoMaMa UNet checkpoint (if None, uses default)
72
+ device: Device to run on
73
+
74
+ Returns:
75
+ VideoInferencePipeline instance
76
+ """
77
+ # Use provided paths or defaults
78
+ if base_model_path is None:
79
+ base_model_path = "checkpoints/stable-video-diffusion-img2vid-xt"
80
+
81
+ if unet_checkpoint_path is None:
82
+ unet_checkpoint_path = "checkpoints/videomama"
83
+
84
+ # Check if paths exist
85
+ if not os.path.exists(base_model_path):
86
+ raise FileNotFoundError(
87
+ f"SVD base model not found at {base_model_path}. "
88
+ f"Please ensure models are downloaded correctly."
89
+ )
90
+
91
+ if not os.path.exists(unet_checkpoint_path):
92
+ raise FileNotFoundError(
93
+ f"VideoMaMa checkpoint not found at {unet_checkpoint_path}. "
94
+ f"Please upload your VideoMaMa model to Hugging Face Hub and update the download logic."
95
+ )
96
+
97
+ print(f"Loading VideoMaMa pipeline...")
98
+ print(f" Base model: {base_model_path}")
99
+ print(f" UNet checkpoint: {unet_checkpoint_path}")
100
+
101
+ pipeline = VideoInferencePipeline(
102
+ base_model_path=base_model_path,
103
+ unet_checkpoint_path=unet_checkpoint_path,
104
+ weight_dtype=torch.float16,
105
+ device=device
106
+ )
107
+
108
+ print("VideoMaMa pipeline loaded successfully!")
109
+
110
+ return pipeline