pizb commited on
Commit
638f167
·
1 Parent(s): 6cec120
Files changed (1) hide show
  1. app.py +68 -125
app.py CHANGED
@@ -123,32 +123,37 @@ def get_prompt(click_state, click_input):
123
  return click_state
124
 
125
 
126
- def load_video(video_input):
127
  """
128
- Load video and extract first frame for mask generation
129
  """
130
  if video_input is None:
131
- return None, None, \
132
  gr.update(visible=False), gr.update(visible=False), \
133
  gr.update(visible=False), gr.update(visible=False)
134
 
135
- # Extract frames
136
- frames, fps = extract_frames_from_video(video_input, max_frames=50)
 
 
 
137
 
138
- if len(frames) == 0:
139
- return None, None, \
140
  gr.update(visible=False), gr.update(visible=False), \
141
  gr.update(visible=False), gr.update(visible=False)
 
 
142
 
143
- # Initialize video state - convert frames to list for pickling
144
  video_state = {
145
- "frames": [frame.tolist() for frame in frames], # Convert numpy to list
146
- "fps": float(fps), # Ensure JSON serializable
147
  "first_frame_mask": None,
148
  "masks": None,
149
  }
150
 
151
- first_frame_pil = Image.fromarray(frames[0])
152
 
153
  return video_state, first_frame_pil, \
154
  gr.update(visible=True), gr.update(visible=True), \
@@ -156,35 +161,6 @@ def load_video(video_input):
156
 
157
 
158
  @spaces.GPU
159
- def sam_refine_gpu(first_frame_list, points, labels):
160
- """
161
- GPU function: Generate mask with SAM2
162
-
163
- Args:
164
- first_frame_list: First frame as list
165
- points: List of [x, y] coordinates
166
- labels: List of labels (1=positive, 0=negative)
167
-
168
- Returns:
169
- mask as list
170
- """
171
- # Lazy load models on first use
172
- initialize_models()
173
-
174
- # Convert to numpy
175
- first_frame = np.array(first_frame_list, dtype=np.uint8)
176
-
177
- # Generate mask with SAM2
178
- mask = sam2_tracker.get_first_frame_mask(
179
- frame=first_frame,
180
- points=points,
181
- labels=labels
182
- )
183
-
184
- # Return as list for pickling
185
- return mask.tolist() if hasattr(mask, 'tolist') else mask
186
-
187
-
188
  def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
189
  """
190
  Add click and update mask on first frame
@@ -195,7 +171,10 @@ def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
195
  click_state: [[points], [labels]]
196
  evt: Gradio SelectData event with click coordinates
197
  """
198
- if video_state is None or "frames" not in video_state:
 
 
 
199
  return None, video_state, click_state
200
 
201
  # Add new click
@@ -207,20 +186,18 @@ def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
207
 
208
  print(f"Added {point_prompt} click at ({x}, {y}). Total clicks: {len(click_state[0])}")
209
 
210
- # Call GPU function with plain data (no Gradio State objects)
211
- mask_list = sam_refine_gpu(
212
- video_state["frames"][0],
213
- click_state[0],
214
- click_state[1]
 
215
  )
216
 
217
- # Store mask as list
218
- video_state["first_frame_mask"] = mask_list
219
 
220
  # Visualize mask and points
221
- first_frame = np.array(video_state["frames"][0], dtype=np.uint8)
222
- mask = np.array(mask_list, dtype=np.uint8)
223
-
224
  painted_image = mask_painter(
225
  first_frame.copy(),
226
  mask,
@@ -268,7 +245,7 @@ def clear_clicks(video_state, click_state):
268
  click_state = [[], []]
269
 
270
  if video_state is not None and "frames" in video_state:
271
- first_frame = np.array(video_state["frames"][0], dtype=np.uint8)
272
  video_state["first_frame_mask"] = None
273
  return Image.fromarray(first_frame), video_state, click_state
274
 
@@ -285,8 +262,7 @@ def propagate_masks(video_state, click_state):
285
  if len(click_state[0]) == 0:
286
  return video_state, "⚠️ Please add at least one point first", gr.update(visible=False)
287
 
288
- # Convert frames back to numpy arrays
289
- frames = [np.array(f, dtype=np.uint8) for f in video_state["frames"]]
290
 
291
  # Track through video
292
  print(f"Tracking object through {len(frames)} frames...")
@@ -296,8 +272,7 @@ def propagate_masks(video_state, click_state):
296
  labels=click_state[1]
297
  )
298
 
299
- # Convert masks to lists for pickling
300
- video_state["masks"] = [m.tolist() if hasattr(m, 'tolist') else m for m in masks]
301
 
302
  status_msg = f"✓ Generated {len(masks)} masks. Ready to run VideoMaMa!"
303
 
@@ -305,88 +280,38 @@ def propagate_masks(video_state, click_state):
305
 
306
 
307
  @spaces.GPU(duration=120)
308
- def run_videomama_with_sam2_gpu(frames_list, points, labels):
309
  """
310
- GPU function: Run SAM2 propagation and VideoMaMa inference
311
-
312
- Args:
313
- frames_list: List of frames as lists
314
- points: List of [x, y] coordinates
315
- labels: List of labels (1=positive, 0=negative)
316
-
317
- Returns:
318
- Tuple of (masks_list, output_frames_list, greenscreen_frames_list)
319
  """
320
  # Lazy load models on first use
321
  initialize_models()
322
 
323
- # Convert frames back to numpy arrays
324
- frames = [np.array(f, dtype=np.uint8) for f in frames_list]
 
 
 
 
 
 
 
325
 
326
- # Step 1: Track through video with SAM2
327
- print(f"🎯 Tracking object through {len(frames)} frames with SAM2...")
328
  masks = sam2_tracker.track_video(
329
  frames=frames,
330
- points=points,
331
- labels=labels
332
  )
 
 
333
  print(f"✓ Generated {len(masks)} masks")
334
 
335
  # Step 2: Run VideoMaMa
336
  print(f"🎨 Running VideoMaMa on {len(frames)} frames...")
337
  output_frames = videomama(videomama_pipeline, frames, masks)
338
 
339
- # Create greenscreen composite
340
- greenscreen_frames = []
341
- for orig_frame, output_frame in zip(frames, output_frames):
342
- # Extract alpha matte from VideoMaMa output
343
- gray = cv2.cvtColor(output_frame, cv2.COLOR_RGB2GRAY)
344
- alpha = np.clip(gray.astype(np.float32) / 255.0, 0, 1)
345
- alpha_3ch = np.stack([alpha, alpha, alpha], axis=-1)
346
-
347
- # Create green background
348
- green_bg = np.zeros_like(orig_frame)
349
- green_bg[:, :] = [156, 251, 165] # Green screen color
350
-
351
- # Composite: original_RGB * alpha + green * (1 - alpha)
352
- composite = (orig_frame.astype(np.float32) * alpha_3ch +
353
- green_bg.astype(np.float32) * (1 - alpha_3ch)).astype(np.uint8)
354
- greenscreen_frames.append(composite)
355
-
356
- # Convert to lists for pickling
357
- masks_list = [m.tolist() if hasattr(m, 'tolist') else m for m in masks]
358
- output_frames_list = [f.tolist() for f in output_frames]
359
- greenscreen_frames_list = [f.tolist() for f in greenscreen_frames]
360
-
361
- return masks_list, output_frames_list, greenscreen_frames_list
362
-
363
-
364
- def run_videomama_with_sam2(video_state, click_state):
365
- """
366
- Run SAM2 propagation and VideoMaMa inference together
367
- """
368
- if video_state is None or "frames" not in video_state:
369
- return video_state, None, None, None, "⚠️ No video loaded"
370
-
371
- if len(click_state[0]) == 0:
372
- return video_state, None, None, None, "⚠️ Please add at least one point first"
373
-
374
- # Call GPU function with plain data (no Gradio State objects)
375
- masks_list, output_frames_list, greenscreen_frames_list = run_videomama_with_sam2_gpu(
376
- video_state["frames"],
377
- click_state[0],
378
- click_state[1]
379
- )
380
-
381
- # Store masks
382
- video_state["masks"] = masks_list
383
-
384
- # Convert back to numpy for video saving
385
- frames = [np.array(f, dtype=np.uint8) for f in video_state["frames"]]
386
- masks = [np.array(m, dtype=np.uint8) for m in masks_list]
387
- output_frames = [np.array(f, dtype=np.uint8) for f in output_frames_list]
388
- greenscreen_frames = [np.array(f, dtype=np.uint8) for f in greenscreen_frames_list]
389
-
390
  # Save output videos
391
  output_dir = Path("outputs")
392
  output_dir.mkdir(exist_ok=True)
@@ -403,7 +328,25 @@ def run_videomama_with_sam2(video_state, click_state):
403
  mask_frames_rgb = [np.stack([m, m, m], axis=-1) for m in masks]
404
  save_video(mask_frames_rgb, mask_video_path, video_state["fps"])
405
 
406
- # Save greenscreen composite
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  save_video(greenscreen_frames, greenscreen_path, video_state["fps"])
408
 
409
  status_msg = f"✓ Complete! Generated {len(output_frames)} frames."
@@ -515,7 +458,7 @@ with gr.Blocks(title="VideoMaMa Demo") as demo:
515
  # Event handlers
516
  load_button.click(
517
  fn=load_video,
518
- inputs=[video_input],
519
  outputs=[video_state, first_frame_display,
520
  point_prompt, clear_button, run_button, status_text]
521
  )
 
123
  return click_state
124
 
125
 
126
+ def load_video(video_input, video_state):
127
  """
128
+ Load video, store path, and extract first frame for mask generation
129
  """
130
  if video_input is None:
131
+ return video_state, None, \
132
  gr.update(visible=False), gr.update(visible=False), \
133
  gr.update(visible=False), gr.update(visible=False)
134
 
135
+ # Extract ONLY the first frame for the UI to save memory/bandwidth
136
+ # We will load the full video inside the GPU function later
137
+ cap = cv2.VideoCapture(video_input)
138
+ ret, first_frame = cap.read()
139
+ cap.release()
140
 
141
+ if not ret:
142
+ return video_state, None, \
143
  gr.update(visible=False), gr.update(visible=False), \
144
  gr.update(visible=False), gr.update(visible=False)
145
+
146
+ first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
147
 
148
+ # Initialize video state with PATH, not full frames
149
  video_state = {
150
+ "video_path": video_input, # <--- Store Path
151
+ "first_frame": first_frame_rgb, # <--- Store only one frame
152
  "first_frame_mask": None,
153
  "masks": None,
154
  }
155
 
156
+ first_frame_pil = Image.fromarray(first_frame_rgb)
157
 
158
  return video_state, first_frame_pil, \
159
  gr.update(visible=True), gr.update(visible=True), \
 
161
 
162
 
163
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
165
  """
166
  Add click and update mask on first frame
 
171
  click_state: [[points], [labels]]
172
  evt: Gradio SelectData event with click coordinates
173
  """
174
+ # Lazy load models on first use
175
+ initialize_models()
176
+
177
+ if video_state is None or "first_frame" not in video_state: # Check for first_frame
178
  return None, video_state, click_state
179
 
180
  # Add new click
 
186
 
187
  print(f"Added {point_prompt} click at ({x}, {y}). Total clicks: {len(click_state[0])}")
188
 
189
+ # Generate mask with SAM2
190
+ first_frame = video_state["first_frame"]
191
+ mask = sam2_tracker.get_first_frame_mask(
192
+ frame=first_frame,
193
+ points=click_state[0],
194
+ labels=click_state[1]
195
  )
196
 
197
+ # Store mask in video state
198
+ video_state["first_frame_mask"] = mask
199
 
200
  # Visualize mask and points
 
 
 
201
  painted_image = mask_painter(
202
  first_frame.copy(),
203
  mask,
 
245
  click_state = [[], []]
246
 
247
  if video_state is not None and "frames" in video_state:
248
+ first_frame = video_state["frames"][0]
249
  video_state["first_frame_mask"] = None
250
  return Image.fromarray(first_frame), video_state, click_state
251
 
 
262
  if len(click_state[0]) == 0:
263
  return video_state, "⚠️ Please add at least one point first", gr.update(visible=False)
264
 
265
+ frames = video_state["frames"]
 
266
 
267
  # Track through video
268
  print(f"Tracking object through {len(frames)} frames...")
 
272
  labels=click_state[1]
273
  )
274
 
275
+ video_state["masks"] = masks
 
276
 
277
  status_msg = f"✓ Generated {len(masks)} masks. Ready to run VideoMaMa!"
278
 
 
280
 
281
 
282
  @spaces.GPU(duration=120)
283
+ def run_videomama_with_sam2(video_state, click_state):
284
  """
285
+ Run SAM2 propagation and VideoMaMa inference together
 
 
 
 
 
 
 
 
286
  """
287
  # Lazy load models on first use
288
  initialize_models()
289
 
290
+ if video_state is None or "video_path" not in video_state:
291
+ return video_state, None, None, None, "⚠️ No video loaded"
292
+
293
+ if len(click_state[0]) == 0:
294
+ return video_state, None, None, None, "⚠️ Please add at least one point first"
295
+
296
+ # RELOAD FRAMES HERE inside the GPU worker
297
+ print(f"Loading frames from {video_state['video_path']}...")
298
+ frames, fps = extract_frames_from_video(video_state["video_path"], max_frames=50)
299
 
300
+ # Update state with FPS just in case (though we likely don't need to return it)
301
+ video_state["fps"] = fps
302
  masks = sam2_tracker.track_video(
303
  frames=frames,
304
+ points=click_state[0],
305
+ labels=click_state[1]
306
  )
307
+
308
+ video_state["masks"] = masks
309
  print(f"✓ Generated {len(masks)} masks")
310
 
311
  # Step 2: Run VideoMaMa
312
  print(f"🎨 Running VideoMaMa on {len(frames)} frames...")
313
  output_frames = videomama(videomama_pipeline, frames, masks)
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  # Save output videos
316
  output_dir = Path("outputs")
317
  output_dir.mkdir(exist_ok=True)
 
328
  mask_frames_rgb = [np.stack([m, m, m], axis=-1) for m in masks]
329
  save_video(mask_frames_rgb, mask_video_path, video_state["fps"])
330
 
331
+ # Create greenscreen composite: RGB * VideoMaMa_alpha + green * (1 - VideoMaMa_alpha)
332
+ # VideoMaMa output_frames already contain the alpha matte result
333
+ greenscreen_frames = []
334
+ for orig_frame, output_frame in zip(frames, output_frames):
335
+ # Extract alpha matte from VideoMaMa output
336
+ # VideoMaMa outputs matted foreground, we use its intensity as alpha
337
+ gray = cv2.cvtColor(output_frame, cv2.COLOR_RGB2GRAY)
338
+ alpha = np.clip(gray.astype(np.float32) / 255.0, 0, 1)
339
+ alpha_3ch = np.stack([alpha, alpha, alpha], axis=-1)
340
+
341
+ # Create green background
342
+ green_bg = np.zeros_like(orig_frame)
343
+ green_bg[:, :] = [156, 251, 165] # Green screen color
344
+
345
+ # Composite: original_RGB * alpha + green * (1 - alpha)
346
+ composite = (orig_frame.astype(np.float32) * alpha_3ch +
347
+ green_bg.astype(np.float32) * (1 - alpha_3ch)).astype(np.uint8)
348
+ greenscreen_frames.append(composite)
349
+
350
  save_video(greenscreen_frames, greenscreen_path, video_state["fps"])
351
 
352
  status_msg = f"✓ Complete! Generated {len(output_frames)} frames."
 
458
  # Event handlers
459
  load_button.click(
460
  fn=load_video,
461
+ inputs=[video_input, video_state],
462
  outputs=[video_state, first_frame_display,
463
  point_prompt, clear_button, run_button, status_text]
464
  )