File size: 6,724 Bytes
24f3fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python3
import cv2
import numpy as np
import os
import argparse
import torch
import urllib

# Import SAM modules
#from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

# URL for the default SAM-ViT-H checkpoint
DEFAULT_CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"

def download_checkpoint(checkpoint_path, url=DEFAULT_CHECKPOINT_URL):
    """
    Downloads the checkpoint file if it does not exist.
    
    Args:
        checkpoint_path (str): Path where the checkpoint file will be saved.
        url (str): URL to download the checkpoint from.
    """
    if os.path.exists(checkpoint_path):
        print(f"Checkpoint already exists at {checkpoint_path}.")
        return
    print(f"Downloading SAM checkpoint from {url}...")
    urllib.request.urlretrieve(url, checkpoint_path)
    print(f"Downloaded checkpoint to {checkpoint_path}.")

def process_video(video_path, output_dir, sam_checkpoint, sam_model_type="vit_h", 
                  frame_skip=5, motion_threshold=20, key_frame_diff_threshold=100000):
    """
    Process the video to extract key frames based on Farneback optical flow.
    For each key frame, SAM2 is used to generate an object mask.
    The background is blurred and the object is kept sharp, then saved to disk.

    Args:
        video_path (str): Path to the input video file.
        output_dir (str): Directory to save key frames.
        sam_checkpoint (str): Path to the SAM model checkpoint.
        sam_model_type (str): SAM model type to use (default: "vit_h").
        frame_skip (int): Number of frames to skip between processing.
        motion_threshold (float): Pixel motion magnitude threshold.
        key_frame_diff_threshold (float): Threshold for difference in motion score for a new key frame.
    """
    os.makedirs(output_dir, exist_ok=True)

    # Set device for SAM model
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load the SAM model and create an automatic mask generator
    #sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint)
    #sam.to(device=device)
    #mask_generator = SamAutomaticMaskGenerator(sam)

    # Open the video file
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Error opening video file: {video_path}")

    ret, prev_frame = cap.read()
    if not ret:
        raise ValueError("Could not read the first frame from the video.")

    prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
    frame_count = 0
    key_frame_count = 0
    last_key_frame_motion = None

    while True:
        # Skip frames for efficiency
        for _ in range(frame_skip):
            cap.grab()

        ret, frame = cap.read()
        if not ret:
            break

        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        
        # Calculate Farneback optical flow between the previous and current frame
        flow = cv2.calcOpticalFlowFarneback(prev_gray, gray,
                                            None, 0.5, 3, 15, 3, 5, 1.2, 0)
        magnitude, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])
        # Create a binary mask where motion exceeds the threshold
        motion_mask = magnitude > motion_threshold

        # Compute a motion score (total number of "moving" pixels)
        motion_score = np.sum(motion_mask)
        
        # Decide if this frame is a key frame based on a significant change in motion
        if last_key_frame_motion is None or abs(motion_score - last_key_frame_motion) > key_frame_diff_threshold:
            # Convert BGR to RGB for SAM processing
            image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # Generate masks automatically with SAM2
            #masks = mask_generator.generate(image_rgb)
            
            # Choose the best mask (e.g., the one with the largest area)
            #if masks:
            #    best_mask = max(masks, key=lambda m: m["area"])["segmentation"]
            #else:
                # Fallback: use a mask that selects the entire frame if no mask is found
            #    best_mask = np.ones(frame.shape[:2], dtype=bool)
            
            # Create a blurred version of the frame for the background
            #blurred_frame = cv2.GaussianBlur(frame, (21, 21), 0)
            # Expand the mask to three channels
            #mask_3c = np.repeat(best_mask[:, :, np.newaxis], 3, axis=2)
            # Composite the image: keep object region sharp, blur the rest
            #masked_frame = np.where(mask_3c, frame, blurred_frame)
            
            # Save the processed key frame
            key_frame_path = os.path.join(output_dir, f"key_frame_{key_frame_count:03d}.jpg")
            cv2.imwrite(key_frame_path, image_rgb)
            print(f"Saved key frame: {key_frame_path}")
            key_frame_count += 1
            last_key_frame_motion = motion_score
        
        # Update previous frame for optical flow calculation
        prev_gray = gray.copy()
        frame_count += 1

    cap.release()
    print(f"Finished processing {frame_count} frames. Total key frames saved: {key_frame_count}.")

def main():
    parser = argparse.ArgumentParser(
        description="Extract key frames from a video using Farneback optical flow and SAM2 for masking. "
                    "The moving object remains sharp while the background is blurred."
    )
    parser.add_argument("--video", type=str, required=True, help="Path to the input video file (e.g., .mp4 or .mov).")
    parser.add_argument("--output", type=str, required=True, help="Directory to save key frames.")
    parser.add_argument("--sam_checkpoint", type=str, default="../checkpoints/sam_vit_h.pth",
                        help="Path to the SAM model checkpoint file. If not present, the default checkpoint will be downloaded.")
    parser.add_argument("--sam_model_type", type=str, default="vit_h", help="SAM model type (default: vit_h).")
    parser.add_argument("--frame_skip", type=int, default=5, help="Number of frames to skip between processing (default: 5).")
    parser.add_argument("--motion_threshold", type=float, default=20, help="Threshold for pixel motion magnitude (default: 20).")
    parser.add_argument("--key_frame_diff_threshold", type=float, default=100000, help="Motion score difference threshold for a new key frame (default: 100000).")
    
    args = parser.parse_args()
    # Download checkpoint if needed
    download_checkpoint(args.sam_checkpoint)

    process_video(args.video, args.output, args.sam_checkpoint, args.sam_model_type,
                  args.frame_skip, args.motion_threshold, args.key_frame_diff_threshold)

if __name__ == "__main__":
    main()