File size: 14,776 Bytes
5a89fd6 693246c 5a89fd6 693246c 5a89fd6 bbaf57f 5a89fd6 693246c 5a89fd6 693246c 5a89fd6 693246c 5a89fd6 bbaf57f 5a89fd6 bbaf57f 5a89fd6 bbaf57f 5a89fd6 bbaf57f 5a89fd6 bbaf57f 5a89fd6 bbaf57f 5a89fd6 bbaf57f 5a89fd6 bbaf57f 5a89fd6 bbaf57f 5a89fd6 bbaf57f 5a89fd6 bbaf57f 5a89fd6 bbaf57f 5a89fd6 bbaf57f 5a89fd6 693246c 12c3a2e 693246c 5a89fd6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 | import math
from typing import Optional, Union, Iterable
import numpy as np
import torch
from torchvision.transforms.v2 import functional as F
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ChannelDimension, PILImageResampling, SizeDict, get_image_size
from transformers.processing_utils import Unpack, VideosKwargs
from transformers.utils.generic import TensorType
from transformers.utils.doc import add_start_docstrings
from transformers.utils import logging
from transformers.video_processing_utils import BASE_VIDEO_PROCESSOR_DOCSTRING, BaseVideoProcessor
from transformers.video_utils import VideoMetadata, group_videos_by_shape, reorder_videos, load_video, VideoInput
from transformers.image_transforms import to_channel_dimension_format
logger = logging.get_logger(__name__)
def smart_resize(
num_frames: int,
height: int,
width: int,
temporal_factor: int = 2,
factor: int = 32,
min_pixels: int = 128 * 128,
max_pixels: int = 16 * 16 * 2 * 2 * 2 * 6144,
):
if num_frames < temporal_factor:
raise ValueError(f"t:{num_frames} must be larger than temporal_factor:{temporal_factor}")
if height < factor or width < factor:
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
elif max(height, width) / min(height, width) > 200:
raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
t_bar = round(num_frames / temporal_factor) * temporal_factor
if t_bar * h_bar * w_bar > max_pixels:
beta = math.sqrt((num_frames * height * width) / max_pixels)
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif t_bar * h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (num_frames * height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
class Qwen3VLVideoProcessorInitKwargs(VideosKwargs):
patch_size: Optional[int]
temporal_patch_size: Optional[int]
merge_size: Optional[int]
focus_size: Optional[int]
min_frames: Optional[int]
max_frames: Optional[int]
processor_device: Optional[str]
@add_start_docstrings(
"Constructs a fast Qwen3-VL image processor that dynamically resizes videos based on the original videos.",
BASE_VIDEO_PROCESSOR_DOCSTRING,
"""
patch_size (`int`, *optional*, defaults to 16):
The spacial patch size of the vision encoder.
temporal_patch_size (`int`, *optional*, defaults to 2):
The temporal patch size of the vision encoder.
merge_size (`int`, *optional*, defaults to 2):
The merge size of the vision encoder to llm encoder.
""",
)
class ZFQwen3VLVideoProcessor(BaseVideoProcessor):
resample = PILImageResampling.BICUBIC
size = {"shortest_edge": 128 * 32 * 32, "longest_edge": 32 * 32 * 768}
image_mean = [0.5, 0.5, 0.5]
image_std = [0.5, 0.5, 0.5]
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
patch_size = 16
temporal_patch_size = 2
merge_size = 2
focus_size = 2
fps = 2
min_frames = 4
max_frames = 768
do_sample_frames = True
processor_device: str = "cpu"
valid_kwargs = Qwen3VLVideoProcessorInitKwargs
model_input_names = ["pixel_values_videos", "video_grid_thw"]
def __init__(self, **kwargs: Unpack[Qwen3VLVideoProcessorInitKwargs]):
super().__init__(**kwargs)
if self.size is not None and (
self.size.get("shortest_edge", None) is None or self.size.get("longest_edge", None) is None
):
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
def _further_process_kwargs( # type: ignore
self,
size: Optional[SizeDict] = None,
**kwargs,
) -> dict:
"""
Update kwargs that need further processing before being validated
Can be overridden by subclasses to customize the processing of kwargs.
"""
if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
return super()._further_process_kwargs(size=size, **kwargs)
def sample_frames( # type: ignore
self,
metadata: VideoMetadata,
num_frames: Optional[int] = None,
fps: Optional[Union[int, float]] = None,
**kwargs,
):
"""
Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames.
If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames`
and `fps` are mutually exclusive.
Args:
video (`torch.Tensor`):
Video that need to be sampled.
metadata (`VideoMetadata`):
Metadata of the video containing information about total duration, fps and total number of frames.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
fps (`int` or `float`, *optional*):
Target frames to sample per second. Defaults to `self.fps`.
Returns:
torch.Tensor:
Sampled video frames.
"""
if fps is not None and num_frames is not None:
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
total_num_frames = metadata.total_num_frames
fps = fps if fps is not None else self.fps
# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames is None and fps is not None:
if metadata.fps is None:
metadata.fps = 24
logger.warning_once( # type: ignore
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
"Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
)
num_frames = int(total_num_frames / metadata.fps * fps)
num_frames = min(min(max(num_frames, self.min_frames), self.max_frames), total_num_frames)
if num_frames is None:
num_frames = min(max(total_num_frames, self.min_frames), self.max_frames)
indices = np.linspace(0, total_num_frames - 1, num_frames).round().astype(int)
return indices
def _preprocess( # type: ignore
self,
videos: list[torch.Tensor],
do_convert_rgb: bool = True,
do_resize: bool = True,
size: Optional[SizeDict] = None,
interpolation: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: float = 1 / 255.0,
do_normalize: bool = True,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
patch_size: Optional[int] = None,
temporal_patch_size: Optional[int] = None,
merge_size: Optional[int] = None,
focus_size: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
):
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
resized_videos_grouped = {}
for shape, stacked_videos in grouped_videos.items():
B, T, C, H, W = stacked_videos.shape
num_frames, height, width = T, H, W
if do_resize:
resized_height, resized_width = smart_resize(
num_frames=num_frames,
height=height,
width=width,
temporal_factor=temporal_patch_size, # type: ignore
factor=patch_size * merge_size * focus_size, # type: ignore
min_pixels=size.shortest_edge, # type: ignore
max_pixels=size.longest_edge, # type: ignore
)
stacked_videos = stacked_videos.view(B * T, C, H, W)
stacked_videos = self.resize(
stacked_videos,
size=SizeDict(height=resized_height, width=resized_width),
interpolation=interpolation, # type: ignore
)
stacked_videos = stacked_videos.view(B, T, C, resized_height, resized_width)
resized_videos_grouped[shape] = stacked_videos
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
# Group videos by size for further processing
# Needed in case do_resize is False, or resize returns videos with different sizes
grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
processed_videos_grouped = {}
processed_grids = {}
for shape, stacked_videos in grouped_videos.items():
resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST) # type: ignore
# Fused rescale and normalize
stacked_videos = self.rescale_and_normalize(
stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std # type: ignore
)
patches = stacked_videos
temporal_focus_size = temporal_patch_size * focus_size # type: ignore
# Check that videos have `num_frames` divisible by `temporal_patch_size`
if res := patches.shape[1] % temporal_focus_size:
repeats = patches[:, -1:].repeat(1, temporal_focus_size - res, 1, 1, 1)
patches = torch.cat([patches, repeats], dim=1)
batch_size, grid_t, channel = patches.shape[:3]
grid_t = grid_t // temporal_patch_size # type: ignore
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size # type: ignore
patches = patches.view(
batch_size,
grid_t,
temporal_patch_size, # type: ignore
channel,
grid_h // merge_size, # type: ignore
merge_size, # type: ignore
patch_size, # type: ignore
grid_w // merge_size, # type: ignore
merge_size, # type: ignore
patch_size, # type: ignore
)
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
flatten_patches = patches.reshape(
batch_size,
grid_t * grid_h * grid_w,
channel * temporal_patch_size * patch_size * patch_size, # type: ignore
)
processed_videos_grouped[shape] = flatten_patches
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
processed_grids = reorder_videos(processed_grids, grouped_videos_index)
pixel_values_videos = torch.cat(processed_videos, dim=0)
video_grid_thw = torch.tensor(processed_grids)
data = {
"pixel_values_videos": pixel_values_videos,
"video_grid_thw": video_grid_thw,
}
return BatchFeature(data=data, tensor_type=return_tensors)
def fetch_videos( # type: ignore
self,
video_url_or_urls: Union[str, list[str], list[list[str]]],
sample_indices_fn=None
):
"""
Convert a single or a list of urls into the corresponding `np.array` objects.
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
returned.
"""
if isinstance(video_url_or_urls, list):
return list(zip(*[self.fetch_videos(x, sample_indices_fn=sample_indices_fn) for x in video_url_or_urls]))
else:
return load_video(
video_url_or_urls, # type: ignore
backend="torchcodec",
sample_indices_fn=sample_indices_fn,
device=self.processor_device
) # type: ignore
def normalize(
self,
image: "torch.Tensor",
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
**kwargs,
) -> "torch.Tensor":
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`torch.Tensor`):
Image to normalize.
mean (`torch.Tensor`, `float` or `Iterable[float]`):
Image mean to use for normalization.
std (`torch.Tensor`, `float` or `Iterable[float]`):
Image standard deviation to use for normalization.
Returns:
`torch.Tensor`: The normalized image.
"""
return F.normalize(image, mean, std, inplace=True) # type: ignore
def rescale(
self,
image: "torch.Tensor",
scale: float,
**kwargs,
) -> "torch.Tensor":
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`torch.Tensor`):
Image to rescale.
scale (`float`):
The scaling factor to rescale pixel values by.
Returns:
`torch.Tensor`: The rescaled image.
"""
return image.mul_(scale)
def _prepare_input_videos(
self,
videos: VideoInput,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional[str] = None,
) -> list["torch.Tensor"]:
"""
Prepare the input videos for processing.
"""
processed_videos = []
for video in videos:
# `make_batched_videos` always returns a 4D array per video
if isinstance(video, np.ndarray):
video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_data_format)
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
video = torch.from_numpy(video).contiguous()
if device is not None:
raise ValueError("The `device` argument is not supported. Please use `processor_device` instead.")
processed_videos.append(video)
return processed_videos
__all__ = ["ZFQwen3VLVideoProcessor"]
|