| |
| |
| """ |
| MiniMax VL family HuggingFace-compatible Processor, ImageProcessor, VideoProcessor. |
| """ |
|
|
| import math |
| import re |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torchvision |
| from torchvision.transforms import InterpolationMode |
| from transformers import BatchFeature |
| from transformers.image_processing_utils_fast import ( |
| BaseImageProcessorFast, |
| group_images_by_shape, |
| reorder_images, |
| ) |
| from transformers.image_utils import PILImageResampling, SizeDict |
| from transformers.processing_utils import ( |
| ImagesKwargs, |
| ProcessingKwargs, |
| ProcessorMixin, |
| Unpack, |
| VideosKwargs, |
| ) |
| from transformers.utils import TensorType |
| from transformers.video_processing_utils import BaseVideoProcessor |
| from transformers.video_utils import group_videos_by_shape, reorder_videos |
|
|
|
|
| class MiniMaxVLProcessorKwargs(ProcessingKwargs, total=False): |
| _defaults = { |
| "videos_kwargs": { |
| "do_resize": False, |
| "return_metadata": True, |
| }, |
| } |
|
|
|
|
| class MiniMaxVLProcessor(ProcessorMixin): |
| IMAGE_TOKEN = "]<]image[>[" |
| VIDEO_TOKEN = "]<]video[>[" |
| VISION_START_TOKEN = "]<]start of image[>[" |
| VISION_END_TOKEN = "]<]end of image[>[" |
|
|
| def __init__( |
| self, image_processor=None, tokenizer=None, video_processor=None, **kwargs |
| ): |
| self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) |
| self.video_token_id = tokenizer.convert_tokens_to_ids(self.VIDEO_TOKEN) |
| super().__init__(image_processor, tokenizer, video_processor) |
| |
| |
| |
| self.vision_start_token_id = tokenizer.convert_tokens_to_ids( |
| self.VISION_START_TOKEN |
| ) |
| self.vision_end_token_id = tokenizer.convert_tokens_to_ids( |
| self.VISION_END_TOKEN |
| ) |
|
|
| def _prune_video_tokens( |
| self, |
| input_text: str, |
| video_segments: List[int], |
| video_token: str, |
| ) -> str: |
| """ |
| Prune video tokens by temporal_patch_size (e.g., 2:1). |
| |
| Expects the prompt to carry exactly sum(video_segments) video |
| tokens — i.e. one token per *sampled* frame. Then drops token. |
| |
| Args: |
| input_text: prompt with N video_tokens per segment |
| video_segments: actual sampled frame count per video segment |
| video_token: the video token string, e.g. ']<]video[>[' |
| |
| Returns: |
| Pruned input_text with ~N/temporal_patch_size tokens per segment. |
| """ |
| |
| if not video_segments or self.video_processor.temporal_patch_size <= 1: |
| return input_text |
|
|
| |
| special_tokens = [video_token] |
| pattern = "|".join(map(re.escape, special_tokens)) |
| parts = re.split(f"({pattern})", input_text) |
|
|
| def is_timestamp(text: str) -> bool: |
| """Check if text ends with timestamp format like ']<]0.0 seconds[>['""" |
| return ( |
| text.endswith("seconds[>[") |
| or text.endswith("seconds[>[ ") |
| or text.endswith("seconds [>[") |
| or text.endswith("seconds [>[ ") |
| ) |
|
|
| def extract_timestamp(text: str) -> str: |
| """Extract timestamp text from the end, starting from ']<]'""" |
| start_index = text.rfind("]<]") |
| if start_index == -1: |
| raise ValueError(f"Failed to extract timestamp: {text}") |
| return text[start_index:] |
|
|
| |
| final_parts = [] |
| current_seg_idx = 0 |
| frame_in_seg = 0 |
| last_timestamp_len = 0 |
|
|
| for part in parts: |
| if part == video_token: |
| if current_seg_idx < len(video_segments): |
| if frame_in_seg % self.video_processor.temporal_patch_size == 0: |
| |
| final_parts.append(part) |
| frame_in_seg += 1 |
| if frame_in_seg >= video_segments[current_seg_idx]: |
| current_seg_idx += 1 |
| frame_in_seg = 0 |
| last_timestamp_len = 0 |
| else: |
| |
| frame_in_seg += 1 |
| if frame_in_seg >= video_segments[current_seg_idx]: |
| current_seg_idx += 1 |
| frame_in_seg = 0 |
| |
| if last_timestamp_len > 0: |
| |
| assert len(final_parts) > 0 |
| final_parts[-1] = final_parts[-1][:-last_timestamp_len] |
| last_timestamp_len = 0 |
| else: |
| |
| final_parts.append(part) |
| last_timestamp_len = 0 |
| else: |
| |
| final_parts.append(part) |
| |
| if is_timestamp(part): |
| last_timestamp_len = len(extract_timestamp(part)) |
| else: |
| last_timestamp_len = 0 |
|
|
| return "".join(final_parts) |
|
|
| def __call__( |
| self, |
| images=None, |
| text=None, |
| videos=None, |
| **kwargs: Unpack[MiniMaxVLProcessorKwargs], |
| ) -> BatchFeature: |
| output_kwargs = self._merge_kwargs( |
| MiniMaxVLProcessorKwargs, |
| tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| **kwargs, |
| ) |
|
|
| if images is not None: |
| images_kwargs = output_kwargs["images_kwargs"] |
| image_inputs = self.image_processor(images=images, **images_kwargs) |
| image_grid_thw = image_inputs["image_grid_thw"] |
|
|
| else: |
| image_inputs = {} |
| image_grid_thw = None |
|
|
| if videos is not None: |
| videos_kwargs = output_kwargs["videos_kwargs"] |
| video_inputs = self.video_processor(videos=videos, **videos_kwargs) |
| video_grid_thw = video_inputs["video_grid_thw"] |
| if not kwargs.get("return_metadata"): |
| video_metadata = video_inputs.pop("video_metadata") |
| else: |
| video_metadata = video_inputs["video_metadata"] |
| else: |
| video_inputs = {} |
| video_grid_thw = None |
|
|
| if not isinstance(text, list): |
| text = [text] |
| text = text.copy() |
|
|
| |
| if image_grid_thw is not None: |
| merge_length = self.image_processor.merge_size**2 |
| placeholder = "]<]placeholder[>[" |
| index = 0 |
| for i in range(len(text)): |
| while self.IMAGE_TOKEN in text[i]: |
| num_tokens = image_grid_thw[index].prod() // merge_length |
| text[i] = text[i].replace( |
| self.IMAGE_TOKEN, |
| self.VISION_START_TOKEN |
| + placeholder * num_tokens |
| + self.VISION_END_TOKEN, |
| 1, |
| ) |
| index += 1 |
| text[i] = text[i].replace(placeholder, self.IMAGE_TOKEN) |
|
|
| |
| if video_grid_thw is not None: |
| merge_length = self.image_processor.merge_size**2 |
| placeholder = "]<]placeholder[>[" |
| index = 0 |
| for i in range(len(text)): |
| while self.VIDEO_TOKEN in text[i]: |
| metadata = video_metadata[index] |
| grid_t = video_grid_thw[index][0] |
| frame_seqlen = video_grid_thw[index][1:].prod() // merge_length |
|
|
| video_placeholder = "" |
| for frame_idx in range(grid_t): |
| if ( |
| metadata.fps is not None |
| and metadata.frames_indices is not None |
| ): |
| ts = ( |
| metadata.frames_indices[ |
| min( |
| frame_idx |
| * self.video_processor.temporal_patch_size, |
| len(metadata.frames_indices) - 1, |
| ) |
| ] |
| / metadata.fps |
| ) |
| video_placeholder += f"]<]{ts:.1f} seconds[>[" |
| video_placeholder += ( |
| self.VISION_START_TOKEN |
| + placeholder * frame_seqlen |
| + self.VISION_END_TOKEN |
| ) |
|
|
| text[i] = text[i].replace(self.VIDEO_TOKEN, video_placeholder, 1) |
| index += 1 |
| text[i] = text[i].replace(placeholder, self.VIDEO_TOKEN) |
|
|
| |
| return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) |
| text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) |
|
|
| return BatchFeature( |
| data={**text_inputs, **image_inputs, **video_inputs}, |
| tensor_type=return_tensors, |
| ) |
|
|