| """Video decoding and processing helpers for Yasa2.""" |
|
|
| from __future__ import annotations |
|
|
| import io |
| import urllib.request |
| from typing import Callable, Dict, List, Union |
|
|
| import imageio.v3 as iio |
| import numpy as np |
| import torch |
| from transformers.video_processing_utils import BaseVideoProcessor |
|
|
| from .image_processing_yasa2 import Yasa2ImageProcessor |
|
|
|
|
| def frame_sampling_uniform(num_frames: int, total_frames: int) -> List[int]: |
| """Sample frames uniformly across the video timeline. |
| |
| Args: |
| num_frames: Number of frames to sample. |
| total_frames: Total number of frames. |
| |
| Returns: |
| List[int]: Equally spaced frame indices (clamped to total_frames). |
| """ |
| if num_frames >= total_frames: |
| return list(range(total_frames)) |
| interval = total_frames / num_frames |
| start_point = interval / 2 |
| return ( |
| (np.arange(start_point, total_frames, interval)).astype(int).tolist() |
| ) |
|
|
|
|
| def frame_sampling_random(num_frames: int, total_frames: int) -> List[int]: |
| """Sample frames randomly without replacement. |
| |
| Args: |
| num_frames: Number of frames to sample. |
| total_frames: Total number of frames. |
| |
| Returns: |
| List[int]: Random unique frame indices. |
| """ |
| if total_frames <= num_frames: |
| return list(range(total_frames)) |
| return np.random.choice( |
| np.arange(total_frames), num_frames, replace=False |
| ).tolist() |
|
|
|
|
| def frame_sampling_chunked(num_frames: int, total_frames: int) -> List[int]: |
| """Sample frames by dividing the video into chunks and picking one index per chunk. |
| |
| Args: |
| num_frames: Number of frames to sample. |
| total_frames: Total number of frames. |
| |
| Returns: |
| List[int]: One randomly chosen index per chunk. |
| """ |
| if total_frames <= num_frames: |
| return list(range(total_frames)) |
| chunk_size = total_frames // num_frames |
| extra_frames = total_frames % num_frames |
| sampled_frames = [] |
| for i in range(num_frames): |
| start = i * chunk_size + min(i, extra_frames) |
| end = start + chunk_size + (1 if i < extra_frames else 0) |
| sampled_frames.append(np.random.randint(start, end)) |
| return sampled_frames |
|
|
|
|
| def _read_bytes_from_uri(uri: str) -> bytes: |
| """Read bytes from a local path or HTTP(S) URL. |
| |
| Args: |
| uri: Local file path or HTTP(S) URL. |
| |
| Returns: |
| Raw bytes content. |
| """ |
| if uri.startswith("http://") or uri.startswith("https://"): |
| with urllib.request.urlopen(uri) as response: |
| return response.read() |
| with open(uri, "rb") as f: |
| return f.read() |
|
|
|
|
| def video_rgb_decoder_iio( |
| video_bytes: bytes, |
| num_frames: int, |
| frame_sampler: Callable[[int, int], List[int]], |
| plugin: str = "pyav", |
| skip_errors: bool = False, |
| ) -> Dict[str, Union[np.ndarray, float, List[int]]]: |
| """Decode video bytes into sampled RGB frames together with metadata. |
| |
| Args: |
| video_bytes: Raw video bytes. |
| num_frames: Number of frames to sample. |
| frame_sampler: Frame sampling function. |
| plugin: ImageIO plugin name. |
| skip_errors: Whether to return error info instead of raising. |
| |
| Returns: |
| Dict[str, Union[np.ndarray, float, List[int]]]: Pixel values, fps, taken indices, and sampled count. |
| """ |
| try: |
| with io.BytesIO(video_bytes) as video_io: |
| properties = iio.improps(video_io, plugin=plugin) |
| total_frames, height, width, channels = properties.shape |
| if channels != 3: |
| raise NotImplementedError( |
| f"Video with {channels} channels not supported." |
| ) |
| video_io.seek(0) |
| metadata = iio.immeta(video_io, plugin=plugin) |
| fps = metadata["fps"] |
| if total_frames == 0: |
| total_frames = int(fps * metadata["duration"]) |
|
|
| |
| frame_idxs = set(frame_sampler(num_frames, total_frames - 1)) |
| frame_idxs_actual = [] |
| pixel_values = [] |
| video_io.seek(0) |
| for idx, frame in enumerate( |
| iio.imiter(video_io, plugin=plugin, thread_type="FRAME") |
| ): |
| if idx in frame_idxs: |
| frame_idxs.remove(idx) |
| frame_idxs_actual.append(idx) |
| pixel_values.append(frame) |
| if not frame_idxs: |
| break |
| if frame_idxs and not skip_errors: |
| raise ValueError(f"Failed to read frames {frame_idxs}.") |
|
|
| pixel_values = np.stack(pixel_values, axis=0) |
| return { |
| "pixel_values": pixel_values, |
| "fps": fps, |
| "frame_idxs": frame_idxs_actual, |
| "num_frames": len(frame_idxs_actual), |
| } |
| except Exception as exc: |
| if not skip_errors: |
| raise |
| return {"error": str(exc)} |
|
|
|
|
| def video_rgb_decoder_factory( |
| num_frames: int, sampling: str = "uniform", skip_errors: bool = False |
| ) -> Callable[[bytes], Dict[str, Union[np.ndarray, float, List[int]]]]: |
| """Create a decoder that samples frames according to the chosen strategy. |
| |
| Args: |
| num_frames: Number of frames to sample. |
| sampling: Sampling strategy name. |
| skip_errors: Whether to return error info instead of raising. |
| |
| Returns: |
| Callable[[bytes], Dict[str, Union[np.ndarray, float, List[int]]]]: Decoder that maps raw bytes to decoded frames/metadata. |
| """ |
| if sampling == "uniform": |
| frame_sampler_fn = frame_sampling_uniform |
| elif sampling == "random": |
| frame_sampler_fn = frame_sampling_random |
| elif sampling == "chunk": |
| frame_sampler_fn = frame_sampling_chunked |
| else: |
| raise NotImplementedError( |
| f"Frame sampling method {sampling} not implemented." |
| ) |
| return lambda video_bytes: video_rgb_decoder_iio( |
| video_bytes, |
| num_frames=num_frames, |
| frame_sampler=frame_sampler_fn, |
| skip_errors=skip_errors, |
| ) |
|
|
|
|
| class Yasa2VideoProcessor(BaseVideoProcessor): |
| """Video processor that samples frames and applies the ConvNeXt image processor.""" |
|
|
| model_input_names = ["pixel_values", "patch_attention_mask"] |
|
|
| def __init__( |
| self, |
| num_frames: int = 6, |
| frame_sample_mode: str = "chunk", |
| patch_size: int = 14, |
| size: int = 512, |
| vision_patch_stride: int = 32, |
| image_mean: List[float] | None = None, |
| image_std: List[float] | None = None, |
| max_num_frames: int | None = None, |
| **kwargs, |
| ) -> None: |
| """Initialize the video processor. |
| |
| Args: |
| num_frames: Number of frames to sample per video. |
| frame_sample_mode: Sampling strategy for frames. |
| patch_size: Vision patch size for attention mask. |
| size: Input resolution for the image processor. |
| vision_patch_stride: Effective stride of the vision encoder. |
| image_mean: Mean values for normalization. |
| image_std: Std values for normalization. |
| max_num_frames: Optional padding target for frames. |
| **kwargs: Passed to BaseVideoProcessor. |
| """ |
| super().__init__(**kwargs) |
| self.num_frames = num_frames |
| self.frame_sample_mode = frame_sample_mode |
| self.patch_size = patch_size |
| self.size = size |
| self.vision_patch_stride = vision_patch_stride |
| self.image_mean = image_mean or [0.485, 0.456, 0.406] |
| self.image_std = image_std or [0.229, 0.224, 0.225] |
| self.max_num_frames = max_num_frames |
| self.image_processor = Yasa2ImageProcessor( |
| size={"shortest_edge": size}, |
| crop_size={"height": size, "width": size}, |
| do_resize=True, |
| do_normalize=True, |
| image_mean=self.image_mean, |
| image_std=self.image_std, |
| patch_size=patch_size, |
| ) |
|
|
| def decode_video( |
| self, video: Union[str, bytes] |
| ) -> Dict[str, Union[np.ndarray, float, List[int]]]: |
| """Decode a video path or raw bytes into sampled frames. |
| |
| Args: |
| video: Video path/URL or raw bytes. |
| |
| Returns: |
| Dict[str, Union[np.ndarray, float, List[int]]]: Decoded frames, fps, sampled indices, and frame count. |
| """ |
| if isinstance(video, str): |
| video_bytes = _read_bytes_from_uri(video) |
| else: |
| video_bytes = video |
| decoder = video_rgb_decoder_factory( |
| num_frames=self.num_frames, sampling=self.frame_sample_mode |
| ) |
| return decoder(video_bytes) |
|
|
| def to_dict( |
| self, |
| ) -> Dict[str, Union[int, str, float, List[float], None, Dict[str, str]]]: |
| """Return a JSON-serializable config for logging and saving. |
| |
| Returns: |
| Dict[str, Union[int, str, float, List[float], None, Dict[str, str]]]: Processor attributes without None values. |
| """ |
| output = super().to_dict() |
| output.pop("image_processor", None) |
| |
| output["vision_patch_stride"] = self.vision_patch_stride |
| |
| return { |
| key: value for key, value in output.items() if value is not None |
| } |
|
|
| def preprocess( |
| self, |
| videos: Union[str, bytes, np.ndarray, List[np.ndarray]], |
| return_tensors: str | None = "pt", |
| **kwargs, |
| ) -> Dict[str, torch.Tensor]: |
| """Preprocess videos into pixel values and patch attention masks. |
| |
| Args: |
| videos: Video path/URL, raw bytes, or frame array(s). |
| return_tensors: Tensor type to return. |
| **kwargs: Unused extra arguments. |
| |
| Returns: |
| Dict[str, torch.Tensor]: `pixel_values` and `patch_attention_mask` tensors, padded as needed. |
| """ |
| if isinstance(videos, (str, bytes)): |
| video_datum = self.decode_video(videos) |
| pixel_values = video_datum["pixel_values"] |
| else: |
| pixel_values = videos |
|
|
| image_outputs = self.image_processor( |
| images=pixel_values, return_tensors="pt" |
| ) |
| img_tensor = image_outputs["pixel_values"] |
| if "patch_attention_mask" in image_outputs: |
| patch_attention_mask = image_outputs["patch_attention_mask"] |
| else: |
| |
| grid_size = max(1, self.size // self.vision_patch_stride) |
| patch_attention_mask = torch.ones( |
| ( |
| img_tensor.shape[0], |
| grid_size, |
| grid_size, |
| ), |
| dtype=torch.bool, |
| ) |
| if ( |
| self.max_num_frames is not None |
| and img_tensor.shape[0] < self.max_num_frames |
| ): |
| pad_frames = self.max_num_frames - img_tensor.shape[0] |
| img_tensor = torch.cat( |
| [ |
| img_tensor, |
| torch.zeros( |
| ( |
| pad_frames, |
| img_tensor.shape[1], |
| img_tensor.shape[2], |
| img_tensor.shape[3], |
| ) |
| ), |
| ], |
| dim=0, |
| ) |
| patch_attention_mask = torch.cat( |
| [ |
| patch_attention_mask, |
| |
| torch.zeros( |
| ( |
| pad_frames, |
| patch_attention_mask.shape[1], |
| patch_attention_mask.shape[2], |
| ), |
| dtype=torch.bool, |
| ), |
| ], |
| dim=0, |
| ) |
| return { |
| "pixel_values": img_tensor, |
| "patch_attention_mask": patch_attention_mask, |
| } |
|
|
|
|
| Yasa2VideoProcessor.register_for_auto_class() |
|
|