Spaces:
Configuration error
Configuration error
| #!/usr/bin/env python3 | |
| import cv2 | |
| import numpy as np | |
| import os | |
| import argparse | |
| import torch | |
| import urllib | |
| # Import SAM modules | |
| #from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
| # URL for the default SAM-ViT-H checkpoint | |
| DEFAULT_CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
| def download_checkpoint(checkpoint_path, url=DEFAULT_CHECKPOINT_URL): | |
| """ | |
| Downloads the checkpoint file if it does not exist. | |
| Args: | |
| checkpoint_path (str): Path where the checkpoint file will be saved. | |
| url (str): URL to download the checkpoint from. | |
| """ | |
| if os.path.exists(checkpoint_path): | |
| print(f"Checkpoint already exists at {checkpoint_path}.") | |
| return | |
| print(f"Downloading SAM checkpoint from {url}...") | |
| urllib.request.urlretrieve(url, checkpoint_path) | |
| print(f"Downloaded checkpoint to {checkpoint_path}.") | |
| def process_video(video_path, output_dir, sam_checkpoint, sam_model_type="vit_h", | |
| frame_skip=5, motion_threshold=20, key_frame_diff_threshold=100000): | |
| """ | |
| Process the video to extract key frames based on Farneback optical flow. | |
| For each key frame, SAM2 is used to generate an object mask. | |
| The background is blurred and the object is kept sharp, then saved to disk. | |
| Args: | |
| video_path (str): Path to the input video file. | |
| output_dir (str): Directory to save key frames. | |
| sam_checkpoint (str): Path to the SAM model checkpoint. | |
| sam_model_type (str): SAM model type to use (default: "vit_h"). | |
| frame_skip (int): Number of frames to skip between processing. | |
| motion_threshold (float): Pixel motion magnitude threshold. | |
| key_frame_diff_threshold (float): Threshold for difference in motion score for a new key frame. | |
| """ | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Set device for SAM model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load the SAM model and create an automatic mask generator | |
| #sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint) | |
| #sam.to(device=device) | |
| #mask_generator = SamAutomaticMaskGenerator(sam) | |
| # Open the video file | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError(f"Error opening video file: {video_path}") | |
| ret, prev_frame = cap.read() | |
| if not ret: | |
| raise ValueError("Could not read the first frame from the video.") | |
| prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY) | |
| frame_count = 0 | |
| key_frame_count = 0 | |
| last_key_frame_motion = None | |
| while True: | |
| # Skip frames for efficiency | |
| for _ in range(frame_skip): | |
| cap.grab() | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| # Calculate Farneback optical flow between the previous and current frame | |
| flow = cv2.calcOpticalFlowFarneback(prev_gray, gray, | |
| None, 0.5, 3, 15, 3, 5, 1.2, 0) | |
| magnitude, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) | |
| # Create a binary mask where motion exceeds the threshold | |
| motion_mask = magnitude > motion_threshold | |
| # Compute a motion score (total number of "moving" pixels) | |
| motion_score = np.sum(motion_mask) | |
| # Decide if this frame is a key frame based on a significant change in motion | |
| if last_key_frame_motion is None or abs(motion_score - last_key_frame_motion) > key_frame_diff_threshold: | |
| # Convert BGR to RGB for SAM processing | |
| image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # Generate masks automatically with SAM2 | |
| #masks = mask_generator.generate(image_rgb) | |
| # Choose the best mask (e.g., the one with the largest area) | |
| #if masks: | |
| # best_mask = max(masks, key=lambda m: m["area"])["segmentation"] | |
| #else: | |
| # Fallback: use a mask that selects the entire frame if no mask is found | |
| # best_mask = np.ones(frame.shape[:2], dtype=bool) | |
| # Create a blurred version of the frame for the background | |
| #blurred_frame = cv2.GaussianBlur(frame, (21, 21), 0) | |
| # Expand the mask to three channels | |
| #mask_3c = np.repeat(best_mask[:, :, np.newaxis], 3, axis=2) | |
| # Composite the image: keep object region sharp, blur the rest | |
| #masked_frame = np.where(mask_3c, frame, blurred_frame) | |
| # Save the processed key frame | |
| key_frame_path = os.path.join(output_dir, f"key_frame_{key_frame_count:03d}.jpg") | |
| cv2.imwrite(key_frame_path, image_rgb) | |
| print(f"Saved key frame: {key_frame_path}") | |
| key_frame_count += 1 | |
| last_key_frame_motion = motion_score | |
| # Update previous frame for optical flow calculation | |
| prev_gray = gray.copy() | |
| frame_count += 1 | |
| cap.release() | |
| print(f"Finished processing {frame_count} frames. Total key frames saved: {key_frame_count}.") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Extract key frames from a video using Farneback optical flow and SAM2 for masking. " | |
| "The moving object remains sharp while the background is blurred." | |
| ) | |
| parser.add_argument("--video", type=str, required=True, help="Path to the input video file (e.g., .mp4 or .mov).") | |
| parser.add_argument("--output", type=str, required=True, help="Directory to save key frames.") | |
| parser.add_argument("--sam_checkpoint", type=str, default="../checkpoints/sam_vit_h.pth", | |
| help="Path to the SAM model checkpoint file. If not present, the default checkpoint will be downloaded.") | |
| parser.add_argument("--sam_model_type", type=str, default="vit_h", help="SAM model type (default: vit_h).") | |
| parser.add_argument("--frame_skip", type=int, default=5, help="Number of frames to skip between processing (default: 5).") | |
| parser.add_argument("--motion_threshold", type=float, default=20, help="Threshold for pixel motion magnitude (default: 20).") | |
| parser.add_argument("--key_frame_diff_threshold", type=float, default=100000, help="Motion score difference threshold for a new key frame (default: 100000).") | |
| args = parser.parse_args() | |
| # Download checkpoint if needed | |
| download_checkpoint(args.sam_checkpoint) | |
| process_video(args.video, args.output, args.sam_checkpoint, args.sam_model_type, | |
| args.frame_skip, args.motion_threshold, args.key_frame_diff_threshold) | |
| if __name__ == "__main__": | |
| main() | |