sensAI-Generic-Object-Detection / shared /video_processing.py
beaupreda's picture
Upload sensAI-Generic-Object-Detection with upload_repo.py
84ad3cf verified
Raw
History Blame Contribute Delete
18.2 kB
"""Shared video processing utilities for Gradio demos.
Provides validation, frame-by-frame inference processing, and Gradio UI helpers
that are common across all video-based demos.
"""
import os
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable
from fractions import Fraction
import av
import cv2
import gradio as gr
import numpy as np
InferenceFn = Callable[[np.ndarray], np.ndarray]
def get_example_videos(examples_dir: str) -> list[list[str]]:
"""Return Gradio-style example rows for every ``*.mp4`` in *examples_dir*.
The caller MUST pass an absolute path. Relative paths used to be tolerated
via a ``"./examples"`` default, but EveWrapper now leaves the process cwd
pointing at the EVE bin directory (required for object-detection model
loading), which would silently resolve relative paths against the wrong
directory. Anchor on ``Path(__file__).resolve().parent`` of the calling
demo.
Args:
examples_dir: Absolute path to a directory containing ``.mp4`` files.
Returns:
Sorted list of single-element lists, each holding one video path.
"""
return sorted([[str(p)] for p in Path(examples_dir).glob("*.mp4")])
@dataclass
class VideoLimits:
"""Constraints applied to uploaded videos."""
max_duration_seconds: int = 30
max_width: int = 3840
max_height: int = 2160
max_file_size_mb: int = 5000
@property
def max_file_size_bytes(self) -> int:
return self.max_file_size_mb * 1024 * 1024
DEFAULT_LIMITS = VideoLimits()
def build_video_constraints_accordion(limits: VideoLimits) -> None:
"""Render a collapsible accordion showing the video upload constraints.
Must be called inside a Gradio layout context (e.g. a Tab or Column).
Args:
limits: The constraints to display.
"""
file_size_text = (
f"{limits.max_file_size_mb / 1000:.0f} GB"
if limits.max_file_size_mb >= 1000
else f"{limits.max_file_size_mb} MB"
)
with gr.Accordion("Video Constraints", open=False):
gr.Markdown(
f"- Maximum video length: **{limits.max_duration_seconds} seconds**\n"
f"- Maximum video resolution: **{limits.max_width} x {limits.max_height}**\n"
f"- Maximum file size: **{file_size_text}**\n\n"
"Videos exceeding these limits will not be uploaded."
)
def validate_video(video_path: str, limits: VideoLimits = DEFAULT_LIMITS) -> None:
"""Validate a video file against size, resolution, and duration limits.
Args:
video_path: Path to the video file.
limits: Constraints to validate against.
Raises:
gr.Error: If any limit is exceeded or the video cannot be opened.
"""
size = os.path.getsize(video_path)
if size > limits.max_file_size_bytes:
megabytes = size / (1024 * 1024)
raise gr.Error(
f"File size ({megabytes:.1f} MB) exceeds the maximum allowed size "
f"of {limits.max_file_size_mb} MB."
)
video_capture = cv2.VideoCapture(video_path)
if not video_capture.isOpened():
raise gr.Error("Could not open the uploaded video for validation.")
try:
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Allow both landscape and portrait orientations
long_side = max(width, height)
short_side = min(width, height)
max_long = max(limits.max_width, limits.max_height)
max_short = min(limits.max_width, limits.max_height)
if long_side > max_long or short_side > max_short:
raise gr.Error(
f"Video resolution ({width}x{height}) exceeds the maximum allowed "
f"resolution of {limits.max_width}x{limits.max_height}."
)
fps = video_capture.get(cv2.CAP_PROP_FPS)
total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
if fps > 0 and total_frames > 0:
duration = total_frames / fps
else:
scan_frames = 0
duration_ms = 0.0
while True:
ret, _ = video_capture.read()
if not ret:
break
scan_frames += 1
duration_ms = video_capture.get(cv2.CAP_PROP_POS_MSEC)
duration = duration_ms / 1000.0 if duration_ms > 0 else 0.0
if duration > limits.max_duration_seconds:
raise gr.Error(
f"Video duration ({duration:.1f}s) exceeds the maximum allowed length of {limits.max_duration_seconds} seconds."
)
finally:
video_capture.release()
def validate_and_update(
video_path: str, limits: VideoLimits = DEFAULT_LIMITS
) -> tuple[dict[str, Any] | None, dict[str, Any]]:
"""Gradio upload handler: validate video and update component interactivity.
Args:
video_path: Path to the uploaded video file.
limits: Constraints to validate against.
Returns:
Tuple of (video component update, button component update).
"""
if not video_path:
return None, gr.update(interactive=False)
try:
validate_video(video_path, limits)
return gr.update(), gr.update(interactive=True)
except Exception as error:
gr.Warning(str(error), duration=None)
return None, gr.update(interactive=False)
def load_example(sample_index: list[Any]) -> tuple[str, dict[str, Any]]:
"""Gradio dataset click handler: load an example video and enable the run button.
Args:
sample_index: List where the first element is the video path.
Returns:
Tuple of (video path, button component update).
"""
return sample_index[0], gr.update(interactive=True)
def process_video(
input_video: str,
inference_fn: InferenceFn,
progress: gr.Progress = gr.Progress(),
output_dir: str | None = None,
) -> str | None:
"""Process a video frame-by-frame through an inference function.
Args:
input_video: Path to the input video file.
inference_fn: Callable that takes a BGR frame and returns a processed BGR frame.
progress: Gradio progress tracker.
output_dir: Directory for the output video. Falls back to the system temp dir.
Returns:
Path to the output video file, or None if input is None.
Raises:
gr.Error: If the video cannot be opened or contains no frames.
"""
if input_video is None:
return None
video_capture = cv2.VideoCapture(input_video)
if not video_capture.isOpened():
raise gr.Error("Could not open the uploaded video.")
fps = video_capture.get(cv2.CAP_PROP_FPS)
width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
# Webcam-recorded videos (especially WebM from browsers) can report bogus
# metadata — fps=0, fps=1000 (ms-based timestamps), or missing frame counts.
if fps <= 0 or fps > 240 or total_frames <= 0:
total_frames = 0
duration_ms = 0.0
while True:
return_value, _ = video_capture.read()
if not return_value:
break
total_frames += 1
duration_ms = video_capture.get(cv2.CAP_PROP_POS_MSEC)
fps = total_frames / (duration_ms / 1000.0) if duration_ms > 0 else 30.0
video_capture.release()
output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False, dir=output_dir)
output_path = output_file.name
output_file.close()
# H.264 via PyAV (wraps FFmpeg) — works on all platforms without codec hunting.
# movflags=faststart moves the moov atom to the beginning of the file so
# browsers can read duration/metadata without needing HTTP Range requests.
container = av.open(output_path, mode="w", options={"movflags": "faststart"})
stream = container.add_stream("libx264", rate=round(fps))
stream.time_base = Fraction(1, round(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
# Baseline profile for maximum browser compatibility
stream.codec_context.options = {"profile": "baseline", "level": "3.1"}
video_capture = cv2.VideoCapture(input_video)
frame_count = 0
try:
while True:
return_value, frame = video_capture.read()
if not return_value:
break
output_image = inference_fn(frame)
progress(
(frame_count, total_frames),
desc=f"Processing Frame {frame_count + 1}/{total_frames}",
)
video_frame = av.VideoFrame.from_ndarray(output_image, format="bgr24")
video_frame.pts = frame_count
for packet in stream.encode(video_frame):
container.mux(packet)
frame_count += 1
# Flush the encoder
for packet in stream.encode():
container.mux(packet)
finally:
video_capture.release()
container.close()
if frame_count == 0:
raise gr.Error("No frames were read from the video.")
return output_path
def reencode_video(input_path: str, output_path: str) -> None:
"""Re-encode a video to H.264 MP4 with proper time_base metadata.
Browser webcam recordings (WebM via MediaRecorder) carry unreliable
metadata: ``fps`` is often 0 or 1000 (ms-based timestamps), frame count
is missing, and duration reports as Infinity in HTML5 players — which
blocks scrubbing because ``currentTime`` never advances past 0.
Re-encoding to CFR H.264 with an explicit ``stream.time_base`` gives
the output a well-defined duration so players can seek normally.
Args:
input_path: Source video file.
output_path: Destination MP4 path.
Raises:
gr.Error: If the source cannot be opened or contains no frames.
"""
cap = cv2.VideoCapture(input_path)
if not cap.isOpened():
raise gr.Error(f"Could not open recorded video: {input_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if fps <= 0 or fps > 240 or total_frames <= 0 or width <= 0 or height <= 0:
scanned = 0
duration_ms = 0.0
while True:
ret, frame = cap.read()
if not ret:
break
if scanned == 0:
height, width = frame.shape[:2]
scanned += 1
duration_ms = cap.get(cv2.CAP_PROP_POS_MSEC)
if scanned == 0:
cap.release()
raise gr.Error("No frames were read from the recorded video.")
if fps <= 0 or fps > 240:
fps = scanned / (duration_ms / 1000.0) if duration_ms > 0 else 30.0
cap.release()
cap = cv2.VideoCapture(input_path)
container = av.open(output_path, mode="w", options={"movflags": "faststart"})
stream = container.add_stream("libx264", rate=round(fps))
stream.time_base = Fraction(1, round(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
stream.codec_context.options = {"profile": "baseline", "level": "3.1"}
frame_count = 0
try:
while True:
ret, frame = cap.read()
if not ret:
break
video_frame = av.VideoFrame.from_ndarray(frame, format="bgr24")
video_frame.pts = frame_count
for packet in stream.encode(video_frame):
container.mux(packet)
frame_count += 1
for packet in stream.encode():
container.mux(packet)
finally:
cap.release()
container.close()
if frame_count == 0:
raise gr.Error("No frames were read from the recorded video.")
def _build_recording_timer_js(elem_id: str, max_duration_seconds: int) -> tuple[str, str]:
"""Build client-side JS for a recording elapsed-time overlay and auto-stop.
Returns a ``(start_js, stop_js)`` pair of zero-argument JS function
strings suitable for the ``js`` parameter of Gradio event handlers.
The timer is displayed as a pill overlay inside the webcam ``<video>``
wrapper. When *max_duration_seconds* elapses the stop-recording button
is clicked programmatically to enforce the limit on the client side.
Args:
elem_id: The ``elem_id`` of the ``gr.Video`` component.
max_duration_seconds: Recording ceiling in seconds.
Returns:
Tuple of (start_js, stop_js) function body strings.
"""
start_js = (
"() => {"
" const E = '%(eid)s', M = %(max)d;"
" const C = document.getElementById(E);"
" if (!C) return;"
" window.__recT = window.__recT || {};"
" if (window.__recT[E]) { clearInterval(window.__recT[E]); delete window.__recT[E]; }"
" const W = C.querySelector('[data-testid=\"video\"] .wrap') || C.querySelector('.wrap');"
" if (!W) return;"
" let t = document.getElementById(E + '-rt');"
" if (!t) {"
" t = document.createElement('div');"
" t.id = E + '-rt';"
" t.style.cssText = "
" 'position:absolute;top:8px;left:50%%;transform:translateX(-50%%);z-index:100;"
" background:rgba(220,38,38,0.85);color:#fff;padding:4px 12px;"
" border-radius:20px;font:bold 14px/1.4 monospace;"
" display:flex;align-items:center;gap:6px;pointer-events:none';"
" W.appendChild(t);"
" }"
" t.style.display = 'flex';"
" const s = Date.now();"
" const pad = n => String(Math.floor(n)).padStart(2, '0');"
" const mm = Math.floor(M / 60), ms = M %% 60;"
" const lbl = mm + ':' + pad(ms);"
" const up = () => {"
" const e = (Date.now() - s) / 1000;"
" const em = Math.floor(e / 60), es = Math.floor(e %% 60);"
" const warn = (M - e) <= Math.min(10, M * 0.2);"
" t.style.background = warn"
" ? 'rgba(185,28,28,0.95)' : 'rgba(220,38,38,0.85)';"
" t.textContent = '\\u25CF ' + em + ':' + pad(es) + ' / ' + lbl;"
" if (e >= M) {"
" const si = C.querySelector('[title=\"stop recording\"]');"
" if (si) { const b = si.closest('button'); if (b) b.click(); }"
" clearInterval(window.__recT[E]); delete window.__recT[E];"
" t.style.display = 'none';"
" }"
" };"
" up();"
" window.__recT[E] = setInterval(up, 250);"
"}"
) % {"eid": elem_id, "max": max_duration_seconds}
stop_js = (
"() => {"
" const E = '%(eid)s';"
" if (window.__recT && window.__recT[E])"
" { clearInterval(window.__recT[E]); delete window.__recT[E]; }"
" const t = document.getElementById(E + '-rt');"
" if (t) t.style.display = 'none';"
"}"
) % {"eid": elem_id}
return start_js, stop_js
def wire_recording_limits(video: gr.Video, max_duration_seconds: int) -> None:
"""Add a client-side elapsed-time overlay and auto-stop to a Video component.
Must be called inside the same ``gr.Blocks`` context as *video*.
Injects JavaScript that:
* Displays a recording timer pill (e.g. ``0:15 / 2:00``) over the
webcam feed while recording.
* Automatically clicks the stop-recording button when the limit is
reached so the user cannot exceed it.
Args:
video: The ``gr.Video`` component to augment.
max_duration_seconds: Maximum recording duration in seconds.
"""
if not video.elem_id:
video.elem_id = f"_rv_{id(video)}"
start_js, stop_js = _build_recording_timer_js(video.elem_id, max_duration_seconds)
video.start_recording(fn=None, js=start_js)
video.stop_recording(fn=None, js=stop_js)
video.clear(fn=None, js=stop_js)
def wire_video_upload(
input_video: gr.Video,
output_video: gr.Video,
process_btn: gr.Button,
example_dataset: gr.Dataset | None = None,
limits: VideoLimits = DEFAULT_LIMITS,
) -> None:
"""Wire the standard video upload/clear/example events.
Connects validation on upload, reset on clear, and example loading.
Also adds a client-side recording timer and auto-stop based on
*limits.max_duration_seconds*.
The caller is responsible for wiring ``process_btn.click`` to their
own inference function.
Args:
input_video: The upload video component.
output_video: The output video component (cleared on input clear).
process_btn: The run button (enabled/disabled automatically).
example_dataset: Optional example dataset component.
limits: Video validation limits.
"""
wire_recording_limits(input_video, limits.max_duration_seconds)
input_video.upload(
fn=lambda v: validate_and_update(v, limits),
inputs=input_video,
outputs=[input_video, process_btn],
)
input_video.stop_recording(
fn=validate_and_update, inputs=input_video, outputs=[input_video, process_btn]
)
input_video.clear(
fn=lambda: (gr.update(interactive=False), None),
outputs=[process_btn, output_video],
)
if example_dataset is not None:
example_dataset.click(
fn=load_example,
inputs=[example_dataset],
outputs=[input_video, process_btn],
)