Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import subprocess | |
| import tempfile | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from functools import lru_cache | |
| import threading | |
| class VideoFrameCache: | |
| """Cache for video frames with lazy loading""" | |
| def __init__(self, video_path, max_cache_size=500): | |
| self.video_path = video_path | |
| self.cache = {} | |
| self.max_cache_size = max_cache_size | |
| self.total_frames = self._get_frame_count() | |
| self.lock = threading.Lock() | |
| def _get_frame_count(self): | |
| cap = cv2.VideoCapture(self.video_path) | |
| total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| cap.release() | |
| return total | |
| def get_frame(self, frame_num): | |
| """Get frame from cache or load it""" | |
| with self.lock: | |
| if frame_num in self.cache: | |
| return self.cache[frame_num] | |
| # Load frame | |
| cap = cv2.VideoCapture(self.video_path) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if ret: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # Cache management - remove oldest if cache is full | |
| if len(self.cache) >= self.max_cache_size: | |
| # Remove first item (simple FIFO) | |
| first_key = next(iter(self.cache)) | |
| del self.cache[first_key] | |
| self.cache[frame_num] = frame_rgb | |
| return frame_rgb | |
| return None | |
| def preload_range(self, start_frame, end_frame): | |
| """Preload a range of frames""" | |
| for i in range(start_frame, min(end_frame, self.total_frames)): | |
| if i not in self.cache: | |
| self.get_frame(i) | |
| def preload_all(self): | |
| """Preload all frames (use with caution for long videos)""" | |
| print(f"Preloading {self.total_frames} frames...") | |
| for i in range(self.total_frames): | |
| self.get_frame(i) | |
| if i % 100 == 0: | |
| print(f"Loaded {i}/{self.total_frames} frames...") | |
| print("All frames loaded!") | |
| def get_frame_count(video_path): | |
| """Get total number of frames in video""" | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| cap.release() | |
| return total_frames | |
| def create_side_by_side_frame(cache1, cache2, frame_num, method1, method2, show_frame_numbers, total_frames): | |
| """Create side-by-side image for a specific frame using cache""" | |
| if cache1 is None or cache2 is None: | |
| return None | |
| frame1 = cache1.get_frame(frame_num) | |
| frame2 = cache2.get_frame(frame_num) | |
| if frame1 is None or frame2 is None: | |
| return None | |
| # Preload next few frames in background for smooth navigation | |
| def preload_nearby(): | |
| for offset in range(1, 20): | |
| if frame_num + offset < total_frames: | |
| cache1.get_frame(frame_num + offset) | |
| cache2.get_frame(frame_num + offset) | |
| thread = threading.Thread(target=preload_nearby, daemon=True) | |
| thread.start() | |
| # Convert to PIL for drawing text | |
| img1 = Image.fromarray(frame1) | |
| img2 = Image.fromarray(frame2) | |
| # Create side-by-side image | |
| width = img1.width + img2.width | |
| height = max(img1.height, img2.height) | |
| result = Image.new('RGB', (width, height)) | |
| result.paste(img1, (0, 0)) | |
| result.paste(img2, (img1.width, 0)) | |
| # Convert back to array for text overlay | |
| result_array = np.array(result) | |
| # Add text overlays | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.8 | |
| thickness = 2 | |
| # Method names | |
| cv2.putText(result_array, method1, (10, 30), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) | |
| cv2.putText(result_array, method2, (img1.width + 10, 30), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) | |
| # Frame numbers | |
| if show_frame_numbers: | |
| frame_text = f'Frame {frame_num}/{total_frames}' | |
| cv2.putText(result_array, frame_text, (10, height - 20), font, 0.6, (255, 255, 255), thickness, cv2.LINE_AA) | |
| cv2.putText(result_array, frame_text, (img1.width + 10, height - 20), font, 0.6, (255, 255, 255), thickness, cv2.LINE_AA) | |
| return result_array | |
| def create_side_by_side_video(video1, video2, method1="Method 1", method2="Method 2", show_frame_numbers=True, preload_all_frames=False): | |
| """Create side-by-side video comparison and initialize caches""" | |
| if video1 is None or video2 is None: | |
| return None, None, 0, 0, None, None, "No videos loaded" | |
| # Get frame count for slider | |
| total_frames = get_frame_count(video1) | |
| # Create frame caches | |
| print("Creating frame caches...") | |
| cache1 = VideoFrameCache(video1, max_cache_size=500) | |
| cache2 = VideoFrameCache(video2, max_cache_size=500) | |
| status_msg = f"Videos loaded. Total frames: {total_frames}\n" | |
| # Preload first frames for instant display | |
| print("Preloading initial frames...") | |
| cache1.preload_range(0, 50) | |
| cache2.preload_range(0, 50) | |
| status_msg += "First 50 frames cached.\n" | |
| # Optionally preload all frames | |
| if preload_all_frames and total_frames < 1000: # Safety limit | |
| def preload_bg(): | |
| cache1.preload_all() | |
| cache2.preload_all() | |
| thread = threading.Thread(target=preload_bg, daemon=True) | |
| thread.start() | |
| status_msg += "Loading all frames in background...\n" | |
| # Create output file | |
| output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name | |
| # Get total frames for display | |
| try: | |
| result = subprocess.run( | |
| ['ffprobe', '-v', 'error', '-select_streams', 'v:0', | |
| '-count_frames', '-show_entries', 'stream=nb_read_frames', | |
| '-of', 'csv=p=0', video1], | |
| capture_output=True, text=True | |
| ) | |
| total_frames_str = result.stdout.strip() | |
| except: | |
| total_frames_str = str(total_frames) | |
| # Build filter complex | |
| if show_frame_numbers: | |
| frame_text = f'Frame %{{n}}/{total_frames_str}' | |
| filter_complex = f""" | |
| [0:v]drawtext=text='{method1}':fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=10, | |
| drawtext=text='{frame_text}':fontsize=20:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=h-40[v0]; | |
| [1:v]drawtext=text='{method2}':fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=10, | |
| drawtext=text='{frame_text}':fontsize=20:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=h-40[v1]; | |
| [v0][v1]hstack=inputs=2[v] | |
| """ | |
| else: | |
| filter_complex = f""" | |
| [0:v]drawtext=text='{method1}':fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=10[v0]; | |
| [1:v]drawtext=text='{method2}':fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=10[v1]; | |
| [v0][v1]hstack=inputs=2[v] | |
| """ | |
| # Run ffmpeg | |
| print("Creating side-by-side video...") | |
| cmd = [ | |
| 'ffmpeg', '-y', '-i', video1, '-i', video2, | |
| '-filter_complex', filter_complex, | |
| '-map', '[v]', '-map', '0:a?', | |
| '-c:v', 'libx264', '-crf', '18', | |
| '-c:a', 'aac', '-shortest', | |
| output_file | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True, capture_output=True) | |
| # Get first frame for preview | |
| first_frame = create_side_by_side_frame(cache1, cache2, 0, method1, method2, show_frame_numbers, total_frames) | |
| status_msg += "Video created successfully!" | |
| return output_file, first_frame, gr.Slider(maximum=total_frames-1, value=0), total_frames, cache1, cache2, status_msg | |
| except subprocess.CalledProcessError as e: | |
| print(f"Error: {e.stderr.decode()}") | |
| return None, None, gr.Slider(maximum=100, value=0), 0, None, None, f"Error: {e.stderr.decode()}" | |
| def update_frame(cache1, cache2, frame_num, method1, method2, show_frames, total_frames): | |
| """Update the displayed frame using cache""" | |
| if cache1 is None or cache2 is None or total_frames == 0: | |
| return None | |
| return create_side_by_side_frame(cache1, cache2, int(frame_num), method1, method2, show_frames, total_frames) | |
| # JavaScript for keyboard navigation | |
| js_code = """ | |
| function() { | |
| document.addEventListener('keydown', function(event) { | |
| const slider = document.querySelector('input[type="range"]'); | |
| if (!slider) return; | |
| const currentValue = parseInt(slider.value); | |
| const maxValue = parseInt(slider.max); | |
| if (event.key === 'ArrowRight' || event.key === 'd' || event.key === 'D') { | |
| event.preventDefault(); | |
| if (currentValue < maxValue) { | |
| slider.value = currentValue + 1; | |
| slider.dispatchEvent(new Event('input', { bubbles: true })); | |
| slider.dispatchEvent(new Event('change', { bubbles: true })); | |
| } | |
| } | |
| if (event.key === 'ArrowLeft' || event.key === 'a' || event.key === 'A') { | |
| event.preventDefault(); | |
| if (currentValue > 0) { | |
| slider.value = currentValue - 1; | |
| slider.dispatchEvent(new Event('input', { bubbles: true })); | |
| slider.dispatchEvent(new Event('change', { bubbles: true })); | |
| } | |
| } | |
| if (event.key === 'ArrowUp' || event.key === 'w' || event.key === 'W') { | |
| event.preventDefault(); | |
| const newValue = Math.min(currentValue + 10, maxValue); | |
| slider.value = newValue; | |
| slider.dispatchEvent(new Event('input', { bubbles: true })); | |
| slider.dispatchEvent(new Event('change', { bubbles: true })); | |
| } | |
| if (event.key === 'ArrowDown' || event.key === 's' || event.key === 'S') { | |
| event.preventDefault(); | |
| const newValue = Math.max(currentValue - 10, 0); | |
| slider.value = newValue; | |
| slider.dispatchEvent(new Event('input', { bubbles: true })); | |
| slider.dispatchEvent(new Event('change', { bubbles: true })); | |
| } | |
| }); | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(title="Video Side-by-Side Comparison", js=js_code) as demo: | |
| gr.Markdown("# Video Side-by-Side Comparison with Cached Frame Navigation") | |
| gr.Markdown("Upload two videos to compare them side by side. Frames are cached for instant navigation!") | |
| # Store video caches and frame count in state | |
| cache1_state = gr.State() | |
| cache2_state = gr.State() | |
| total_frames_state = gr.State(value=0) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video1_input = gr.Video(label="Video 1") | |
| method1_input = gr.Textbox(label="Method 1 Name", value="Ground Truth") | |
| with gr.Column(): | |
| video2_input = gr.Video(label="Video 2") | |
| method2_input = gr.Textbox(label="Method 2 Name", value="Generated") | |
| with gr.Row(): | |
| show_frames_checkbox = gr.Checkbox(label="Show Frame Numbers", value=True) | |
| preload_checkbox = gr.Checkbox(label="Preload All Frames (for videos < 1000 frames)", value=False) | |
| compare_btn = gr.Button("Create Comparison Video & Load Frames", variant="primary") | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| output_video = gr.Video(label="Side-by-Side Comparison Video") | |
| gr.Markdown("## Frame-by-Frame Navigation (Cached)") | |
| gr.Markdown(""" | |
| **Keyboard shortcuts:** | |
| - ← / → (or A / D): Previous/Next frame | |
| - ↑ / ↓ (or W / S): Jump 10 frames backward/forward | |
| **Note:** First 50 frames are preloaded immediately. Additional frames load on-demand and are cached for instant replay. | |
| """) | |
| frame_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Frame Number") | |
| with gr.Row(): | |
| prev_btn = gr.Button("← Previous Frame") | |
| next_btn = gr.Button("Next Frame →") | |
| frame_display = gr.Image(label="Current Frame Comparison") | |
| # Create video comparison | |
| def process_videos(video1, video2, method1, method2, show_frames, preload_all): | |
| result = create_side_by_side_video(video1, video2, method1, method2, show_frames, preload_all) | |
| return result[0], result[1], result[2], result[3], result[4], result[5], result[6] | |
| compare_btn.click( | |
| fn=process_videos, | |
| inputs=[video1_input, video2_input, method1_input, method2_input, show_frames_checkbox, preload_checkbox], | |
| outputs=[output_video, frame_display, frame_slider, total_frames_state, cache1_state, cache2_state, status_text] | |
| ) | |
| # Frame navigation | |
| frame_slider.change( | |
| fn=update_frame, | |
| inputs=[cache1_state, cache2_state, frame_slider, method1_input, method2_input, show_frames_checkbox, total_frames_state], | |
| outputs=frame_display | |
| ) | |
| def prev_frame(current_frame): | |
| return max(0, current_frame - 1) | |
| def next_frame(current_frame, total): | |
| return min(total - 1, current_frame + 1) | |
| prev_btn.click( | |
| fn=prev_frame, | |
| inputs=[frame_slider], | |
| outputs=frame_slider | |
| ) | |
| next_btn.click( | |
| fn=next_frame, | |
| inputs=[frame_slider, total_frames_state], | |
| outputs=frame_slider | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |