#!/usr/bin/env python3 import cv2 import numpy as np import os import argparse import torch import urllib # Import SAM modules #from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # URL for the default SAM-ViT-H checkpoint DEFAULT_CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" def download_checkpoint(checkpoint_path, url=DEFAULT_CHECKPOINT_URL): """ Downloads the checkpoint file if it does not exist. Args: checkpoint_path (str): Path where the checkpoint file will be saved. url (str): URL to download the checkpoint from. """ if os.path.exists(checkpoint_path): print(f"Checkpoint already exists at {checkpoint_path}.") return print(f"Downloading SAM checkpoint from {url}...") urllib.request.urlretrieve(url, checkpoint_path) print(f"Downloaded checkpoint to {checkpoint_path}.") def process_video(video_path, output_dir, sam_checkpoint, sam_model_type="vit_h", frame_skip=5, motion_threshold=20, key_frame_diff_threshold=100000): """ Process the video to extract key frames based on Farneback optical flow. For each key frame, SAM2 is used to generate an object mask. The background is blurred and the object is kept sharp, then saved to disk. Args: video_path (str): Path to the input video file. output_dir (str): Directory to save key frames. sam_checkpoint (str): Path to the SAM model checkpoint. sam_model_type (str): SAM model type to use (default: "vit_h"). frame_skip (int): Number of frames to skip between processing. motion_threshold (float): Pixel motion magnitude threshold. key_frame_diff_threshold (float): Threshold for difference in motion score for a new key frame. """ os.makedirs(output_dir, exist_ok=True) # Set device for SAM model device = "cuda" if torch.cuda.is_available() else "cpu" # Load the SAM model and create an automatic mask generator #sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint) #sam.to(device=device) #mask_generator = SamAutomaticMaskGenerator(sam) # Open the video file cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Error opening video file: {video_path}") ret, prev_frame = cap.read() if not ret: raise ValueError("Could not read the first frame from the video.") prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY) frame_count = 0 key_frame_count = 0 last_key_frame_motion = None while True: # Skip frames for efficiency for _ in range(frame_skip): cap.grab() ret, frame = cap.read() if not ret: break gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) # Calculate Farneback optical flow between the previous and current frame flow = cv2.calcOpticalFlowFarneback(prev_gray, gray, None, 0.5, 3, 15, 3, 5, 1.2, 0) magnitude, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) # Create a binary mask where motion exceeds the threshold motion_mask = magnitude > motion_threshold # Compute a motion score (total number of "moving" pixels) motion_score = np.sum(motion_mask) # Decide if this frame is a key frame based on a significant change in motion if last_key_frame_motion is None or abs(motion_score - last_key_frame_motion) > key_frame_diff_threshold: # Convert BGR to RGB for SAM processing image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Generate masks automatically with SAM2 #masks = mask_generator.generate(image_rgb) # Choose the best mask (e.g., the one with the largest area) #if masks: # best_mask = max(masks, key=lambda m: m["area"])["segmentation"] #else: # Fallback: use a mask that selects the entire frame if no mask is found # best_mask = np.ones(frame.shape[:2], dtype=bool) # Create a blurred version of the frame for the background #blurred_frame = cv2.GaussianBlur(frame, (21, 21), 0) # Expand the mask to three channels #mask_3c = np.repeat(best_mask[:, :, np.newaxis], 3, axis=2) # Composite the image: keep object region sharp, blur the rest #masked_frame = np.where(mask_3c, frame, blurred_frame) # Save the processed key frame key_frame_path = os.path.join(output_dir, f"key_frame_{key_frame_count:03d}.jpg") cv2.imwrite(key_frame_path, image_rgb) print(f"Saved key frame: {key_frame_path}") key_frame_count += 1 last_key_frame_motion = motion_score # Update previous frame for optical flow calculation prev_gray = gray.copy() frame_count += 1 cap.release() print(f"Finished processing {frame_count} frames. Total key frames saved: {key_frame_count}.") def main(): parser = argparse.ArgumentParser( description="Extract key frames from a video using Farneback optical flow and SAM2 for masking. " "The moving object remains sharp while the background is blurred." ) parser.add_argument("--video", type=str, required=True, help="Path to the input video file (e.g., .mp4 or .mov).") parser.add_argument("--output", type=str, required=True, help="Directory to save key frames.") parser.add_argument("--sam_checkpoint", type=str, default="../checkpoints/sam_vit_h.pth", help="Path to the SAM model checkpoint file. If not present, the default checkpoint will be downloaded.") parser.add_argument("--sam_model_type", type=str, default="vit_h", help="SAM model type (default: vit_h).") parser.add_argument("--frame_skip", type=int, default=5, help="Number of frames to skip between processing (default: 5).") parser.add_argument("--motion_threshold", type=float, default=20, help="Threshold for pixel motion magnitude (default: 20).") parser.add_argument("--key_frame_diff_threshold", type=float, default=100000, help="Motion score difference threshold for a new key frame (default: 100000).") args = parser.parse_args() # Download checkpoint if needed download_checkpoint(args.sam_checkpoint) process_video(args.video, args.output, args.sam_checkpoint, args.sam_model_type, args.frame_skip, args.motion_threshold, args.key_frame_diff_threshold) if __name__ == "__main__": main()