pizb commited on
Commit
e1a1811
·
1 Parent(s): 638f167
Files changed (1) hide show
  1. app.py +51 -43
app.py CHANGED
@@ -125,35 +125,30 @@ def get_prompt(click_state, click_input):
125
 
126
  def load_video(video_input, video_state):
127
  """
128
- Load video, store path, and extract first frame for mask generation
129
  """
130
  if video_input is None:
131
  return video_state, None, \
132
  gr.update(visible=False), gr.update(visible=False), \
133
  gr.update(visible=False), gr.update(visible=False)
134
 
135
- # Extract ONLY the first frame for the UI to save memory/bandwidth
136
- # We will load the full video inside the GPU function later
137
- cap = cv2.VideoCapture(video_input)
138
- ret, first_frame = cap.read()
139
- cap.release()
140
 
141
- if not ret:
142
  return video_state, None, \
143
  gr.update(visible=False), gr.update(visible=False), \
144
  gr.update(visible=False), gr.update(visible=False)
145
-
146
- first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
147
 
148
- # Initialize video state with PATH, not full frames
149
  video_state = {
150
- "video_path": video_input, # <--- Store Path
151
- "first_frame": first_frame_rgb, # <--- Store only one frame
152
  "first_frame_mask": None,
153
  "masks": None,
154
  }
155
 
156
- first_frame_pil = Image.fromarray(first_frame_rgb)
157
 
158
  return video_state, first_frame_pil, \
159
  gr.update(visible=True), gr.update(visible=True), \
@@ -161,6 +156,17 @@ def load_video(video_input, video_state):
161
 
162
 
163
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
164
  def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
165
  """
166
  Add click and update mask on first frame
@@ -171,10 +177,7 @@ def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
171
  click_state: [[points], [labels]]
172
  evt: Gradio SelectData event with click coordinates
173
  """
174
- # Lazy load models on first use
175
- initialize_models()
176
-
177
- if video_state is None or "first_frame" not in video_state: # Check for first_frame
178
  return None, video_state, click_state
179
 
180
  # Add new click
@@ -186,13 +189,9 @@ def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
186
 
187
  print(f"Added {point_prompt} click at ({x}, {y}). Total clicks: {len(click_state[0])}")
188
 
189
- # Generate mask with SAM2
190
- first_frame = video_state["first_frame"]
191
- mask = sam2_tracker.get_first_frame_mask(
192
- frame=first_frame,
193
- points=click_state[0],
194
- labels=click_state[1]
195
- )
196
 
197
  # Store mask in video state
198
  video_state["first_frame_mask"] = mask
@@ -280,37 +279,46 @@ def propagate_masks(video_state, click_state):
280
 
281
 
282
  @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  def run_videomama_with_sam2(video_state, click_state):
284
  """
285
  Run SAM2 propagation and VideoMaMa inference together
286
  """
287
- # Lazy load models on first use
288
- initialize_models()
289
-
290
- if video_state is None or "video_path" not in video_state:
291
  return video_state, None, None, None, "⚠️ No video loaded"
292
 
293
  if len(click_state[0]) == 0:
294
  return video_state, None, None, None, "⚠️ Please add at least one point first"
295
 
296
- # RELOAD FRAMES HERE inside the GPU worker
297
- print(f"Loading frames from {video_state['video_path']}...")
298
- frames, fps = extract_frames_from_video(video_state["video_path"], max_frames=50)
299
 
300
- # Update state with FPS just in case (though we likely don't need to return it)
301
- video_state["fps"] = fps
302
- masks = sam2_tracker.track_video(
303
- frames=frames,
304
- points=click_state[0],
305
- labels=click_state[1]
306
  )
307
 
308
  video_state["masks"] = masks
309
- print(f"✓ Generated {len(masks)} masks")
310
-
311
- # Step 2: Run VideoMaMa
312
- print(f"🎨 Running VideoMaMa on {len(frames)} frames...")
313
- output_frames = videomama(videomama_pipeline, frames, masks)
314
 
315
  # Save output videos
316
  output_dir = Path("outputs")
@@ -513,4 +521,4 @@ if __name__ == "__main__":
513
  # server_port=7860,
514
  # share=True
515
  # )
516
- demo.launch()
 
125
 
126
  def load_video(video_input, video_state):
127
  """
128
+ Load video and extract first frame for mask generation
129
  """
130
  if video_input is None:
131
  return video_state, None, \
132
  gr.update(visible=False), gr.update(visible=False), \
133
  gr.update(visible=False), gr.update(visible=False)
134
 
135
+ # Extract frames
136
+ frames, fps = extract_frames_from_video(video_input, max_frames=50)
 
 
 
137
 
138
+ if len(frames) == 0:
139
  return video_state, None, \
140
  gr.update(visible=False), gr.update(visible=False), \
141
  gr.update(visible=False), gr.update(visible=False)
 
 
142
 
143
+ # Initialize video state
144
  video_state = {
145
+ "frames": frames,
146
+ "fps": fps,
147
  "first_frame_mask": None,
148
  "masks": None,
149
  }
150
 
151
+ first_frame_pil = Image.fromarray(frames[0])
152
 
153
  return video_state, first_frame_pil, \
154
  gr.update(visible=True), gr.update(visible=True), \
 
156
 
157
 
158
  @spaces.GPU
159
+ def generate_sam2_mask(first_frame, points, labels):
160
+ """GPU-intensive SAM2 mask generation"""
161
+ initialize_models()
162
+ mask = sam2_tracker.get_first_frame_mask(
163
+ frame=first_frame,
164
+ points=points,
165
+ labels=labels
166
+ )
167
+ return mask
168
+
169
+
170
  def sam_refine(video_state, point_prompt, click_state, evt: gr.SelectData):
171
  """
172
  Add click and update mask on first frame
 
177
  click_state: [[points], [labels]]
178
  evt: Gradio SelectData event with click coordinates
179
  """
180
+ if video_state is None or "frames" not in video_state:
 
 
 
181
  return None, video_state, click_state
182
 
183
  # Add new click
 
189
 
190
  print(f"Added {point_prompt} click at ({x}, {y}). Total clicks: {len(click_state[0])}")
191
 
192
+ # Generate mask with SAM2 (GPU operation)
193
+ first_frame = video_state["frames"][0]
194
+ mask = generate_sam2_mask(first_frame, click_state[0], click_state[1])
 
 
 
 
195
 
196
  # Store mask in video state
197
  video_state["first_frame_mask"] = mask
 
279
 
280
 
281
  @spaces.GPU(duration=120)
282
+ def process_video_with_models(frames, points, labels):
283
+ """GPU-intensive video processing with SAM2 and VideoMaMa"""
284
+ initialize_models()
285
+
286
+ # Step 1: Track through video with SAM2
287
+ print(f"🎯 Tracking object through {len(frames)} frames with SAM2...")
288
+ masks = sam2_tracker.track_video(
289
+ frames=frames,
290
+ points=points,
291
+ labels=labels
292
+ )
293
+ print(f"✓ Generated {len(masks)} masks")
294
+
295
+ # Step 2: Run VideoMaMa
296
+ print(f"🎨 Running VideoMaMa on {len(frames)} frames...")
297
+ output_frames = videomama(videomama_pipeline, frames, masks)
298
+
299
+ return masks, output_frames
300
+
301
+
302
  def run_videomama_with_sam2(video_state, click_state):
303
  """
304
  Run SAM2 propagation and VideoMaMa inference together
305
  """
306
+ if video_state is None or "frames" not in video_state:
 
 
 
307
  return video_state, None, None, None, "⚠️ No video loaded"
308
 
309
  if len(click_state[0]) == 0:
310
  return video_state, None, None, None, "⚠️ Please add at least one point first"
311
 
312
+ frames = video_state["frames"]
 
 
313
 
314
+ # Run GPU-intensive processing
315
+ masks, output_frames = process_video_with_models(
316
+ frames,
317
+ click_state[0],
318
+ click_state[1]
 
319
  )
320
 
321
  video_state["masks"] = masks
 
 
 
 
 
322
 
323
  # Save output videos
324
  output_dir = Path("outputs")
 
521
  # server_port=7860,
522
  # share=True
523
  # )
524
+ demo.launch()