import gradio as gr import subprocess import tempfile import cv2 import numpy as np from PIL import Image from functools import lru_cache import threading class VideoFrameCache: """Cache for video frames with lazy loading""" def __init__(self, video_path, max_cache_size=500): self.video_path = video_path self.cache = {} self.max_cache_size = max_cache_size self.total_frames = self._get_frame_count() self.lock = threading.Lock() def _get_frame_count(self): cap = cv2.VideoCapture(self.video_path) total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() return total def get_frame(self, frame_num): """Get frame from cache or load it""" with self.lock: if frame_num in self.cache: return self.cache[frame_num] # Load frame cap = cv2.VideoCapture(self.video_path) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) ret, frame = cap.read() cap.release() if ret: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Cache management - remove oldest if cache is full if len(self.cache) >= self.max_cache_size: # Remove first item (simple FIFO) first_key = next(iter(self.cache)) del self.cache[first_key] self.cache[frame_num] = frame_rgb return frame_rgb return None def preload_range(self, start_frame, end_frame): """Preload a range of frames""" for i in range(start_frame, min(end_frame, self.total_frames)): if i not in self.cache: self.get_frame(i) def preload_all(self): """Preload all frames (use with caution for long videos)""" print(f"Preloading {self.total_frames} frames...") for i in range(self.total_frames): self.get_frame(i) if i % 100 == 0: print(f"Loaded {i}/{self.total_frames} frames...") print("All frames loaded!") def get_frame_count(video_path): """Get total number of frames in video""" cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() return total_frames def create_side_by_side_frame(cache1, cache2, frame_num, method1, method2, show_frame_numbers, total_frames): """Create side-by-side image for a specific frame using cache""" if cache1 is None or cache2 is None: return None frame1 = cache1.get_frame(frame_num) frame2 = cache2.get_frame(frame_num) if frame1 is None or frame2 is None: return None # Preload next few frames in background for smooth navigation def preload_nearby(): for offset in range(1, 20): if frame_num + offset < total_frames: cache1.get_frame(frame_num + offset) cache2.get_frame(frame_num + offset) thread = threading.Thread(target=preload_nearby, daemon=True) thread.start() # Convert to PIL for drawing text img1 = Image.fromarray(frame1) img2 = Image.fromarray(frame2) # Create side-by-side image width = img1.width + img2.width height = max(img1.height, img2.height) result = Image.new('RGB', (width, height)) result.paste(img1, (0, 0)) result.paste(img2, (img1.width, 0)) # Convert back to array for text overlay result_array = np.array(result) # Add text overlays font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.8 thickness = 2 # Method names cv2.putText(result_array, method1, (10, 30), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) cv2.putText(result_array, method2, (img1.width + 10, 30), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) # Frame numbers if show_frame_numbers: frame_text = f'Frame {frame_num}/{total_frames}' cv2.putText(result_array, frame_text, (10, height - 20), font, 0.6, (255, 255, 255), thickness, cv2.LINE_AA) cv2.putText(result_array, frame_text, (img1.width + 10, height - 20), font, 0.6, (255, 255, 255), thickness, cv2.LINE_AA) return result_array def create_side_by_side_video(video1, video2, method1="Method 1", method2="Method 2", show_frame_numbers=True, preload_all_frames=False): """Create side-by-side video comparison and initialize caches""" if video1 is None or video2 is None: return None, None, 0, 0, None, None, "No videos loaded" # Get frame count for slider total_frames = get_frame_count(video1) # Create frame caches print("Creating frame caches...") cache1 = VideoFrameCache(video1, max_cache_size=500) cache2 = VideoFrameCache(video2, max_cache_size=500) status_msg = f"Videos loaded. Total frames: {total_frames}\n" # Preload first frames for instant display print("Preloading initial frames...") cache1.preload_range(0, 50) cache2.preload_range(0, 50) status_msg += "First 50 frames cached.\n" # Optionally preload all frames if preload_all_frames and total_frames < 1000: # Safety limit def preload_bg(): cache1.preload_all() cache2.preload_all() thread = threading.Thread(target=preload_bg, daemon=True) thread.start() status_msg += "Loading all frames in background...\n" # Create output file output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name # Get total frames for display try: result = subprocess.run( ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-count_frames', '-show_entries', 'stream=nb_read_frames', '-of', 'csv=p=0', video1], capture_output=True, text=True ) total_frames_str = result.stdout.strip() except: total_frames_str = str(total_frames) # Build filter complex if show_frame_numbers: frame_text = f'Frame %{{n}}/{total_frames_str}' filter_complex = f""" [0:v]drawtext=text='{method1}':fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=10, drawtext=text='{frame_text}':fontsize=20:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=h-40[v0]; [1:v]drawtext=text='{method2}':fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=10, drawtext=text='{frame_text}':fontsize=20:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=h-40[v1]; [v0][v1]hstack=inputs=2[v] """ else: filter_complex = f""" [0:v]drawtext=text='{method1}':fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=10[v0]; [1:v]drawtext=text='{method2}':fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=10[v1]; [v0][v1]hstack=inputs=2[v] """ # Run ffmpeg print("Creating side-by-side video...") cmd = [ 'ffmpeg', '-y', '-i', video1, '-i', video2, '-filter_complex', filter_complex, '-map', '[v]', '-map', '0:a?', '-c:v', 'libx264', '-crf', '18', '-c:a', 'aac', '-shortest', output_file ] try: subprocess.run(cmd, check=True, capture_output=True) # Get first frame for preview first_frame = create_side_by_side_frame(cache1, cache2, 0, method1, method2, show_frame_numbers, total_frames) status_msg += "Video created successfully!" return output_file, first_frame, gr.Slider(maximum=total_frames-1, value=0), total_frames, cache1, cache2, status_msg except subprocess.CalledProcessError as e: print(f"Error: {e.stderr.decode()}") return None, None, gr.Slider(maximum=100, value=0), 0, None, None, f"Error: {e.stderr.decode()}" def update_frame(cache1, cache2, frame_num, method1, method2, show_frames, total_frames): """Update the displayed frame using cache""" if cache1 is None or cache2 is None or total_frames == 0: return None return create_side_by_side_frame(cache1, cache2, int(frame_num), method1, method2, show_frames, total_frames) # JavaScript for keyboard navigation js_code = """ function() { document.addEventListener('keydown', function(event) { const slider = document.querySelector('input[type="range"]'); if (!slider) return; const currentValue = parseInt(slider.value); const maxValue = parseInt(slider.max); if (event.key === 'ArrowRight' || event.key === 'd' || event.key === 'D') { event.preventDefault(); if (currentValue < maxValue) { slider.value = currentValue + 1; slider.dispatchEvent(new Event('input', { bubbles: true })); slider.dispatchEvent(new Event('change', { bubbles: true })); } } if (event.key === 'ArrowLeft' || event.key === 'a' || event.key === 'A') { event.preventDefault(); if (currentValue > 0) { slider.value = currentValue - 1; slider.dispatchEvent(new Event('input', { bubbles: true })); slider.dispatchEvent(new Event('change', { bubbles: true })); } } if (event.key === 'ArrowUp' || event.key === 'w' || event.key === 'W') { event.preventDefault(); const newValue = Math.min(currentValue + 10, maxValue); slider.value = newValue; slider.dispatchEvent(new Event('input', { bubbles: true })); slider.dispatchEvent(new Event('change', { bubbles: true })); } if (event.key === 'ArrowDown' || event.key === 's' || event.key === 'S') { event.preventDefault(); const newValue = Math.max(currentValue - 10, 0); slider.value = newValue; slider.dispatchEvent(new Event('input', { bubbles: true })); slider.dispatchEvent(new Event('change', { bubbles: true })); } }); } """ # Create Gradio interface with gr.Blocks(title="Video Side-by-Side Comparison", js=js_code) as demo: gr.Markdown("# Video Side-by-Side Comparison with Cached Frame Navigation") gr.Markdown("Upload two videos to compare them side by side. Frames are cached for instant navigation!") # Store video caches and frame count in state cache1_state = gr.State() cache2_state = gr.State() total_frames_state = gr.State(value=0) with gr.Row(): with gr.Column(): video1_input = gr.Video(label="Video 1") method1_input = gr.Textbox(label="Method 1 Name", value="Ground Truth") with gr.Column(): video2_input = gr.Video(label="Video 2") method2_input = gr.Textbox(label="Method 2 Name", value="Generated") with gr.Row(): show_frames_checkbox = gr.Checkbox(label="Show Frame Numbers", value=True) preload_checkbox = gr.Checkbox(label="Preload All Frames (for videos < 1000 frames)", value=False) compare_btn = gr.Button("Create Comparison Video & Load Frames", variant="primary") status_text = gr.Textbox(label="Status", interactive=False) output_video = gr.Video(label="Side-by-Side Comparison Video") gr.Markdown("## Frame-by-Frame Navigation (Cached)") gr.Markdown(""" **Keyboard shortcuts:** - ← / → (or A / D): Previous/Next frame - ↑ / ↓ (or W / S): Jump 10 frames backward/forward **Note:** First 50 frames are preloaded immediately. Additional frames load on-demand and are cached for instant replay. """) frame_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Frame Number") with gr.Row(): prev_btn = gr.Button("← Previous Frame") next_btn = gr.Button("Next Frame →") frame_display = gr.Image(label="Current Frame Comparison") # Create video comparison def process_videos(video1, video2, method1, method2, show_frames, preload_all): result = create_side_by_side_video(video1, video2, method1, method2, show_frames, preload_all) return result[0], result[1], result[2], result[3], result[4], result[5], result[6] compare_btn.click( fn=process_videos, inputs=[video1_input, video2_input, method1_input, method2_input, show_frames_checkbox, preload_checkbox], outputs=[output_video, frame_display, frame_slider, total_frames_state, cache1_state, cache2_state, status_text] ) # Frame navigation frame_slider.change( fn=update_frame, inputs=[cache1_state, cache2_state, frame_slider, method1_input, method2_input, show_frames_checkbox, total_frames_state], outputs=frame_display ) def prev_frame(current_frame): return max(0, current_frame - 1) def next_frame(current_frame, total): return min(total - 1, current_frame + 1) prev_btn.click( fn=prev_frame, inputs=[frame_slider], outputs=frame_slider ) next_btn.click( fn=next_frame, inputs=[frame_slider, total_frames_state], outputs=frame_slider ) if __name__ == "__main__": demo.launch()