Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Shared video processing utilities for Gradio demos. | |
| Provides validation, frame-by-frame inference processing, and Gradio UI helpers | |
| that are common across all video-based demos. | |
| """ | |
| import os | |
| import tempfile | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Callable | |
| from fractions import Fraction | |
| import av | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| InferenceFn = Callable[[np.ndarray], np.ndarray] | |
| def get_example_videos(examples_dir: str) -> list[list[str]]: | |
| """Return Gradio-style example rows for every ``*.mp4`` in *examples_dir*. | |
| The caller MUST pass an absolute path. Relative paths used to be tolerated | |
| via a ``"./examples"`` default, but EveWrapper now leaves the process cwd | |
| pointing at the EVE bin directory (required for object-detection model | |
| loading), which would silently resolve relative paths against the wrong | |
| directory. Anchor on ``Path(__file__).resolve().parent`` of the calling | |
| demo. | |
| Args: | |
| examples_dir: Absolute path to a directory containing ``.mp4`` files. | |
| Returns: | |
| Sorted list of single-element lists, each holding one video path. | |
| """ | |
| return sorted([[str(p)] for p in Path(examples_dir).glob("*.mp4")]) | |
| class VideoLimits: | |
| """Constraints applied to uploaded videos.""" | |
| max_duration_seconds: int = 30 | |
| max_width: int = 3840 | |
| max_height: int = 2160 | |
| max_file_size_mb: int = 5000 | |
| def max_file_size_bytes(self) -> int: | |
| return self.max_file_size_mb * 1024 * 1024 | |
| DEFAULT_LIMITS = VideoLimits() | |
| def build_video_constraints_accordion(limits: VideoLimits) -> None: | |
| """Render a collapsible accordion showing the video upload constraints. | |
| Must be called inside a Gradio layout context (e.g. a Tab or Column). | |
| Args: | |
| limits: The constraints to display. | |
| """ | |
| file_size_text = ( | |
| f"{limits.max_file_size_mb / 1000:.0f} GB" | |
| if limits.max_file_size_mb >= 1000 | |
| else f"{limits.max_file_size_mb} MB" | |
| ) | |
| with gr.Accordion("Video Constraints", open=False): | |
| gr.Markdown( | |
| f"- Maximum video length: **{limits.max_duration_seconds} seconds**\n" | |
| f"- Maximum video resolution: **{limits.max_width} x {limits.max_height}**\n" | |
| f"- Maximum file size: **{file_size_text}**\n\n" | |
| "Videos exceeding these limits will not be uploaded." | |
| ) | |
| def validate_video(video_path: str, limits: VideoLimits = DEFAULT_LIMITS) -> None: | |
| """Validate a video file against size, resolution, and duration limits. | |
| Args: | |
| video_path: Path to the video file. | |
| limits: Constraints to validate against. | |
| Raises: | |
| gr.Error: If any limit is exceeded or the video cannot be opened. | |
| """ | |
| size = os.path.getsize(video_path) | |
| if size > limits.max_file_size_bytes: | |
| megabytes = size / (1024 * 1024) | |
| raise gr.Error( | |
| f"File size ({megabytes:.1f} MB) exceeds the maximum allowed size " | |
| f"of {limits.max_file_size_mb} MB." | |
| ) | |
| video_capture = cv2.VideoCapture(video_path) | |
| if not video_capture.isOpened(): | |
| raise gr.Error("Could not open the uploaded video for validation.") | |
| try: | |
| width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # Allow both landscape and portrait orientations | |
| long_side = max(width, height) | |
| short_side = min(width, height) | |
| max_long = max(limits.max_width, limits.max_height) | |
| max_short = min(limits.max_width, limits.max_height) | |
| if long_side > max_long or short_side > max_short: | |
| raise gr.Error( | |
| f"Video resolution ({width}x{height}) exceeds the maximum allowed " | |
| f"resolution of {limits.max_width}x{limits.max_height}." | |
| ) | |
| fps = video_capture.get(cv2.CAP_PROP_FPS) | |
| total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if fps > 0 and total_frames > 0: | |
| duration = total_frames / fps | |
| else: | |
| scan_frames = 0 | |
| duration_ms = 0.0 | |
| while True: | |
| ret, _ = video_capture.read() | |
| if not ret: | |
| break | |
| scan_frames += 1 | |
| duration_ms = video_capture.get(cv2.CAP_PROP_POS_MSEC) | |
| duration = duration_ms / 1000.0 if duration_ms > 0 else 0.0 | |
| if duration > limits.max_duration_seconds: | |
| raise gr.Error( | |
| f"Video duration ({duration:.1f}s) exceeds the maximum allowed length of {limits.max_duration_seconds} seconds." | |
| ) | |
| finally: | |
| video_capture.release() | |
| def validate_and_update( | |
| video_path: str, limits: VideoLimits = DEFAULT_LIMITS | |
| ) -> tuple[dict[str, Any] | None, dict[str, Any]]: | |
| """Gradio upload handler: validate video and update component interactivity. | |
| Args: | |
| video_path: Path to the uploaded video file. | |
| limits: Constraints to validate against. | |
| Returns: | |
| Tuple of (video component update, button component update). | |
| """ | |
| if not video_path: | |
| return None, gr.update(interactive=False) | |
| try: | |
| validate_video(video_path, limits) | |
| return gr.update(), gr.update(interactive=True) | |
| except Exception as error: | |
| gr.Warning(str(error), duration=None) | |
| return None, gr.update(interactive=False) | |
| def load_example(sample_index: list[Any]) -> tuple[str, dict[str, Any]]: | |
| """Gradio dataset click handler: load an example video and enable the run button. | |
| Args: | |
| sample_index: List where the first element is the video path. | |
| Returns: | |
| Tuple of (video path, button component update). | |
| """ | |
| return sample_index[0], gr.update(interactive=True) | |
| def process_video( | |
| input_video: str, | |
| inference_fn: InferenceFn, | |
| progress: gr.Progress = gr.Progress(), | |
| output_dir: str | None = None, | |
| ) -> str | None: | |
| """Process a video frame-by-frame through an inference function. | |
| Args: | |
| input_video: Path to the input video file. | |
| inference_fn: Callable that takes a BGR frame and returns a processed BGR frame. | |
| progress: Gradio progress tracker. | |
| output_dir: Directory for the output video. Falls back to the system temp dir. | |
| Returns: | |
| Path to the output video file, or None if input is None. | |
| Raises: | |
| gr.Error: If the video cannot be opened or contains no frames. | |
| """ | |
| if input_video is None: | |
| return None | |
| video_capture = cv2.VideoCapture(input_video) | |
| if not video_capture.isOpened(): | |
| raise gr.Error("Could not open the uploaded video.") | |
| fps = video_capture.get(cv2.CAP_PROP_FPS) | |
| width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # Webcam-recorded videos (especially WebM from browsers) can report bogus | |
| # metadata — fps=0, fps=1000 (ms-based timestamps), or missing frame counts. | |
| if fps <= 0 or fps > 240 or total_frames <= 0: | |
| total_frames = 0 | |
| duration_ms = 0.0 | |
| while True: | |
| return_value, _ = video_capture.read() | |
| if not return_value: | |
| break | |
| total_frames += 1 | |
| duration_ms = video_capture.get(cv2.CAP_PROP_POS_MSEC) | |
| fps = total_frames / (duration_ms / 1000.0) if duration_ms > 0 else 30.0 | |
| video_capture.release() | |
| output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False, dir=output_dir) | |
| output_path = output_file.name | |
| output_file.close() | |
| # H.264 via PyAV (wraps FFmpeg) — works on all platforms without codec hunting. | |
| # movflags=faststart moves the moov atom to the beginning of the file so | |
| # browsers can read duration/metadata without needing HTTP Range requests. | |
| container = av.open(output_path, mode="w", options={"movflags": "faststart"}) | |
| stream = container.add_stream("libx264", rate=round(fps)) | |
| stream.time_base = Fraction(1, round(fps)) | |
| stream.width = width | |
| stream.height = height | |
| stream.pix_fmt = "yuv420p" | |
| # Baseline profile for maximum browser compatibility | |
| stream.codec_context.options = {"profile": "baseline", "level": "3.1"} | |
| video_capture = cv2.VideoCapture(input_video) | |
| frame_count = 0 | |
| try: | |
| while True: | |
| return_value, frame = video_capture.read() | |
| if not return_value: | |
| break | |
| output_image = inference_fn(frame) | |
| progress( | |
| (frame_count, total_frames), | |
| desc=f"Processing Frame {frame_count + 1}/{total_frames}", | |
| ) | |
| video_frame = av.VideoFrame.from_ndarray(output_image, format="bgr24") | |
| video_frame.pts = frame_count | |
| for packet in stream.encode(video_frame): | |
| container.mux(packet) | |
| frame_count += 1 | |
| # Flush the encoder | |
| for packet in stream.encode(): | |
| container.mux(packet) | |
| finally: | |
| video_capture.release() | |
| container.close() | |
| if frame_count == 0: | |
| raise gr.Error("No frames were read from the video.") | |
| return output_path | |
| def reencode_video(input_path: str, output_path: str) -> None: | |
| """Re-encode a video to H.264 MP4 with proper time_base metadata. | |
| Browser webcam recordings (WebM via MediaRecorder) carry unreliable | |
| metadata: ``fps`` is often 0 or 1000 (ms-based timestamps), frame count | |
| is missing, and duration reports as Infinity in HTML5 players — which | |
| blocks scrubbing because ``currentTime`` never advances past 0. | |
| Re-encoding to CFR H.264 with an explicit ``stream.time_base`` gives | |
| the output a well-defined duration so players can seek normally. | |
| Args: | |
| input_path: Source video file. | |
| output_path: Destination MP4 path. | |
| Raises: | |
| gr.Error: If the source cannot be opened or contains no frames. | |
| """ | |
| cap = cv2.VideoCapture(input_path) | |
| if not cap.isOpened(): | |
| raise gr.Error(f"Could not open recorded video: {input_path}") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if fps <= 0 or fps > 240 or total_frames <= 0 or width <= 0 or height <= 0: | |
| scanned = 0 | |
| duration_ms = 0.0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if scanned == 0: | |
| height, width = frame.shape[:2] | |
| scanned += 1 | |
| duration_ms = cap.get(cv2.CAP_PROP_POS_MSEC) | |
| if scanned == 0: | |
| cap.release() | |
| raise gr.Error("No frames were read from the recorded video.") | |
| if fps <= 0 or fps > 240: | |
| fps = scanned / (duration_ms / 1000.0) if duration_ms > 0 else 30.0 | |
| cap.release() | |
| cap = cv2.VideoCapture(input_path) | |
| container = av.open(output_path, mode="w", options={"movflags": "faststart"}) | |
| stream = container.add_stream("libx264", rate=round(fps)) | |
| stream.time_base = Fraction(1, round(fps)) | |
| stream.width = width | |
| stream.height = height | |
| stream.pix_fmt = "yuv420p" | |
| stream.codec_context.options = {"profile": "baseline", "level": "3.1"} | |
| frame_count = 0 | |
| try: | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| video_frame = av.VideoFrame.from_ndarray(frame, format="bgr24") | |
| video_frame.pts = frame_count | |
| for packet in stream.encode(video_frame): | |
| container.mux(packet) | |
| frame_count += 1 | |
| for packet in stream.encode(): | |
| container.mux(packet) | |
| finally: | |
| cap.release() | |
| container.close() | |
| if frame_count == 0: | |
| raise gr.Error("No frames were read from the recorded video.") | |
| def _build_recording_timer_js(elem_id: str, max_duration_seconds: int) -> tuple[str, str]: | |
| """Build client-side JS for a recording elapsed-time overlay and auto-stop. | |
| Returns a ``(start_js, stop_js)`` pair of zero-argument JS function | |
| strings suitable for the ``js`` parameter of Gradio event handlers. | |
| The timer is displayed as a pill overlay inside the webcam ``<video>`` | |
| wrapper. When *max_duration_seconds* elapses the stop-recording button | |
| is clicked programmatically to enforce the limit on the client side. | |
| Args: | |
| elem_id: The ``elem_id`` of the ``gr.Video`` component. | |
| max_duration_seconds: Recording ceiling in seconds. | |
| Returns: | |
| Tuple of (start_js, stop_js) function body strings. | |
| """ | |
| start_js = ( | |
| "() => {" | |
| " const E = '%(eid)s', M = %(max)d;" | |
| " const C = document.getElementById(E);" | |
| " if (!C) return;" | |
| " window.__recT = window.__recT || {};" | |
| " if (window.__recT[E]) { clearInterval(window.__recT[E]); delete window.__recT[E]; }" | |
| " const W = C.querySelector('[data-testid=\"video\"] .wrap') || C.querySelector('.wrap');" | |
| " if (!W) return;" | |
| " let t = document.getElementById(E + '-rt');" | |
| " if (!t) {" | |
| " t = document.createElement('div');" | |
| " t.id = E + '-rt';" | |
| " t.style.cssText = " | |
| " 'position:absolute;top:8px;left:50%%;transform:translateX(-50%%);z-index:100;" | |
| " background:rgba(220,38,38,0.85);color:#fff;padding:4px 12px;" | |
| " border-radius:20px;font:bold 14px/1.4 monospace;" | |
| " display:flex;align-items:center;gap:6px;pointer-events:none';" | |
| " W.appendChild(t);" | |
| " }" | |
| " t.style.display = 'flex';" | |
| " const s = Date.now();" | |
| " const pad = n => String(Math.floor(n)).padStart(2, '0');" | |
| " const mm = Math.floor(M / 60), ms = M %% 60;" | |
| " const lbl = mm + ':' + pad(ms);" | |
| " const up = () => {" | |
| " const e = (Date.now() - s) / 1000;" | |
| " const em = Math.floor(e / 60), es = Math.floor(e %% 60);" | |
| " const warn = (M - e) <= Math.min(10, M * 0.2);" | |
| " t.style.background = warn" | |
| " ? 'rgba(185,28,28,0.95)' : 'rgba(220,38,38,0.85)';" | |
| " t.textContent = '\\u25CF ' + em + ':' + pad(es) + ' / ' + lbl;" | |
| " if (e >= M) {" | |
| " const si = C.querySelector('[title=\"stop recording\"]');" | |
| " if (si) { const b = si.closest('button'); if (b) b.click(); }" | |
| " clearInterval(window.__recT[E]); delete window.__recT[E];" | |
| " t.style.display = 'none';" | |
| " }" | |
| " };" | |
| " up();" | |
| " window.__recT[E] = setInterval(up, 250);" | |
| "}" | |
| ) % {"eid": elem_id, "max": max_duration_seconds} | |
| stop_js = ( | |
| "() => {" | |
| " const E = '%(eid)s';" | |
| " if (window.__recT && window.__recT[E])" | |
| " { clearInterval(window.__recT[E]); delete window.__recT[E]; }" | |
| " const t = document.getElementById(E + '-rt');" | |
| " if (t) t.style.display = 'none';" | |
| "}" | |
| ) % {"eid": elem_id} | |
| return start_js, stop_js | |
| def wire_recording_limits(video: gr.Video, max_duration_seconds: int) -> None: | |
| """Add a client-side elapsed-time overlay and auto-stop to a Video component. | |
| Must be called inside the same ``gr.Blocks`` context as *video*. | |
| Injects JavaScript that: | |
| * Displays a recording timer pill (e.g. ``0:15 / 2:00``) over the | |
| webcam feed while recording. | |
| * Automatically clicks the stop-recording button when the limit is | |
| reached so the user cannot exceed it. | |
| Args: | |
| video: The ``gr.Video`` component to augment. | |
| max_duration_seconds: Maximum recording duration in seconds. | |
| """ | |
| if not video.elem_id: | |
| video.elem_id = f"_rv_{id(video)}" | |
| start_js, stop_js = _build_recording_timer_js(video.elem_id, max_duration_seconds) | |
| video.start_recording(fn=None, js=start_js) | |
| video.stop_recording(fn=None, js=stop_js) | |
| video.clear(fn=None, js=stop_js) | |
| def wire_video_upload( | |
| input_video: gr.Video, | |
| output_video: gr.Video, | |
| process_btn: gr.Button, | |
| example_dataset: gr.Dataset | None = None, | |
| limits: VideoLimits = DEFAULT_LIMITS, | |
| ) -> None: | |
| """Wire the standard video upload/clear/example events. | |
| Connects validation on upload, reset on clear, and example loading. | |
| Also adds a client-side recording timer and auto-stop based on | |
| *limits.max_duration_seconds*. | |
| The caller is responsible for wiring ``process_btn.click`` to their | |
| own inference function. | |
| Args: | |
| input_video: The upload video component. | |
| output_video: The output video component (cleared on input clear). | |
| process_btn: The run button (enabled/disabled automatically). | |
| example_dataset: Optional example dataset component. | |
| limits: Video validation limits. | |
| """ | |
| wire_recording_limits(input_video, limits.max_duration_seconds) | |
| input_video.upload( | |
| fn=lambda v: validate_and_update(v, limits), | |
| inputs=input_video, | |
| outputs=[input_video, process_btn], | |
| ) | |
| input_video.stop_recording( | |
| fn=validate_and_update, inputs=input_video, outputs=[input_video, process_btn] | |
| ) | |
| input_video.clear( | |
| fn=lambda: (gr.update(interactive=False), None), | |
| outputs=[process_btn, output_video], | |
| ) | |
| if example_dataset is not None: | |
| example_dataset.click( | |
| fn=load_example, | |
| inputs=[example_dataset], | |
| outputs=[input_video, process_btn], | |
| ) | |