| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """video processor class for Moss-VL.""" |
|
|
| import json |
| import logging as system_logging |
| import math |
| import os |
| import re |
| import subprocess |
| import traceback |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| from joblib import Parallel, delayed |
| from torchcodec.decoders import VideoDecoder |
|
|
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.image_utils import ChannelDimension, PILImageResampling, SizeDict, get_image_size, validate_kwargs |
| from transformers.processing_utils import Unpack, VideosKwargs |
| from transformers.utils import TensorType, add_start_docstrings, logging |
| from transformers.video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor |
| from transformers.video_utils import VideoMetadata, group_videos_by_shape, reorder_videos |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| def check_video_for_extra_streams_and_errors(video_path: str) -> dict: |
| """ |
| Check if video file has abnormal streams or errors reported by ffprobe. |
| |
| Args: |
| video_path: Path to the video file. |
| |
| Returns: |
| A dictionary containing: |
| - 'has_extra_streams': bool, whether there are streams other than video and audio. |
| - 'unsupported_codec_errors': list, all "Unsupported codec" error messages. |
| - 'ffprobe_output_error': str, other errors/warnings from ffprobe stderr. |
| - 'ffprobe_successful': bool, whether ffprobe command executed successfully (return code 0). |
| - 'stream_details': list, codec_type and index for each stream. |
| - 'num_streams': int, total number of streams identified in the video file. |
| """ |
| result = { |
| 'has_extra_streams': False, |
| 'unsupported_codec_errors': [], |
| 'ffprobe_output_error': '', |
| 'ffprobe_successful': False, |
| 'stream_details': [], |
| 'num_streams': 0 |
| } |
| |
| command = [ |
| "ffprobe", |
| "-v", "error", |
| "-show_streams", |
| "-show_format", |
| "-of", "json", |
| video_path |
| ] |
| |
| try: |
| process = subprocess.run( |
| command, |
| capture_output=True, |
| text=True, |
| check=False |
| ) |
| result['ffprobe_successful'] = (process.returncode == 0) |
| |
| if process.stderr: |
| result['ffprobe_output_error'] = process.stderr |
| unsupported_codec_pattern = re.compile(r"Unsupported codec with id \d+ for input stream \d+") |
| result['unsupported_codec_errors'] = unsupported_codec_pattern.findall(process.stderr) |
| |
| if process.stdout: |
| ffprobe_data = json.loads(process.stdout) |
| if 'streams' in ffprobe_data: |
| result['num_streams'] = len(ffprobe_data['streams']) |
| for stream in ffprobe_data['streams']: |
| stream_type = stream.get('codec_type') |
| stream_index = stream.get('index') |
| result['stream_details'].append({'index': stream_index, 'codec_type': stream_type}) |
| if stream_type not in ['video', 'audio']: |
| result['has_extra_streams'] = True |
| |
| if 'format' in ffprobe_data and 'nb_streams' in ffprobe_data['format']: |
| if result['num_streams'] == 0: |
| result['num_streams'] = ffprobe_data['format']['nb_streams'] |
| elif result['num_streams'] != ffprobe_data['format']['nb_streams']: |
| logger.warning( |
| f"Number of streams in 'streams' list ({result['num_streams']}) " |
| f"differs from 'nb_streams' in 'format' ({ffprobe_data['format']['nb_streams']})." |
| ) |
| except FileNotFoundError: |
| result['ffprobe_output_error'] = "ffprobe command not found. Please ensure FFmpeg is installed and in your PATH." |
| result['ffprobe_successful'] = False |
| except json.JSONDecodeError: |
| result['ffprobe_output_error'] = "Failed to parse ffprobe JSON output. Check ffprobe installation or video file." |
| result['ffprobe_successful'] = False |
| except Exception as e: |
| result['ffprobe_output_error'] = f"An unexpected error occurred: {e}" |
| result['ffprobe_successful'] = False |
| |
| return result |
|
|
|
|
| def remove_video_extra_stream_ffmpeg(input_video: str, output_video: str) -> bool: |
| """ |
| Remove extra streams from video using ffmpeg. |
| |
| Args: |
| input_video: Path to input video. |
| output_video: Path to output video. |
| |
| Returns: |
| bool: True if successful, False otherwise. |
| """ |
| command_list = [ |
| "ffmpeg", "-y", "-i", input_video, |
| "-map", "0:v:0", |
| "-c", "copy", |
| "-an", |
| "-sn", |
| "-dn", |
| "-map_metadata", "-1", |
| "-map_chapters", "-1", |
| "-movflags", "faststart", |
| output_video, |
| ] |
| |
| try: |
| subprocess.run(command_list, shell=False, check=True, capture_output=True) |
| return True |
| except subprocess.CalledProcessError as e: |
| system_logging.error(f"Command execution failed with return code: {e.returncode}, video: {input_video}") |
| system_logging.error(f"Error output:\n{e.stderr}") |
| return False |
| except FileNotFoundError: |
| system_logging.error("Error: ffmpeg command not found. Please ensure ffmpeg is installed and in PATH.") |
| return False |
| except Exception as e: |
| system_logging.error(f"Unexpected error executing command: {e}, video: {input_video}", exc_info=True) |
| return False |
|
|
|
|
| def clean_video_streams(video_path: str) -> str: |
| """ |
| Clean video streams if extra streams are detected. |
| |
| Args: |
| video_path: Path to the video file. |
| |
| Returns: |
| str: Path to cleaned video (or original if no cleaning needed). |
| """ |
| ffprobe_res = check_video_for_extra_streams_and_errors(video_path) |
| if ffprobe_res['has_extra_streams']: |
| base_name = os.path.basename(video_path) |
| output_folder = os.path.dirname(video_path) |
| file_name_without_ext, file_ext = os.path.splitext(base_name) |
| new_base_name = f"{file_name_without_ext}_fix{file_ext}" |
| video_path_output = os.path.join(output_folder, new_base_name) |
| |
| process_flag = remove_video_extra_stream_ffmpeg(video_path, video_path_output) |
| if not process_flag: |
| logger.warning("Failed to remove extra streams with ffmpeg") |
| return video_path |
| return video_path_output |
| return video_path |
|
|
|
|
| def split_indices(indices: List[Union[int, float]], num_chunks: int) -> List[List[Union[int, float]]]: |
| """ |
| Split an index list into roughly equal chunks. |
| |
| Args: |
| indices: List of indices to split. |
| num_chunks: Number of chunks to create. |
| |
| Returns: |
| List of index chunks. |
| """ |
| chunk_size = len(indices) // num_chunks |
| chunks = [] |
| for i in range(num_chunks - 1): |
| chunks.append(indices[i * chunk_size:(i + 1) * chunk_size]) |
| chunks.append(indices[(num_chunks - 1) * chunk_size:]) |
| return chunks |
|
|
|
|
| def decode_sequentially(indices: List[int], video_path: str, ffmpeg_threads: int = 0): |
| """ |
| Decode frames sequentially from a video. |
| |
| Args: |
| indices: List of frame indices to decode. |
| video_path: Path to the video file. |
| ffmpeg_threads: Number of ffmpeg threads to use. |
| |
| Returns: |
| FrameBatch from torchcodec. |
| """ |
| decoder = VideoDecoder(video_path, num_ffmpeg_threads=ffmpeg_threads) |
| try: |
| return decoder.get_frames_at(indices) |
| finally: |
| del decoder |
|
|
|
|
| def decode_with_multithreading(indices: List[int], num_threads: int, video_path: str) -> dict: |
| """ |
| Decode frames using multithreading with joblib. |
| |
| Args: |
| indices: List of frame indices to decode. |
| num_threads: Number of threads to use. |
| video_path: Path to the video file. |
| |
| Returns: |
| dict: Contains 'data', 'duration_seconds', 'pts_seconds' tensors. |
| """ |
| chunks = split_indices(indices, num_chunks=num_threads) |
| results = Parallel(n_jobs=num_threads, prefer="threads", verbose=0)( |
| delayed(decode_sequentially)(chunk, video_path) for chunk in chunks |
| ) |
| |
| return { |
| "data": torch.cat([frame_batch.data for frame_batch in results], dim=0), |
| "duration_seconds": torch.cat([frame_batch.duration_seconds for frame_batch in results], dim=0), |
| "pts_seconds": torch.cat([frame_batch.pts_seconds for frame_batch in results], dim=0) |
| } |
|
|
|
|
| def decode_sequentially_timestamp(timestamp_list: List[float], video_path: str, ffmpeg_threads: int = 0): |
| """ |
| Decode frames sequentially from a video based on timestamps. |
| |
| Args: |
| timestamp_list: List of timestamps (in seconds) to decode. |
| video_path: Path to the video file. |
| ffmpeg_threads: Number of ffmpeg threads to use. |
| |
| Returns: |
| FrameBatch from torchcodec. |
| """ |
| decoder = VideoDecoder(video_path, num_ffmpeg_threads=ffmpeg_threads) |
| try: |
| metadata = decoder.metadata |
|
|
| min_pts = metadata.begin_stream_seconds_from_content |
| if min_pts is None: |
| min_pts = 0.0 |
|
|
| max_pts = None |
| if metadata.num_frames_from_content and metadata.average_fps: |
| max_pts = (metadata.num_frames_from_content - 1) / metadata.average_fps + min_pts |
| elif metadata.end_stream_seconds_from_content is not None: |
| max_pts = metadata.end_stream_seconds_from_content |
| else: |
| max_pts = metadata.duration_seconds |
|
|
| if max_pts is not None and max_pts > 0: |
| timestamp_list = [max(min_pts, min(t, max_pts)) for t in timestamp_list] |
| elif min_pts > 0: |
| timestamp_list = [max(min_pts, t) for t in timestamp_list] |
|
|
| return decoder.get_frames_played_at(timestamp_list) |
| finally: |
| del decoder |
|
|
|
|
| def timestamp_decode_with_multithreading(timestamp_list: List[float], num_threads: int, video_path: str) -> dict: |
| """ |
| Decode frames using multithreading based on timestamps. |
| |
| Args: |
| timestamp_list: List of timestamps (in seconds) to decode. |
| num_threads: Number of threads to use. |
| video_path: Path to the video file. |
| |
| Returns: |
| dict: Contains 'data', 'duration_seconds', 'pts_seconds' tensors. |
| """ |
| chunks = split_indices(timestamp_list, num_chunks=num_threads) |
| results = Parallel(n_jobs=num_threads, prefer="threads", verbose=0)( |
| delayed(decode_sequentially_timestamp)(chunk, video_path) for chunk in chunks |
| ) |
| |
| |
| data_list = [frame_batch.data for frame_batch in results] |
| duration_list = [frame_batch.duration_seconds for frame_batch in results] |
| pts_list = [frame_batch.pts_seconds for frame_batch in results] |
| |
| if not data_list: |
| logger.warning("No frames were successfully decoded.") |
| return {"data": torch.empty(0), "duration_seconds": torch.empty(0), "pts_seconds": torch.empty(0)} |
| |
| return { |
| "data": torch.cat(data_list, dim=0), |
| "duration_seconds": torch.cat(duration_list, dim=0), |
| "pts_seconds": torch.cat(pts_list, dim=0) |
| } |
|
|
|
|
| def extract_frames_with_torchcodec( |
| video_path: str, |
| sample_frames_count: int, |
| num_threads: int = 4, |
| |
| ) -> Optional[dict]: |
| """ |
| Extract frames from video using torchcodec with multithreading. |
| |
| Args: |
| video_path: Path to the video file. |
| sample_frames_count: Number of frames to sample. |
| num_threads: Number of threads to use for extraction. |
| sampling_method: Sampling method, either "index" (uniform frame indices) or "timestamp" (uniform timestamps). |
| |
| Returns: |
| dict: Contains 'data' (N, C, H, W), 'duration_seconds' (N,), 'pts_seconds' (N,) tensors. |
| Returns None if extraction fails. |
| """ |
| try: |
| video_path = clean_video_streams(video_path) |
| decoder = VideoDecoder(video_path, num_ffmpeg_threads=0) |
| metadata = decoder.metadata |
|
|
|
|
| total_frames_in_video = metadata.num_frames_from_content |
| |
| effective_sample_count = min(sample_frames_count, total_frames_in_video) |
| if effective_sample_count == 0: |
| logger.error("Cannot extract frames: video has 0 frames or specified frame count is 0") |
| return None |
| |
| |
| frame_indices = np.linspace(0, total_frames_in_video - 1, effective_sample_count).astype(np.int32) |
| |
| frame_indices = np.unique(np.clip(frame_indices, 0, total_frames_in_video - 1)) |
| |
| result = decode_with_multithreading(frame_indices.tolist(), num_threads=num_threads, video_path=video_path) |
| |
| result["frame_indices"] = frame_indices |
| return result |
|
|
|
|
| |
| except Exception: |
| traceback.print_exc() |
| return None |
|
|
|
|
| def smart_resize( |
| num_frames: int, |
| height: int, |
| width: int, |
| temporal_factor: int = 1, |
| factor: int = 32, |
| min_pixels: int = 128 * 128, |
| max_pixels: int = 16 * 16 * 2 * 2 * 2 * 6144, |
| per_frame_min_pixels: int = None, |
| per_frame_max_pixels: int = None, |
| ): |
| if num_frames < temporal_factor: |
| raise ValueError(f"t:{num_frames} must be larger than temporal_factor:{temporal_factor}") |
| if height < factor or width < factor: |
| raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") |
| elif max(height, width) / min(height, width) > 200: |
| raise ValueError( |
| f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" |
| ) |
| h_bar = round(height / factor) * factor |
| w_bar = round(width / factor) * factor |
| t_bar = round(num_frames / temporal_factor) * temporal_factor |
|
|
| |
| if per_frame_max_pixels is not None and h_bar * w_bar > per_frame_max_pixels: |
| beta = math.sqrt((height * width) / per_frame_max_pixels) |
| h_bar = max(factor, math.floor(height / beta / factor) * factor) |
| w_bar = max(factor, math.floor(width / beta / factor) * factor) |
|
|
| |
| if t_bar * h_bar * w_bar > max_pixels: |
| beta = math.sqrt((num_frames * height * width) / max_pixels) |
| h_bar = max(factor, math.floor(height / beta / factor) * factor) |
| w_bar = max(factor, math.floor(width / beta / factor) * factor) |
| elif t_bar * h_bar * w_bar < min_pixels: |
| beta = math.sqrt(min_pixels / (num_frames * height * width)) |
| h_bar = math.ceil(height * beta / factor) * factor |
| w_bar = math.ceil(width * beta / factor) * factor |
|
|
| |
| |
| if per_frame_min_pixels is not None and h_bar * w_bar < per_frame_min_pixels: |
| beta = math.sqrt(per_frame_min_pixels / (height * width)) |
| h_bar = math.ceil(height * beta / factor) * factor |
| w_bar = math.ceil(width * beta / factor) * factor |
|
|
| return h_bar, w_bar |
|
|
|
|
| class MossVLVideoProcessorInitKwargs(VideosKwargs): |
| patch_size: Optional[int] |
| temporal_patch_size: Optional[int] |
| merge_size: Optional[int] |
| min_frames: Optional[int] |
| max_frames: Optional[int] |
| video_fps: Optional[Union[int, float]] |
| num_extract_threads: Optional[int] |
| |
| video_max_pixels: Optional[int] |
|
|
|
|
| @add_start_docstrings( |
| "Constructs a fast Moss-VL video processor that dynamically resizes videos based on the original videos.", |
| BASE_VIDEO_PROCESSOR_DOCSTRING, |
| """ |
| patch_size (`int`, *optional*, defaults to 16): |
| The spacial patch size of the vision encoder. |
| temporal_patch_size (`int`, *optional*, defaults to 1): |
| The temporal patch size of the vision encoder. |
| merge_size (`int`, *optional*, defaults to 2): |
| The merge size of the vision encoder to llm encoder. |
| video_fps (`float`, *optional*, defaults to 1.0): |
| Target frames per second for video sampling. |
| min_frames (`int`, *optional*, defaults to 1): |
| Minimum number of frames to sample from a video. |
| max_frames (`int`, *optional*, defaults to 256): |
| Maximum number of frames to sample from a video. |
| num_extract_threads (`int`, *optional*, defaults to 4): |
| Number of threads to use for frame extraction. |
| """, |
| ) |
| class MossVLVideoProcessor(BaseVideoProcessor): |
| resample = PILImageResampling.BICUBIC |
| size = {"shortest_edge": 128 * 32 * 32, "longest_edge": 32 * 32 * 768} |
| image_mean = [0.5, 0.5, 0.5] |
| image_std = [0.5, 0.5, 0.5] |
| do_resize = True |
| do_rescale = True |
| do_normalize = True |
| do_convert_rgb = True |
| patch_size = 16 |
| temporal_patch_size = 1 |
| merge_size = 2 |
| video_fps = 1.0 |
| min_frames = 1 |
| max_frames = 256 |
| num_extract_threads = 4 |
| do_sample_frames = True |
| |
| video_max_pixels = None |
| valid_kwargs = MossVLVideoProcessorInitKwargs |
| model_input_names = ["pixel_values_videos", "video_grid_thw"] |
|
|
| def __init__(self, **kwargs: Unpack[MossVLVideoProcessorInitKwargs]): |
| super().__init__(**kwargs) |
| if self.size is not None and ( |
| self.size.get("shortest_edge", None) is None or self.size.get("longest_edge", None) is None |
| ): |
| raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") |
|
|
| def _further_process_kwargs( |
| self, |
| size: Optional[SizeDict] = None, |
| **kwargs, |
| ) -> dict: |
| """ |
| Update kwargs that need further processing before being validated |
| Can be overridden by subclasses to customize the processing of kwargs. |
| """ |
| if size is not None and ("shortest_edge" not in size or "longest_edge" not in size): |
| raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") |
|
|
| return super()._further_process_kwargs(size=size, **kwargs) |
|
|
| def _get_video_path_from_input(self, video_input: Union[str, Dict[str, Any]]) -> str: |
| """Normalize a video input into a video path.""" |
| if isinstance(video_input, dict): |
| return video_input["video_path"] |
| return video_input |
|
|
| def _get_video_duration_seconds(self, video_input: Union[str, Dict[str, Any]]) -> float: |
| """Get video duration in seconds for weighted frame-budget allocation.""" |
| video_path = clean_video_streams(self._get_video_path_from_input(video_input)) |
| decoder = VideoDecoder(video_path, num_ffmpeg_threads=0) |
| try: |
| metadata = decoder.metadata |
| duration = None |
| if ( |
| metadata.end_stream_seconds_from_content is not None |
| and metadata.begin_stream_seconds_from_content is not None |
| ): |
| duration = metadata.end_stream_seconds_from_content - metadata.begin_stream_seconds_from_content |
| if duration is None or duration <= 0: |
| duration = metadata.duration_seconds |
| return max(0.0, float(duration or 0.0)) |
| finally: |
| del decoder |
|
|
| def _allocate_max_frames_for_multiple_videos( |
| self, |
| video_inputs: List[Union[str, Dict[str, Any]]], |
| total_max_frames: Optional[int], |
| ) -> List[Optional[int]]: |
| """ |
| Treat max_frames as a total budget for multi-video input and allocate it by duration. |
| |
| The returned values are per-video max_frames. Segment dict inputs still keep their |
| existing per-segment weighting logic after receiving the video-level allocation. |
| """ |
| if not video_inputs: |
| return [] |
| if total_max_frames is None or len(video_inputs) == 1: |
| return [total_max_frames] * len(video_inputs) |
|
|
| total_max_frames = int(total_max_frames) |
| num_videos = len(video_inputs) |
| if total_max_frames < num_videos: |
| logger.warning( |
| "Received max_frames=%s for %s videos. At least one frame per video is required, " |
| "so falling back to 1 frame per video.", |
| total_max_frames, |
| num_videos, |
| ) |
| return [1] * num_videos |
|
|
| video_durations = [self._get_video_duration_seconds(video_input) for video_input in video_inputs] |
| total_duration = sum(video_durations) |
|
|
| |
| allocations = [1] * num_videos |
| remaining_budget = total_max_frames - num_videos |
| if remaining_budget == 0: |
| return allocations |
|
|
| if total_duration <= 0: |
| raw_extra_allocations = [remaining_budget / num_videos] * num_videos |
| else: |
| raw_extra_allocations = [ |
| remaining_budget * (duration / total_duration) for duration in video_durations |
| ] |
|
|
| base_extra_allocations = [int(math.floor(value)) for value in raw_extra_allocations] |
| allocations = [base + extra for base, extra in zip(allocations, base_extra_allocations)] |
|
|
| remainder = remaining_budget - sum(base_extra_allocations) |
| if remainder > 0: |
| fractional_parts = [ |
| (raw_value - base_value, index) |
| for index, (raw_value, base_value) in enumerate(zip(raw_extra_allocations, base_extra_allocations)) |
| ] |
| fractional_parts.sort(key=lambda item: (-item[0], item[1])) |
| for _, index in fractional_parts[:remainder]: |
| allocations[index] += 1 |
|
|
| return allocations |
|
|
| def calculate_num_frames( |
| self, |
| metadata: VideoMetadata, |
| num_frames: Optional[int] = None, |
| fps: Optional[Union[int, float]] = None, |
| min_frames: Optional[int] = None, |
| max_frames: Optional[int] = None, |
| **kwargs, |
| ) -> int: |
| """ |
| Calculate the number of frames to sample using fps-based logic with min/max constraints. |
| |
| Logic: |
| 1. Calculate target_frames based on fps and video duration |
| 2. Apply min_frames and max_frames constraints |
| 3. Apply max_allowed_frames protection (rough cap from total video_max_pixels budget) |
| 4. Return the number of frames to sample |
| |
| Args: |
| metadata (`VideoMetadata`): |
| Metadata of the video containing information about total duration, fps and total number of frames. |
| num_frames (`int`, *optional*): |
| Maximum number of frames to sample. If provided, overrides fps-based calculation. |
| fps (`int` or `float`, *optional*): |
| Target frames to sample per second. Defaults to `self.video_fps`. |
| min_frames (`int`, *optional*): |
| Minimum number of frames to sample. If None, uses self.min_frames. |
| max_frames (`int`, *optional*): |
| Maximum number of frames to sample. If None, uses self.max_frames. |
| Returns: |
| int: |
| Number of frames to sample. |
| """ |
| if fps is not None and num_frames is not None: |
| raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!") |
|
|
| total_num_frames = metadata.total_num_frames |
| |
| |
| effective_min_frames = min_frames if min_frames is not None else self.min_frames |
| effective_max_frames = max_frames if max_frames is not None else self.max_frames |
| |
| |
| |
| per_frame_min_pixels = self.size.get("shortest_edge", None) if self.size else None |
| video_max_pixels = getattr(self, "video_max_pixels", None) |
| if per_frame_min_pixels is not None and video_max_pixels is not None and per_frame_min_pixels > 0: |
| max_allowed_frames = video_max_pixels // per_frame_min_pixels |
| effective_max_frames = min(effective_max_frames, max_allowed_frames) |
| |
| |
| if hasattr(metadata, 'duration') and metadata.duration is not None: |
| duration = metadata.duration |
| else: |
| video_fps = metadata.fps |
| if video_fps is not None and video_fps > 0: |
| duration = total_num_frames / video_fps |
| else: |
| |
| video_fps = 24.0 |
| duration = total_num_frames / video_fps |
| logger.warning_once( |
| "Could not determine video fps from metadata, defaulting to 24 fps for duration calculation." |
| ) |
|
|
| |
| target_fps = fps if fps is not None else self.video_fps |
| |
| |
| if num_frames is None: |
| |
| target_total_frames = int(math.ceil(duration * target_fps - 1e-6)) |
| |
| |
| sample_frames = max(target_total_frames, effective_min_frames) |
| sample_frames = min(sample_frames, effective_max_frames, total_num_frames) |
| else: |
| |
| sample_frames = min(max(num_frames, effective_min_frames), effective_max_frames, total_num_frames) |
|
|
| return sample_frames |
|
|
|
|
| def _fetch_video_segment( |
| self, |
| video_path: str, |
| segment: List[float], |
| min_frames: Optional[int] = None, |
| max_frames: Optional[int] = None, |
| video_fps: Optional[float] = None, |
| ): |
| """ |
| Fetch video frames for a specific segment. |
| |
| Args: |
| video_path: Path to the video file |
| segment: [start, end] for a segment (left-closed, right-open) or [time] for a single frame |
| min_frames: Minimum frames for this segment (weighted). Defaults to self.min_frames. Must be >= 1. |
| max_frames: Maximum frames for this segment (weighted). Defaults to self.max_frames. Must be >= 1. |
| video_fps: Target frames per second for video sampling. If None, uses self.video_fps. |
| |
| Returns: |
| Tuple of (video_tensor, video_metadata) |
| """ |
| |
| min_frames = max(1, min_frames if min_frames is not None else self.min_frames) |
| max_frames = max(1, max_frames if max_frames is not None else self.max_frames) |
| |
| target_video_fps = video_fps if video_fps is not None else self.video_fps |
| |
| video_path = clean_video_streams(video_path) |
| decoder = VideoDecoder(video_path, num_ffmpeg_threads=0) |
| try: |
| torchcodec_metadata = decoder.metadata |
| |
| video_fps = torchcodec_metadata.average_fps |
| |
| |
| duration = None |
| if torchcodec_metadata.end_stream_seconds_from_content is not None and torchcodec_metadata.begin_stream_seconds_from_content is not None: |
| duration = torchcodec_metadata.end_stream_seconds_from_content - torchcodec_metadata.begin_stream_seconds_from_content |
| if duration is None or duration <= 0: |
| duration = torchcodec_metadata.duration_seconds |
| |
| if len(segment) == 1: |
| |
| timestamp = segment[0] |
| frame_batch = decoder.get_frames_played_at([timestamp]) |
| video_tensor = frame_batch.data |
| actual_timestamps = [timestamp] |
| sample_count = 1 |
| else: |
| |
| start_time, end_time = segment |
| segment_duration = end_time - start_time |
| |
| |
| target_frames = int(math.ceil(segment_duration * target_video_fps)) |
| target_frames = max(target_frames, min_frames) |
| target_frames = min(target_frames, max_frames) |
| |
| |
| if target_frames == 1: |
| actual_timestamps = [start_time] |
| else: |
| |
| actual_timestamps = np.linspace(start_time, end_time, target_frames, endpoint=False).tolist() |
| |
| |
| result = timestamp_decode_with_multithreading(actual_timestamps, self.num_extract_threads, video_path) |
| video_tensor = result["data"] |
| sample_count = len(actual_timestamps) |
| |
| |
| video_metadata = VideoMetadata( |
| total_num_frames=sample_count, |
| fps=video_fps, |
| duration=duration, |
| video_backend="torchcodec", |
| height=torchcodec_metadata.height, |
| width=torchcodec_metadata.width, |
| frames_indices=None |
| ) |
| |
| |
| video_metadata.actual_timestamps = actual_timestamps |
| |
| return video_tensor, video_metadata |
| finally: |
| del decoder |
|
|
| def fetch_videos( |
| self, |
| video_url_or_urls: Union[str, Dict[str, Any], List[Union[str, Dict[str, Any]]]], |
| sample_indices_fn=None, |
| video_fps: Optional[float] = None, |
| min_frames: Optional[int] = None, |
| max_frames: Optional[int] = None, |
| ): |
| """ |
| Override fetch_videos to use torchcodec for frame extraction. |
| |
| This method uses torchcodec with multithreading for efficient frame extraction. |
| Frame count is calculated by the calculate_num_frames method |
| (fps-based with min/max constraints). |
| |
| Args: |
| video_url_or_urls: Can be one of: |
| - str: Single video path |
| - Dict: Video with segments {"video_path": str, "segments": List[List[float]]} |
| - List[Union[str, Dict]]: List of video paths or segment dicts |
| sample_indices_fn: (Not used) Kept for compatibility with base class signature. |
| video_fps: Target frames per second for video sampling. If None, uses self.video_fps. |
| min_frames: Minimum number of frames to sample. If None, uses self.min_frames. |
| max_frames: Maximum number of frames to sample. If None, uses self.max_frames. |
| |
| Returns: |
| Tuple of (videos, metadata) where videos are torch.Tensors and metadata are VideoMetadata objects. |
| """ |
| |
| effective_video_fps = video_fps if video_fps is not None else self.video_fps |
| effective_min_frames = min_frames if min_frames is not None else self.min_frames |
| effective_max_frames = max_frames if max_frames is not None else self.max_frames |
| |
| if isinstance(video_url_or_urls, list): |
| all_videos = [] |
| all_metadata = [] |
| if len(video_url_or_urls) == 1: |
| per_video_max_frames = [effective_max_frames] |
| else: |
| per_video_max_frames = self._allocate_max_frames_for_multiple_videos( |
| video_url_or_urls, |
| effective_max_frames, |
| ) |
| for x, allocated_max_frames in zip(video_url_or_urls, per_video_max_frames): |
| result = self.fetch_videos( |
| x, |
| video_fps=effective_video_fps, |
| min_frames=effective_min_frames, |
| max_frames=allocated_max_frames, |
| ) |
| |
| if isinstance(result[0], list): |
| all_videos.extend(result[0]) |
| all_metadata.extend(result[1]) |
| else: |
| all_videos.append(result[0]) |
| all_metadata.append(result[1]) |
| return all_videos, all_metadata |
| |
| |
| if isinstance(video_url_or_urls, dict): |
| video_path = video_url_or_urls["video_path"] |
| segments = video_url_or_urls["segments"] |
| |
| |
| |
| segment_durations = [] |
| for seg in segments: |
| if len(seg) == 2: |
| segment_durations.append(seg[1] - seg[0]) |
| else: |
| segment_durations.append(None) |
| |
| total_segment_duration = sum(d for d in segment_durations if d is not None) |
| |
| videos = [] |
| metadata = [] |
| for i, segment in enumerate(segments): |
| if len(segment) == 1: |
| |
| video, meta = self._fetch_video_segment(video_path, segment, video_fps=effective_video_fps) |
| else: |
| |
| if total_segment_duration > 0: |
| weight = segment_durations[i] / total_segment_duration |
| else: |
| |
| num_range_segments = sum(1 for d in segment_durations if d is not None) |
| weight = 1.0 / num_range_segments if num_range_segments > 0 else 1.0 |
| |
| |
| weighted_min_frames = max(1, int(round(effective_min_frames * weight))) |
| weighted_max_frames = max(1, int(round(effective_max_frames * weight))) |
| |
| video, meta = self._fetch_video_segment( |
| video_path, segment, |
| min_frames=weighted_min_frames, |
| max_frames=weighted_max_frames, |
| video_fps=effective_video_fps, |
| ) |
| videos.append(video) |
| metadata.append(meta) |
| return videos, metadata |
| |
| |
| video_path = video_url_or_urls |
| |
| |
| video_path = clean_video_streams(video_path) |
|
|
| decoder = None |
| try: |
| |
| decoder = VideoDecoder(video_path, num_ffmpeg_threads=0) |
| torchcodec_metadata = decoder.metadata |
|
|
| duration = None |
| if torchcodec_metadata.end_stream_seconds_from_content is not None and torchcodec_metadata.begin_stream_seconds_from_content is not None: |
| duration = torchcodec_metadata.end_stream_seconds_from_content - torchcodec_metadata.begin_stream_seconds_from_content |
| |
| if duration is None or duration <= 0: |
| duration = torchcodec_metadata.duration_seconds |
| |
| |
| total_frames_in_video = torchcodec_metadata.num_frames_from_content |
| |
| |
| temp_metadata = VideoMetadata( |
| total_num_frames=total_frames_in_video, |
| fps=torchcodec_metadata.average_fps, |
| duration=duration, |
| video_backend="torchcodec", |
| height=torchcodec_metadata.height, |
| width=torchcodec_metadata.width, |
| frames_indices=None |
| ) |
| |
| |
| sample_frames_count = self.calculate_num_frames( |
| temp_metadata, |
| fps=effective_video_fps, |
| min_frames=effective_min_frames, |
| max_frames=effective_max_frames, |
| ) |
| |
| |
| effective_sample_count = min(sample_frames_count, total_frames_in_video) |
| if effective_sample_count == 0: |
| raise ValueError(f"Cannot extract frames: video has 0 frames or specified frame count is 0") |
| |
| |
| frame_indices = np.linspace(0, total_frames_in_video - 1, effective_sample_count).astype(np.int32) |
| |
| frame_indices = np.unique(np.clip(frame_indices, 0, total_frames_in_video - 1)) |
| |
| |
| result = decode_with_multithreading(frame_indices.tolist(), num_threads=self.num_extract_threads, video_path=video_path) |
| |
| |
| frames_tensor = result["data"] |
| |
| |
| video_metadata = VideoMetadata( |
| total_num_frames=len(frame_indices), |
| fps=torchcodec_metadata.average_fps, |
| duration=duration, |
| video_backend="torchcodec", |
| height=torchcodec_metadata.height, |
| width=torchcodec_metadata.width, |
| frames_indices=frame_indices |
| ) |
| |
| |
| if frames_tensor.dim() == 4: |
| video_tensor = frames_tensor |
| else: |
| raise ValueError(f"Unexpected frame tensor shape: {frames_tensor.shape}") |
| |
| return video_tensor, video_metadata |
| |
| except Exception as e: |
| logger.error(f"Error loading video {video_path}: {e}") |
| traceback.print_exc() |
| raise ValueError(f"Failed to load video {video_path}: {e}") |
| finally: |
| if decoder is not None: |
| del decoder |
|
|
| def _preprocess( |
| self, |
| videos: list[torch.Tensor], |
| do_convert_rgb: bool = True, |
| do_resize: bool = True, |
| size: Optional[SizeDict] = None, |
| interpolation: PILImageResampling = PILImageResampling.BICUBIC, |
| do_rescale: bool = True, |
| rescale_factor: float = 1 / 255.0, |
| do_normalize: bool = True, |
| image_mean: Optional[Union[float, list[float]]] = None, |
| image_std: Optional[Union[float, list[float]]] = None, |
| patch_size: Optional[int] = None, |
| temporal_patch_size: Optional[int] = None, |
| merge_size: Optional[int] = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| **kwargs, |
| ): |
| grouped_videos, grouped_videos_index = group_videos_by_shape(videos) |
| resized_videos_grouped = {} |
|
|
| video_max_pixels = getattr(self, "video_max_pixels", None) |
| if video_max_pixels is not None: |
| total_volume = sum( |
| sv.shape[0] * sv.shape[1] * sv.shape[3] * sv.shape[4] |
| for sv in grouped_videos.values() |
| ) |
| else: |
| total_volume = 0 |
|
|
| for shape, stacked_videos in grouped_videos.items(): |
| B, T, C, H, W = stacked_videos.shape |
| num_frames, height, width = T, H, W |
| |
| if do_convert_rgb: |
| stacked_videos = self.convert_to_rgb(stacked_videos) |
| if do_resize: |
| if video_max_pixels is not None and total_volume > 0: |
| allocated_max_pixels = int(video_max_pixels * (T * H * W) / total_volume) |
| else: |
| allocated_max_pixels = size.longest_edge |
| resized_height, resized_width = smart_resize( |
| num_frames=num_frames, |
| height=height, |
| width=width, |
| temporal_factor=temporal_patch_size, |
| factor=patch_size * merge_size, |
| min_pixels=size.shortest_edge, |
| max_pixels=allocated_max_pixels, |
| per_frame_min_pixels=size.shortest_edge, |
| per_frame_max_pixels=size.longest_edge, |
| ) |
| stacked_videos = stacked_videos.view(B * T, C, H, W) |
| stacked_videos = self.resize( |
| stacked_videos, |
| size=SizeDict(height=resized_height, width=resized_width), |
| interpolation=interpolation, |
| ) |
| stacked_videos = stacked_videos.view(B, T, C, resized_height, resized_width) |
| resized_videos_grouped[shape] = stacked_videos |
| resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) |
|
|
| |
| |
| grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos) |
| processed_videos_grouped = {} |
| processed_grids = {} |
| for shape, stacked_videos in grouped_videos.items(): |
| resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST) |
|
|
| |
| stacked_videos = self.rescale_and_normalize( |
| stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std |
| ) |
| patches = stacked_videos |
|
|
| |
| if patches.shape[1] % temporal_patch_size != 0: |
| repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1) |
| patches = torch.cat([patches, repeats], dim=1) |
| batch_size, grid_t, channel = patches.shape[:3] |
| grid_t = grid_t // temporal_patch_size |
| grid_h, grid_w = resized_height // patch_size, resized_width // patch_size |
|
|
| patches = patches.view( |
| batch_size, |
| grid_t, |
| temporal_patch_size, |
| channel, |
| grid_h // merge_size, |
| merge_size, |
| patch_size, |
| grid_w // merge_size, |
| merge_size, |
| patch_size, |
| ) |
| patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) |
| flatten_patches = patches.reshape( |
| batch_size, |
| grid_t * grid_h * grid_w, |
| channel * temporal_patch_size * patch_size * patch_size, |
| ) |
|
|
| processed_videos_grouped[shape] = flatten_patches |
| processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size |
|
|
| processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index) |
| processed_grids = reorder_videos(processed_grids, grouped_videos_index) |
| pixel_values_videos = torch.cat(processed_videos, dim=0) |
| video_grid_thw = torch.tensor(processed_grids) |
| data = { |
| "pixel_values_videos": pixel_values_videos, |
| "video_grid_thw": video_grid_thw, |
| } |
|
|
| return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
| def preprocess( |
| self, |
| videos: Union[str, Dict[str, Any], List[Union[str, Dict[str, Any]]]], |
| **kwargs, |
| ) -> BatchFeature: |
| """ |
| Preprocess videos for the model. |
| |
| This method overrides the base class to handle two video input formats: |
| 1. String path: "path/to/video.mp4" |
| 2. Dict with segments: {"video_path": "...", "segment": [[start, end], [time], ...]} |
| |
| Args: |
| videos: Video input(s) in one of the supported formats. |
| **kwargs: Additional arguments passed to _preprocess. |
| |
| Returns: |
| BatchFeature with pixel_values_videos, video_grid_thw, and optionally video_metadata. |
| """ |
| |
| validate_kwargs( |
| captured_kwargs=kwargs.keys(), |
| valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"], |
| ) |
| |
| |
| for kwarg_name in self.valid_kwargs.__annotations__: |
| kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) |
| |
| |
| return_tensors = kwargs.pop("return_tensors", None) |
| return_metadata = kwargs.pop("return_metadata", False) |
| input_data_format = kwargs.pop("input_data_format", None) |
| device = kwargs.pop("device", None) |
| kwargs.pop("video_metadata", None) |
| kwargs.pop("do_sample_frames", None) |
| kwargs.pop("data_format", None) |
| |
| |
| if not isinstance(videos, list): |
| videos = [videos] |
| |
| |
| video_fps = kwargs.pop("video_fps", None) |
| min_frames = kwargs.pop("min_frames", None) |
| max_frames = kwargs.pop("max_frames", None) |
| |
| |
| video_tensors, video_metadata = self.fetch_videos( |
| videos, |
| video_fps=video_fps, |
| min_frames=min_frames, |
| max_frames=max_frames, |
| ) |
| |
| |
| prepared_videos = self._prepare_input_videos( |
| videos=video_tensors, |
| input_data_format=input_data_format, |
| device=device, |
| ) |
| |
| |
| kwargs = self._further_process_kwargs(**kwargs) |
| self._validate_preprocess_kwargs(**kwargs) |
| |
| |
| result = self._preprocess(videos=prepared_videos, return_tensors=return_tensors, **kwargs) |
| |
| |
| if return_metadata: |
| result["video_metadata"] = video_metadata |
| |
| return result |
|
|
|
|
| __all__ = ["MossVLVideoProcessor"] |
|
|
|
|