import gradio as gr import cv2 import numpy as np from PIL import Image, ImageDraw import tempfile import os import json import zipfile import torch from segment_anything import sam_model_registry, SamPredictor from transformers import pipeline import supervision as sv from datetime import datetime import time from typing import List, Tuple, Dict, Optional class SAM3ObjectExtractor: def __init__(self, model_type="vit_h", checkpoint_path="sam_vit_h_4b8939.pth"): """Initialize SAM3 model""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") # Load SAM model try: sam = sam_model_registry[model_type](checkpoint=checkpoint_path) sam.to(device=self.device) self.predictor = SamPredictor(sam) print("SAM3 model loaded successfully!") except Exception as e: print(f"Error loading SAM3 model: {e}") self.predictor = None # Load object detection model for automatic prompts try: self.detector = pipeline( "object-detection", model="facebook/detr-resnet-50", device=0 if torch.cuda.is_available() else -1 ) print("Object detection model loaded!") except Exception as e: print(f"Error loading detection model: {e}") self.detector = None def extract_frames(self, video_path: str, max_frames: int = 10) -> List[Tuple[np.ndarray, float]]: """Extract frames from video""" cap = cv2.VideoCapture(video_path) frames = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) if total_frames <= max_frames: frame_indices = list(range(total_frames)) else: frame_indices = np.linspace(0, total_frames - 1, max_frames, dtype=int) for frame_idx in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() if ret: timestamp = frame_idx / fps frames.append((frame, timestamp)) cap.release() return frames def generate_prompts_with_detection(self, frame: np.ndarray, category: str) -> List[Tuple[np.ndarray, str]]: """Generate prompts using object detection for SAM3""" if self.detector is None: return self._generate_grid_prompts(frame) try: # Convert frame to RGB for detection frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(frame_rgb) # Run object detection detections = self.detector(pil_image) prompts = [] # Filter detections by category category_keywords = { 'home-objects': ['cup', 'bottle', 'bowl', 'vase', 'book', 'phone', 'laptop'], 'furniture': ['chair', 'table', 'sofa', 'bed', 'desk', 'cabinet'], 'building': ['door', 'window', 'wall', 'column', 'stairs', 'ceiling'] } keywords = category_keywords.get(category, []) for detection in detections: label = detection['label'].lower() confidence = detection['score'] # Check if detection matches our category if any(keyword in label for keyword in keywords) and confidence > 0.5: # Get bounding box center as point prompt box = detection['box'] center_x = box['xmin'] + (box['xmax'] - box['xmin']) // 2 center_y = box['ymin'] + (box['ymax'] - box['ymin']) // 2 prompts.append(( np.array([center_x, center_y]), f"{label}: {confidence:.2f}" )) if not prompts: return self._generate_grid_prompts(frame) return prompts except Exception as e: print(f"Detection failed: {e}") return self._generate_grid_prompts(frame) def _generate_grid_prompts(self, frame: np.ndarray) -> List[Tuple[np.ndarray, str]]: """Generate grid-based prompts for SAM3""" h, w = frame.shape[:2] prompts = [] # Generate grid points grid_size = 4 for i in range(grid_size): for j in range(grid_size): x = (i + 0.5) * w / grid_size y = (j + 0.5) * h / grid_size prompts.append((np.array([x, y]), f"Grid point ({i},{j})")) return prompts def segment_with_sam3(self, frame: np.ndarray, prompts: List[Tuple[np.ndarray, str]]) -> List[Dict]: """Use SAM3 to segment objects based on prompts""" if self.predictor is None: return [] try: # Set the image for SAM3 frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) self.predictor.set_image(frame_rgb) segments = [] for point, label in prompts: # Get mask from SAM3 masks, scores, logits = self.predictor.predict( point_coords=np.array([point]), point_labels=np.array([1]), # 1 for positive point multimask_output=True, model_version="vit_h" ) # Use the best mask if len(masks) > 0: best_mask_idx = np.argmax(scores) best_mask = masks[best_mask_idx] best_score = scores[best_mask_idx] # Only keep high-quality masks if best_score > 0.7: # Get bounding box y_indices, x_indices = np.where(best_mask) if len(x_indices) > 0 and len(y_indices) > 0: x_min, x_max = x_indices.min(), x_indices.max() y_min, y_max = y_indices.min(), y_indices.max() segments.append({ 'mask': best_mask, 'bbox': (x_min, y_min, x_max, y_max), 'confidence': best_score, 'label': label, 'center': (np.mean(x_indices), np.mean(y_indices)) }) return segments except Exception as e: print(f"SAM3 segmentation failed: {e}") return [] def extract_object_from_mask(self, frame: np.ndarray, mask: np.ndarray) -> np.ndarray: """Extract object using SAM3 mask""" # Create a masked image masked_frame = frame.copy() mask_3d = np.stack([mask] * 3, axis=-1) # Apply mask result = np.zeros_like(frame) result[mask_3d == 1] = masked_frame[mask_3d == 1] # Crop to bounding box y_indices, x_indices = np.where(mask) if len(x_indices) > 0 and len(y_indices) > 0: x_min, x_max = x_indices.min(), x_indices.max() y_min, y_max = y_indices.min(), y_indices.max() return result[y_min:y_max, x_min:x_max] return result def draw_segments(self, frame: np.ndarray, segments: List[Dict]) -> np.ndarray: """Draw SAM3 segmentation results""" frame_copy = frame.copy() for segment in segments: mask = segment['mask'] bbox = segment['bbox'] confidence = segment['confidence'] label = segment['label'] # Draw mask overlay mask_overlay = np.zeros_like(frame_copy) mask_overlay[mask] = [0, 255, 0] # Green overlay frame_copy = cv2.addWeighted(frame_copy, 0.7, mask_overlay, 0.3, 0) # Draw bounding box x_min, y_min, x_max, y_max = bbox color = (0, 255, 0) if confidence > 0.8 else (0, 165, 255) cv2.rectangle(frame_copy, (x_min, y_min), (x_max, y_max), color, 2) # Draw label label_text = f"SAM3: {confidence:.2f}" label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] cv2.rectangle(frame_copy, (x_min, y_min - label_size[1] - 10), (x_min + label_size[0], y_min), color, -1) cv2.putText(frame_copy, label_text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) return frame_copy def process_video_with_sam3(video_file, target_class): """Main processing function using SAM3""" if video_file is None or target_class is None: return None, None, None, "Please upload a video and select an object class." try: # Initialize SAM3 extractor extractor = SAM3ObjectExtractor() if extractor.predictor is None: return None, None, None, "❌ SAM3 model failed to load. Please check installation." # Create temporary directory temp_dir = tempfile.mkdtemp() # Extract frames frames = extractor.extract_frames(video_file, max_frames=6) if not frames: return None, None, None, "Could not extract frames from video." all_objects = [] processed_frames = [] extracted_objects = [] # Process each frame for i, (frame, timestamp) in enumerate(frames): print(f"Processing frame {i+1}/{len(frames)} at timestamp {timestamp:.2f}s") # Generate prompts using object detection prompts = extractor.generate_prompts_with_detection(frame, target_class) # Use SAM3 for segmentation segments = extractor.segment_with_sam3(frame, prompts) # Draw SAM3 results on frame frame_with_segments = extractor.draw_segments(frame, segments) processed_frames.append(frame_with_segments) # Extract individual objects using SAM3 masks for j, segment in enumerate(segments): obj_roi = extractor.extract_object_from_mask(frame, segment['mask']) # Save extracted object obj_filename = f"sam3_object_{i}_{j}_{int(timestamp*1000)}.jpg" obj_path = os.path.join(temp_dir, obj_filename) cv2.imwrite(obj_path, obj_roi) # Add to results obj_data = { 'frame_index': i, 'timestamp': timestamp, 'class_name': target_class, 'confidence': segment['confidence'], 'bbox': segment['bbox'], 'mask_area': np.sum(segment['mask']), 'image_path': obj_path, 'filename': obj_filename, 'label': segment['label'] } all_objects.append(obj_data) extracted_objects.append((obj_roi, obj_data)) # Create results summary summary = { 'total_objects': len(all_objects), 'avg_confidence': np.mean([obj['confidence'] for obj in all_objects]) if all_objects else 0, 'avg_mask_area': np.mean([obj['mask_area'] for obj in all_objects]) if all_objects else 0, 'frames_processed': len(frames), 'target_class': target_class, 'model_used': 'SAM3 (Segment Anything Model 3)' } # Create a result collage of SAM3 extractions if extracted_objects: grid_size = min(4, int(np.ceil(np.sqrt(len(extracted_objects))))) collage = create_sam3_collage([obj[0] for obj in extracted_objects[:grid_size*grid_size]], grid_size) else: collage = None # Save processed video frame with SAM3 results if processed_frames: result_frame_path = os.path.join(temp_dir, "sam3_result_frame.jpg") cv2.imwrite(result_frame_path, processed_frames[0]) result_frame = result_frame_path else: result_frame = None status_message = f"✅ SAM3 Processing complete! Found {summary['total_objects']} objects with avg confidence {summary['avg_confidence']:.2f}" return result_frame, collage, all_objects, status_message except Exception as e: return None, None, None, f"❌ SAM3 processing error: {str(e)}" def create_sam3_collage(objects: List[np.ndarray], grid_size: int) -> np.ndarray: """Create a collage of SAM3 extracted objects""" if not objects: return None target_size = (150, 150) resized_objects = [] for obj in objects: if obj is not None and obj.size > 0: resized = cv2.resize(obj, target_size) # Add SAM3 watermark/indicator cv2.putText(resized, "SAM3", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) resized_objects.append(resized) if not resized_objects: return None rows = min(grid_size, len(resized_objects)) cols = grid_size padding = 10 collage = np.ones((rows * target_size[1] + (rows + 1) * padding, cols * target_size[0] + (cols + 1) * padding, 3), dtype=np.uint8) * 255 for i, obj in enumerate(resized_objects[:rows * cols]): row = i // cols col = i % cols y_start = row * target_size[1] + (row + 1) * padding y_end = y_start + target_size[1] x_start = col * target_size[0] + (col + 1) * padding x_end = x_start + target_size[0] collage[y_start:y_end, x_start:x_end] = obj return collage def create_sam3_download(objects: List[Dict]) -> str: """Create a SAM3-branded download package""" if not objects: return None temp_dir = tempfile.mkdtemp() zip_path = os.path.join(temp_dir, "sam3_extracted_objects.zip") with zipfile.ZipFile(zip_path, 'w') as zipf: # Add SAM3 metadata metadata = { 'model': 'SAM3 - Segment Anything Model 3', 'extraction_time': datetime.now().isoformat(), 'total_objects': len(objects), 'objects': objects, 'processing_method': 'SAM3_segmentation_with_detection_prompts' } zipf.writestr("sam3_metadata.json", json.dumps(metadata, indent=2)) # Add SAM3 objects for obj in objects: if os.path.exists(obj['image_path']): zipf.write(obj['image_path'], f"sam3_{obj['filename']}") return zip_path # Create Gradio interface def create_sam3_interface(): with gr.Blocks() as demo: gr.Markdown(""" # 🎯 SAM3 Video Object Extractor ### Advanced AI-powered object segmentation using Segment Anything Model 3 [Built with anycoder](https://huggingface.co/spaces/akhaliq/anycoder) **Features:** - 🧠 SAM3 (Segment Anything Model 3) for precise object segmentation - 🔍 Automatic object detection for smart prompting - 📹 Video frame extraction and processing - 🎨 High-quality mask-based object extraction """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📹 Upload Video") video_input = gr.Video( label="Select Video File", sources=["upload"], type="filepath" ) gr.Markdown("### 🏷️ Select Object Class") class_selector = gr.Radio( choices=[ ("🏠 Home Objects", "home-objects"), ("🪑 Furniture", "furniture"), ("🏢 Building Elements", "building") ], label="Choose object category for SAM3 detection", value=None ) process_btn = gr.Button( "🚀 Process with SAM3", variant="primary", size="lg" ) with gr.Column(scale=1): gr.Markdown("### 🧠 SAM3 Status") status_output = gr.Textbox( label="Processing Status", interactive=False, placeholder="SAM3 ready for processing..." ) with gr.Accordion("🔬 SAM3 Technology", open=False): gr.Markdown(""" **SAM3 Processing Pipeline:** 1. **Frame Extraction** - Sample key frames from video 2. **Object Detection** - Generate smart prompts with DETR 3. **SAM3 Segmentation** - Precise mask generation 4. **Object Extraction** - Clean mask-based cropping 5. **Quality Filtering** - High-confidence results only **Models Used:** - SAM3 (Segment Anything Model 3) - DETR for automatic prompting """) with gr.Row(): with gr.Column(): gr.Markdown("### 🖼️ SAM3 Detection Results") result_image = gr.Image( label="Frame with SAM3 Segmentation", type="filepath" ) with gr.Column(): gr.Markdown("### 📦 SAM3 Extracted Objects") collage_image = gr.Image( label="SAM3 Object Collage", type="filepath" ) with gr.Row(): gr.Markdown("### 📋 SAM3 Object Gallery") objects_gallery = gr.Gallery( label="SAM3 Extracted Objects", show_label=True, elem_id="sam3_objects_gallery", columns=4, rows=2, height="auto", allow_preview=True ) # Hidden components objects_data = gr.State() with gr.Row(): download_btn = gr.Button( "📥 Download SAM3 Results (ZIP)", variant="secondary", visible=False ) download_file = gr.File( label="SAM3 Download Package", visible=False ) # Process function def handle_sam3_process(video, class_type): if video is None: return None, None, None, "❌ Please upload a video file.", gr.update(visible=False), None if class_type is None: return None, None, None, "❌ Please select an object class for SAM3.", gr.update(visible=False), None # Process with SAM3 result_frame, collage, objects, status = process_video_with_sam3(video, class_type) # Prepare gallery gallery_images = [] if objects: for obj in objects[:8]: if os.path.exists(obj['image_path']): gallery_images.append(obj['image_path']) download_visible = len(objects) > 0 return result_frame, collage, objects, status, gr.update(visible=download_visible), gallery_images # Download function def handle_sam3_download(objects): if objects: zip_path = create_sam3_download(objects) return zip_path return None # Wire up events process_btn.click( fn=handle_sam3_process, inputs=[video_input, class_selector], outputs=[result_image, collage_image, objects_data, status_output, download_btn, objects_gallery] ) download_btn.click( fn=handle_sam3_download, inputs=[objects_data], outputs=[download_file] ) return demo # Launch the application if __name__ == "__main__": demo = create_sam3_interface() demo.launch( theme=gr.themes.Soft( primary_hue="green", secondary_hue="blue", neutral_hue="slate", font=gr.themes.GoogleFont("Inter"), text_size="lg", spacing_size="lg", radius_size="md" ).set( button_primary_background_fill="*primary_600", button_primary_background_fill_hover="*primary_700", block_title_text_weight="600", ), footer_links=[ {"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"} ] )