reka-edge-2603 / video_processing_yasa2.py
donovanOng92's picture
upload
7d24555 verified
"""Video decoding and processing helpers for Yasa2."""
from __future__ import annotations
import io
import urllib.request
from typing import Callable, Dict, List, Union
import imageio.v3 as iio
import numpy as np
import torch
from transformers.video_processing_utils import BaseVideoProcessor
from .image_processing_yasa2 import Yasa2ImageProcessor
def frame_sampling_uniform(num_frames: int, total_frames: int) -> List[int]:
"""Sample frames uniformly across the video timeline.
Args:
num_frames: Number of frames to sample.
total_frames: Total number of frames.
Returns:
List[int]: Equally spaced frame indices (clamped to total_frames).
"""
if num_frames >= total_frames:
return list(range(total_frames))
interval = total_frames / num_frames
start_point = interval / 2
return (
(np.arange(start_point, total_frames, interval)).astype(int).tolist()
)
def frame_sampling_random(num_frames: int, total_frames: int) -> List[int]:
"""Sample frames randomly without replacement.
Args:
num_frames: Number of frames to sample.
total_frames: Total number of frames.
Returns:
List[int]: Random unique frame indices.
"""
if total_frames <= num_frames:
return list(range(total_frames))
return np.random.choice(
np.arange(total_frames), num_frames, replace=False
).tolist()
def frame_sampling_chunked(num_frames: int, total_frames: int) -> List[int]:
"""Sample frames by dividing the video into chunks and picking one index per chunk.
Args:
num_frames: Number of frames to sample.
total_frames: Total number of frames.
Returns:
List[int]: One randomly chosen index per chunk.
"""
if total_frames <= num_frames:
return list(range(total_frames))
chunk_size = total_frames // num_frames
extra_frames = total_frames % num_frames
sampled_frames = []
for i in range(num_frames):
start = i * chunk_size + min(i, extra_frames)
end = start + chunk_size + (1 if i < extra_frames else 0)
sampled_frames.append(np.random.randint(start, end))
return sampled_frames
def _read_bytes_from_uri(uri: str) -> bytes:
"""Read bytes from a local path or HTTP(S) URL.
Args:
uri: Local file path or HTTP(S) URL.
Returns:
Raw bytes content.
"""
if uri.startswith("http://") or uri.startswith("https://"):
with urllib.request.urlopen(uri) as response:
return response.read()
with open(uri, "rb") as f:
return f.read()
def video_rgb_decoder_iio(
video_bytes: bytes,
num_frames: int,
frame_sampler: Callable[[int, int], List[int]],
plugin: str = "pyav",
skip_errors: bool = False,
) -> Dict[str, Union[np.ndarray, float, List[int]]]:
"""Decode video bytes into sampled RGB frames together with metadata.
Args:
video_bytes: Raw video bytes.
num_frames: Number of frames to sample.
frame_sampler: Frame sampling function.
plugin: ImageIO plugin name.
skip_errors: Whether to return error info instead of raising.
Returns:
Dict[str, Union[np.ndarray, float, List[int]]]: Pixel values, fps, taken indices, and sampled count.
"""
try:
with io.BytesIO(video_bytes) as video_io:
properties = iio.improps(video_io, plugin=plugin)
total_frames, height, width, channels = properties.shape
if channels != 3:
raise NotImplementedError(
f"Video with {channels} channels not supported."
)
video_io.seek(0)
metadata = iio.immeta(video_io, plugin=plugin)
fps = metadata["fps"]
if total_frames == 0:
total_frames = int(fps * metadata["duration"])
# Mirror training-time sampling behavior (total_frames - 1).
frame_idxs = set(frame_sampler(num_frames, total_frames - 1))
frame_idxs_actual = []
pixel_values = []
video_io.seek(0)
for idx, frame in enumerate(
iio.imiter(video_io, plugin=plugin, thread_type="FRAME")
):
if idx in frame_idxs:
frame_idxs.remove(idx)
frame_idxs_actual.append(idx)
pixel_values.append(frame)
if not frame_idxs:
break
if frame_idxs and not skip_errors:
raise ValueError(f"Failed to read frames {frame_idxs}.")
pixel_values = np.stack(pixel_values, axis=0)
return {
"pixel_values": pixel_values,
"fps": fps,
"frame_idxs": frame_idxs_actual,
"num_frames": len(frame_idxs_actual),
}
except Exception as exc:
if not skip_errors:
raise
return {"error": str(exc)}
def video_rgb_decoder_factory(
num_frames: int, sampling: str = "uniform", skip_errors: bool = False
) -> Callable[[bytes], Dict[str, Union[np.ndarray, float, List[int]]]]:
"""Create a decoder that samples frames according to the chosen strategy.
Args:
num_frames: Number of frames to sample.
sampling: Sampling strategy name.
skip_errors: Whether to return error info instead of raising.
Returns:
Callable[[bytes], Dict[str, Union[np.ndarray, float, List[int]]]]: Decoder that maps raw bytes to decoded frames/metadata.
"""
if sampling == "uniform":
frame_sampler_fn = frame_sampling_uniform
elif sampling == "random":
frame_sampler_fn = frame_sampling_random
elif sampling == "chunk":
frame_sampler_fn = frame_sampling_chunked
else:
raise NotImplementedError(
f"Frame sampling method {sampling} not implemented."
)
return lambda video_bytes: video_rgb_decoder_iio(
video_bytes,
num_frames=num_frames,
frame_sampler=frame_sampler_fn,
skip_errors=skip_errors,
)
class Yasa2VideoProcessor(BaseVideoProcessor):
"""Video processor that samples frames and applies the ConvNeXt image processor."""
model_input_names = ["pixel_values", "patch_attention_mask"]
def __init__(
self,
num_frames: int = 6,
frame_sample_mode: str = "chunk",
patch_size: int = 14,
size: int = 512,
vision_patch_stride: int = 32,
image_mean: List[float] | None = None,
image_std: List[float] | None = None,
max_num_frames: int | None = None,
**kwargs,
) -> None:
"""Initialize the video processor.
Args:
num_frames: Number of frames to sample per video.
frame_sample_mode: Sampling strategy for frames.
patch_size: Vision patch size for attention mask.
size: Input resolution for the image processor.
vision_patch_stride: Effective stride of the vision encoder.
image_mean: Mean values for normalization.
image_std: Std values for normalization.
max_num_frames: Optional padding target for frames.
**kwargs: Passed to BaseVideoProcessor.
"""
super().__init__(**kwargs)
self.num_frames = num_frames
self.frame_sample_mode = frame_sample_mode
self.patch_size = patch_size
self.size = size
self.vision_patch_stride = vision_patch_stride
self.image_mean = image_mean or [0.485, 0.456, 0.406]
self.image_std = image_std or [0.229, 0.224, 0.225]
self.max_num_frames = max_num_frames
self.image_processor = Yasa2ImageProcessor(
size={"shortest_edge": size},
crop_size={"height": size, "width": size},
do_resize=True,
do_normalize=True,
image_mean=self.image_mean,
image_std=self.image_std,
patch_size=patch_size,
)
def decode_video(
self, video: Union[str, bytes]
) -> Dict[str, Union[np.ndarray, float, List[int]]]:
"""Decode a video path or raw bytes into sampled frames.
Args:
video: Video path/URL or raw bytes.
Returns:
Dict[str, Union[np.ndarray, float, List[int]]]: Decoded frames, fps, sampled indices, and frame count.
"""
if isinstance(video, str):
video_bytes = _read_bytes_from_uri(video)
else:
video_bytes = video
decoder = video_rgb_decoder_factory(
num_frames=self.num_frames, sampling=self.frame_sample_mode
)
return decoder(video_bytes)
def to_dict(
self,
) -> Dict[str, Union[int, str, float, List[float], None, Dict[str, str]]]:
"""Return a JSON-serializable config for logging and saving.
Returns:
Dict[str, Union[int, str, float, List[float], None, Dict[str, str]]]: Processor attributes without None values.
"""
output = super().to_dict()
output.pop("image_processor", None)
# Ensure the encoder stride used for patch masking is serialized.
output["vision_patch_stride"] = self.vision_patch_stride
# Drop unset values to avoid serializing nulls into the config.
return {
key: value for key, value in output.items() if value is not None
}
def preprocess(
self,
videos: Union[str, bytes, np.ndarray, List[np.ndarray]],
return_tensors: str | None = "pt",
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Preprocess videos into pixel values and patch attention masks.
Args:
videos: Video path/URL, raw bytes, or frame array(s).
return_tensors: Tensor type to return.
**kwargs: Unused extra arguments.
Returns:
Dict[str, torch.Tensor]: `pixel_values` and `patch_attention_mask` tensors, padded as needed.
"""
if isinstance(videos, (str, bytes)):
video_datum = self.decode_video(videos)
pixel_values = video_datum["pixel_values"]
else:
pixel_values = videos
image_outputs = self.image_processor(
images=pixel_values, return_tensors="pt"
)
img_tensor = image_outputs["pixel_values"]
if "patch_attention_mask" in image_outputs:
patch_attention_mask = image_outputs["patch_attention_mask"]
else:
# ConvNeXt outputs features at stride 32 (512 -> 16 grid), so build a grid mask at that resolution.
grid_size = max(1, self.size // self.vision_patch_stride)
patch_attention_mask = torch.ones(
(
img_tensor.shape[0],
grid_size,
grid_size,
),
dtype=torch.bool,
)
if (
self.max_num_frames is not None
and img_tensor.shape[0] < self.max_num_frames
):
pad_frames = self.max_num_frames - img_tensor.shape[0]
img_tensor = torch.cat(
[
img_tensor,
torch.zeros(
(
pad_frames,
img_tensor.shape[1],
img_tensor.shape[2],
img_tensor.shape[3],
)
),
],
dim=0,
)
patch_attention_mask = torch.cat(
[
patch_attention_mask,
# Mask out padded frames to avoid leaking zero-padded inputs.
torch.zeros(
(
pad_frames,
patch_attention_mask.shape[1],
patch_attention_mask.shape[2],
),
dtype=torch.bool,
),
],
dim=0,
)
return {
"pixel_values": img_tensor,
"patch_attention_mask": patch_attention_mask,
}
Yasa2VideoProcessor.register_for_auto_class()