Kacper Kania
Add caching
424a598
import gradio as gr
import subprocess
import tempfile
import cv2
import numpy as np
from PIL import Image
from functools import lru_cache
import threading
class VideoFrameCache:
"""Cache for video frames with lazy loading"""
def __init__(self, video_path, max_cache_size=500):
self.video_path = video_path
self.cache = {}
self.max_cache_size = max_cache_size
self.total_frames = self._get_frame_count()
self.lock = threading.Lock()
def _get_frame_count(self):
cap = cv2.VideoCapture(self.video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
return total
def get_frame(self, frame_num):
"""Get frame from cache or load it"""
with self.lock:
if frame_num in self.cache:
return self.cache[frame_num]
# Load frame
cap = cv2.VideoCapture(self.video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
ret, frame = cap.read()
cap.release()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Cache management - remove oldest if cache is full
if len(self.cache) >= self.max_cache_size:
# Remove first item (simple FIFO)
first_key = next(iter(self.cache))
del self.cache[first_key]
self.cache[frame_num] = frame_rgb
return frame_rgb
return None
def preload_range(self, start_frame, end_frame):
"""Preload a range of frames"""
for i in range(start_frame, min(end_frame, self.total_frames)):
if i not in self.cache:
self.get_frame(i)
def preload_all(self):
"""Preload all frames (use with caution for long videos)"""
print(f"Preloading {self.total_frames} frames...")
for i in range(self.total_frames):
self.get_frame(i)
if i % 100 == 0:
print(f"Loaded {i}/{self.total_frames} frames...")
print("All frames loaded!")
def get_frame_count(video_path):
"""Get total number of frames in video"""
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
return total_frames
def create_side_by_side_frame(cache1, cache2, frame_num, method1, method2, show_frame_numbers, total_frames):
"""Create side-by-side image for a specific frame using cache"""
if cache1 is None or cache2 is None:
return None
frame1 = cache1.get_frame(frame_num)
frame2 = cache2.get_frame(frame_num)
if frame1 is None or frame2 is None:
return None
# Preload next few frames in background for smooth navigation
def preload_nearby():
for offset in range(1, 20):
if frame_num + offset < total_frames:
cache1.get_frame(frame_num + offset)
cache2.get_frame(frame_num + offset)
thread = threading.Thread(target=preload_nearby, daemon=True)
thread.start()
# Convert to PIL for drawing text
img1 = Image.fromarray(frame1)
img2 = Image.fromarray(frame2)
# Create side-by-side image
width = img1.width + img2.width
height = max(img1.height, img2.height)
result = Image.new('RGB', (width, height))
result.paste(img1, (0, 0))
result.paste(img2, (img1.width, 0))
# Convert back to array for text overlay
result_array = np.array(result)
# Add text overlays
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.8
thickness = 2
# Method names
cv2.putText(result_array, method1, (10, 30), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
cv2.putText(result_array, method2, (img1.width + 10, 30), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
# Frame numbers
if show_frame_numbers:
frame_text = f'Frame {frame_num}/{total_frames}'
cv2.putText(result_array, frame_text, (10, height - 20), font, 0.6, (255, 255, 255), thickness, cv2.LINE_AA)
cv2.putText(result_array, frame_text, (img1.width + 10, height - 20), font, 0.6, (255, 255, 255), thickness, cv2.LINE_AA)
return result_array
def create_side_by_side_video(video1, video2, method1="Method 1", method2="Method 2", show_frame_numbers=True, preload_all_frames=False):
"""Create side-by-side video comparison and initialize caches"""
if video1 is None or video2 is None:
return None, None, 0, 0, None, None, "No videos loaded"
# Get frame count for slider
total_frames = get_frame_count(video1)
# Create frame caches
print("Creating frame caches...")
cache1 = VideoFrameCache(video1, max_cache_size=500)
cache2 = VideoFrameCache(video2, max_cache_size=500)
status_msg = f"Videos loaded. Total frames: {total_frames}\n"
# Preload first frames for instant display
print("Preloading initial frames...")
cache1.preload_range(0, 50)
cache2.preload_range(0, 50)
status_msg += "First 50 frames cached.\n"
# Optionally preload all frames
if preload_all_frames and total_frames < 1000: # Safety limit
def preload_bg():
cache1.preload_all()
cache2.preload_all()
thread = threading.Thread(target=preload_bg, daemon=True)
thread.start()
status_msg += "Loading all frames in background...\n"
# Create output file
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
# Get total frames for display
try:
result = subprocess.run(
['ffprobe', '-v', 'error', '-select_streams', 'v:0',
'-count_frames', '-show_entries', 'stream=nb_read_frames',
'-of', 'csv=p=0', video1],
capture_output=True, text=True
)
total_frames_str = result.stdout.strip()
except:
total_frames_str = str(total_frames)
# Build filter complex
if show_frame_numbers:
frame_text = f'Frame %{{n}}/{total_frames_str}'
filter_complex = f"""
[0:v]drawtext=text='{method1}':fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=10,
drawtext=text='{frame_text}':fontsize=20:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=h-40[v0];
[1:v]drawtext=text='{method2}':fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=10,
drawtext=text='{frame_text}':fontsize=20:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=h-40[v1];
[v0][v1]hstack=inputs=2[v]
"""
else:
filter_complex = f"""
[0:v]drawtext=text='{method1}':fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=10[v0];
[1:v]drawtext=text='{method2}':fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5:boxborderw=5:x=10:y=10[v1];
[v0][v1]hstack=inputs=2[v]
"""
# Run ffmpeg
print("Creating side-by-side video...")
cmd = [
'ffmpeg', '-y', '-i', video1, '-i', video2,
'-filter_complex', filter_complex,
'-map', '[v]', '-map', '0:a?',
'-c:v', 'libx264', '-crf', '18',
'-c:a', 'aac', '-shortest',
output_file
]
try:
subprocess.run(cmd, check=True, capture_output=True)
# Get first frame for preview
first_frame = create_side_by_side_frame(cache1, cache2, 0, method1, method2, show_frame_numbers, total_frames)
status_msg += "Video created successfully!"
return output_file, first_frame, gr.Slider(maximum=total_frames-1, value=0), total_frames, cache1, cache2, status_msg
except subprocess.CalledProcessError as e:
print(f"Error: {e.stderr.decode()}")
return None, None, gr.Slider(maximum=100, value=0), 0, None, None, f"Error: {e.stderr.decode()}"
def update_frame(cache1, cache2, frame_num, method1, method2, show_frames, total_frames):
"""Update the displayed frame using cache"""
if cache1 is None or cache2 is None or total_frames == 0:
return None
return create_side_by_side_frame(cache1, cache2, int(frame_num), method1, method2, show_frames, total_frames)
# JavaScript for keyboard navigation
js_code = """
function() {
document.addEventListener('keydown', function(event) {
const slider = document.querySelector('input[type="range"]');
if (!slider) return;
const currentValue = parseInt(slider.value);
const maxValue = parseInt(slider.max);
if (event.key === 'ArrowRight' || event.key === 'd' || event.key === 'D') {
event.preventDefault();
if (currentValue < maxValue) {
slider.value = currentValue + 1;
slider.dispatchEvent(new Event('input', { bubbles: true }));
slider.dispatchEvent(new Event('change', { bubbles: true }));
}
}
if (event.key === 'ArrowLeft' || event.key === 'a' || event.key === 'A') {
event.preventDefault();
if (currentValue > 0) {
slider.value = currentValue - 1;
slider.dispatchEvent(new Event('input', { bubbles: true }));
slider.dispatchEvent(new Event('change', { bubbles: true }));
}
}
if (event.key === 'ArrowUp' || event.key === 'w' || event.key === 'W') {
event.preventDefault();
const newValue = Math.min(currentValue + 10, maxValue);
slider.value = newValue;
slider.dispatchEvent(new Event('input', { bubbles: true }));
slider.dispatchEvent(new Event('change', { bubbles: true }));
}
if (event.key === 'ArrowDown' || event.key === 's' || event.key === 'S') {
event.preventDefault();
const newValue = Math.max(currentValue - 10, 0);
slider.value = newValue;
slider.dispatchEvent(new Event('input', { bubbles: true }));
slider.dispatchEvent(new Event('change', { bubbles: true }));
}
});
}
"""
# Create Gradio interface
with gr.Blocks(title="Video Side-by-Side Comparison", js=js_code) as demo:
gr.Markdown("# Video Side-by-Side Comparison with Cached Frame Navigation")
gr.Markdown("Upload two videos to compare them side by side. Frames are cached for instant navigation!")
# Store video caches and frame count in state
cache1_state = gr.State()
cache2_state = gr.State()
total_frames_state = gr.State(value=0)
with gr.Row():
with gr.Column():
video1_input = gr.Video(label="Video 1")
method1_input = gr.Textbox(label="Method 1 Name", value="Ground Truth")
with gr.Column():
video2_input = gr.Video(label="Video 2")
method2_input = gr.Textbox(label="Method 2 Name", value="Generated")
with gr.Row():
show_frames_checkbox = gr.Checkbox(label="Show Frame Numbers", value=True)
preload_checkbox = gr.Checkbox(label="Preload All Frames (for videos < 1000 frames)", value=False)
compare_btn = gr.Button("Create Comparison Video & Load Frames", variant="primary")
status_text = gr.Textbox(label="Status", interactive=False)
output_video = gr.Video(label="Side-by-Side Comparison Video")
gr.Markdown("## Frame-by-Frame Navigation (Cached)")
gr.Markdown("""
**Keyboard shortcuts:**
- ← / → (or A / D): Previous/Next frame
- ↑ / ↓ (or W / S): Jump 10 frames backward/forward
**Note:** First 50 frames are preloaded immediately. Additional frames load on-demand and are cached for instant replay.
""")
frame_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Frame Number")
with gr.Row():
prev_btn = gr.Button("← Previous Frame")
next_btn = gr.Button("Next Frame →")
frame_display = gr.Image(label="Current Frame Comparison")
# Create video comparison
def process_videos(video1, video2, method1, method2, show_frames, preload_all):
result = create_side_by_side_video(video1, video2, method1, method2, show_frames, preload_all)
return result[0], result[1], result[2], result[3], result[4], result[5], result[6]
compare_btn.click(
fn=process_videos,
inputs=[video1_input, video2_input, method1_input, method2_input, show_frames_checkbox, preload_checkbox],
outputs=[output_video, frame_display, frame_slider, total_frames_state, cache1_state, cache2_state, status_text]
)
# Frame navigation
frame_slider.change(
fn=update_frame,
inputs=[cache1_state, cache2_state, frame_slider, method1_input, method2_input, show_frames_checkbox, total_frames_state],
outputs=frame_display
)
def prev_frame(current_frame):
return max(0, current_frame - 1)
def next_frame(current_frame, total):
return min(total - 1, current_frame + 1)
prev_btn.click(
fn=prev_frame,
inputs=[frame_slider],
outputs=frame_slider
)
next_btn.click(
fn=next_frame,
inputs=[frame_slider, total_frames_state],
outputs=frame_slider
)
if __name__ == "__main__":
demo.launch()