Kacper Kania commited on
Commit
424a598
·
1 Parent(s): c78c6b1

Add caching

Browse files
Files changed (1) hide show
  1. app.py +128 -45
app.py CHANGED
@@ -1,20 +1,67 @@
1
  import gradio as gr
2
  import subprocess
3
- import os
4
  import tempfile
5
  import cv2
6
  import numpy as np
7
  from PIL import Image
 
 
8
 
9
- def extract_frame(video_path, frame_num):
10
- """Extract a specific frame from video"""
11
- cap = cv2.VideoCapture(video_path)
12
- cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
13
- ret, frame = cap.read()
14
- cap.release()
15
- if ret:
16
- return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
17
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def get_frame_count(video_path):
20
  """Get total number of frames in video"""
@@ -23,17 +70,27 @@ def get_frame_count(video_path):
23
  cap.release()
24
  return total_frames
25
 
26
- def create_side_by_side_frame(video1, video2, frame_num, method1, method2, show_frame_numbers, total_frames):
27
- """Create side-by-side image for a specific frame"""
28
- if video1 is None or video2 is None:
29
  return None
30
 
31
- frame1 = extract_frame(video1, frame_num)
32
- frame2 = extract_frame(video2, frame_num)
33
 
34
  if frame1 is None or frame2 is None:
35
  return None
36
 
 
 
 
 
 
 
 
 
 
 
37
  # Convert to PIL for drawing text
38
  img1 = Image.fromarray(frame1)
39
  img2 = Image.fromarray(frame2)
@@ -45,7 +102,7 @@ def create_side_by_side_frame(video1, video2, frame_num, method1, method2, show_
45
  result.paste(img1, (0, 0))
46
  result.paste(img2, (img1.width, 0))
47
 
48
- # Convert back to array for text overlay (using cv2 for simplicity)
49
  result_array = np.array(result)
50
 
51
  # Add text overlays
@@ -65,14 +122,37 @@ def create_side_by_side_frame(video1, video2, frame_num, method1, method2, show_
65
 
66
  return result_array
67
 
68
- def create_side_by_side_video(video1, video2, method1="Method 1", method2="Method 2", show_frame_numbers=True):
69
- """Create side-by-side video comparison"""
70
  if video1 is None or video2 is None:
71
- return None, None, 0, 0
72
 
73
  # Get frame count for slider
74
  total_frames = get_frame_count(video1)
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # Create output file
77
  output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
78
 
@@ -106,6 +186,7 @@ drawtext=text='{frame_text}':fontsize=20:fontcolor=white:box=1:boxcolor=black@0.
106
  """
107
 
108
  # Run ffmpeg
 
109
  cmd = [
110
  'ffmpeg', '-y', '-i', video1, '-i', video2,
111
  '-filter_complex', filter_complex,
@@ -118,31 +199,30 @@ drawtext=text='{frame_text}':fontsize=20:fontcolor=white:box=1:boxcolor=black@0.
118
  try:
119
  subprocess.run(cmd, check=True, capture_output=True)
120
  # Get first frame for preview
121
- first_frame = create_side_by_side_frame(video1, video2, 0, method1, method2, show_frame_numbers, total_frames)
122
- return output_file, first_frame, gr.Slider(maximum=total_frames-1, value=0), total_frames
 
123
  except subprocess.CalledProcessError as e:
124
  print(f"Error: {e.stderr.decode()}")
125
- return None, None, gr.Slider(maximum=100, value=0), 0
126
 
127
- def update_frame(video1, video2, frame_num, method1, method2, show_frames, total_frames):
128
- """Update the displayed frame"""
129
- if video1 is None or video2 is None or total_frames == 0:
130
  return None
131
 
132
- return create_side_by_side_frame(video1, video2, int(frame_num), method1, method2, show_frames, total_frames)
133
 
134
  # JavaScript for keyboard navigation
135
  js_code = """
136
  function() {
137
  document.addEventListener('keydown', function(event) {
138
- // Get the slider element
139
  const slider = document.querySelector('input[type="range"]');
140
  if (!slider) return;
141
 
142
  const currentValue = parseInt(slider.value);
143
  const maxValue = parseInt(slider.max);
144
 
145
- // Arrow Right or D - Next frame
146
  if (event.key === 'ArrowRight' || event.key === 'd' || event.key === 'D') {
147
  event.preventDefault();
148
  if (currentValue < maxValue) {
@@ -152,7 +232,6 @@ function() {
152
  }
153
  }
154
 
155
- // Arrow Left or A - Previous frame
156
  if (event.key === 'ArrowLeft' || event.key === 'a' || event.key === 'A') {
157
  event.preventDefault();
158
  if (currentValue > 0) {
@@ -162,7 +241,6 @@ function() {
162
  }
163
  }
164
 
165
- // Arrow Up or W - Jump forward 10 frames
166
  if (event.key === 'ArrowUp' || event.key === 'w' || event.key === 'W') {
167
  event.preventDefault();
168
  const newValue = Math.min(currentValue + 10, maxValue);
@@ -171,7 +249,6 @@ function() {
171
  slider.dispatchEvent(new Event('change', { bubbles: true }));
172
  }
173
 
174
- // Arrow Down or S - Jump backward 10 frames
175
  if (event.key === 'ArrowDown' || event.key === 's' || event.key === 'S') {
176
  event.preventDefault();
177
  const newValue = Math.max(currentValue - 10, 0);
@@ -185,12 +262,12 @@ function() {
185
 
186
  # Create Gradio interface
187
  with gr.Blocks(title="Video Side-by-Side Comparison", js=js_code) as demo:
188
- gr.Markdown("# Video Side-by-Side Comparison with Frame Navigation")
189
- gr.Markdown("Upload two videos to compare them side by side with labels and frame numbers.")
190
 
191
- # Store video paths and frame count in state
192
- video1_state = gr.State()
193
- video2_state = gr.State()
194
  total_frames_state = gr.State(value=0)
195
 
196
  with gr.Row():
@@ -202,17 +279,23 @@ with gr.Blocks(title="Video Side-by-Side Comparison", js=js_code) as demo:
202
  video2_input = gr.Video(label="Video 2")
203
  method2_input = gr.Textbox(label="Method 2 Name", value="Generated")
204
 
205
- show_frames_checkbox = gr.Checkbox(label="Show Frame Numbers", value=True)
 
 
 
 
206
 
207
- compare_btn = gr.Button("Create Comparison Video", variant="primary")
208
 
209
  output_video = gr.Video(label="Side-by-Side Comparison Video")
210
 
211
- gr.Markdown("## Frame-by-Frame Navigation")
212
  gr.Markdown("""
213
  **Keyboard shortcuts:**
214
  - ← / → (or A / D): Previous/Next frame
215
  - ↑ / ↓ (or W / S): Jump 10 frames backward/forward
 
 
216
  """)
217
 
218
  frame_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Frame Number")
@@ -224,20 +307,20 @@ with gr.Blocks(title="Video Side-by-Side Comparison", js=js_code) as demo:
224
  frame_display = gr.Image(label="Current Frame Comparison")
225
 
226
  # Create video comparison
227
- def process_videos(video1, video2, method1, method2, show_frames):
228
- result = create_side_by_side_video(video1, video2, method1, method2, show_frames)
229
- return result[0], result[1], result[2], video1, video2, result[3]
230
 
231
  compare_btn.click(
232
  fn=process_videos,
233
- inputs=[video1_input, video2_input, method1_input, method2_input, show_frames_checkbox],
234
- outputs=[output_video, frame_display, frame_slider, video1_state, video2_state, total_frames_state]
235
  )
236
 
237
  # Frame navigation
238
  frame_slider.change(
239
  fn=update_frame,
240
- inputs=[video1_state, video2_state, frame_slider, method1_input, method2_input, show_frames_checkbox, total_frames_state],
241
  outputs=frame_display
242
  )
243
 
 
1
  import gradio as gr
2
  import subprocess
 
3
  import tempfile
4
  import cv2
5
  import numpy as np
6
  from PIL import Image
7
+ from functools import lru_cache
8
+ import threading
9
 
10
+ class VideoFrameCache:
11
+ """Cache for video frames with lazy loading"""
12
+ def __init__(self, video_path, max_cache_size=500):
13
+ self.video_path = video_path
14
+ self.cache = {}
15
+ self.max_cache_size = max_cache_size
16
+ self.total_frames = self._get_frame_count()
17
+ self.lock = threading.Lock()
18
+
19
+ def _get_frame_count(self):
20
+ cap = cv2.VideoCapture(self.video_path)
21
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
22
+ cap.release()
23
+ return total
24
+
25
+ def get_frame(self, frame_num):
26
+ """Get frame from cache or load it"""
27
+ with self.lock:
28
+ if frame_num in self.cache:
29
+ return self.cache[frame_num]
30
+
31
+ # Load frame
32
+ cap = cv2.VideoCapture(self.video_path)
33
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
34
+ ret, frame = cap.read()
35
+ cap.release()
36
+
37
+ if ret:
38
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
39
+
40
+ # Cache management - remove oldest if cache is full
41
+ if len(self.cache) >= self.max_cache_size:
42
+ # Remove first item (simple FIFO)
43
+ first_key = next(iter(self.cache))
44
+ del self.cache[first_key]
45
+
46
+ self.cache[frame_num] = frame_rgb
47
+ return frame_rgb
48
+
49
+ return None
50
+
51
+ def preload_range(self, start_frame, end_frame):
52
+ """Preload a range of frames"""
53
+ for i in range(start_frame, min(end_frame, self.total_frames)):
54
+ if i not in self.cache:
55
+ self.get_frame(i)
56
+
57
+ def preload_all(self):
58
+ """Preload all frames (use with caution for long videos)"""
59
+ print(f"Preloading {self.total_frames} frames...")
60
+ for i in range(self.total_frames):
61
+ self.get_frame(i)
62
+ if i % 100 == 0:
63
+ print(f"Loaded {i}/{self.total_frames} frames...")
64
+ print("All frames loaded!")
65
 
66
  def get_frame_count(video_path):
67
  """Get total number of frames in video"""
 
70
  cap.release()
71
  return total_frames
72
 
73
+ def create_side_by_side_frame(cache1, cache2, frame_num, method1, method2, show_frame_numbers, total_frames):
74
+ """Create side-by-side image for a specific frame using cache"""
75
+ if cache1 is None or cache2 is None:
76
  return None
77
 
78
+ frame1 = cache1.get_frame(frame_num)
79
+ frame2 = cache2.get_frame(frame_num)
80
 
81
  if frame1 is None or frame2 is None:
82
  return None
83
 
84
+ # Preload next few frames in background for smooth navigation
85
+ def preload_nearby():
86
+ for offset in range(1, 20):
87
+ if frame_num + offset < total_frames:
88
+ cache1.get_frame(frame_num + offset)
89
+ cache2.get_frame(frame_num + offset)
90
+
91
+ thread = threading.Thread(target=preload_nearby, daemon=True)
92
+ thread.start()
93
+
94
  # Convert to PIL for drawing text
95
  img1 = Image.fromarray(frame1)
96
  img2 = Image.fromarray(frame2)
 
102
  result.paste(img1, (0, 0))
103
  result.paste(img2, (img1.width, 0))
104
 
105
+ # Convert back to array for text overlay
106
  result_array = np.array(result)
107
 
108
  # Add text overlays
 
122
 
123
  return result_array
124
 
125
+ def create_side_by_side_video(video1, video2, method1="Method 1", method2="Method 2", show_frame_numbers=True, preload_all_frames=False):
126
+ """Create side-by-side video comparison and initialize caches"""
127
  if video1 is None or video2 is None:
128
+ return None, None, 0, 0, None, None, "No videos loaded"
129
 
130
  # Get frame count for slider
131
  total_frames = get_frame_count(video1)
132
 
133
+ # Create frame caches
134
+ print("Creating frame caches...")
135
+ cache1 = VideoFrameCache(video1, max_cache_size=500)
136
+ cache2 = VideoFrameCache(video2, max_cache_size=500)
137
+
138
+ status_msg = f"Videos loaded. Total frames: {total_frames}\n"
139
+
140
+ # Preload first frames for instant display
141
+ print("Preloading initial frames...")
142
+ cache1.preload_range(0, 50)
143
+ cache2.preload_range(0, 50)
144
+ status_msg += "First 50 frames cached.\n"
145
+
146
+ # Optionally preload all frames
147
+ if preload_all_frames and total_frames < 1000: # Safety limit
148
+ def preload_bg():
149
+ cache1.preload_all()
150
+ cache2.preload_all()
151
+
152
+ thread = threading.Thread(target=preload_bg, daemon=True)
153
+ thread.start()
154
+ status_msg += "Loading all frames in background...\n"
155
+
156
  # Create output file
157
  output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
158
 
 
186
  """
187
 
188
  # Run ffmpeg
189
+ print("Creating side-by-side video...")
190
  cmd = [
191
  'ffmpeg', '-y', '-i', video1, '-i', video2,
192
  '-filter_complex', filter_complex,
 
199
  try:
200
  subprocess.run(cmd, check=True, capture_output=True)
201
  # Get first frame for preview
202
+ first_frame = create_side_by_side_frame(cache1, cache2, 0, method1, method2, show_frame_numbers, total_frames)
203
+ status_msg += "Video created successfully!"
204
+ return output_file, first_frame, gr.Slider(maximum=total_frames-1, value=0), total_frames, cache1, cache2, status_msg
205
  except subprocess.CalledProcessError as e:
206
  print(f"Error: {e.stderr.decode()}")
207
+ return None, None, gr.Slider(maximum=100, value=0), 0, None, None, f"Error: {e.stderr.decode()}"
208
 
209
+ def update_frame(cache1, cache2, frame_num, method1, method2, show_frames, total_frames):
210
+ """Update the displayed frame using cache"""
211
+ if cache1 is None or cache2 is None or total_frames == 0:
212
  return None
213
 
214
+ return create_side_by_side_frame(cache1, cache2, int(frame_num), method1, method2, show_frames, total_frames)
215
 
216
  # JavaScript for keyboard navigation
217
  js_code = """
218
  function() {
219
  document.addEventListener('keydown', function(event) {
 
220
  const slider = document.querySelector('input[type="range"]');
221
  if (!slider) return;
222
 
223
  const currentValue = parseInt(slider.value);
224
  const maxValue = parseInt(slider.max);
225
 
 
226
  if (event.key === 'ArrowRight' || event.key === 'd' || event.key === 'D') {
227
  event.preventDefault();
228
  if (currentValue < maxValue) {
 
232
  }
233
  }
234
 
 
235
  if (event.key === 'ArrowLeft' || event.key === 'a' || event.key === 'A') {
236
  event.preventDefault();
237
  if (currentValue > 0) {
 
241
  }
242
  }
243
 
 
244
  if (event.key === 'ArrowUp' || event.key === 'w' || event.key === 'W') {
245
  event.preventDefault();
246
  const newValue = Math.min(currentValue + 10, maxValue);
 
249
  slider.dispatchEvent(new Event('change', { bubbles: true }));
250
  }
251
 
 
252
  if (event.key === 'ArrowDown' || event.key === 's' || event.key === 'S') {
253
  event.preventDefault();
254
  const newValue = Math.max(currentValue - 10, 0);
 
262
 
263
  # Create Gradio interface
264
  with gr.Blocks(title="Video Side-by-Side Comparison", js=js_code) as demo:
265
+ gr.Markdown("# Video Side-by-Side Comparison with Cached Frame Navigation")
266
+ gr.Markdown("Upload two videos to compare them side by side. Frames are cached for instant navigation!")
267
 
268
+ # Store video caches and frame count in state
269
+ cache1_state = gr.State()
270
+ cache2_state = gr.State()
271
  total_frames_state = gr.State(value=0)
272
 
273
  with gr.Row():
 
279
  video2_input = gr.Video(label="Video 2")
280
  method2_input = gr.Textbox(label="Method 2 Name", value="Generated")
281
 
282
+ with gr.Row():
283
+ show_frames_checkbox = gr.Checkbox(label="Show Frame Numbers", value=True)
284
+ preload_checkbox = gr.Checkbox(label="Preload All Frames (for videos < 1000 frames)", value=False)
285
+
286
+ compare_btn = gr.Button("Create Comparison Video & Load Frames", variant="primary")
287
 
288
+ status_text = gr.Textbox(label="Status", interactive=False)
289
 
290
  output_video = gr.Video(label="Side-by-Side Comparison Video")
291
 
292
+ gr.Markdown("## Frame-by-Frame Navigation (Cached)")
293
  gr.Markdown("""
294
  **Keyboard shortcuts:**
295
  - ← / → (or A / D): Previous/Next frame
296
  - ↑ / ↓ (or W / S): Jump 10 frames backward/forward
297
+
298
+ **Note:** First 50 frames are preloaded immediately. Additional frames load on-demand and are cached for instant replay.
299
  """)
300
 
301
  frame_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Frame Number")
 
307
  frame_display = gr.Image(label="Current Frame Comparison")
308
 
309
  # Create video comparison
310
+ def process_videos(video1, video2, method1, method2, show_frames, preload_all):
311
+ result = create_side_by_side_video(video1, video2, method1, method2, show_frames, preload_all)
312
+ return result[0], result[1], result[2], result[3], result[4], result[5], result[6]
313
 
314
  compare_btn.click(
315
  fn=process_videos,
316
+ inputs=[video1_input, video2_input, method1_input, method2_input, show_frames_checkbox, preload_checkbox],
317
+ outputs=[output_video, frame_display, frame_slider, total_frames_state, cache1_state, cache2_state, status_text]
318
  )
319
 
320
  # Frame navigation
321
  frame_slider.change(
322
  fn=update_frame,
323
+ inputs=[cache1_state, cache2_state, frame_slider, method1_input, method2_input, show_frames_checkbox, total_frames_state],
324
  outputs=frame_display
325
  )
326