"""Video frame extraction helpers for LlavaOnevision2 native video input. Helpers (decord-first / opencv-fallback decoding) are used by ``LlavaOnevision2VideoProcessor`` defined below. The helpers were ported from the training pipeline with minor cleanups: - dropped wrapper-only imports - consolidated timestamp helpers - kept decord-first / opencv-fallback decoding identical Public API: - format_timestamp(seconds) -> "MM:SS.xx" - choose_target_frames(duration, max_frames, fixed_num_frames=None, target_fps=None) -> int - select_frame_indices(frame_count, target_count) -> list[int] - smart_resize(h, w, patch_size=14, min_pixels=None, max_pixels=None, align_patch_size=None) -> (h, w) - extract_video_frames(video_path, ...) -> (frames_np, frame_indices, timestamps_dict) - extract_video_frames_to_pil(video_path, ...) -> (frames_pil, frame_indices, timestamps_dict) """ from __future__ import annotations import logging import math from typing import List, Optional, Tuple import numpy as np import torch logger = logging.getLogger(__name__) # ============================================================================= # Timestamp helpers # ============================================================================= def format_timestamp(seconds: float) -> str: minutes = int(seconds // 60) sec = seconds - minutes * 60 return f"{minutes:02d}:{sec:09.6f}" def time_str_to_seconds(t: str) -> float: """Convert ``MM:SS.xx`` back to a float number of seconds. Inverse of :func:`format_timestamp`. """ minute, sec = t.split(":") return int(minute) * 60 + float(sec) # ============================================================================= # Frame-count / index selection # ============================================================================= def choose_target_frames( duration_seconds: float, max_frames: int, fixed_num_frames: Optional[int] = None, target_fps: Optional[float] = None, ) -> int: """Choose target frame count based on video duration in seconds. Sampling strategy: - if ``target_fps`` is set, sample at that fps (capped by ``max_frames``) - elif ``fixed_num_frames`` is set, use that exact count - else duration < 10s -> 8 frames - duration < 30s -> 16 frames - otherwise -> ``max_frames`` (default 32) """ if target_fps is not None and target_fps > 0: return min(max(1, int(duration_seconds * target_fps)), max_frames) if fixed_num_frames is not None: return fixed_num_frames if duration_seconds < 10: return 8 if duration_seconds < 30: return 16 return max_frames def select_frame_indices(frame_count: int, target_count: int) -> List[int]: if frame_count <= target_count: return list(range(frame_count)) return torch.linspace(0, frame_count - 1, target_count).round().long().tolist() # ============================================================================= # Spatial resize # ============================================================================= def smart_resize(height, width, patch_size=14, min_pixels=None, max_pixels=None, align_patch_size=None): if height <= 0 or width <= 0: raise ValueError(f"Invalid size: height={height}, width={width}") factor = align_patch_size or patch_size h_bar = max(factor, int(round(height / factor) * factor)) w_bar = max(factor, int(round(width / factor) * factor)) if max_pixels and h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = math.floor(height / beta / factor) * factor w_bar = math.floor(width / beta / factor) * factor elif min_pixels and h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor w_bar = math.ceil(width * beta / factor) * factor return int(h_bar), int(w_bar) # ============================================================================= # Frame extraction (decord first, opencv fallback) # ============================================================================= def extract_video_frames( video_path: str, max_frames: int = 32, patch_size: int = 14, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, resize_frames: bool = True, fixed_num_frames: Optional[int] = None, target_fps: Optional[float] = None, ) -> Tuple[List[np.ndarray], torch.Tensor, dict]: """Extract frames from a video. Sampling rule matches :func:`choose_target_frames`. Decoding tries decord first (better codec coverage) and falls back to OpenCV. Args: video_path: path to the input video file. max_frames: cap for long videos. patch_size: vision tower patch size for alignment. min_pixels: minimum pixel budget for resize. max_pixels: maximum pixel budget for resize. resize_frames: whether to apply :func:`smart_resize` (with ``align_patch_size = patch_size * 2``, i.e. 28 for spatial_merge=2). fixed_num_frames: see :func:`choose_target_frames`. target_fps: see :func:`choose_target_frames`. Returns: Tuple of: - ``frames`` : list of RGB ``np.ndarray`` (H, W, 3), dtype uint8. - ``frame_indices`` : 1D ``torch.Tensor[int64]`` of selected indices. - ``timestamps`` : ``dict[str(frame_idx) -> "MM:SS.xx"]``. Notes: Lazy imports of ``decord`` and ``cv2`` keep the module importable in environments where neither is installed (e.g. unit tests that only exercise the helpers above). """ frames: List[np.ndarray] = [] timestamps: dict = {} frame_indices: List[int] = [] # Prefer decord because of broader codec support. try: import decord # type: ignore vr = decord.VideoReader(video_path) frame_count = len(vr) fps = vr.get_avg_fps() if not fps or fps <= 0: fps = 30.0 duration = frame_count / fps target_count = choose_target_frames( duration, max_frames, fixed_num_frames, target_fps ) selected_indices = select_frame_indices(frame_count, target_count) # One-shot batch decode + torchvision BICUBIC+antialias resize. # Mirrors qwen_vl_utils.fetch_video, replacing per-frame cv2 INTER_AREA/LINEAR. arr = vr.get_batch(selected_indices).asnumpy() # [N,H,W,3] uint8 RGB H, W = arr.shape[1], arr.shape[2] if resize_frames and (min_pixels or max_pixels): resized_h, resized_w = smart_resize( H, W, patch_size, min_pixels=min_pixels, max_pixels=max_pixels, align_patch_size=patch_size * 2, ) if (resized_h, resized_w) != (H, W): from torchvision import transforms as _T from torchvision.transforms import InterpolationMode as _IM video_t = torch.from_numpy(arr).permute(0, 3, 1, 2).contiguous() video_t = _T.functional.resize( video_t, [resized_h, resized_w], interpolation=_IM.BICUBIC, antialias=True, ) arr = video_t.permute(0, 2, 3, 1).contiguous().numpy() frames = list(arr) frame_indices = list(selected_indices) for frame_idx in selected_indices: timestamps[str(int(frame_idx))] = format_timestamp(int(frame_idx) / fps) return frames, torch.tensor(frame_indices, dtype=torch.int64), timestamps except Exception as e: logger.warning( f"decord failed to open {video_path}: {e}; falling back to OpenCV" ) # OpenCV fallback. import cv2 # type: ignore cap = cv2.VideoCapture(video_path) if not cap.isOpened(): logger.warning(f"OpenCV also failed to open video, skipped: {video_path}") return frames, torch.tensor(frame_indices, dtype=torch.int64), timestamps fps = cap.get(cv2.CAP_PROP_FPS) if not fps or fps <= 0: fps = 30.0 frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) if frame_count > 0: duration = frame_count / fps target_count = choose_target_frames( duration, max_frames, fixed_num_frames, target_fps ) selected_indices = select_frame_indices(frame_count, target_count) for frame_idx in selected_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() if not ret: continue frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if resize_frames and (min_pixels or max_pixels): resized_h, resized_w = smart_resize( frame.shape[0], frame.shape[1], patch_size, min_pixels, max_pixels, align_patch_size=patch_size * 2, ) if (resized_h, resized_w) != (frame.shape[0], frame.shape[1]): interp = ( cv2.INTER_AREA if resized_h < frame.shape[0] or resized_w < frame.shape[1] else cv2.INTER_LINEAR ) frame = cv2.resize(frame, (resized_w, resized_h), interpolation=interp) frames.append(frame) timestamps[str(frame_idx)] = format_timestamp(frame_idx / fps) frame_indices.append(frame_idx) else: # Unknown frame count: read sequentially then sample. frame_idx = 0 temp_frames: List[Tuple[int, np.ndarray]] = [] while True: ret, frame = cap.read() if not ret: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) temp_frames.append((frame_idx, frame)) frame_idx += 1 if temp_frames: duration = len(temp_frames) / fps target_count = choose_target_frames( duration, max_frames, fixed_num_frames, target_fps ) selected_indices = select_frame_indices(len(temp_frames), target_count) for idx in selected_indices: frame_idx, frame = temp_frames[idx] if resize_frames and (min_pixels or max_pixels): resized_h, resized_w = smart_resize( frame.shape[0], frame.shape[1], patch_size, min_pixels, max_pixels, align_patch_size=patch_size * 2, ) if (resized_h, resized_w) != (frame.shape[0], frame.shape[1]): interp = ( cv2.INTER_AREA if resized_h < frame.shape[0] or resized_w < frame.shape[1] else cv2.INTER_LINEAR ) frame = cv2.resize(frame, (resized_w, resized_h), interpolation=interp) frames.append(frame) timestamps[str(frame_idx)] = format_timestamp(frame_idx / fps) frame_indices.append(frame_idx) cap.release() return frames, torch.tensor(frame_indices, dtype=torch.int64), timestamps def extract_video_frames_to_pil( video_path: str, max_frames: int = 32, patch_size: int = 14, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, resize_frames: bool = True, fixed_num_frames: Optional[int] = None, target_fps: Optional[float] = None, ): """Same as :func:`extract_video_frames` but returns a list of PIL Images.""" from PIL import Image # local import: PIL is mandatory for the processor frames_np, frame_indices, timestamps = extract_video_frames( video_path=video_path, max_frames=max_frames, patch_size=patch_size, min_pixels=min_pixels, max_pixels=max_pixels, resize_frames=resize_frames, fixed_num_frames=fixed_num_frames, target_fps=target_fps, ) frames_pil = [Image.fromarray(frame) for frame in frames_np] return frames_pil, frame_indices, timestamps # ============================================================================= # patch_positions construction (row-major + 2x2 block-layout reorder) # ============================================================================= # Block-layout reorder mirroring the training pipeline, kept here so the # VideoProcessor is self-contained. def _convert_positions_to_block_layout( positions: torch.Tensor, t: int, h: int, w: int, spatial_merge_size: int = 2, ) -> torch.Tensor: """Reorder ``[t*h*w, 3]`` row-major positions to 2x2 block layout.""" sms = spatial_merge_size if sms == 1: return positions device = positions.device total = t * h * w indices = torch.arange(total, device=device).view(t, h, w) h_m, w_m = h // sms, w // sms indices = ( indices.view(t, h_m, sms, w_m, sms) .permute(0, 1, 3, 2, 4) .contiguous() .view(total) ) return positions[indices] def build_patch_positions( grid_thw: torch.Tensor, spatial_merge_size: int = 2, frame_indices: Optional[List[Optional[torch.Tensor]]] = None, ) -> torch.Tensor: """Build block-layout ``[t,h,w]`` patch positions for one or many videos/images. Args: grid_thw: ``[num_samples, 3]`` LongTensor (T, H_p, W_p) per sample. spatial_merge_size: vision tower spatial-merge size (default 2). frame_indices: optional list (one entry per row of ``grid_thw``) of real frame indices to use as the t-coordinate. Each entry should be a 1-D LongTensor of length ``T`` for that sample. When provided this matches the training pipeline, where ``t`` is the original frame number in the source video so the vision tower's 3-D RoPE encodes the actual temporal position rather than a 0..T-1 dense index. Pass ``None`` for an entry to fall back to dense ``arange(T)`` for that sample. Returns: ``[sum(T*H_p*W_p), 3]`` Int64Tensor in block layout, ready to feed ``forward(... patch_positions=...)``. """ out = [] for sample_idx, row in enumerate(grid_thw): t_v, h_v, w_v = int(row[0]), int(row[1]), int(row[2]) h_coords = torch.arange(h_v, dtype=torch.int64).repeat_interleave(w_v).repeat(t_v) w_coords = torch.arange(w_v, dtype=torch.int64).repeat(h_v).repeat(t_v) # t-coords: prefer real frame_indices (training convention) when given. sample_frame_idx = None if frame_indices is not None and sample_idx < len(frame_indices): sample_frame_idx = frame_indices[sample_idx] if sample_frame_idx is not None: fi = torch.as_tensor(sample_frame_idx, dtype=torch.int64) if fi.numel() != t_v: raise ValueError( f"frame_indices[{sample_idx}] has length {fi.numel()} but " f"grid_thw[{sample_idx}, 0] = {t_v}" ) t_coords = fi.repeat_interleave(h_v * w_v) else: # Each frame's t coordinate runs 0..t_v-1 (each value repeated h_v*w_v). t_coords = torch.arange(t_v, dtype=torch.int64).repeat_interleave(h_v * w_v) pp = torch.stack([t_coords, h_coords, w_coords], dim=1) pp = _convert_positions_to_block_layout(pp, t_v, h_v, w_v, spatial_merge_size) out.append(pp) return torch.cat(out, dim=0) # ============================================================================= # LlavaOnevision2VideoProcessor # ============================================================================= # A thin processor that wraps `Qwen2VLImageProcessor` to convert raw video # files (or pre-decoded frame lists) into the tensor bundle needed by the # LlavaOnevision2 model. # # Output (BatchFeature): # - pixel_values_videos : [sum(T*H_p*W_p), C, P, P] patch tensor # - video_grid_thw : [num_videos, 3] (T_eff, H_p, W_p) # - patch_positions : [sum(T*H_p*W_p), 3] block layout # - frame_timestamps : list[list[float]] per-video per-frame seconds # # Aligned with the modeling code, we deliberately # DO NOT emit `second_per_grid_ts`. class LlavaOnevision2VideoProcessor: """Decode + sample + patch-ify videos for LlavaOnevision2. Designed to be standalone (does not inherit ``transformers.ProcessorMixin``) so it can be unit-tested without the full Processor stack. """ # Canonical defaults. DEFAULT_MAX_FRAMES = 384 DEFAULT_PATCH_SIZE = 14 DEFAULT_SPATIAL_MERGE_SIZE = 2 DEFAULT_TEMPORAL_PATCH_SIZE = 1 # this checkpoint ships tps=1 DEFAULT_MIN_PIXELS = 256 * 28 * 28 DEFAULT_MAX_PIXELS = 1605632 def __init__( self, image_processor=None, max_frames: int = DEFAULT_MAX_FRAMES, fixed_num_frames: Optional[int] = None, target_fps: Optional[float] = None, patch_size: int = DEFAULT_PATCH_SIZE, spatial_merge_size: int = DEFAULT_SPATIAL_MERGE_SIZE, temporal_patch_size: int = DEFAULT_TEMPORAL_PATCH_SIZE, min_pixels: int = DEFAULT_MIN_PIXELS, max_pixels: int = DEFAULT_MAX_PIXELS, resize_frames: bool = True, ): """ Args: image_processor: a `Qwen2VLImageProcessor` instance. If ``None`` an instance is built from the other kwargs at first call. max_frames / fixed_num_frames / target_fps: see :func:`choose_target_frames`. patch_size: vision tower patch size (default 14). spatial_merge_size: vision tower spatial merge factor (default 2). temporal_patch_size: temporal-patch grouping; this checkpoint ships ``temporal_patch_size=1`` so each pv row is one single patch (3*14*14=588) and ``Σ t·h·w == total_patches`` naturally. Override only if loading a non-default processor. min_pixels / max_pixels: smart_resize budget. resize_frames: whether to resize frames before patching. """ self._image_processor = image_processor self.max_frames = max_frames self.fixed_num_frames = fixed_num_frames self.target_fps = target_fps self.patch_size = patch_size self.spatial_merge_size = spatial_merge_size self.temporal_patch_size = temporal_patch_size self.min_pixels = min_pixels self.max_pixels = max_pixels self.resize_frames = resize_frames # ------------------------------------------------------------------ utils @property def image_processor(self): """Lazy-build the underlying `Qwen2VLImageProcessor`.""" if self._image_processor is None: from transformers import Qwen2VLImageProcessor self._image_processor = Qwen2VLImageProcessor( min_pixels=self.min_pixels, max_pixels=self.max_pixels, patch_size=self.patch_size, merge_size=self.spatial_merge_size, temporal_patch_size=self.temporal_patch_size, ) return self._image_processor @staticmethod def _coerce_video_input(video): """Normalise a single video input to ``(frames_pil, timestamps_seconds)``. Accepts: - ``str`` path to a video file, - ``list[PIL.Image]`` (already decoded; timestamps default to None), - ``list[np.ndarray]`` (RGB uint8; converted to PIL). """ from PIL import Image if isinstance(video, str): return None # signal: use video path through extract_video_frames_to_pil if isinstance(video, list) and len(video) > 0: first = video[0] if isinstance(first, Image.Image): return list(video), None if isinstance(first, np.ndarray): return [Image.fromarray(f) for f in video], None raise TypeError( f"Unsupported video input type: {type(video).__name__}. " "Expected file path, list[PIL.Image], or list[np.ndarray]." ) # ---------------------------------------------------------------- __call__ def __call__( self, videos, return_tensors: Optional[str] = "pt", **kwargs, ): """Process one or several videos. Args: videos: a single video or a list of videos. Each video may be a path, a list of PIL frames, or a list of np.ndarray RGB frames. return_tensors: only ``"pt"`` is supported (mirrors the underlying image processor). **kwargs: ignored / reserved for transformers ProcessorMixin compatibility (e.g. ``do_rescale``). Returns: A dict-like object with keys: - ``pixel_values_videos`` : Tensor ``[N_total_patches, C, P, P]`` - ``video_grid_thw`` : Tensor ``[num_videos, 3]`` (T, H_p, W_p) - ``patch_positions`` : Tensor ``[N_total_patches, 3]`` block layout - ``frame_timestamps`` : ``list[list[float]]`` per video """ if return_tensors not in (None, "pt"): raise ValueError( f"return_tensors={return_tensors!r} not supported; only 'pt' is." ) # Normalise to a list of videos. if not isinstance(videos, (list, tuple)) or ( len(videos) > 0 and (isinstance(videos[0], str) is False) and not isinstance(videos[0], list) ): # Heuristic: a single video as `list[PIL.Image]` should not be # treated as a batch of single-frame videos. We detect that case # by checking the inner element type. from PIL import Image if isinstance(videos, list) and len(videos) > 0 and isinstance( videos[0], (Image.Image, np.ndarray) ): videos = [videos] elif isinstance(videos, str): videos = [videos] if not isinstance(videos, (list, tuple)): videos = [videos] per_video_pixel_values = [] per_video_grid_thw = [] per_video_patch_positions = [] frame_timestamps_all: List[List[float]] = [] for video in videos: # 1) Decode + sample if isinstance(video, str): frames_pil, frame_indices, timestamps = extract_video_frames_to_pil( video_path=video, max_frames=self.max_frames, patch_size=self.patch_size, min_pixels=self.min_pixels, max_pixels=self.max_pixels, resize_frames=self.resize_frames, fixed_num_frames=self.fixed_num_frames, target_fps=self.target_fps, ) # Reconstruct fps from any two timestamps, fall back to 30. seconds_seq: List[float] = [] if len(frames_pil) > 0: fi_list = frame_indices.tolist() for fi in fi_list: ts = timestamps.get(str(int(fi))) if ts is None: seconds_seq.append(0.0) else: seconds_seq.append(time_str_to_seconds(ts)) # Real frame indices in the source video (training convention # for the t-axis of patch_positions). frame_indices_t = frame_indices.to(torch.int64) else: pre_decoded = self._coerce_video_input(video) frames_pil, _ = pre_decoded seconds_seq = [float(i) for i in range(len(frames_pil))] # Without the original video we have no real indices; fall back # to dense ``arange(T)``. frame_indices_t = torch.arange(len(frames_pil), dtype=torch.int64) if len(frames_pil) == 0: raise ValueError(f"No frames decoded from video: {video!r}") # 2) Patch-ify via Qwen2VLImageProcessor. # Video frames go # through the *image* path, one frame == one image. The # resulting `image_grid_thw` has shape ``[N, 3]`` with each row # ``[1, H_p, W_p]``. We then merge into a single video grid # ``[1, T=N, H_p, W_p]`` (smart_resize guarantees same H/W). # # Important: this checkpoint ships an image processor with # ``temporal_patch_size=1``, so each pv row encodes ONE single # patch (3*14*14 = 588). The OneVision encoder's embedding # layer reshapes pv via ``view(-1, 3, 14, 14)`` and produces # exactly ``pv.shape[0]`` patches, so the cu_seqlens check # ``Σ t·h·w == total_patches`` is satisfied with the natural # per-frame grid below. The lazy-built fallback in # ``image_processor`` honors ``temporal_patch_size=1`` to keep # standalone tests aligned with the checkpoint convention. ip = self.image_processor data = ip(images=frames_pil, return_tensors="pt") pixel_values = data["pixel_values"] image_grid_thw = data["image_grid_thw"] # [N, 3] if not torch.all(image_grid_thw[:, 1] == image_grid_thw[0, 1]) or not torch.all( image_grid_thw[:, 2] == image_grid_thw[0, 2] ): raise RuntimeError( "Frames yielded inconsistent (H_p, W_p); smart_resize should " f"prevent this. Got grid_thw={image_grid_thw.tolist()}" ) T_eff = int(image_grid_thw[:, 0].sum().item()) # sum of per-frame t (each is 1) H_p = int(image_grid_thw[0, 1].item()) W_p = int(image_grid_thw[0, 2].item()) video_grid_thw = torch.tensor( [[T_eff, H_p, W_p]], dtype=image_grid_thw.dtype ) pixel_values_videos = pixel_values # already [T_eff*H_p*W_p, C, P, P] # 3) patch_positions in block layout (over the merged video grid). # Use REAL frame_indices for the t-axis (training convention). patch_positions = build_patch_positions( video_grid_thw, spatial_merge_size=self.spatial_merge_size, frame_indices=[frame_indices_t], ) per_video_pixel_values.append(pixel_values_videos) per_video_grid_thw.append(video_grid_thw) per_video_patch_positions.append(patch_positions) frame_timestamps_all.append(seconds_seq) out_pixel_values = torch.cat(per_video_pixel_values, dim=0) out_grid_thw = torch.cat(per_video_grid_thw, dim=0) out_patch_positions = torch.cat(per_video_patch_positions, dim=0) try: from transformers.feature_extraction_utils import BatchFeature return BatchFeature( data={ "pixel_values_videos": out_pixel_values, "video_grid_thw": out_grid_thw, "patch_positions": out_patch_positions, "frame_timestamps": frame_timestamps_all, } ) except Exception: return { "pixel_values_videos": out_pixel_values, "video_grid_thw": out_grid_thw, "patch_positions": out_patch_positions, "frame_timestamps": frame_timestamps_all, } __all__ = [ "format_timestamp", "time_str_to_seconds", "choose_target_frames", "select_frame_indices", "smart_resize", "extract_video_frames", "extract_video_frames_to_pil", "build_patch_positions", "LlavaOnevision2VideoProcessor", ]