Spaces:
Configuration error
Configuration error
File size: 6,724 Bytes
24f3fb6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
#!/usr/bin/env python3
import cv2
import numpy as np
import os
import argparse
import torch
import urllib
# Import SAM modules
#from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
# URL for the default SAM-ViT-H checkpoint
DEFAULT_CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
def download_checkpoint(checkpoint_path, url=DEFAULT_CHECKPOINT_URL):
"""
Downloads the checkpoint file if it does not exist.
Args:
checkpoint_path (str): Path where the checkpoint file will be saved.
url (str): URL to download the checkpoint from.
"""
if os.path.exists(checkpoint_path):
print(f"Checkpoint already exists at {checkpoint_path}.")
return
print(f"Downloading SAM checkpoint from {url}...")
urllib.request.urlretrieve(url, checkpoint_path)
print(f"Downloaded checkpoint to {checkpoint_path}.")
def process_video(video_path, output_dir, sam_checkpoint, sam_model_type="vit_h",
frame_skip=5, motion_threshold=20, key_frame_diff_threshold=100000):
"""
Process the video to extract key frames based on Farneback optical flow.
For each key frame, SAM2 is used to generate an object mask.
The background is blurred and the object is kept sharp, then saved to disk.
Args:
video_path (str): Path to the input video file.
output_dir (str): Directory to save key frames.
sam_checkpoint (str): Path to the SAM model checkpoint.
sam_model_type (str): SAM model type to use (default: "vit_h").
frame_skip (int): Number of frames to skip between processing.
motion_threshold (float): Pixel motion magnitude threshold.
key_frame_diff_threshold (float): Threshold for difference in motion score for a new key frame.
"""
os.makedirs(output_dir, exist_ok=True)
# Set device for SAM model
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the SAM model and create an automatic mask generator
#sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint)
#sam.to(device=device)
#mask_generator = SamAutomaticMaskGenerator(sam)
# Open the video file
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Error opening video file: {video_path}")
ret, prev_frame = cap.read()
if not ret:
raise ValueError("Could not read the first frame from the video.")
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
frame_count = 0
key_frame_count = 0
last_key_frame_motion = None
while True:
# Skip frames for efficiency
for _ in range(frame_skip):
cap.grab()
ret, frame = cap.read()
if not ret:
break
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# Calculate Farneback optical flow between the previous and current frame
flow = cv2.calcOpticalFlowFarneback(prev_gray, gray,
None, 0.5, 3, 15, 3, 5, 1.2, 0)
magnitude, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])
# Create a binary mask where motion exceeds the threshold
motion_mask = magnitude > motion_threshold
# Compute a motion score (total number of "moving" pixels)
motion_score = np.sum(motion_mask)
# Decide if this frame is a key frame based on a significant change in motion
if last_key_frame_motion is None or abs(motion_score - last_key_frame_motion) > key_frame_diff_threshold:
# Convert BGR to RGB for SAM processing
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Generate masks automatically with SAM2
#masks = mask_generator.generate(image_rgb)
# Choose the best mask (e.g., the one with the largest area)
#if masks:
# best_mask = max(masks, key=lambda m: m["area"])["segmentation"]
#else:
# Fallback: use a mask that selects the entire frame if no mask is found
# best_mask = np.ones(frame.shape[:2], dtype=bool)
# Create a blurred version of the frame for the background
#blurred_frame = cv2.GaussianBlur(frame, (21, 21), 0)
# Expand the mask to three channels
#mask_3c = np.repeat(best_mask[:, :, np.newaxis], 3, axis=2)
# Composite the image: keep object region sharp, blur the rest
#masked_frame = np.where(mask_3c, frame, blurred_frame)
# Save the processed key frame
key_frame_path = os.path.join(output_dir, f"key_frame_{key_frame_count:03d}.jpg")
cv2.imwrite(key_frame_path, image_rgb)
print(f"Saved key frame: {key_frame_path}")
key_frame_count += 1
last_key_frame_motion = motion_score
# Update previous frame for optical flow calculation
prev_gray = gray.copy()
frame_count += 1
cap.release()
print(f"Finished processing {frame_count} frames. Total key frames saved: {key_frame_count}.")
def main():
parser = argparse.ArgumentParser(
description="Extract key frames from a video using Farneback optical flow and SAM2 for masking. "
"The moving object remains sharp while the background is blurred."
)
parser.add_argument("--video", type=str, required=True, help="Path to the input video file (e.g., .mp4 or .mov).")
parser.add_argument("--output", type=str, required=True, help="Directory to save key frames.")
parser.add_argument("--sam_checkpoint", type=str, default="../checkpoints/sam_vit_h.pth",
help="Path to the SAM model checkpoint file. If not present, the default checkpoint will be downloaded.")
parser.add_argument("--sam_model_type", type=str, default="vit_h", help="SAM model type (default: vit_h).")
parser.add_argument("--frame_skip", type=int, default=5, help="Number of frames to skip between processing (default: 5).")
parser.add_argument("--motion_threshold", type=float, default=20, help="Threshold for pixel motion magnitude (default: 20).")
parser.add_argument("--key_frame_diff_threshold", type=float, default=100000, help="Motion score difference threshold for a new key frame (default: 100000).")
args = parser.parse_args()
# Download checkpoint if needed
download_checkpoint(args.sam_checkpoint)
process_video(args.video, args.output, args.sam_checkpoint, args.sam_model_type,
args.frame_skip, args.motion_threshold, args.key_frame_diff_threshold)
if __name__ == "__main__":
main()
|