MiniMax-M3-MXFP8 / processing_minimax.py
xuebi
initial commit
2a60e16
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
"""
MiniMax VL family HuggingFace-compatible Processor, ImageProcessor, VideoProcessor.
"""
import math
import re
from typing import List, Optional, Tuple, Union
import torch
import torchvision
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
group_images_by_shape,
reorder_images,
)
from transformers.image_utils import PILImageResampling, SizeDict
from transformers.processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
VideosKwargs,
)
from transformers.utils import TensorType
from transformers.video_processing_utils import BaseVideoProcessor
from transformers.video_utils import group_videos_by_shape, reorder_videos
class MiniMaxVLProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"videos_kwargs": {
"do_resize": False,
"return_metadata": True,
},
}
class MiniMaxVLProcessor(ProcessorMixin):
IMAGE_TOKEN = "]<]image[>["
VIDEO_TOKEN = "]<]video[>["
VISION_START_TOKEN = "]<]start of image[>["
VISION_END_TOKEN = "]<]end of image[>["
def __init__(
self, image_processor=None, tokenizer=None, video_processor=None, **kwargs
):
self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
self.video_token_id = tokenizer.convert_tokens_to_ids(self.VIDEO_TOKEN)
super().__init__(image_processor, tokenizer, video_processor)
# Video expansion also uses image start/end tokens. Separate video
# start/end tokens exist in the tokenizer, but the original MiniMax
# serving path did not use them; keep that behavior for compatibility.
self.vision_start_token_id = tokenizer.convert_tokens_to_ids(
self.VISION_START_TOKEN
)
self.vision_end_token_id = tokenizer.convert_tokens_to_ids(
self.VISION_END_TOKEN
)
def _prune_video_tokens(
self,
input_text: str,
video_segments: List[int],
video_token: str,
) -> str:
"""
Prune video tokens by temporal_patch_size (e.g., 2:1).
Expects the prompt to carry exactly sum(video_segments) video
tokens — i.e. one token per *sampled* frame. Then drops token.
Args:
input_text: prompt with N video_tokens per segment
video_segments: actual sampled frame count per video segment
video_token: the video token string, e.g. ']<]video[>['
Returns:
Pruned input_text with ~N/temporal_patch_size tokens per segment.
"""
# If no videos or temporal_patch_size <= 1, no pruning needed
if not video_segments or self.video_processor.temporal_patch_size <= 1:
return input_text
# Split while keeping delimiters
special_tokens = [video_token] # , image_token]
pattern = "|".join(map(re.escape, special_tokens))
parts = re.split(f"({pattern})", input_text)
def is_timestamp(text: str) -> bool:
"""Check if text ends with timestamp format like ']<]0.0 seconds[>['"""
return (
text.endswith("seconds[>[")
or text.endswith("seconds[>[ ")
or text.endswith("seconds [>[")
or text.endswith("seconds [>[ ")
)
def extract_timestamp(text: str) -> str:
"""Extract timestamp text from the end, starting from ']<]'"""
start_index = text.rfind("]<]")
if start_index == -1:
raise ValueError(f"Failed to extract timestamp: {text}")
return text[start_index:]
# Build new text with pruned video tokens
final_parts = []
current_seg_idx = 0 # Which video segment we're in
frame_in_seg = 0 # Frame index within current segment
last_timestamp_len = 0 # Length of timestamp to potentially remove
for part in parts:
if part == video_token:
if current_seg_idx < len(video_segments):
if frame_in_seg % self.video_processor.temporal_patch_size == 0:
# Keep this video token
final_parts.append(part)
frame_in_seg += 1
if frame_in_seg >= video_segments[current_seg_idx]:
current_seg_idx += 1
frame_in_seg = 0
last_timestamp_len = 0
else:
# Skip this video token
frame_in_seg += 1
if frame_in_seg >= video_segments[current_seg_idx]:
current_seg_idx += 1
frame_in_seg = 0
# Remove the timestamp that was already appended
if last_timestamp_len > 0:
# Truncate the last part to remove timestamp
assert len(final_parts) > 0
final_parts[-1] = final_parts[-1][:-last_timestamp_len]
last_timestamp_len = 0
else:
# No more video segments, keep as is
final_parts.append(part)
last_timestamp_len = 0
else:
# Text part
final_parts.append(part)
# Check if this text ends with a timestamp
if is_timestamp(part):
last_timestamp_len = len(extract_timestamp(part))
else:
last_timestamp_len = 0
return "".join(final_parts)
def __call__(
self,
images=None,
text=None,
videos=None,
**kwargs: Unpack[MiniMaxVLProcessorKwargs],
) -> BatchFeature:
output_kwargs = self._merge_kwargs(
MiniMaxVLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
images_kwargs = output_kwargs["images_kwargs"]
image_inputs = self.image_processor(images=images, **images_kwargs)
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
image_grid_thw = None
if videos is not None:
videos_kwargs = output_kwargs["videos_kwargs"]
video_inputs = self.video_processor(videos=videos, **videos_kwargs)
video_grid_thw = video_inputs["video_grid_thw"]
if not kwargs.get("return_metadata"):
video_metadata = video_inputs.pop("video_metadata")
else:
video_metadata = video_inputs["video_metadata"]
else:
video_inputs = {}
video_grid_thw = None
if not isinstance(text, list):
text = [text]
text = text.copy()
# Expand image tokens
if image_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
placeholder = "]<]placeholder[>["
index = 0
for i in range(len(text)):
while self.IMAGE_TOKEN in text[i]:
num_tokens = image_grid_thw[index].prod() // merge_length
text[i] = text[i].replace(
self.IMAGE_TOKEN,
self.VISION_START_TOKEN
+ placeholder * num_tokens
+ self.VISION_END_TOKEN,
1,
)
index += 1
text[i] = text[i].replace(placeholder, self.IMAGE_TOKEN)
# Expand video tokens
if video_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
placeholder = "]<]placeholder[>["
index = 0
for i in range(len(text)):
while self.VIDEO_TOKEN in text[i]:
metadata = video_metadata[index]
grid_t = video_grid_thw[index][0]
frame_seqlen = video_grid_thw[index][1:].prod() // merge_length
video_placeholder = ""
for frame_idx in range(grid_t):
if (
metadata.fps is not None
and metadata.frames_indices is not None
):
ts = (
metadata.frames_indices[
min(
frame_idx
* self.video_processor.temporal_patch_size,
len(metadata.frames_indices) - 1,
)
]
/ metadata.fps
)
video_placeholder += f"]<]{ts:.1f} seconds[>["
video_placeholder += (
self.VISION_START_TOKEN
+ placeholder * frame_seqlen
+ self.VISION_END_TOKEN
)
text[i] = text[i].replace(self.VIDEO_TOKEN, video_placeholder, 1)
index += 1
text[i] = text[i].replace(placeholder, self.VIDEO_TOKEN)
# Tokenize
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(
data={**text_inputs, **image_inputs, **video_inputs},
tensor_type=return_tensors,
)