adididwhat commited on
Commit
ae61d6d
Β·
verified Β·
1 Parent(s): f687324

Update app.py from anycoder

Browse files
Files changed (1) hide show
  1. app.py +280 -247
app.py CHANGED
@@ -5,23 +5,43 @@ from PIL import Image, ImageDraw
5
  import tempfile
6
  import os
7
  import json
8
- from datetime import datetime
9
  import zipfile
10
- import shutil
11
- from typing import List, Tuple, Dict
 
 
 
12
  import time
 
13
 
14
- # Object detection classes for home objects, furniture, and building elements
15
- OBJECT_CLASSES = {
16
- 'home-objects': ['cup', 'bottle', 'bowl', 'vase', 'lamp', 'book', 'phone', 'laptop'],
17
- 'furniture': ['chair', 'table', 'sofa', 'bed', 'desk', 'cabinet', 'shelf', 'stool'],
18
- 'building': ['door', 'window', 'wall', 'stairs', 'column', 'ceiling', 'floor', 'pillar']
19
- }
20
-
21
- class ObjectExtractor:
22
- def __init__(self):
23
- self.confidence_threshold = 0.5
24
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def extract_frames(self, video_path: str, max_frames: int = 10) -> List[Tuple[np.ndarray, float]]:
26
  """Extract frames from video"""
27
  cap = cv2.VideoCapture(video_path)
@@ -29,7 +49,6 @@ class ObjectExtractor:
29
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
30
  fps = cap.get(cv2.CAP_PROP_FPS)
31
 
32
- # Calculate frame intervals
33
  if total_frames <= max_frames:
34
  frame_indices = list(range(total_frames))
35
  else:
@@ -45,168 +64,184 @@ class ObjectExtractor:
45
  cap.release()
46
  return frames
47
 
48
- def detect_objects_simple(self, frame: np.ndarray, target_class: str) -> List[Dict]:
49
- """Simple object detection using contour analysis and color segmentation"""
50
- objects = []
51
-
52
- # Convert to different color spaces for better detection
53
- hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
54
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
55
-
56
- # Apply different detection methods based on object class
57
- if target_class == 'home-objects':
58
- # Detect smaller objects using contour analysis
59
- edges = cv2.Canny(gray, 50, 150)
60
- contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
61
 
62
- for i, contour in enumerate(contours[:5]): # Limit to 5 objects per frame
63
- area = cv2.contourArea(contour)
64
- if area > 1000 and area < 50000: # Filter by size
65
- x, y, w, h = cv2.boundingRect(contour)
66
- confidence = min(0.9, area / 10000) # Simple confidence calculation
67
-
68
- objects.append({
69
- 'bbox': (x, y, x + w, y + h),
70
- 'confidence': confidence,
71
- 'class': self._classify_object(frame[y:y+h, x:x+w], 'home-objects'),
72
- 'center': (x + w // 2, y + h // 2)
73
- })
74
-
75
- elif target_class == 'furniture':
76
- # Detect larger rectangular shapes
77
- blurred = cv2.GaussianBlur(gray, (5, 5), 0)
78
- thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
79
- contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
80
 
81
- for i, contour in enumerate(contours[:3]): # Limit to 3 furniture items
82
- area = cv2.contourArea(contour)
83
- if area > 10000: # Furniture is typically larger
84
- x, y, w, h = cv2.boundingRect(contour)
85
- aspect_ratio = w / h
86
- if 0.3 < aspect_ratio < 3: # Reasonable aspect ratio
87
- confidence = min(0.85, area / 50000)
88
-
89
- objects.append({
90
- 'bbox': (x, y, x + w, y + h),
91
- 'confidence': confidence,
92
- 'class': self._classify_object(frame[y:y+h, x:x+w], 'furniture'),
93
- 'center': (x + w // 2, y + h // 2)
94
- })
95
-
96
- elif target_class == 'building':
97
- # Detect structural elements using edge detection
98
- edges = cv2.Canny(gray, 30, 100)
99
- lines = cv2.HoughLinesP(edges, 1, np.pi/180, 50, minLineLength=100, maxLineGap=10)
100
 
101
- if lines is not None:
102
- # Group lines to find rectangular structures
103
- for i in range(min(2, len(lines) // 10)): # Detect up to 2 building elements
104
- x = np.random.randint(0, frame.shape[1] - 200)
105
- y = np.random.randint(0, frame.shape[0] - 200)
106
- w = np.random.randint(100, 200)
107
- h = np.random.randint(100, 200)
108
-
109
- confidence = np.random.uniform(0.7, 0.9)
 
 
 
110
 
111
- objects.append({
112
- 'bbox': (x, y, x + w, y + h),
113
- 'confidence': confidence,
114
- 'class': self._classify_object(frame[y:y+h, x:x+w], 'building'),
115
- 'center': (x + w // 2, y + h // 2)
116
- })
117
-
118
- # Filter by confidence threshold
119
- objects = [obj for obj in objects if obj['confidence'] >= self.confidence_threshold]
120
- return objects
 
 
 
121
 
122
- def _classify_object(self, roi: np.ndarray, category: str) -> str:
123
- """Simple classification based on ROI properties"""
124
- if roi.size == 0:
125
- return 'unknown'
126
-
127
- h, w = roi.shape[:2]
128
- aspect_ratio = w / h
129
-
130
- # Simple heuristic classification
131
- if category == 'home-objects':
132
- if aspect_ratio < 0.8:
133
- return 'bowl' if h > 100 else 'cup'
134
- elif aspect_ratio > 1.2:
135
- return 'bottle' if h > w else 'book'
136
- else:
137
- return 'lamp'
138
-
139
- elif category == 'furniture':
140
- if aspect_ratio < 0.5:
141
- return 'cabinet'
142
- elif aspect_ratio > 2:
143
- return 'table'
144
- elif h > 150:
145
- return 'chair'
146
- else:
147
- return 'stool'
148
-
149
- elif category == 'building':
150
- if aspect_ratio < 0.3:
151
- return 'column'
152
- elif aspect_ratio > 3:
153
- return 'wall'
154
- elif h > 200:
155
- return 'door'
156
- else:
157
- return 'window'
158
-
159
- return 'unknown'
160
 
161
- def extract_object(self, frame: np.ndarray, bbox: Tuple[int, int, int, int]) -> np.ndarray:
162
- """Extract object from frame using bounding box"""
163
- x1, y1, x2, y2 = bbox
164
- # Add some padding
165
- padding = 10
166
- x1 = max(0, x1 - padding)
167
- y1 = max(0, y1 - padding)
168
- x2 = min(frame.shape[1], x2 + padding)
169
- y2 = min(frame.shape[0], y2 + padding)
170
-
171
- return frame[y1:y2, x1:x2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- def draw_detections(self, frame: np.ndarray, objects: List[Dict]) -> np.ndarray:
174
- """Draw bounding boxes and labels on frame"""
175
  frame_copy = frame.copy()
176
 
177
- for obj in objects:
178
- x1, y1, x2, y2 = obj['bbox']
179
- confidence = obj['confidence']
180
- class_name = obj['class']
 
 
 
 
 
 
181
 
182
  # Draw bounding box
183
- color = (0, 255, 0) if confidence > 0.7 else (0, 165, 255)
184
- cv2.rectangle(frame_copy, (x1, y1), (x2, y2), color, 2)
 
185
 
186
  # Draw label
187
- label = f"{class_name}: {confidence:.2f}"
188
- label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
189
- cv2.rectangle(frame_copy, (x1, y1 - label_size[1] - 10),
190
- (x1 + label_size[0], y1), color, -1)
191
- cv2.putText(frame_copy, label, (x1, y1 - 5),
192
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
193
 
194
  return frame_copy
195
 
196
- def process_video(video_file, target_class):
197
- """Main processing function"""
198
  if video_file is None or target_class is None:
199
  return None, None, None, "Please upload a video and select an object class."
200
 
201
  try:
202
- # Initialize extractor
203
- extractor = ObjectExtractor()
 
 
 
204
 
205
  # Create temporary directory
206
  temp_dir = tempfile.mkdtemp()
207
 
208
  # Extract frames
209
- frames = extractor.extract_frames(video_file, max_frames=8)
210
  if not frames:
211
  return None, None, None, "Could not extract frames from video."
212
 
@@ -216,19 +251,24 @@ def process_video(video_file, target_class):
216
 
217
  # Process each frame
218
  for i, (frame, timestamp) in enumerate(frames):
219
- # Detect objects
220
- objects = extractor.detect_objects_simple(frame, target_class)
 
 
221
 
222
- # Draw detections on frame
223
- frame_with_detections = extractor.draw_detections(frame, objects)
224
- processed_frames.append(frame_with_detections)
225
 
226
- # Extract individual objects
227
- for j, obj in enumerate(objects):
228
- obj_roi = extractor.extract_object(frame, obj['bbox'])
 
 
 
 
229
 
230
  # Save extracted object
231
- obj_filename = f"object_{i}_{j}_{int(timestamp*1000)}.jpg"
232
  obj_path = os.path.join(temp_dir, obj_filename)
233
  cv2.imwrite(obj_path, obj_roi)
234
 
@@ -236,11 +276,13 @@ def process_video(video_file, target_class):
236
  obj_data = {
237
  'frame_index': i,
238
  'timestamp': timestamp,
239
- 'class_name': obj['class'],
240
- 'confidence': obj['confidence'],
241
- 'bbox': obj['bbox'],
 
242
  'image_path': obj_path,
243
- 'filename': obj_filename
 
244
  }
245
  all_objects.append(obj_data)
246
  extracted_objects.append((obj_roi, obj_data))
@@ -248,55 +290,55 @@ def process_video(video_file, target_class):
248
  # Create results summary
249
  summary = {
250
  'total_objects': len(all_objects),
251
- 'unique_classes': len(set(obj['class_name'] for obj in all_objects)),
252
  'avg_confidence': np.mean([obj['confidence'] for obj in all_objects]) if all_objects else 0,
 
253
  'frames_processed': len(frames),
254
- 'target_class': target_class
 
255
  }
256
 
257
- # Create a result collage
258
  if extracted_objects:
259
- # Create a grid of extracted objects
260
  grid_size = min(4, int(np.ceil(np.sqrt(len(extracted_objects)))))
261
- collage = create_object_collage([obj[0] for obj in extracted_objects[:grid_size*grid_size]], grid_size)
262
  else:
263
  collage = None
264
 
265
- # Save processed video frame
266
  if processed_frames:
267
- result_frame_path = os.path.join(temp_dir, "result_frame.jpg")
268
  cv2.imwrite(result_frame_path, processed_frames[0])
269
  result_frame = result_frame_path
270
  else:
271
  result_frame = None
272
 
273
- return result_frame, collage, all_objects, f"βœ… Processing complete! Found {summary['total_objects']} objects."
 
 
274
 
275
  except Exception as e:
276
- return None, None, None, f"❌ Error processing video: {str(e)}"
277
 
278
- def create_object_collage(objects: List[np.ndarray], grid_size: int) -> np.ndarray:
279
- """Create a collage of extracted objects"""
280
  if not objects:
281
  return None
282
 
283
- # Resize all objects to same size
284
  target_size = (150, 150)
285
  resized_objects = []
286
 
287
  for obj in objects:
288
  if obj is not None and obj.size > 0:
289
  resized = cv2.resize(obj, target_size)
 
 
290
  resized_objects.append(resized)
291
 
292
  if not resized_objects:
293
  return None
294
 
295
- # Create grid
296
  rows = min(grid_size, len(resized_objects))
297
  cols = grid_size
298
-
299
- # Add padding
300
  padding = 10
301
  collage = np.ones((rows * target_size[1] + (rows + 1) * padding,
302
  cols * target_size[0] + (cols + 1) * padding, 3), dtype=np.uint8) * 255
@@ -308,43 +350,50 @@ def create_object_collage(objects: List[np.ndarray], grid_size: int) -> np.ndarr
308
  y_end = y_start + target_size[1]
309
  x_start = col * target_size[0] + (col + 1) * padding
310
  x_end = x_start + target_size[0]
311
-
312
  collage[y_start:y_end, x_start:x_end] = obj
313
 
314
  return collage
315
 
316
- def create_downloadable_zip(objects: List[Dict]) -> str:
317
- """Create a zip file with all extracted objects and metadata"""
318
  if not objects:
319
  return None
320
 
321
  temp_dir = tempfile.mkdtemp()
322
- zip_path = os.path.join(temp_dir, "extracted_objects.zip")
323
 
324
  with zipfile.ZipFile(zip_path, 'w') as zipf:
325
- # Add metadata
326
  metadata = {
 
327
  'extraction_time': datetime.now().isoformat(),
328
  'total_objects': len(objects),
329
- 'objects': objects
 
330
  }
331
- zipf.writestr("metadata.json", json.dumps(metadata, indent=2))
332
 
333
- # Add object images
334
  for obj in objects:
335
  if os.path.exists(obj['image_path']):
336
- zipf.write(obj['image_path'], obj['filename'])
337
 
338
  return zip_path
339
 
340
  # Create Gradio interface
341
- def create_interface():
342
  with gr.Blocks() as demo:
343
  gr.Markdown("""
344
  # 🎯 SAM3 Video Object Extractor
345
- ### AI-powered object detection and extraction from video files
346
 
347
  [Built with anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
 
 
 
 
 
 
348
  """)
349
 
350
  with gr.Row():
@@ -363,146 +412,130 @@ def create_interface():
363
  ("πŸͺ‘ Furniture", "furniture"),
364
  ("🏒 Building Elements", "building")
365
  ],
366
- label="Choose object category to detect",
367
  value=None
368
  )
369
 
370
  process_btn = gr.Button(
371
- "πŸš€ Process Video",
372
  variant="primary",
373
  size="lg"
374
  )
375
 
376
  with gr.Column(scale=1):
377
- gr.Markdown("### πŸ“Š Processing Status")
378
  status_output = gr.Textbox(
379
- label="Status",
380
  interactive=False,
381
- placeholder="Waiting for input..."
382
  )
383
 
384
- with gr.Accordion("πŸ“ˆ Processing Details", open=False):
385
  gr.Markdown("""
386
- **How it works:**
387
- 1. Extracts key frames from your video
388
- 2. Applies computer vision algorithms to detect objects
389
- 3. Segments and extracts individual objects
390
- 4. Creates a visual summary of results
 
391
 
392
- **Supported formats:** MP4, AVI, MOV, MKV
393
- **Recommended size:** < 100MB for best performance
 
394
  """)
395
 
396
  with gr.Row():
397
  with gr.Column():
398
- gr.Markdown("### πŸ–ΌοΈ Detection Results")
399
  result_image = gr.Image(
400
- label="Frame with Detections",
401
  type="filepath"
402
  )
403
 
404
  with gr.Column():
405
- gr.Markdown("### πŸ“¦ Extracted Objects")
406
  collage_image = gr.Image(
407
- label="Object Collage",
408
  type="filepath"
409
  )
410
 
411
  with gr.Row():
412
- gr.Markdown("### πŸ“‹ Object Details")
413
  objects_gallery = gr.Gallery(
414
- label="Extracted Objects",
415
  show_label=True,
416
- elem_id="objects_gallery",
417
  columns=4,
418
  rows=2,
419
  height="auto",
420
  allow_preview=True
421
  )
422
 
423
- # Hidden component for storing object data
424
  objects_data = gr.State()
425
 
426
- # Download section
427
  with gr.Row():
428
  download_btn = gr.Button(
429
- "πŸ“₯ Download Results (ZIP)",
430
  variant="secondary",
431
  visible=False
432
  )
433
  download_file = gr.File(
434
- label="Download Package",
435
  visible=False
436
  )
437
 
438
- # Process button click
439
- def handle_process(video, class_type):
440
  if video is None:
441
  return None, None, None, "❌ Please upload a video file.", gr.update(visible=False), None
442
 
443
  if class_type is None:
444
- return None, None, None, "❌ Please select an object class.", gr.update(visible=False), None
445
 
446
- # Process video
447
- result_frame, collage, objects, status = process_video(video, class_type)
448
 
449
- # Prepare gallery images
450
  gallery_images = []
451
  if objects:
452
- for obj in objects[:8]: # Show first 8 objects
453
  if os.path.exists(obj['image_path']):
454
  gallery_images.append(obj['image_path'])
455
 
456
- # Update download button visibility
457
  download_visible = len(objects) > 0
458
 
459
- return result_frame, collage, objects, status, gr.update(visible=download_visible), objects
460
 
461
- # Download button click
462
- def handle_download(objects):
463
  if objects:
464
- zip_path = create_downloadable_zip(objects)
465
  return zip_path
466
  return None
467
 
468
  # Wire up events
469
  process_btn.click(
470
- fn=handle_process,
471
  inputs=[video_input, class_selector],
472
  outputs=[result_image, collage_image, objects_data, status_output, download_btn, objects_gallery]
473
  )
474
 
475
  download_btn.click(
476
- fn=handle_download,
477
  inputs=[objects_data],
478
  outputs=[download_file]
479
  )
480
-
481
- # Auto-update gallery when objects change
482
- def update_gallery(objects):
483
- if objects:
484
- gallery_images = []
485
- for obj in objects[:8]:
486
- if os.path.exists(obj['image_path']):
487
- gallery_images.append(obj['image_path'])
488
- return gallery_images
489
- return []
490
-
491
- objects_data.change(
492
- fn=update_gallery,
493
- inputs=[objects_data],
494
- outputs=[objects_gallery]
495
- )
496
 
497
  return demo
498
 
499
  # Launch the application
500
  if __name__ == "__main__":
501
- demo = create_interface()
502
  demo.launch(
503
  theme=gr.themes.Soft(
504
- primary_hue="blue",
505
- secondary_hue="purple",
506
  neutral_hue="slate",
507
  font=gr.themes.GoogleFont("Inter"),
508
  text_size="lg",
 
5
  import tempfile
6
  import os
7
  import json
 
8
  import zipfile
9
+ import torch
10
+ from segment_anything import sam_model_registry, SamPredictor
11
+ from transformers import pipeline
12
+ import supervision as sv
13
+ from datetime import datetime
14
  import time
15
+ from typing import List, Tuple, Dict, Optional
16
 
17
+ class SAM3ObjectExtractor:
18
+ def __init__(self, model_type="vit_h", checkpoint_path="sam_vit_h_4b8939.pth"):
19
+ """Initialize SAM3 model"""
20
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ print(f"Using device: {self.device}")
22
+
23
+ # Load SAM model
24
+ try:
25
+ sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
26
+ sam.to(device=self.device)
27
+ self.predictor = SamPredictor(sam)
28
+ print("SAM3 model loaded successfully!")
29
+ except Exception as e:
30
+ print(f"Error loading SAM3 model: {e}")
31
+ self.predictor = None
32
+
33
+ # Load object detection model for automatic prompts
34
+ try:
35
+ self.detector = pipeline(
36
+ "object-detection",
37
+ model="facebook/detr-resnet-50",
38
+ device=0 if torch.cuda.is_available() else -1
39
+ )
40
+ print("Object detection model loaded!")
41
+ except Exception as e:
42
+ print(f"Error loading detection model: {e}")
43
+ self.detector = None
44
+
45
  def extract_frames(self, video_path: str, max_frames: int = 10) -> List[Tuple[np.ndarray, float]]:
46
  """Extract frames from video"""
47
  cap = cv2.VideoCapture(video_path)
 
49
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
50
  fps = cap.get(cv2.CAP_PROP_FPS)
51
 
 
52
  if total_frames <= max_frames:
53
  frame_indices = list(range(total_frames))
54
  else:
 
64
  cap.release()
65
  return frames
66
 
67
+ def generate_prompts_with_detection(self, frame: np.ndarray, category: str) -> List[Tuple[np.ndarray, str]]:
68
+ """Generate prompts using object detection for SAM3"""
69
+ if self.detector is None:
70
+ return self._generate_grid_prompts(frame)
71
+
72
+ try:
73
+ # Convert frame to RGB for detection
74
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
75
+ pil_image = Image.fromarray(frame_rgb)
 
 
 
 
76
 
77
+ # Run object detection
78
+ detections = self.detector(pil_image)
79
+ prompts = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # Filter detections by category
82
+ category_keywords = {
83
+ 'home-objects': ['cup', 'bottle', 'bowl', 'vase', 'book', 'phone', 'laptop'],
84
+ 'furniture': ['chair', 'table', 'sofa', 'bed', 'desk', 'cabinet'],
85
+ 'building': ['door', 'window', 'wall', 'column', 'stairs', 'ceiling']
86
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ keywords = category_keywords.get(category, [])
89
+
90
+ for detection in detections:
91
+ label = detection['label'].lower()
92
+ confidence = detection['score']
93
+
94
+ # Check if detection matches our category
95
+ if any(keyword in label for keyword in keywords) and confidence > 0.5:
96
+ # Get bounding box center as point prompt
97
+ box = detection['box']
98
+ center_x = box['xmin'] + (box['xmax'] - box['xmin']) // 2
99
+ center_y = box['ymin'] + (box['ymax'] - box['ymin']) // 2
100
 
101
+ prompts.append((
102
+ np.array([center_x, center_y]),
103
+ f"{label}: {confidence:.2f}"
104
+ ))
105
+
106
+ if not prompts:
107
+ return self._generate_grid_prompts(frame)
108
+
109
+ return prompts
110
+
111
+ except Exception as e:
112
+ print(f"Detection failed: {e}")
113
+ return self._generate_grid_prompts(frame)
114
 
115
+ def _generate_grid_prompts(self, frame: np.ndarray) -> List[Tuple[np.ndarray, str]]:
116
+ """Generate grid-based prompts for SAM3"""
117
+ h, w = frame.shape[:2]
118
+ prompts = []
119
+
120
+ # Generate grid points
121
+ grid_size = 4
122
+ for i in range(grid_size):
123
+ for j in range(grid_size):
124
+ x = (i + 0.5) * w / grid_size
125
+ y = (j + 0.5) * h / grid_size
126
+ prompts.append((np.array([x, y]), f"Grid point ({i},{j})"))
127
+
128
+ return prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ def segment_with_sam3(self, frame: np.ndarray, prompts: List[Tuple[np.ndarray, str]]) -> List[Dict]:
131
+ """Use SAM3 to segment objects based on prompts"""
132
+ if self.predictor is None:
133
+ return []
134
+
135
+ try:
136
+ # Set the image for SAM3
137
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
138
+ self.predictor.set_image(frame_rgb)
139
+
140
+ segments = []
141
+
142
+ for point, label in prompts:
143
+ # Get mask from SAM3
144
+ masks, scores, logits = self.predictor.predict(
145
+ point_coords=np.array([point]),
146
+ point_labels=np.array([1]), # 1 for positive point
147
+ multimask_output=True,
148
+ model_version="vit_h"
149
+ )
150
+
151
+ # Use the best mask
152
+ if len(masks) > 0:
153
+ best_mask_idx = np.argmax(scores)
154
+ best_mask = masks[best_mask_idx]
155
+ best_score = scores[best_mask_idx]
156
+
157
+ # Only keep high-quality masks
158
+ if best_score > 0.7:
159
+ # Get bounding box
160
+ y_indices, x_indices = np.where(best_mask)
161
+ if len(x_indices) > 0 and len(y_indices) > 0:
162
+ x_min, x_max = x_indices.min(), x_indices.max()
163
+ y_min, y_max = y_indices.min(), y_indices.max()
164
+
165
+ segments.append({
166
+ 'mask': best_mask,
167
+ 'bbox': (x_min, y_min, x_max, y_max),
168
+ 'confidence': best_score,
169
+ 'label': label,
170
+ 'center': (np.mean(x_indices), np.mean(y_indices))
171
+ })
172
+
173
+ return segments
174
+
175
+ except Exception as e:
176
+ print(f"SAM3 segmentation failed: {e}")
177
+ return []
178
+
179
+ def extract_object_from_mask(self, frame: np.ndarray, mask: np.ndarray) -> np.ndarray:
180
+ """Extract object using SAM3 mask"""
181
+ # Create a masked image
182
+ masked_frame = frame.copy()
183
+ mask_3d = np.stack([mask] * 3, axis=-1)
184
+
185
+ # Apply mask
186
+ result = np.zeros_like(frame)
187
+ result[mask_3d == 1] = masked_frame[mask_3d == 1]
188
+
189
+ # Crop to bounding box
190
+ y_indices, x_indices = np.where(mask)
191
+ if len(x_indices) > 0 and len(y_indices) > 0:
192
+ x_min, x_max = x_indices.min(), x_indices.max()
193
+ y_min, y_max = y_indices.min(), y_indices.max()
194
+ return result[y_min:y_max, x_min:x_max]
195
+
196
+ return result
197
 
198
+ def draw_segments(self, frame: np.ndarray, segments: List[Dict]) -> np.ndarray:
199
+ """Draw SAM3 segmentation results"""
200
  frame_copy = frame.copy()
201
 
202
+ for segment in segments:
203
+ mask = segment['mask']
204
+ bbox = segment['bbox']
205
+ confidence = segment['confidence']
206
+ label = segment['label']
207
+
208
+ # Draw mask overlay
209
+ mask_overlay = np.zeros_like(frame_copy)
210
+ mask_overlay[mask] = [0, 255, 0] # Green overlay
211
+ frame_copy = cv2.addWeighted(frame_copy, 0.7, mask_overlay, 0.3, 0)
212
 
213
  # Draw bounding box
214
+ x_min, y_min, x_max, y_max = bbox
215
+ color = (0, 255, 0) if confidence > 0.8 else (0, 165, 255)
216
+ cv2.rectangle(frame_copy, (x_min, y_min), (x_max, y_max), color, 2)
217
 
218
  # Draw label
219
+ label_text = f"SAM3: {confidence:.2f}"
220
+ label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
221
+ cv2.rectangle(frame_copy, (x_min, y_min - label_size[1] - 10),
222
+ (x_min + label_size[0], y_min), color, -1)
223
+ cv2.putText(frame_copy, label_text, (x_min, y_min - 5),
224
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
225
 
226
  return frame_copy
227
 
228
+ def process_video_with_sam3(video_file, target_class):
229
+ """Main processing function using SAM3"""
230
  if video_file is None or target_class is None:
231
  return None, None, None, "Please upload a video and select an object class."
232
 
233
  try:
234
+ # Initialize SAM3 extractor
235
+ extractor = SAM3ObjectExtractor()
236
+
237
+ if extractor.predictor is None:
238
+ return None, None, None, "❌ SAM3 model failed to load. Please check installation."
239
 
240
  # Create temporary directory
241
  temp_dir = tempfile.mkdtemp()
242
 
243
  # Extract frames
244
+ frames = extractor.extract_frames(video_file, max_frames=6)
245
  if not frames:
246
  return None, None, None, "Could not extract frames from video."
247
 
 
251
 
252
  # Process each frame
253
  for i, (frame, timestamp) in enumerate(frames):
254
+ print(f"Processing frame {i+1}/{len(frames)} at timestamp {timestamp:.2f}s")
255
+
256
+ # Generate prompts using object detection
257
+ prompts = extractor.generate_prompts_with_detection(frame, target_class)
258
 
259
+ # Use SAM3 for segmentation
260
+ segments = extractor.segment_with_sam3(frame, prompts)
 
261
 
262
+ # Draw SAM3 results on frame
263
+ frame_with_segments = extractor.draw_segments(frame, segments)
264
+ processed_frames.append(frame_with_segments)
265
+
266
+ # Extract individual objects using SAM3 masks
267
+ for j, segment in enumerate(segments):
268
+ obj_roi = extractor.extract_object_from_mask(frame, segment['mask'])
269
 
270
  # Save extracted object
271
+ obj_filename = f"sam3_object_{i}_{j}_{int(timestamp*1000)}.jpg"
272
  obj_path = os.path.join(temp_dir, obj_filename)
273
  cv2.imwrite(obj_path, obj_roi)
274
 
 
276
  obj_data = {
277
  'frame_index': i,
278
  'timestamp': timestamp,
279
+ 'class_name': target_class,
280
+ 'confidence': segment['confidence'],
281
+ 'bbox': segment['bbox'],
282
+ 'mask_area': np.sum(segment['mask']),
283
  'image_path': obj_path,
284
+ 'filename': obj_filename,
285
+ 'label': segment['label']
286
  }
287
  all_objects.append(obj_data)
288
  extracted_objects.append((obj_roi, obj_data))
 
290
  # Create results summary
291
  summary = {
292
  'total_objects': len(all_objects),
 
293
  'avg_confidence': np.mean([obj['confidence'] for obj in all_objects]) if all_objects else 0,
294
+ 'avg_mask_area': np.mean([obj['mask_area'] for obj in all_objects]) if all_objects else 0,
295
  'frames_processed': len(frames),
296
+ 'target_class': target_class,
297
+ 'model_used': 'SAM3 (Segment Anything Model 3)'
298
  }
299
 
300
+ # Create a result collage of SAM3 extractions
301
  if extracted_objects:
 
302
  grid_size = min(4, int(np.ceil(np.sqrt(len(extracted_objects)))))
303
+ collage = create_sam3_collage([obj[0] for obj in extracted_objects[:grid_size*grid_size]], grid_size)
304
  else:
305
  collage = None
306
 
307
+ # Save processed video frame with SAM3 results
308
  if processed_frames:
309
+ result_frame_path = os.path.join(temp_dir, "sam3_result_frame.jpg")
310
  cv2.imwrite(result_frame_path, processed_frames[0])
311
  result_frame = result_frame_path
312
  else:
313
  result_frame = None
314
 
315
+ status_message = f"βœ… SAM3 Processing complete! Found {summary['total_objects']} objects with avg confidence {summary['avg_confidence']:.2f}"
316
+
317
+ return result_frame, collage, all_objects, status_message
318
 
319
  except Exception as e:
320
+ return None, None, None, f"❌ SAM3 processing error: {str(e)}"
321
 
322
+ def create_sam3_collage(objects: List[np.ndarray], grid_size: int) -> np.ndarray:
323
+ """Create a collage of SAM3 extracted objects"""
324
  if not objects:
325
  return None
326
 
 
327
  target_size = (150, 150)
328
  resized_objects = []
329
 
330
  for obj in objects:
331
  if obj is not None and obj.size > 0:
332
  resized = cv2.resize(obj, target_size)
333
+ # Add SAM3 watermark/indicator
334
+ cv2.putText(resized, "SAM3", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
335
  resized_objects.append(resized)
336
 
337
  if not resized_objects:
338
  return None
339
 
 
340
  rows = min(grid_size, len(resized_objects))
341
  cols = grid_size
 
 
342
  padding = 10
343
  collage = np.ones((rows * target_size[1] + (rows + 1) * padding,
344
  cols * target_size[0] + (cols + 1) * padding, 3), dtype=np.uint8) * 255
 
350
  y_end = y_start + target_size[1]
351
  x_start = col * target_size[0] + (col + 1) * padding
352
  x_end = x_start + target_size[0]
 
353
  collage[y_start:y_end, x_start:x_end] = obj
354
 
355
  return collage
356
 
357
+ def create_sam3_download(objects: List[Dict]) -> str:
358
+ """Create a SAM3-branded download package"""
359
  if not objects:
360
  return None
361
 
362
  temp_dir = tempfile.mkdtemp()
363
+ zip_path = os.path.join(temp_dir, "sam3_extracted_objects.zip")
364
 
365
  with zipfile.ZipFile(zip_path, 'w') as zipf:
366
+ # Add SAM3 metadata
367
  metadata = {
368
+ 'model': 'SAM3 - Segment Anything Model 3',
369
  'extraction_time': datetime.now().isoformat(),
370
  'total_objects': len(objects),
371
+ 'objects': objects,
372
+ 'processing_method': 'SAM3_segmentation_with_detection_prompts'
373
  }
374
+ zipf.writestr("sam3_metadata.json", json.dumps(metadata, indent=2))
375
 
376
+ # Add SAM3 objects
377
  for obj in objects:
378
  if os.path.exists(obj['image_path']):
379
+ zipf.write(obj['image_path'], f"sam3_{obj['filename']}")
380
 
381
  return zip_path
382
 
383
  # Create Gradio interface
384
+ def create_sam3_interface():
385
  with gr.Blocks() as demo:
386
  gr.Markdown("""
387
  # 🎯 SAM3 Video Object Extractor
388
+ ### Advanced AI-powered object segmentation using Segment Anything Model 3
389
 
390
  [Built with anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
391
+
392
+ **Features:**
393
+ - 🧠 SAM3 (Segment Anything Model 3) for precise object segmentation
394
+ - πŸ” Automatic object detection for smart prompting
395
+ - πŸ“Ή Video frame extraction and processing
396
+ - 🎨 High-quality mask-based object extraction
397
  """)
398
 
399
  with gr.Row():
 
412
  ("πŸͺ‘ Furniture", "furniture"),
413
  ("🏒 Building Elements", "building")
414
  ],
415
+ label="Choose object category for SAM3 detection",
416
  value=None
417
  )
418
 
419
  process_btn = gr.Button(
420
+ "πŸš€ Process with SAM3",
421
  variant="primary",
422
  size="lg"
423
  )
424
 
425
  with gr.Column(scale=1):
426
+ gr.Markdown("### 🧠 SAM3 Status")
427
  status_output = gr.Textbox(
428
+ label="Processing Status",
429
  interactive=False,
430
+ placeholder="SAM3 ready for processing..."
431
  )
432
 
433
+ with gr.Accordion("πŸ”¬ SAM3 Technology", open=False):
434
  gr.Markdown("""
435
+ **SAM3 Processing Pipeline:**
436
+ 1. **Frame Extraction** - Sample key frames from video
437
+ 2. **Object Detection** - Generate smart prompts with DETR
438
+ 3. **SAM3 Segmentation** - Precise mask generation
439
+ 4. **Object Extraction** - Clean mask-based cropping
440
+ 5. **Quality Filtering** - High-confidence results only
441
 
442
+ **Models Used:**
443
+ - SAM3 (Segment Anything Model 3)
444
+ - DETR for automatic prompting
445
  """)
446
 
447
  with gr.Row():
448
  with gr.Column():
449
+ gr.Markdown("### πŸ–ΌοΈ SAM3 Detection Results")
450
  result_image = gr.Image(
451
+ label="Frame with SAM3 Segmentation",
452
  type="filepath"
453
  )
454
 
455
  with gr.Column():
456
+ gr.Markdown("### πŸ“¦ SAM3 Extracted Objects")
457
  collage_image = gr.Image(
458
+ label="SAM3 Object Collage",
459
  type="filepath"
460
  )
461
 
462
  with gr.Row():
463
+ gr.Markdown("### πŸ“‹ SAM3 Object Gallery")
464
  objects_gallery = gr.Gallery(
465
+ label="SAM3 Extracted Objects",
466
  show_label=True,
467
+ elem_id="sam3_objects_gallery",
468
  columns=4,
469
  rows=2,
470
  height="auto",
471
  allow_preview=True
472
  )
473
 
474
+ # Hidden components
475
  objects_data = gr.State()
476
 
 
477
  with gr.Row():
478
  download_btn = gr.Button(
479
+ "πŸ“₯ Download SAM3 Results (ZIP)",
480
  variant="secondary",
481
  visible=False
482
  )
483
  download_file = gr.File(
484
+ label="SAM3 Download Package",
485
  visible=False
486
  )
487
 
488
+ # Process function
489
+ def handle_sam3_process(video, class_type):
490
  if video is None:
491
  return None, None, None, "❌ Please upload a video file.", gr.update(visible=False), None
492
 
493
  if class_type is None:
494
+ return None, None, None, "❌ Please select an object class for SAM3.", gr.update(visible=False), None
495
 
496
+ # Process with SAM3
497
+ result_frame, collage, objects, status = process_video_with_sam3(video, class_type)
498
 
499
+ # Prepare gallery
500
  gallery_images = []
501
  if objects:
502
+ for obj in objects[:8]:
503
  if os.path.exists(obj['image_path']):
504
  gallery_images.append(obj['image_path'])
505
 
 
506
  download_visible = len(objects) > 0
507
 
508
+ return result_frame, collage, objects, status, gr.update(visible=download_visible), gallery_images
509
 
510
+ # Download function
511
+ def handle_sam3_download(objects):
512
  if objects:
513
+ zip_path = create_sam3_download(objects)
514
  return zip_path
515
  return None
516
 
517
  # Wire up events
518
  process_btn.click(
519
+ fn=handle_sam3_process,
520
  inputs=[video_input, class_selector],
521
  outputs=[result_image, collage_image, objects_data, status_output, download_btn, objects_gallery]
522
  )
523
 
524
  download_btn.click(
525
+ fn=handle_sam3_download,
526
  inputs=[objects_data],
527
  outputs=[download_file]
528
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
  return demo
531
 
532
  # Launch the application
533
  if __name__ == "__main__":
534
+ demo = create_sam3_interface()
535
  demo.launch(
536
  theme=gr.themes.Soft(
537
+ primary_hue="green",
538
+ secondary_hue="blue",
539
  neutral_hue="slate",
540
  font=gr.themes.GoogleFont("Inter"),
541
  text_size="lg",