update
Browse files
app.py
CHANGED
|
@@ -123,32 +123,37 @@ def get_prompt(click_state, click_input):
|
|
| 123 |
return click_state
|
| 124 |
|
| 125 |
|
| 126 |
-
def load_video(video_input):
|
| 127 |
"""
|
| 128 |
-
Load video and extract first frame for mask generation
|
| 129 |
"""
|
| 130 |
if video_input is None:
|
| 131 |
-
return
|
| 132 |
gr.update(visible=False), gr.update(visible=False), \
|
| 133 |
gr.update(visible=False), gr.update(visible=False)
|
| 134 |
|
| 135 |
-
# Extract
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
if
|
| 139 |
-
return
|
| 140 |
gr.update(visible=False), gr.update(visible=False), \
|
| 141 |
gr.update(visible=False), gr.update(visible=False)
|
|
|
|
|
|
|
| 142 |
|
| 143 |
-
# Initialize video state
|
| 144 |
video_state = {
|
| 145 |
-
"
|
| 146 |
-
"
|
| 147 |
"first_frame_mask": None,
|
| 148 |
"masks": None,
|
| 149 |
}
|
| 150 |
|
| 151 |
-
first_frame_pil = Image.fromarray(
|
| 152 |
|
| 153 |
return video_state, first_frame_pil, \
|
| 154 |
gr.update(visible=True), gr.update(visible=True), \
|
|
@@ -156,35 +161,6 @@ def load_video(video_input):
|
|
| 156 |
|
| 157 |
|
| 158 |
@spaces.GPU
|
| 159 |
-
def sam_refine_gpu(first_frame_list, points, labels):
|
| 160 |
-
"""
|
| 161 |
-
GPU function: Generate mask with SAM2
|
| 162 |
-
|
| 163 |
-
Args:
|
| 164 |
-
first_frame_list: First frame as list
|
| 165 |
-
points: List of [x, y] coordinates
|
| 166 |
-
labels: List of labels (1=positive, 0=negative)
|
| 167 |
-
|
| 168 |
-
Returns:
|
| 169 |
-
mask as list
|
| 170 |
-
"""
|
| 171 |
-
# Lazy load models on first use
|
| 172 |
-
initialize_models()
|
| 173 |
-
|
| 174 |
-
# Convert to numpy
|
| 175 |
-
first_frame = np.array(first_frame_list, dtype=np.uint8)
|
| 176 |
-
|
| 177 |
-
# Generate mask with SAM2
|
| 178 |
-
mask = sam2_tracker.get_first_frame_mask(
|
| 179 |
-
frame=first_frame,
|
| 180 |
-
points=points,
|
| 181 |
-
labels=labels
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
# Return as list for pickling
|
| 185 |
-
return mask.tolist() if hasattr(mask, 'tolist') else mask
|
| 186 |
-
|
| 187 |
-
|
| 188 |
def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
|
| 189 |
"""
|
| 190 |
Add click and update mask on first frame
|
|
@@ -195,7 +171,10 @@ def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
|
|
| 195 |
click_state: [[points], [labels]]
|
| 196 |
evt: Gradio SelectData event with click coordinates
|
| 197 |
"""
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
| 199 |
return None, video_state, click_state
|
| 200 |
|
| 201 |
# Add new click
|
|
@@ -207,20 +186,18 @@ def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
|
|
| 207 |
|
| 208 |
print(f"Added {point_prompt} click at ({x}, {y}). Total clicks: {len(click_state[0])}")
|
| 209 |
|
| 210 |
-
#
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
click_state[
|
|
|
|
| 215 |
)
|
| 216 |
|
| 217 |
-
# Store mask
|
| 218 |
-
video_state["first_frame_mask"] =
|
| 219 |
|
| 220 |
# Visualize mask and points
|
| 221 |
-
first_frame = np.array(video_state["frames"][0], dtype=np.uint8)
|
| 222 |
-
mask = np.array(mask_list, dtype=np.uint8)
|
| 223 |
-
|
| 224 |
painted_image = mask_painter(
|
| 225 |
first_frame.copy(),
|
| 226 |
mask,
|
|
@@ -268,7 +245,7 @@ def clear_clicks(video_state, click_state):
|
|
| 268 |
click_state = [[], []]
|
| 269 |
|
| 270 |
if video_state is not None and "frames" in video_state:
|
| 271 |
-
first_frame =
|
| 272 |
video_state["first_frame_mask"] = None
|
| 273 |
return Image.fromarray(first_frame), video_state, click_state
|
| 274 |
|
|
@@ -285,8 +262,7 @@ def propagate_masks(video_state, click_state):
|
|
| 285 |
if len(click_state[0]) == 0:
|
| 286 |
return video_state, "⚠️ Please add at least one point first", gr.update(visible=False)
|
| 287 |
|
| 288 |
-
|
| 289 |
-
frames = [np.array(f, dtype=np.uint8) for f in video_state["frames"]]
|
| 290 |
|
| 291 |
# Track through video
|
| 292 |
print(f"Tracking object through {len(frames)} frames...")
|
|
@@ -296,8 +272,7 @@ def propagate_masks(video_state, click_state):
|
|
| 296 |
labels=click_state[1]
|
| 297 |
)
|
| 298 |
|
| 299 |
-
|
| 300 |
-
video_state["masks"] = [m.tolist() if hasattr(m, 'tolist') else m for m in masks]
|
| 301 |
|
| 302 |
status_msg = f"✓ Generated {len(masks)} masks. Ready to run VideoMaMa!"
|
| 303 |
|
|
@@ -305,88 +280,38 @@ def propagate_masks(video_state, click_state):
|
|
| 305 |
|
| 306 |
|
| 307 |
@spaces.GPU(duration=120)
|
| 308 |
-
def
|
| 309 |
"""
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
Args:
|
| 313 |
-
frames_list: List of frames as lists
|
| 314 |
-
points: List of [x, y] coordinates
|
| 315 |
-
labels: List of labels (1=positive, 0=negative)
|
| 316 |
-
|
| 317 |
-
Returns:
|
| 318 |
-
Tuple of (masks_list, output_frames_list, greenscreen_frames_list)
|
| 319 |
"""
|
| 320 |
# Lazy load models on first use
|
| 321 |
initialize_models()
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
-
#
|
| 327 |
-
|
| 328 |
masks = sam2_tracker.track_video(
|
| 329 |
frames=frames,
|
| 330 |
-
points=
|
| 331 |
-
labels=
|
| 332 |
)
|
|
|
|
|
|
|
| 333 |
print(f"✓ Generated {len(masks)} masks")
|
| 334 |
|
| 335 |
# Step 2: Run VideoMaMa
|
| 336 |
print(f"🎨 Running VideoMaMa on {len(frames)} frames...")
|
| 337 |
output_frames = videomama(videomama_pipeline, frames, masks)
|
| 338 |
|
| 339 |
-
# Create greenscreen composite
|
| 340 |
-
greenscreen_frames = []
|
| 341 |
-
for orig_frame, output_frame in zip(frames, output_frames):
|
| 342 |
-
# Extract alpha matte from VideoMaMa output
|
| 343 |
-
gray = cv2.cvtColor(output_frame, cv2.COLOR_RGB2GRAY)
|
| 344 |
-
alpha = np.clip(gray.astype(np.float32) / 255.0, 0, 1)
|
| 345 |
-
alpha_3ch = np.stack([alpha, alpha, alpha], axis=-1)
|
| 346 |
-
|
| 347 |
-
# Create green background
|
| 348 |
-
green_bg = np.zeros_like(orig_frame)
|
| 349 |
-
green_bg[:, :] = [156, 251, 165] # Green screen color
|
| 350 |
-
|
| 351 |
-
# Composite: original_RGB * alpha + green * (1 - alpha)
|
| 352 |
-
composite = (orig_frame.astype(np.float32) * alpha_3ch +
|
| 353 |
-
green_bg.astype(np.float32) * (1 - alpha_3ch)).astype(np.uint8)
|
| 354 |
-
greenscreen_frames.append(composite)
|
| 355 |
-
|
| 356 |
-
# Convert to lists for pickling
|
| 357 |
-
masks_list = [m.tolist() if hasattr(m, 'tolist') else m for m in masks]
|
| 358 |
-
output_frames_list = [f.tolist() for f in output_frames]
|
| 359 |
-
greenscreen_frames_list = [f.tolist() for f in greenscreen_frames]
|
| 360 |
-
|
| 361 |
-
return masks_list, output_frames_list, greenscreen_frames_list
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
def run_videomama_with_sam2(video_state, click_state):
|
| 365 |
-
"""
|
| 366 |
-
Run SAM2 propagation and VideoMaMa inference together
|
| 367 |
-
"""
|
| 368 |
-
if video_state is None or "frames" not in video_state:
|
| 369 |
-
return video_state, None, None, None, "⚠️ No video loaded"
|
| 370 |
-
|
| 371 |
-
if len(click_state[0]) == 0:
|
| 372 |
-
return video_state, None, None, None, "⚠️ Please add at least one point first"
|
| 373 |
-
|
| 374 |
-
# Call GPU function with plain data (no Gradio State objects)
|
| 375 |
-
masks_list, output_frames_list, greenscreen_frames_list = run_videomama_with_sam2_gpu(
|
| 376 |
-
video_state["frames"],
|
| 377 |
-
click_state[0],
|
| 378 |
-
click_state[1]
|
| 379 |
-
)
|
| 380 |
-
|
| 381 |
-
# Store masks
|
| 382 |
-
video_state["masks"] = masks_list
|
| 383 |
-
|
| 384 |
-
# Convert back to numpy for video saving
|
| 385 |
-
frames = [np.array(f, dtype=np.uint8) for f in video_state["frames"]]
|
| 386 |
-
masks = [np.array(m, dtype=np.uint8) for m in masks_list]
|
| 387 |
-
output_frames = [np.array(f, dtype=np.uint8) for f in output_frames_list]
|
| 388 |
-
greenscreen_frames = [np.array(f, dtype=np.uint8) for f in greenscreen_frames_list]
|
| 389 |
-
|
| 390 |
# Save output videos
|
| 391 |
output_dir = Path("outputs")
|
| 392 |
output_dir.mkdir(exist_ok=True)
|
|
@@ -403,7 +328,25 @@ def run_videomama_with_sam2(video_state, click_state):
|
|
| 403 |
mask_frames_rgb = [np.stack([m, m, m], axis=-1) for m in masks]
|
| 404 |
save_video(mask_frames_rgb, mask_video_path, video_state["fps"])
|
| 405 |
|
| 406 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
save_video(greenscreen_frames, greenscreen_path, video_state["fps"])
|
| 408 |
|
| 409 |
status_msg = f"✓ Complete! Generated {len(output_frames)} frames."
|
|
@@ -515,7 +458,7 @@ with gr.Blocks(title="VideoMaMa Demo") as demo:
|
|
| 515 |
# Event handlers
|
| 516 |
load_button.click(
|
| 517 |
fn=load_video,
|
| 518 |
-
inputs=[video_input],
|
| 519 |
outputs=[video_state, first_frame_display,
|
| 520 |
point_prompt, clear_button, run_button, status_text]
|
| 521 |
)
|
|
|
|
| 123 |
return click_state
|
| 124 |
|
| 125 |
|
| 126 |
+
def load_video(video_input, video_state):
|
| 127 |
"""
|
| 128 |
+
Load video, store path, and extract first frame for mask generation
|
| 129 |
"""
|
| 130 |
if video_input is None:
|
| 131 |
+
return video_state, None, \
|
| 132 |
gr.update(visible=False), gr.update(visible=False), \
|
| 133 |
gr.update(visible=False), gr.update(visible=False)
|
| 134 |
|
| 135 |
+
# Extract ONLY the first frame for the UI to save memory/bandwidth
|
| 136 |
+
# We will load the full video inside the GPU function later
|
| 137 |
+
cap = cv2.VideoCapture(video_input)
|
| 138 |
+
ret, first_frame = cap.read()
|
| 139 |
+
cap.release()
|
| 140 |
|
| 141 |
+
if not ret:
|
| 142 |
+
return video_state, None, \
|
| 143 |
gr.update(visible=False), gr.update(visible=False), \
|
| 144 |
gr.update(visible=False), gr.update(visible=False)
|
| 145 |
+
|
| 146 |
+
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
| 147 |
|
| 148 |
+
# Initialize video state with PATH, not full frames
|
| 149 |
video_state = {
|
| 150 |
+
"video_path": video_input, # <--- Store Path
|
| 151 |
+
"first_frame": first_frame_rgb, # <--- Store only one frame
|
| 152 |
"first_frame_mask": None,
|
| 153 |
"masks": None,
|
| 154 |
}
|
| 155 |
|
| 156 |
+
first_frame_pil = Image.fromarray(first_frame_rgb)
|
| 157 |
|
| 158 |
return video_state, first_frame_pil, \
|
| 159 |
gr.update(visible=True), gr.update(visible=True), \
|
|
|
|
| 161 |
|
| 162 |
|
| 163 |
@spaces.GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
|
| 165 |
"""
|
| 166 |
Add click and update mask on first frame
|
|
|
|
| 171 |
click_state: [[points], [labels]]
|
| 172 |
evt: Gradio SelectData event with click coordinates
|
| 173 |
"""
|
| 174 |
+
# Lazy load models on first use
|
| 175 |
+
initialize_models()
|
| 176 |
+
|
| 177 |
+
if video_state is None or "first_frame" not in video_state: # Check for first_frame
|
| 178 |
return None, video_state, click_state
|
| 179 |
|
| 180 |
# Add new click
|
|
|
|
| 186 |
|
| 187 |
print(f"Added {point_prompt} click at ({x}, {y}). Total clicks: {len(click_state[0])}")
|
| 188 |
|
| 189 |
+
# Generate mask with SAM2
|
| 190 |
+
first_frame = video_state["first_frame"]
|
| 191 |
+
mask = sam2_tracker.get_first_frame_mask(
|
| 192 |
+
frame=first_frame,
|
| 193 |
+
points=click_state[0],
|
| 194 |
+
labels=click_state[1]
|
| 195 |
)
|
| 196 |
|
| 197 |
+
# Store mask in video state
|
| 198 |
+
video_state["first_frame_mask"] = mask
|
| 199 |
|
| 200 |
# Visualize mask and points
|
|
|
|
|
|
|
|
|
|
| 201 |
painted_image = mask_painter(
|
| 202 |
first_frame.copy(),
|
| 203 |
mask,
|
|
|
|
| 245 |
click_state = [[], []]
|
| 246 |
|
| 247 |
if video_state is not None and "frames" in video_state:
|
| 248 |
+
first_frame = video_state["frames"][0]
|
| 249 |
video_state["first_frame_mask"] = None
|
| 250 |
return Image.fromarray(first_frame), video_state, click_state
|
| 251 |
|
|
|
|
| 262 |
if len(click_state[0]) == 0:
|
| 263 |
return video_state, "⚠️ Please add at least one point first", gr.update(visible=False)
|
| 264 |
|
| 265 |
+
frames = video_state["frames"]
|
|
|
|
| 266 |
|
| 267 |
# Track through video
|
| 268 |
print(f"Tracking object through {len(frames)} frames...")
|
|
|
|
| 272 |
labels=click_state[1]
|
| 273 |
)
|
| 274 |
|
| 275 |
+
video_state["masks"] = masks
|
|
|
|
| 276 |
|
| 277 |
status_msg = f"✓ Generated {len(masks)} masks. Ready to run VideoMaMa!"
|
| 278 |
|
|
|
|
| 280 |
|
| 281 |
|
| 282 |
@spaces.GPU(duration=120)
|
| 283 |
+
def run_videomama_with_sam2(video_state, click_state):
|
| 284 |
"""
|
| 285 |
+
Run SAM2 propagation and VideoMaMa inference together
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
"""
|
| 287 |
# Lazy load models on first use
|
| 288 |
initialize_models()
|
| 289 |
|
| 290 |
+
if video_state is None or "video_path" not in video_state:
|
| 291 |
+
return video_state, None, None, None, "⚠️ No video loaded"
|
| 292 |
+
|
| 293 |
+
if len(click_state[0]) == 0:
|
| 294 |
+
return video_state, None, None, None, "⚠️ Please add at least one point first"
|
| 295 |
+
|
| 296 |
+
# RELOAD FRAMES HERE inside the GPU worker
|
| 297 |
+
print(f"Loading frames from {video_state['video_path']}...")
|
| 298 |
+
frames, fps = extract_frames_from_video(video_state["video_path"], max_frames=50)
|
| 299 |
|
| 300 |
+
# Update state with FPS just in case (though we likely don't need to return it)
|
| 301 |
+
video_state["fps"] = fps
|
| 302 |
masks = sam2_tracker.track_video(
|
| 303 |
frames=frames,
|
| 304 |
+
points=click_state[0],
|
| 305 |
+
labels=click_state[1]
|
| 306 |
)
|
| 307 |
+
|
| 308 |
+
video_state["masks"] = masks
|
| 309 |
print(f"✓ Generated {len(masks)} masks")
|
| 310 |
|
| 311 |
# Step 2: Run VideoMaMa
|
| 312 |
print(f"🎨 Running VideoMaMa on {len(frames)} frames...")
|
| 313 |
output_frames = videomama(videomama_pipeline, frames, masks)
|
| 314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
# Save output videos
|
| 316 |
output_dir = Path("outputs")
|
| 317 |
output_dir.mkdir(exist_ok=True)
|
|
|
|
| 328 |
mask_frames_rgb = [np.stack([m, m, m], axis=-1) for m in masks]
|
| 329 |
save_video(mask_frames_rgb, mask_video_path, video_state["fps"])
|
| 330 |
|
| 331 |
+
# Create greenscreen composite: RGB * VideoMaMa_alpha + green * (1 - VideoMaMa_alpha)
|
| 332 |
+
# VideoMaMa output_frames already contain the alpha matte result
|
| 333 |
+
greenscreen_frames = []
|
| 334 |
+
for orig_frame, output_frame in zip(frames, output_frames):
|
| 335 |
+
# Extract alpha matte from VideoMaMa output
|
| 336 |
+
# VideoMaMa outputs matted foreground, we use its intensity as alpha
|
| 337 |
+
gray = cv2.cvtColor(output_frame, cv2.COLOR_RGB2GRAY)
|
| 338 |
+
alpha = np.clip(gray.astype(np.float32) / 255.0, 0, 1)
|
| 339 |
+
alpha_3ch = np.stack([alpha, alpha, alpha], axis=-1)
|
| 340 |
+
|
| 341 |
+
# Create green background
|
| 342 |
+
green_bg = np.zeros_like(orig_frame)
|
| 343 |
+
green_bg[:, :] = [156, 251, 165] # Green screen color
|
| 344 |
+
|
| 345 |
+
# Composite: original_RGB * alpha + green * (1 - alpha)
|
| 346 |
+
composite = (orig_frame.astype(np.float32) * alpha_3ch +
|
| 347 |
+
green_bg.astype(np.float32) * (1 - alpha_3ch)).astype(np.uint8)
|
| 348 |
+
greenscreen_frames.append(composite)
|
| 349 |
+
|
| 350 |
save_video(greenscreen_frames, greenscreen_path, video_state["fps"])
|
| 351 |
|
| 352 |
status_msg = f"✓ Complete! Generated {len(output_frames)} frames."
|
|
|
|
| 458 |
# Event handlers
|
| 459 |
load_button.click(
|
| 460 |
fn=load_video,
|
| 461 |
+
inputs=[video_input, video_state],
|
| 462 |
outputs=[video_state, first_frame_display,
|
| 463 |
point_prompt, clear_button, run_button, status_text]
|
| 464 |
)
|