Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from ultralytics import SAM, YOLOWorld | |
| import os | |
| # Initialize models with proper error handling and auto-download | |
| def initialize_models(): | |
| """Initialize models with proper error handling.""" | |
| try: | |
| sam_model = SAM("mobile_sam.pt") # This auto-downloads | |
| print("✅ SAM model loaded successfully") | |
| except Exception as e: | |
| print(f"❌ Error loading SAM model: {e}") | |
| raise | |
| try: | |
| # Try different YOLO-World model names that auto-download | |
| yolo_model = YOLOWorld("yolov8s-world.pt") # Small world model (auto-downloads) | |
| print("✅ YOLO-World model loaded successfully") | |
| return sam_model, yolo_model | |
| except Exception as e: | |
| print(f"❌ Error loading YOLO-World model: {e}") | |
| try: | |
| # Fallback to regular YOLO if YOLO-World fails | |
| from ultralytics import YOLO | |
| yolo_model = YOLO("yolov8n.pt") # Regular YOLO nano model | |
| print("⚠️ Using regular YOLO model as fallback") | |
| return sam_model, yolo_model | |
| except Exception as e2: | |
| print(f"❌ Fallback YOLO model also failed: {e2}") | |
| raise | |
| sam_model, yolo_model = initialize_models() | |
| def detect_motorcycles(first_frame, prompt="motorcycle"): | |
| """Detect motorcycles in the first frame using YOLO-World and return bounding boxes.""" | |
| try: | |
| # Check if it's YOLO-World model | |
| if hasattr(yolo_model, 'set_classes'): | |
| yolo_model.set_classes([prompt]) | |
| results = yolo_model.predict(first_frame, device="cpu", max_det=2, imgsz=320, verbose=False) | |
| else: | |
| # Regular YOLO model - can't set custom classes, will detect all objects | |
| results = yolo_model.predict(first_frame, device="cpu", max_det=5, imgsz=320, verbose=False) | |
| print("⚠️ Using regular YOLO - detecting all objects, not just the specified prompt") | |
| except Exception as e: | |
| print(f"Error in YOLO prediction: {e}") | |
| return np.array([]) | |
| boxes = [] | |
| for result in results: | |
| if result.boxes is not None and len(result.boxes.xyxy) > 0: | |
| boxes.extend(result.boxes.xyxy.cpu().numpy()) | |
| if len(boxes) > 0: | |
| boxes = np.vstack(boxes) | |
| print(f"Detected {len(boxes)} objects") | |
| else: | |
| boxes = np.array([]) | |
| print("No objects detected") | |
| return boxes | |
| def segment_and_highlight_video(video_path, prompt="motorcycle", highlight_color="red"): | |
| """Segment and highlight motorcycles in a video using SAM 2 and YOLO-World.""" | |
| # Get video properties first | |
| cap = cv2.VideoCapture(video_path) | |
| original_fps = cap.get(cv2.CAP_PROP_FPS) | |
| original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Target resolution | |
| target_width, target_height = 320, 180 | |
| # Get first frame for detection | |
| ret, first_frame = cap.read() | |
| if not ret: | |
| cap.release() | |
| raise ValueError("Could not read first frame from video.") | |
| # Resize first frame for detection | |
| first_frame_resized = cv2.resize(first_frame, (target_width, target_height)) | |
| cap.release() | |
| # Detect boxes in resized first frame | |
| boxes = detect_motorcycles(first_frame_resized, prompt) | |
| if len(boxes) == 0: | |
| return video_path # No motorcycles detected, return original | |
| # Boxes are already in the target resolution coordinate system | |
| print(f"Detected {len(boxes)} objects with boxes: {boxes}") | |
| # Color map for highlighting | |
| color_map = {"red": (0, 0, 255), "green": (0, 255, 0), "blue": (255, 0, 0)} | |
| highlight_rgb = color_map.get(highlight_color.lower(), (0, 0, 255)) | |
| # Process video frame by frame instead of using SAM's video prediction | |
| cap = cv2.VideoCapture(video_path) | |
| output_path = "output.mp4" | |
| out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), original_fps, (target_width, target_height)) | |
| frame_count = 0 | |
| max_frames = min(total_frames, 150) # Limit to 150 frames (~5 seconds at 30fps) | |
| print(f"Processing {max_frames} frames...") | |
| while frame_count < max_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Resize frame to target resolution | |
| frame_resized = cv2.resize(frame, (target_width, target_height)) | |
| try: | |
| # Run SAM on individual frame with explicit resolution control | |
| sam_results = sam_model.predict( | |
| source=frame_resized, | |
| bboxes=boxes, | |
| device="cpu", | |
| imgsz=320, # Force SAM resolution | |
| conf=0.25, | |
| verbose=False | |
| ) | |
| highlighted_frame = frame_resized.copy() | |
| # Process SAM results | |
| if len(sam_results) > 0 and sam_results[0].masks is not None: | |
| masks = sam_results[0].masks.data.cpu().numpy() | |
| if len(masks) > 0: | |
| # Combine all masks | |
| combined_mask = np.any(masks, axis=0).astype(np.uint8) | |
| # Create colored overlay | |
| overlay = np.zeros_like(frame_resized) | |
| overlay[combined_mask == 1] = highlight_rgb | |
| # Blend with original frame | |
| highlighted_frame = cv2.addWeighted(frame_resized, 0.7, overlay, 0.3, 0) | |
| except Exception as e: | |
| print(f"Error processing frame {frame_count}: {e}") | |
| highlighted_frame = frame_resized | |
| out.write(highlighted_frame) | |
| frame_count += 1 | |
| # Progress indicator | |
| if frame_count % 30 == 0: | |
| print(f"Processed {frame_count}/{max_frames} frames") | |
| cap.release() | |
| out.release() | |
| print(f"Video processing complete. Output saved to {output_path}") | |
| return output_path | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=segment_and_highlight_video, | |
| inputs=[ | |
| gr.Video(label="Upload Video"), | |
| gr.Textbox(label="Prompt", placeholder="e.g., motorcycle", value="motorcycle"), | |
| gr.Dropdown(choices=["red", "green", "blue"], label="Highlight Color", value="red") | |
| ], | |
| outputs=gr.Video(label="Highlighted Video"), | |
| title="Video Segmentation with MobileSAM and YOLO (CPU Optimized)", | |
| description="Upload a short video (5-10 seconds), specify a text prompt (e.g., 'motorcycle'), and choose a highlight color. Uses MobileSAM + YOLO for CPU processing at 320x180 resolution.", | |
| examples=[ | |
| [None, "motorcycle", "red"], | |
| [None, "car", "green"], | |
| [None, "person", "blue"] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |