Mirko Trasciatti commited on
Commit
8457ca9
·
1 Parent(s): a2fc2ab

Restore to last known working version (single object only)

Browse files
Files changed (1) hide show
  1. app.py +34 -215
app.py CHANGED
@@ -65,138 +65,6 @@ def load_video_cv2(video_path):
65
  return frames, {'fps': fps}
66
 
67
 
68
- @spaces.GPU
69
- def segment_video_multi_objects(video_file, annotations_json, remove_bg):
70
- """
71
- Segment video with MULTIPLE objects.
72
-
73
- annotations_json: JSON string with format:
74
- [
75
- {"x": 360, "y": 640, "frame": 0, "obj_id": 1},
76
- {"x": 360, "y": 640, "frame": 189, "obj_id": 2}
77
- ]
78
- """
79
- global device, model, processor
80
-
81
- if model is None:
82
- initialize_model()
83
-
84
- try:
85
- if video_file is None:
86
- return None, "❌ Error: No video file provided"
87
-
88
- video_path = str(video_file)
89
- if not os.path.exists(video_path):
90
- return None, f"❌ Error: Video file not found: {video_path}"
91
-
92
- print(f"Processing video from: {video_path}")
93
-
94
- # Parse annotations
95
- try:
96
- annotations = json.loads(annotations_json)
97
- except:
98
- return None, f"❌ Error: Invalid JSON format for annotations"
99
-
100
- if not annotations or len(annotations) == 0:
101
- return None, "❌ Error: No annotations provided"
102
-
103
- print(f"Processing {len(annotations)} objects...")
104
-
105
- # Load video
106
- video_frames, video_info = load_video_cv2(video_path)
107
- fps = video_info.get('fps', 30.0)
108
-
109
- # Initialize inference session
110
- dtype = torch.float32
111
- inference_session = processor.init_video_session(
112
- video=video_frames,
113
- inference_device=device,
114
- dtype=dtype,
115
- )
116
-
117
- # Add all annotations
118
- for ann in annotations:
119
- x = int(ann['x'])
120
- y = int(ann['y'])
121
- frame = int(ann['frame'])
122
- obj_id = int(ann.get('obj_id', 1))
123
-
124
- print(f" Adding object {obj_id} at ({x}, {y}) on frame {frame}")
125
-
126
- processor.add_inputs_to_inference_session(
127
- inference_session=inference_session,
128
- frame_idx=frame,
129
- obj_ids=obj_id,
130
- input_points=[[[[x, y]]]],
131
- input_labels=[[[1]]],
132
- )
133
-
134
- # Run inference on this frame
135
- model(
136
- inference_session=inference_session,
137
- frame_idx=frame,
138
- )
139
-
140
- # Propagate through video (will track ALL objects)
141
- video_segments = {}
142
- for sam2_output in model.propagate_in_video_iterator(inference_session):
143
- video_res_masks = processor.post_process_masks(
144
- [sam2_output.pred_masks],
145
- original_sizes=[[inference_session.video_height, inference_session.video_width]],
146
- binarize=False,
147
- )[0]
148
- video_segments[sam2_output.frame_idx] = video_res_masks
149
-
150
- # Create output video
151
- output_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
152
- first_frame = np.array(video_frames[0])
153
- height, width = first_frame.shape[:2]
154
-
155
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
156
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
157
-
158
- for frame_idx, frame_pil in enumerate(video_frames):
159
- frame = np.array(frame_pil)
160
-
161
- if frame_idx in video_segments:
162
- mask = video_segments[frame_idx].cpu().numpy()
163
-
164
- # Combine ALL object masks
165
- if mask.ndim == 4:
166
- # Shape: [batch, num_objects, height, width]
167
- # Combine across object dimension
168
- mask = mask[0] # Remove batch dim
169
- if mask.ndim == 3:
170
- # Combine all object masks with max (OR operation)
171
- mask = mask.max(axis=0)
172
-
173
- if mask.shape != (height, width):
174
- mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
175
-
176
- mask_binary = (mask > 0.5).astype(np.uint8)
177
-
178
- if remove_bg:
179
- background = np.zeros_like(frame)
180
- mask_3d = np.repeat(mask_binary[:, :, np.newaxis], 3, axis=2)
181
- frame = frame * mask_3d + background * (1 - mask_3d)
182
-
183
- frame_bgr = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR)
184
- out.write(frame_bgr)
185
-
186
- out.release()
187
-
188
- if os.path.exists(output_path):
189
- return output_path, f"✅ Success! Processed {len(annotations)} objects across {len(video_segments)} frames"
190
- else:
191
- return None, f"❌ Error: Output file was not created"
192
-
193
- except Exception as e:
194
- import traceback
195
- error_details = traceback.format_exc()
196
- print(f"Error in segment_video_multi_objects: {error_details}")
197
- return None, f"❌ Error: {str(e)}"
198
-
199
-
200
  @spaces.GPU
201
  def segment_video_simple(video_file, point_x, point_y, frame_idx, remove_bg):
202
  """Simple video segmentation with a single point."""
@@ -261,7 +129,7 @@ def segment_video_simple(video_file, point_x, point_y, frame_idx, remove_bg):
261
  video_segments[sam2_output.frame_idx] = video_res_masks
262
 
263
  # Create output video
264
- output_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
265
  first_frame = np.array(video_frames[0])
266
  height, width = first_frame.shape[:2]
267
 
@@ -316,95 +184,46 @@ def create_app():
316
  # 🎥 SAM2 Video Background Remover
317
 
318
  Remove backgrounds from videos by tracking objects with Meta's SAM2.
 
 
 
 
 
319
  """)
320
 
321
- with gr.Tabs():
322
- # Tab 1: Single Object
323
- with gr.Tab("Single Object"):
324
- gr.Markdown("""
325
- **Track ONE object:**
326
- 1. Upload a video
327
- 2. Enter X, Y coordinates of the object to track
328
- 3. Click "Process Video"
329
- """)
330
 
331
  with gr.Row():
332
- with gr.Column():
333
- video_input = gr.File(label="Upload Video", file_types=["video"])
334
-
335
- with gr.Row():
336
- point_x = gr.Textbox(label="Point X", value="320")
337
- point_y = gr.Textbox(label="Point Y", value="240")
338
-
339
- frame_idx = gr.Textbox(label="Frame Index", value="0")
340
- remove_bg = gr.Checkbox(label="Remove Background", value=True)
341
-
342
- process_btn = gr.Button("🎬 Process Video", variant="primary")
343
-
344
- with gr.Column():
345
- output_video = gr.File(label="Output Video")
346
- status_text = gr.Textbox(label="Status", lines=3)
347
 
348
- process_btn.click(
349
- fn=segment_video_simple,
350
- inputs=[video_input, point_x, point_y, frame_idx, remove_bg],
351
- outputs=[output_video, status_text]
352
- )
353
 
354
- gr.Markdown("""
355
- ### Tips:
356
- - Point X, Y: Coordinates of the object in the video
357
- - Frame Index: Usually 0 (first frame)
358
- - Portrait and landscape videos are both supported!
359
- """)
360
 
361
- # Tab 2: Multiple Objects
362
- with gr.Tab("Multiple Objects"):
363
- gr.Markdown("""
364
- **Track MULTIPLE objects:**
365
- 1. Upload a video
366
- 2. Enter annotations as JSON (see example below)
367
- 3. Click "Process Video"
368
-
369
- **Example JSON** (ball at frame 0, player at frame 189):
370
- ```json
371
- [
372
- {"x": 360, "y": 640, "frame": 0, "obj_id": 1},
373
- {"x": 360, "y": 640, "frame": 189, "obj_id": 2}
374
- ]
375
- ```
376
- """)
377
-
378
- with gr.Row():
379
- with gr.Column():
380
- video_input_multi = gr.File(label="Upload Video", file_types=["video"])
381
-
382
- annotations_json = gr.Textbox(
383
- label="Annotations (JSON)",
384
- value='[{"x": 360, "y": 640, "frame": 0, "obj_id": 1}, {"x": 360, "y": 640, "frame": 189, "obj_id": 2}]',
385
- lines=5
386
- )
387
-
388
- remove_bg_multi = gr.Checkbox(label="Remove Background", value=True)
389
-
390
- process_btn_multi = gr.Button("🎬 Process Multiple Objects", variant="primary")
391
-
392
- with gr.Column():
393
- output_video_multi = gr.File(label="Output Video")
394
- status_text_multi = gr.Textbox(label="Status", lines=3)
395
-
396
- process_btn_multi.click(
397
- fn=segment_video_multi_objects,
398
- inputs=[video_input_multi, annotations_json, remove_bg_multi],
399
- outputs=[output_video_multi, status_text_multi]
400
- )
401
-
402
- gr.Markdown("""
403
- ### Tips:
404
- - Each object needs: `x`, `y`, `frame`, and unique `obj_id`
405
- - Pick frames where each object is clearly visible
406
- - All objects will be tracked and combined in the output!
407
- """)
408
 
409
  return app
410
 
 
65
  return frames, {'fps': fps}
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  @spaces.GPU
69
  def segment_video_simple(video_file, point_x, point_y, frame_idx, remove_bg):
70
  """Simple video segmentation with a single point."""
 
129
  video_segments[sam2_output.frame_idx] = video_res_masks
130
 
131
  # Create output video
132
+ output_path = tempfile.mktemp(suffix=".mp4")
133
  first_frame = np.array(video_frames[0])
134
  height, width = first_frame.shape[:2]
135
 
 
184
  # 🎥 SAM2 Video Background Remover
185
 
186
  Remove backgrounds from videos by tracking objects with Meta's SAM2.
187
+
188
+ **How to use:**
189
+ 1. Upload a video
190
+ 2. Enter X, Y coordinates of the object to track (from first frame)
191
+ 3. Click "Process Video"
192
  """)
193
 
194
+ with gr.Row():
195
+ with gr.Column():
196
+ # Using gr.File instead of gr.Video for better API compatibility
197
+ video_input = gr.File(label="Upload Video", file_types=["video"])
 
 
 
 
 
198
 
199
  with gr.Row():
200
+ point_x = gr.Textbox(label="Point X", value="320")
201
+ point_y = gr.Textbox(label="Point Y", value="240")
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ frame_idx = gr.Textbox(label="Frame Index", value="0")
204
+ remove_bg = gr.Checkbox(label="Remove Background", value=True)
 
 
 
205
 
206
+ process_btn = gr.Button("🎬 Process Video", variant="primary")
 
 
 
 
 
207
 
208
+ with gr.Column():
209
+ output_video = gr.File(label="Output Video")
210
+ status_text = gr.Textbox(label="Status", lines=3)
211
+
212
+ process_btn.click(
213
+ fn=segment_video_simple,
214
+ inputs=[video_input, point_x, point_y, frame_idx, remove_bg],
215
+ outputs=[output_video, status_text]
216
+ )
217
+
218
+ gr.Markdown("""
219
+ ### Tips:
220
+ - Point X, Y: Coordinates of the object in the video
221
+ - For a 720x1280 portrait video, center is typically X=360, Y=640
222
+ - For a 1920x1080 landscape video, center is typically X=960, Y=540
223
+ - Frame Index: Usually 0 (first frame)
224
+ - Processing time depends on video length (CPU processing is slow)
225
+ - Portrait and landscape videos are both supported!
226
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  return app
229