Spaces:
Running
Running
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import tempfile | |
| import os | |
| import json | |
| import zipfile | |
| import torch | |
| from segment_anything import sam_model_registry, SamPredictor | |
| from transformers import pipeline | |
| import supervision as sv | |
| from datetime import datetime | |
| import time | |
| from typing import List, Tuple, Dict, Optional | |
| class SAM3ObjectExtractor: | |
| def __init__(self, model_type="vit_h", checkpoint_path="sam_vit_h_4b8939.pth"): | |
| """Initialize SAM3 model""" | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {self.device}") | |
| # Load SAM model | |
| try: | |
| sam = sam_model_registry[model_type](checkpoint=checkpoint_path) | |
| sam.to(device=self.device) | |
| self.predictor = SamPredictor(sam) | |
| print("SAM3 model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading SAM3 model: {e}") | |
| self.predictor = None | |
| # Load object detection model for automatic prompts | |
| try: | |
| self.detector = pipeline( | |
| "object-detection", | |
| model="facebook/detr-resnet-50", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| print("Object detection model loaded!") | |
| except Exception as e: | |
| print(f"Error loading detection model: {e}") | |
| self.detector = None | |
| def extract_frames(self, video_path: str, max_frames: int = 10) -> List[Tuple[np.ndarray, float]]: | |
| """Extract frames from video""" | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if total_frames <= max_frames: | |
| frame_indices = list(range(total_frames)) | |
| else: | |
| frame_indices = np.linspace(0, total_frames - 1, max_frames, dtype=int) | |
| for frame_idx in frame_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| timestamp = frame_idx / fps | |
| frames.append((frame, timestamp)) | |
| cap.release() | |
| return frames | |
| def generate_prompts_with_detection(self, frame: np.ndarray, category: str) -> List[Tuple[np.ndarray, str]]: | |
| """Generate prompts using object detection for SAM3""" | |
| if self.detector is None: | |
| return self._generate_grid_prompts(frame) | |
| try: | |
| # Convert frame to RGB for detection | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_image = Image.fromarray(frame_rgb) | |
| # Run object detection | |
| detections = self.detector(pil_image) | |
| prompts = [] | |
| # Filter detections by category | |
| category_keywords = { | |
| 'home-objects': ['cup', 'bottle', 'bowl', 'vase', 'book', 'phone', 'laptop'], | |
| 'furniture': ['chair', 'table', 'sofa', 'bed', 'desk', 'cabinet'], | |
| 'building': ['door', 'window', 'wall', 'column', 'stairs', 'ceiling'] | |
| } | |
| keywords = category_keywords.get(category, []) | |
| for detection in detections: | |
| label = detection['label'].lower() | |
| confidence = detection['score'] | |
| # Check if detection matches our category | |
| if any(keyword in label for keyword in keywords) and confidence > 0.5: | |
| # Get bounding box center as point prompt | |
| box = detection['box'] | |
| center_x = box['xmin'] + (box['xmax'] - box['xmin']) // 2 | |
| center_y = box['ymin'] + (box['ymax'] - box['ymin']) // 2 | |
| prompts.append(( | |
| np.array([center_x, center_y]), | |
| f"{label}: {confidence:.2f}" | |
| )) | |
| if not prompts: | |
| return self._generate_grid_prompts(frame) | |
| return prompts | |
| except Exception as e: | |
| print(f"Detection failed: {e}") | |
| return self._generate_grid_prompts(frame) | |
| def _generate_grid_prompts(self, frame: np.ndarray) -> List[Tuple[np.ndarray, str]]: | |
| """Generate grid-based prompts for SAM3""" | |
| h, w = frame.shape[:2] | |
| prompts = [] | |
| # Generate grid points | |
| grid_size = 4 | |
| for i in range(grid_size): | |
| for j in range(grid_size): | |
| x = (i + 0.5) * w / grid_size | |
| y = (j + 0.5) * h / grid_size | |
| prompts.append((np.array([x, y]), f"Grid point ({i},{j})")) | |
| return prompts | |
| def segment_with_sam3(self, frame: np.ndarray, prompts: List[Tuple[np.ndarray, str]]) -> List[Dict]: | |
| """Use SAM3 to segment objects based on prompts""" | |
| if self.predictor is None: | |
| return [] | |
| try: | |
| # Set the image for SAM3 | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| self.predictor.set_image(frame_rgb) | |
| segments = [] | |
| for point, label in prompts: | |
| # Get mask from SAM3 | |
| masks, scores, logits = self.predictor.predict( | |
| point_coords=np.array([point]), | |
| point_labels=np.array([1]), # 1 for positive point | |
| multimask_output=True, | |
| model_version="vit_h" | |
| ) | |
| # Use the best mask | |
| if len(masks) > 0: | |
| best_mask_idx = np.argmax(scores) | |
| best_mask = masks[best_mask_idx] | |
| best_score = scores[best_mask_idx] | |
| # Only keep high-quality masks | |
| if best_score > 0.7: | |
| # Get bounding box | |
| y_indices, x_indices = np.where(best_mask) | |
| if len(x_indices) > 0 and len(y_indices) > 0: | |
| x_min, x_max = x_indices.min(), x_indices.max() | |
| y_min, y_max = y_indices.min(), y_indices.max() | |
| segments.append({ | |
| 'mask': best_mask, | |
| 'bbox': (x_min, y_min, x_max, y_max), | |
| 'confidence': best_score, | |
| 'label': label, | |
| 'center': (np.mean(x_indices), np.mean(y_indices)) | |
| }) | |
| return segments | |
| except Exception as e: | |
| print(f"SAM3 segmentation failed: {e}") | |
| return [] | |
| def extract_object_from_mask(self, frame: np.ndarray, mask: np.ndarray) -> np.ndarray: | |
| """Extract object using SAM3 mask""" | |
| # Create a masked image | |
| masked_frame = frame.copy() | |
| mask_3d = np.stack([mask] * 3, axis=-1) | |
| # Apply mask | |
| result = np.zeros_like(frame) | |
| result[mask_3d == 1] = masked_frame[mask_3d == 1] | |
| # Crop to bounding box | |
| y_indices, x_indices = np.where(mask) | |
| if len(x_indices) > 0 and len(y_indices) > 0: | |
| x_min, x_max = x_indices.min(), x_indices.max() | |
| y_min, y_max = y_indices.min(), y_indices.max() | |
| return result[y_min:y_max, x_min:x_max] | |
| return result | |
| def draw_segments(self, frame: np.ndarray, segments: List[Dict]) -> np.ndarray: | |
| """Draw SAM3 segmentation results""" | |
| frame_copy = frame.copy() | |
| for segment in segments: | |
| mask = segment['mask'] | |
| bbox = segment['bbox'] | |
| confidence = segment['confidence'] | |
| label = segment['label'] | |
| # Draw mask overlay | |
| mask_overlay = np.zeros_like(frame_copy) | |
| mask_overlay[mask] = [0, 255, 0] # Green overlay | |
| frame_copy = cv2.addWeighted(frame_copy, 0.7, mask_overlay, 0.3, 0) | |
| # Draw bounding box | |
| x_min, y_min, x_max, y_max = bbox | |
| color = (0, 255, 0) if confidence > 0.8 else (0, 165, 255) | |
| cv2.rectangle(frame_copy, (x_min, y_min), (x_max, y_max), color, 2) | |
| # Draw label | |
| label_text = f"SAM3: {confidence:.2f}" | |
| label_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0] | |
| cv2.rectangle(frame_copy, (x_min, y_min - label_size[1] - 10), | |
| (x_min + label_size[0], y_min), color, -1) | |
| cv2.putText(frame_copy, label_text, (x_min, y_min - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2) | |
| return frame_copy | |
| def process_video_with_sam3(video_file, target_class): | |
| """Main processing function using SAM3""" | |
| if video_file is None or target_class is None: | |
| return None, None, None, "Please upload a video and select an object class." | |
| try: | |
| # Initialize SAM3 extractor | |
| extractor = SAM3ObjectExtractor() | |
| if extractor.predictor is None: | |
| return None, None, None, "β SAM3 model failed to load. Please check installation." | |
| # Create temporary directory | |
| temp_dir = tempfile.mkdtemp() | |
| # Extract frames | |
| frames = extractor.extract_frames(video_file, max_frames=6) | |
| if not frames: | |
| return None, None, None, "Could not extract frames from video." | |
| all_objects = [] | |
| processed_frames = [] | |
| extracted_objects = [] | |
| # Process each frame | |
| for i, (frame, timestamp) in enumerate(frames): | |
| print(f"Processing frame {i+1}/{len(frames)} at timestamp {timestamp:.2f}s") | |
| # Generate prompts using object detection | |
| prompts = extractor.generate_prompts_with_detection(frame, target_class) | |
| # Use SAM3 for segmentation | |
| segments = extractor.segment_with_sam3(frame, prompts) | |
| # Draw SAM3 results on frame | |
| frame_with_segments = extractor.draw_segments(frame, segments) | |
| processed_frames.append(frame_with_segments) | |
| # Extract individual objects using SAM3 masks | |
| for j, segment in enumerate(segments): | |
| obj_roi = extractor.extract_object_from_mask(frame, segment['mask']) | |
| # Save extracted object | |
| obj_filename = f"sam3_object_{i}_{j}_{int(timestamp*1000)}.jpg" | |
| obj_path = os.path.join(temp_dir, obj_filename) | |
| cv2.imwrite(obj_path, obj_roi) | |
| # Add to results | |
| obj_data = { | |
| 'frame_index': i, | |
| 'timestamp': timestamp, | |
| 'class_name': target_class, | |
| 'confidence': segment['confidence'], | |
| 'bbox': segment['bbox'], | |
| 'mask_area': np.sum(segment['mask']), | |
| 'image_path': obj_path, | |
| 'filename': obj_filename, | |
| 'label': segment['label'] | |
| } | |
| all_objects.append(obj_data) | |
| extracted_objects.append((obj_roi, obj_data)) | |
| # Create results summary | |
| summary = { | |
| 'total_objects': len(all_objects), | |
| 'avg_confidence': np.mean([obj['confidence'] for obj in all_objects]) if all_objects else 0, | |
| 'avg_mask_area': np.mean([obj['mask_area'] for obj in all_objects]) if all_objects else 0, | |
| 'frames_processed': len(frames), | |
| 'target_class': target_class, | |
| 'model_used': 'SAM3 (Segment Anything Model 3)' | |
| } | |
| # Create a result collage of SAM3 extractions | |
| if extracted_objects: | |
| grid_size = min(4, int(np.ceil(np.sqrt(len(extracted_objects))))) | |
| collage = create_sam3_collage([obj[0] for obj in extracted_objects[:grid_size*grid_size]], grid_size) | |
| else: | |
| collage = None | |
| # Save processed video frame with SAM3 results | |
| if processed_frames: | |
| result_frame_path = os.path.join(temp_dir, "sam3_result_frame.jpg") | |
| cv2.imwrite(result_frame_path, processed_frames[0]) | |
| result_frame = result_frame_path | |
| else: | |
| result_frame = None | |
| status_message = f"β SAM3 Processing complete! Found {summary['total_objects']} objects with avg confidence {summary['avg_confidence']:.2f}" | |
| return result_frame, collage, all_objects, status_message | |
| except Exception as e: | |
| return None, None, None, f"β SAM3 processing error: {str(e)}" | |
| def create_sam3_collage(objects: List[np.ndarray], grid_size: int) -> np.ndarray: | |
| """Create a collage of SAM3 extracted objects""" | |
| if not objects: | |
| return None | |
| target_size = (150, 150) | |
| resized_objects = [] | |
| for obj in objects: | |
| if obj is not None and obj.size > 0: | |
| resized = cv2.resize(obj, target_size) | |
| # Add SAM3 watermark/indicator | |
| cv2.putText(resized, "SAM3", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
| resized_objects.append(resized) | |
| if not resized_objects: | |
| return None | |
| rows = min(grid_size, len(resized_objects)) | |
| cols = grid_size | |
| padding = 10 | |
| collage = np.ones((rows * target_size[1] + (rows + 1) * padding, | |
| cols * target_size[0] + (cols + 1) * padding, 3), dtype=np.uint8) * 255 | |
| for i, obj in enumerate(resized_objects[:rows * cols]): | |
| row = i // cols | |
| col = i % cols | |
| y_start = row * target_size[1] + (row + 1) * padding | |
| y_end = y_start + target_size[1] | |
| x_start = col * target_size[0] + (col + 1) * padding | |
| x_end = x_start + target_size[0] | |
| collage[y_start:y_end, x_start:x_end] = obj | |
| return collage | |
| def create_sam3_download(objects: List[Dict]) -> str: | |
| """Create a SAM3-branded download package""" | |
| if not objects: | |
| return None | |
| temp_dir = tempfile.mkdtemp() | |
| zip_path = os.path.join(temp_dir, "sam3_extracted_objects.zip") | |
| with zipfile.ZipFile(zip_path, 'w') as zipf: | |
| # Add SAM3 metadata | |
| metadata = { | |
| 'model': 'SAM3 - Segment Anything Model 3', | |
| 'extraction_time': datetime.now().isoformat(), | |
| 'total_objects': len(objects), | |
| 'objects': objects, | |
| 'processing_method': 'SAM3_segmentation_with_detection_prompts' | |
| } | |
| zipf.writestr("sam3_metadata.json", json.dumps(metadata, indent=2)) | |
| # Add SAM3 objects | |
| for obj in objects: | |
| if os.path.exists(obj['image_path']): | |
| zipf.write(obj['image_path'], f"sam3_{obj['filename']}") | |
| return zip_path | |
| # Create Gradio interface | |
| def create_sam3_interface(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # π― SAM3 Video Object Extractor | |
| ### Advanced AI-powered object segmentation using Segment Anything Model 3 | |
| [Built with anycoder](https://huggingface.co/spaces/akhaliq/anycoder) | |
| **Features:** | |
| - π§ SAM3 (Segment Anything Model 3) for precise object segmentation | |
| - π Automatic object detection for smart prompting | |
| - πΉ Video frame extraction and processing | |
| - π¨ High-quality mask-based object extraction | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### πΉ Upload Video") | |
| video_input = gr.Video( | |
| label="Select Video File", | |
| sources=["upload"], | |
| type="filepath" | |
| ) | |
| gr.Markdown("### π·οΈ Select Object Class") | |
| class_selector = gr.Radio( | |
| choices=[ | |
| ("π Home Objects", "home-objects"), | |
| ("πͺ Furniture", "furniture"), | |
| ("π’ Building Elements", "building") | |
| ], | |
| label="Choose object category for SAM3 detection", | |
| value=None | |
| ) | |
| process_btn = gr.Button( | |
| "π Process with SAM3", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π§ SAM3 Status") | |
| status_output = gr.Textbox( | |
| label="Processing Status", | |
| interactive=False, | |
| placeholder="SAM3 ready for processing..." | |
| ) | |
| with gr.Accordion("π¬ SAM3 Technology", open=False): | |
| gr.Markdown(""" | |
| **SAM3 Processing Pipeline:** | |
| 1. **Frame Extraction** - Sample key frames from video | |
| 2. **Object Detection** - Generate smart prompts with DETR | |
| 3. **SAM3 Segmentation** - Precise mask generation | |
| 4. **Object Extraction** - Clean mask-based cropping | |
| 5. **Quality Filtering** - High-confidence results only | |
| **Models Used:** | |
| - SAM3 (Segment Anything Model 3) | |
| - DETR for automatic prompting | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### πΌοΈ SAM3 Detection Results") | |
| result_image = gr.Image( | |
| label="Frame with SAM3 Segmentation", | |
| type="filepath" | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### π¦ SAM3 Extracted Objects") | |
| collage_image = gr.Image( | |
| label="SAM3 Object Collage", | |
| type="filepath" | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("### π SAM3 Object Gallery") | |
| objects_gallery = gr.Gallery( | |
| label="SAM3 Extracted Objects", | |
| show_label=True, | |
| elem_id="sam3_objects_gallery", | |
| columns=4, | |
| rows=2, | |
| height="auto", | |
| allow_preview=True | |
| ) | |
| # Hidden components | |
| objects_data = gr.State() | |
| with gr.Row(): | |
| download_btn = gr.Button( | |
| "π₯ Download SAM3 Results (ZIP)", | |
| variant="secondary", | |
| visible=False | |
| ) | |
| download_file = gr.File( | |
| label="SAM3 Download Package", | |
| visible=False | |
| ) | |
| # Process function | |
| def handle_sam3_process(video, class_type): | |
| if video is None: | |
| return None, None, None, "β Please upload a video file.", gr.update(visible=False), None | |
| if class_type is None: | |
| return None, None, None, "β Please select an object class for SAM3.", gr.update(visible=False), None | |
| # Process with SAM3 | |
| result_frame, collage, objects, status = process_video_with_sam3(video, class_type) | |
| # Prepare gallery | |
| gallery_images = [] | |
| if objects: | |
| for obj in objects[:8]: | |
| if os.path.exists(obj['image_path']): | |
| gallery_images.append(obj['image_path']) | |
| download_visible = len(objects) > 0 | |
| return result_frame, collage, objects, status, gr.update(visible=download_visible), gallery_images | |
| # Download function | |
| def handle_sam3_download(objects): | |
| if objects: | |
| zip_path = create_sam3_download(objects) | |
| return zip_path | |
| return None | |
| # Wire up events | |
| process_btn.click( | |
| fn=handle_sam3_process, | |
| inputs=[video_input, class_selector], | |
| outputs=[result_image, collage_image, objects_data, status_output, download_btn, objects_gallery] | |
| ) | |
| download_btn.click( | |
| fn=handle_sam3_download, | |
| inputs=[objects_data], | |
| outputs=[download_file] | |
| ) | |
| return demo | |
| # Launch the application | |
| if __name__ == "__main__": | |
| demo = create_sam3_interface() | |
| demo.launch( | |
| theme=gr.themes.Soft( | |
| primary_hue="green", | |
| secondary_hue="blue", | |
| neutral_hue="slate", | |
| font=gr.themes.GoogleFont("Inter"), | |
| text_size="lg", | |
| spacing_size="lg", | |
| radius_size="md" | |
| ).set( | |
| button_primary_background_fill="*primary_600", | |
| button_primary_background_fill_hover="*primary_700", | |
| block_title_text_weight="600", | |
| ), | |
| footer_links=[ | |
| {"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"} | |
| ] | |
| ) |