MiniMax-M3-MXFP8 / video_processor.py
xuebi
initial commit
2a60e16
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
"""
MiniMax VL family HuggingFace-compatible VideoProcessor.
"""
import math
from typing import List, Optional, Tuple, Union
import torch
import torchvision
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature
from transformers.image_utils import PILImageResampling, SizeDict
from transformers.processing_utils import (
Unpack,
VideosKwargs,
)
from transformers.utils import TensorType
from transformers.video_processing_utils import BaseVideoProcessor
from transformers.video_utils import group_videos_by_shape, reorder_videos
MAX_RATIO = 200
def round_by_factor(number: int, factor: int) -> int:
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
return math.floor(number / factor) * factor
def smart_resize(
height: int,
width: int,
factor: int = 28,
min_pixels: int = 4 * 28 * 28,
max_pixels: int = 451584,
) -> tuple[int, int]:
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, "
f"got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
class MiniMaxM3VLVideoProcessorKwargs(VideosKwargs, total=False):
patch_size: int
temporal_patch_size: int
merge_size: int
min_pixels: int
max_pixels: int
total_pixels: int
min_frames: int
max_frames: int
fps: float | int
class MiniMaxM3VLVideoProcessor(BaseVideoProcessor):
do_resize = True
resample = PILImageResampling.BICUBIC
size = {"height": 672, "width": 672}
default_to_square = False
do_rescale = True
rescale_factor = 1 / 255
do_normalize = True
image_mean = [0.48145466, 0.4578275, 0.40821073]
image_std = [0.26862954, 0.26130258, 0.27577711]
do_convert_rgb = True
do_sample_frames = False
patch_size = 14
temporal_patch_size = 2
merge_size = 2
min_pixels = 4 * 28 * 28
max_pixels = 768 * 28 * 28 # 602,112
total_pixels = int(64000 * 28 * 28 * 0.9) # ~45M, ~64k tokens budget
fps = 1.0
min_frames = 4
max_frames = 768
valid_kwargs = MiniMaxM3VLVideoProcessorKwargs
model_input_names = ["pixel_values_videos", "video_grid_thw"]
def __init__(self, **kwargs: Unpack[MiniMaxM3VLVideoProcessorKwargs]):
super().__init__(**kwargs)
def _preprocess(
self,
videos: List[torch.Tensor],
do_convert_rgb: bool,
do_resize: bool,
size: SizeDict,
resample: PILImageResampling | InterpolationMode | int | None,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: float | List[float] | None,
image_std: float | List[float] | None,
patch_size: int,
temporal_patch_size: int,
merge_size: int,
min_pixels: int,
max_pixels: int,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
resized_videos_grouped = {}
factor = patch_size * merge_size
for shape, stacked_videos in grouped_videos.items():
batch_size, num_frames, channels, height, width = stacked_videos.shape
resized_height, resized_width = height, width
if do_resize:
resized_height, resized_width = smart_resize(
height, width, factor=factor,
min_pixels=min_pixels, max_pixels=max_pixels,
)
stacked_videos = stacked_videos.view(
batch_size * num_frames, channels, height, width
)
stacked_videos = self.resize(
stacked_videos,
size=SizeDict(height=resized_height, width=resized_width),
resample=resample,
)
stacked_videos = stacked_videos.view(
batch_size,
num_frames,
channels,
resized_height,
resized_width,
)
resized_videos_grouped[shape] = stacked_videos
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
processed_videos_grouped = {}
processed_grids = {}
for shape, stacked_videos in grouped_videos.items():
resized_height, resized_width = stacked_videos.shape[-2:]
patches = self.rescale_and_normalize(
stacked_videos,
do_rescale,
rescale_factor,
do_normalize,
image_mean,
image_std,
)
if pad := -patches.shape[1] % temporal_patch_size:
repeats = patches[:, -1:].expand(-1, pad, -1, -1, -1)
patches = torch.cat([patches, repeats], dim=1)
batch_size, grid_t, channels = patches.shape[:3]
grid_t = grid_t // temporal_patch_size
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
patches = patches.view(
batch_size,
grid_t,
temporal_patch_size,
channels,
grid_h // merge_size,
merge_size,
patch_size,
grid_w // merge_size,
merge_size,
patch_size,
)
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
flatten_patches = patches.reshape(
batch_size,
grid_t * grid_h * grid_w,
channels * temporal_patch_size * patch_size * patch_size,
)
processed_videos_grouped[shape] = flatten_patches
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
processed_videos = reorder_videos(
processed_videos_grouped, grouped_videos_index
)
processed_grids = reorder_videos(processed_grids, grouped_videos_index)
pixel_values_videos = torch.cat(processed_videos, dim=0)
video_grid_thw = torch.tensor(processed_grids, dtype=torch.long)
return BatchFeature(
data={
"pixel_values_videos": pixel_values_videos,
"video_grid_thw": video_grid_thw,
},
tensor_type=return_tensors,
)