Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import numpy as np | |
| import gradio as gr | |
| from moviepy import VideoFileClip | |
| import torch | |
| import clip | |
| import cv2 | |
| from PIL import Image | |
| from scenedetect import VideoManager, SceneManager | |
| from scenedetect.detectors import ContentDetector, AdaptiveDetector, ThresholdDetector, HistogramDetector, HashDetector | |
| # Device options | |
| DEVICE_OPTIONS = { | |
| "cpu": "cpu", | |
| "cuda": "cuda" if torch.cuda.is_available() else "cpu", | |
| "mps": "mps" if torch.backends.mps.is_available() else "cpu" | |
| } | |
| def load_clip_model(device): | |
| return clip.load("ViT-B/32", device=device) | |
| # --- Video Processing Functions --- | |
| def extract_frames(video_path, fps=2): | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| frame_rate = int(cap.get(cv2.CAP_PROP_FPS) / fps) | |
| count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if count % frame_rate == 0: | |
| frames.append(frame) | |
| count += 1 | |
| cap.release() | |
| return frames | |
| def get_clip_features(frames, model, preprocess, device): | |
| features = [] | |
| for frame in frames: | |
| img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| img_input = preprocess(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| feature = model.encode_image(img_input) | |
| features.append(feature.cpu().numpy()[0]) | |
| return features | |
| def compute_distance(a, b, method): | |
| if method == "cosine": | |
| return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) | |
| elif method == "l2": | |
| return np.linalg.norm(a - b) | |
| elif method == "l1": | |
| return np.sum(np.abs(a - b)) | |
| else: | |
| return np.linalg.norm(a - b) | |
| def find_match(clip_feats, ref_feats, threshold=0.3, similarity="l2"): | |
| len_clip = len(clip_feats) | |
| best_match = -1 | |
| best_score = float('inf') if similarity != "cosine" else -float('inf') | |
| for i in range(len(ref_feats) - len_clip + 1): | |
| window = ref_feats[i:i + len_clip] | |
| dists = [compute_distance(a, b, similarity) for a, b in zip(clip_feats, window)] | |
| dist = np.mean(dists) | |
| if (similarity != "cosine" and dist < best_score) or (similarity == "cosine" and dist > best_score): | |
| best_score = dist | |
| best_match = i | |
| if (similarity != "cosine" and best_score < threshold) or (similarity == "cosine" and best_score > threshold): | |
| return best_match, best_score | |
| return -1, best_score | |
| # Scene Detection | |
| def get_detector(detector_name, threshold): | |
| if detector_name == "ContentDetector": | |
| return ContentDetector(threshold=threshold) | |
| elif detector_name == "AdaptiveDetector": | |
| return AdaptiveDetector() | |
| elif detector_name == "ThresholdDetector": | |
| return ThresholdDetector(threshold=threshold) | |
| elif detector_name == "HashDetector": | |
| return HashDetector(threshold=threshold) | |
| elif detector_name == "HistogramDetector": | |
| return HistogramDetector(threshold=threshold) | |
| else: | |
| return ContentDetector(threshold=threshold) | |
| def detect_scenes(video_path, detector_name, threshold): | |
| video_manager = VideoManager([video_path]) | |
| scene_manager = SceneManager() | |
| detector = get_detector(detector_name, threshold) | |
| scene_manager.add_detector(detector) | |
| video_manager.set_downscale_factor() | |
| video_manager.start() | |
| scene_manager.detect_scenes(frame_source=video_manager) | |
| scene_list = scene_manager.get_scene_list() | |
| return [(scene[0].get_seconds(), scene[1].get_seconds()) for scene in scene_list] | |
| def find_scene_for_timestamp(scenes, match_time): | |
| for start, end in scenes: | |
| if start <= match_time <= end: | |
| return (start, end) | |
| return None | |
| def extract_scene(video_path, scene_range, output_path): | |
| start_time, end_time = scene_range | |
| clip = VideoFileClip(video_path).subclipped(start_time, end_time) | |
| clip.write_videofile(output_path, codec="libx264", audio_codec="aac") | |
| return output_path | |
| # Main logic | |
| def process_videos(clip_path, ref_path, match_threshold, scene_threshold, detector_type, similarity_type, device_type, output_path): | |
| device = DEVICE_OPTIONS.get(device_type, "cpu") | |
| model, preprocess = load_clip_model(device) | |
| clip_frames = extract_frames(clip_path) | |
| ref_frames = extract_frames(ref_path) | |
| clip_feats = get_clip_features(clip_frames, model, preprocess, device) | |
| ref_feats = get_clip_features(ref_frames, model, preprocess, device) | |
| match_index, score = find_match(clip_feats, ref_feats, match_threshold, similarity_type) | |
| if match_index == -1: | |
| return f"No match found (best score = {score:.4f})", None | |
| match_time = match_index * 0.5 | |
| scenes = detect_scenes(ref_path, detector_type, scene_threshold) | |
| matched_scene = find_scene_for_timestamp(scenes, match_time) | |
| if not matched_scene: | |
| return "Match found, but no scene boundaries detected.", None | |
| output_path = os.path.join(output_path, "matched_scene.mp4") | |
| result_path = extract_scene(ref_path, matched_scene, output_path) | |
| return f"Match found at ~{match_time:.2f}s (score = {score:.4f})\nScene from {matched_scene[0]:.2f}s to {matched_scene[1]:.2f}s", result_path | |
| # Gradio Interface | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| iface = gr.Interface( | |
| fn=process_videos, | |
| inputs=[ | |
| gr.Video(label="Clip Video"), | |
| gr.Video(label="Reference Video"), | |
| gr.Slider(0.1, 100.0, value=0.3, label="Matching Threshold (lower = stricter, cosine = higher = better)"), | |
| gr.Slider(0.01, 100, value=30, step=1, label="Scene Detection Threshold"), | |
| gr.Dropdown([ | |
| "ContentDetector", "AdaptiveDetector", "ThresholdDetector", "HistogramDetector", "HashDetector" | |
| ], value="ContentDetector", label="Scene Detector Type"), | |
| gr.Dropdown(["l2", "l1", "cosine"], value="l2", label="Similarity Metric"), | |
| gr.Dropdown(["cpu", "cuda", "mps"], value="cpu", label="Processing Device"), | |
| gr.Text(value=tmpdir,visible=False) | |
| ], | |
| outputs=[ | |
| gr.Text(label="Match Info"), | |
| gr.Video(label="Matched Scene") | |
| ], | |
| title="AI Video Clip Matcher", | |
| description="Upload a short video clip and a reference video. The system will try to find where the clip appears in the reference video and extract the full scene around it." | |
| ) | |
| # --- Launch the App --- | |
| if __name__ == "__main__": | |
| print("Launching Gradio interface...") | |
| # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values | |
| # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker | |
| iface.launch() | |