| import concurrent | |
| import concurrent.futures | |
| import dataclasses | |
| import multiprocessing as mp | |
| import os | |
| import re | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Dict, Iterator, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import BaseImageProcessorFast | |
| from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem | |
| from sglang.srt.utils import is_npu, load_audio, load_image, load_video, logger | |
| _is_npu = is_npu() | |
| class BaseMultiModalProcessorOutput: | |
| # input_text, with each frame of video/image represented with a image_token | |
| input_text: str | |
| # frames loaded from image, in given order | |
| images: Optional[list[Union[Image.Image, dict]]] = dataclasses.field( | |
| default_factory=list | |
| ) | |
| # videos | |
| videos: Optional[list[Union[torch.Tensor, dict]]] = dataclasses.field( | |
| default_factory=list | |
| ) | |
| # audios | |
| audios: Optional[list[Union[np.ndarray, dict]]] = dataclasses.field( | |
| default_factory=list | |
| ) | |
| def organize_results(self) -> List[Tuple[Modality, Any]]: | |
| """ | |
| :return: a list of results, with their corresponding modalities | |
| """ | |
| return ( | |
| [(Modality.IMAGE, data) for data in self.images] | |
| + [(Modality.VIDEO, data) for data in self.videos] | |
| + [(Modality.AUDIO, data) for data in self.audios] | |
| ) | |
| class MultimodalSpecialTokens: | |
| image_token: Optional[Union[str, List[str]]] = None | |
| video_token: Optional[Union[str, List[str]]] = None | |
| audio_token: Optional[Union[str, List[str]]] = None | |
| image_token_id: Optional[int] = None | |
| video_token_id: Optional[int] = None | |
| audio_token_id: Optional[int] = None | |
| image_token_regex: Optional[re.Pattern] = None | |
| video_token_regex: Optional[re.Pattern] = None | |
| audio_token_regex: Optional[re.Pattern] = None | |
| combined_regex: Optional[re.Pattern] = None | |
| def build(self, processor): | |
| self.convert_to_strs(processor) | |
| self.parse_regex() | |
| self.get_combined_regex() | |
| return self | |
| def convert_to_str(self, token: Union[str, int], processor) -> str: | |
| if token is None: | |
| return token | |
| if isinstance(token, str): | |
| return token | |
| return processor.tokenizer.convert_ids_to_tokens([token])[0] | |
| def convert_to_strs(self, processor): | |
| if not self.image_token: | |
| self.image_token = self.convert_to_str(self.image_token_id, processor) | |
| if not self.video_token: | |
| self.video_token = self.convert_to_str(self.video_token_id, processor) | |
| if not self.audio_token: | |
| self.audio_token = self.convert_to_str(self.audio_token_id, processor) | |
| def get_modality_of_token(self, token: str) -> Optional[Modality]: | |
| """ | |
| :return: the modality associated with the given token, if the token is a special_token or matches with the multimodal token regex | |
| """ | |
| modality = { | |
| self.image_token: Modality.IMAGE, | |
| self.video_token: Modality.VIDEO, | |
| self.audio_token: Modality.AUDIO, | |
| }.get(token) | |
| if modality: | |
| return modality | |
| for regex, modality in [ | |
| (self.image_token_regex, Modality.IMAGE), | |
| (self.video_token_regex, Modality.VIDEO), | |
| (self.audio_token_regex, Modality.AUDIO), | |
| ]: | |
| if regex and regex.match(token): | |
| return modality | |
| return None | |
| def get_token_id_by_modality(self, modality: Modality) -> Optional[int]: | |
| return { | |
| Modality.IMAGE: self.image_token_id, | |
| Modality.MULTI_IMAGES: self.image_token_id, | |
| Modality.VIDEO: self.video_token_id, | |
| Modality.AUDIO: self.audio_token_id, | |
| }.get(modality) | |
| def parse_regex(self): | |
| if self.image_token_regex is None and self.image_token is not None: | |
| self.image_token_regex = re.compile(re.escape(self.image_token)) | |
| if self.video_token_regex is None and self.video_token is not None: | |
| self.video_token_regex = re.compile(re.escape(self.video_token)) | |
| if self.audio_token_regex is None and self.audio_token is not None: | |
| self.audio_token_regex = re.compile(re.escape(self.audio_token)) | |
| def get_combined_regex(self) -> re.Pattern: | |
| """ | |
| Builds and returns a regex, used to split input str into tokens (with mm special tokens) | |
| """ | |
| if self.combined_regex: | |
| return self.combined_regex | |
| tokens = [ | |
| self.image_token_regex, | |
| self.video_token_regex, | |
| self.audio_token_regex, | |
| ] | |
| patterns = [] | |
| flags = 0 | |
| for t in tokens: | |
| if t is not None: | |
| patterns.append(t.pattern) | |
| flags |= t.flags | |
| combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")" | |
| self.combined_regex = re.compile(combined, flags) | |
| return self.combined_regex | |
| class BaseMultimodalProcessor(ABC): | |
| models = [] | |
| def __init__( | |
| self, hf_config, server_args, _processor, transport_mode, *args, **kwargs | |
| ): | |
| self.hf_config = hf_config | |
| self._processor = _processor | |
| self.server_args = server_args | |
| self.transport_mode = transport_mode | |
| # FIXME: not accurate, model and image specific | |
| self.NUM_TOKEN_PER_FRAME = 330 | |
| self.io_executor = concurrent.futures.ThreadPoolExecutor( | |
| max_workers=int(os.environ.get("SGLANG_IO_WORKERS", 4)) | |
| ) | |
| self.cpu_executor = concurrent.futures.ProcessPoolExecutor( | |
| mp_context=mp.get_context("fork"), | |
| max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())), | |
| ) | |
| # Mapping from attribute names to modality types | |
| self.ATTR_NAME_TO_MODALITY = { | |
| # Image-related attributes | |
| "pixel_values": Modality.IMAGE, | |
| "image_sizes": Modality.IMAGE, | |
| "image_grid_thw": Modality.IMAGE, | |
| "image_attention_mask": Modality.IMAGE, | |
| "image_emb_mask": Modality.IMAGE, | |
| "images_spatial_crop": Modality.IMAGE, | |
| "images_crop": Modality.IMAGE, | |
| "tgt_size": Modality.IMAGE, | |
| "image_grid_hws": Modality.IMAGE, | |
| "aspect_ratio_ids": Modality.IMAGE, | |
| "aspect_ratio_mask": Modality.IMAGE, | |
| "num_patches": Modality.IMAGE, | |
| "patch_pixel_values": Modality.IMAGE, | |
| # Audio-related attributes | |
| "audio_features": Modality.AUDIO, | |
| "audio_feature_lens": Modality.AUDIO, | |
| "input_features": Modality.AUDIO, | |
| "input_features_mask": Modality.AUDIO, | |
| "audio_attention_mask": Modality.AUDIO, | |
| "feature_attention_mask": Modality.AUDIO, | |
| # Video-related attributes | |
| "pixel_values_videos": Modality.VIDEO, | |
| "second_per_grid_ts": Modality.VIDEO, | |
| "video_grid_thw": Modality.VIDEO, | |
| # Generic attributes that could apply to multiple modalities | |
| # "precomputed_embeddings" - handled specially as it can be any modality | |
| } | |
| # name of the feature filed | |
| # TODO: pass from processors | |
| self.FEATURE_NAMES = [ | |
| "pixel_values", | |
| "pixel_values_videos", | |
| "audio_features", | |
| "input_features", | |
| ] | |
| def process_mm_data( | |
| self, input_text, images=None, videos=None, audios=None, **kwargs | |
| ) -> dict: | |
| """ | |
| process multimodal data with transformers AutoProcessor | |
| """ | |
| if images: | |
| kwargs["images"] = images | |
| if videos: | |
| kwargs["videos"] = videos | |
| if audios: | |
| if self._processor.__class__.__name__ in { | |
| "Gemma3nProcessor", | |
| "Qwen2AudioProcessor", | |
| "Qwen3OmniMoeProcessor", | |
| }: | |
| # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107 | |
| kwargs["audio"] = audios | |
| else: | |
| kwargs["audios"] = audios | |
| processor = self._processor | |
| if ( | |
| hasattr(processor, "image_processor") | |
| and isinstance(processor.image_processor, BaseImageProcessorFast) | |
| and not self.server_args.disable_fast_image_processor | |
| ): | |
| if not _is_npu: | |
| kwargs["device"] = "cuda" | |
| elif processor.__class__.__name__ not in { | |
| "Qwen2_5_VLProcessor", | |
| "Qwen3VLProcessor", | |
| }: | |
| # Note: for qwen-vl, processor has some reshape issue because of dims restriction on Ascend. | |
| kwargs["device"] = "npu" | |
| result = processor.__call__( | |
| text=[input_text], | |
| padding=True, | |
| return_tensors="pt", | |
| **kwargs, | |
| ) | |
| if not self.server_args.keep_mm_feature_on_device: | |
| # move feature tensors to cpu | |
| for feature_name in self.FEATURE_NAMES: | |
| if feature_name in result and isinstance( | |
| result[feature_name], torch.Tensor | |
| ): | |
| result[feature_name] = result[feature_name].to("cpu") | |
| return result | |
| async def process_mm_data_async( | |
| self, | |
| image_data, | |
| audio_data, | |
| input_text, | |
| request_obj, | |
| **kwargs, | |
| ) -> Optional[Dict[str, Any]]: | |
| pass | |
| def get_estimated_frames_list(self, image_data): | |
| """ | |
| estimate the total frame count from all visual input | |
| """ | |
| # Lazy import because decord is not available on some arm platforms. | |
| from decord import VideoReader, cpu | |
| # Before processing inputs | |
| if not image_data or len(image_data) == 0: | |
| return [] | |
| estimated_frames_list = [] | |
| for image in image_data: | |
| if isinstance(image, str) and image.startswith("video:"): | |
| path = image[len("video:") :] | |
| # Estimate frames for the video | |
| vr = VideoReader(path, ctx=cpu(0)) | |
| num_frames = len(vr) | |
| else: | |
| # For images, each contributes one frame | |
| num_frames = 1 | |
| estimated_frames_list.append(num_frames) | |
| return estimated_frames_list | |
| def _load_single_item( | |
| data, | |
| modality: Modality, | |
| frame_count_limit=None, | |
| audio_sample_rate: Optional[int] = None, | |
| discard_alpha_channel=True, | |
| ): | |
| """ | |
| Load a single multimodal data. | |
| If data is precomputed, returns directly. | |
| Static method that can be pickled for multiprocessing""" | |
| if isinstance(data, dict): | |
| return data | |
| try: | |
| if modality == Modality.IMAGE: | |
| img, _ = load_image(data) | |
| if discard_alpha_channel and img.mode != "RGB": | |
| img = img.convert("RGB") | |
| return img | |
| elif modality == Modality.VIDEO: | |
| return load_video(data, frame_count_limit) | |
| elif modality == Modality.AUDIO: | |
| return load_audio(data, audio_sample_rate) | |
| except Exception as e: | |
| raise RuntimeError(f"Error while loading data {data}: {e}") | |
| def submit_data_loading_tasks( | |
| self, | |
| text_parts: List[str], | |
| multimodal_tokens: MultimodalSpecialTokens, | |
| data_iterators: dict[Modality, Iterator[Any]], | |
| discard_alpha_channel: bool = True, | |
| image_estimated_frames_iter: Optional[iter] = None, | |
| image_scaling_factor: float = 1.0, | |
| max_image_frames: int = 30, | |
| audio_sample_rate: Optional[int] = None, | |
| ) -> Tuple[List, List]: | |
| """ | |
| load multimodal data parallelly using iterators. | |
| """ | |
| futures = [] | |
| task_info = [] | |
| for text_part in text_parts: | |
| modality = multimodal_tokens.get_modality_of_token(text_part) | |
| if modality is not None: | |
| data_iterator = data_iterators.get(modality) | |
| if data_iterator is None: | |
| raise ValueError(f"No data iterator found for token: {text_part}") | |
| try: | |
| data = next(data_iterator) | |
| except StopIteration: | |
| raise ValueError( | |
| f"Mismatch: More '{text_part}' tokens found than corresponding data items provided." | |
| ) | |
| frame_count_limit = None | |
| if modality == Modality.IMAGE and image_estimated_frames_iter: | |
| try: | |
| estimated_frames = next(image_estimated_frames_iter) | |
| # Use the pre-calculated scaling factor and max frames | |
| frame_count_limit = max( | |
| 1, int(estimated_frames * image_scaling_factor) | |
| ) | |
| # Ensure we don't exceed the absolute max (redundant if scaling_factor handles it) | |
| # frame_count_limit = min(frame_count_limit, max_image_frames) | |
| except StopIteration: | |
| raise ValueError( | |
| "Mismatch between image tokens and estimated frame counts." | |
| ) | |
| futures.append( | |
| self.io_executor.submit( | |
| BaseMultimodalProcessor._load_single_item, | |
| data, | |
| modality, | |
| frame_count_limit, | |
| audio_sample_rate, | |
| discard_alpha_channel, | |
| ) | |
| ) | |
| task_info.append((modality, data, frame_count_limit)) | |
| for modality, iterator in data_iterators.items(): | |
| try: | |
| next(iterator) | |
| logger.warning( | |
| f"Warning: More {modality.name.lower()} data items provided than corresponding tokens found in the prompt." | |
| ) | |
| except StopIteration: | |
| pass | |
| except Exception: | |
| pass | |
| return futures, task_info | |
| def load_mm_data( | |
| self, | |
| prompt: str, | |
| multimodal_tokens: MultimodalSpecialTokens, | |
| image_data: Optional[list] = None, | |
| video_data: Optional[list] = None, | |
| audio_data: Optional[list] = None, | |
| return_text: Optional[bool] = True, | |
| discard_alpha_channel: bool = True, | |
| audio_sample_rate: Optional[int] = None, | |
| ) -> BaseMultiModalProcessorOutput: | |
| """ | |
| Each frame of video/image will be replaced by a single image token | |
| Args: | |
| multimodal_tokens (list[str]): list of special token which denoting a single multimodal data | |
| e.g. image token or audio token | |
| discard_alpha_channel: if True, discards the alpha channel in the returned images | |
| """ | |
| multimodal_tokens_pattern = multimodal_tokens.get_combined_regex() | |
| if isinstance(prompt, list) and return_text: | |
| assert len(prompt) and isinstance(prompt[0], int) | |
| prompt = self._processor.tokenizer.decode(prompt) | |
| else: | |
| prompt = prompt | |
| assert isinstance(prompt, str) | |
| # split text into list of normal text and special tokens | |
| text_parts = re.split(multimodal_tokens_pattern, prompt) | |
| # collect all data | |
| data_iterators = {} | |
| if multimodal_tokens.image_token and image_data: | |
| data_iterators[Modality.IMAGE] = iter(image_data) | |
| if multimodal_tokens.video_token and video_data: | |
| data_iterators[Modality.VIDEO] = iter(video_data) | |
| if multimodal_tokens.audio_token and audio_data: | |
| data_iterators[Modality.AUDIO] = iter(audio_data) | |
| # futures: the futures of loaded data | |
| # task_info: modality, raw_data, and other metadata of each data | |
| futures, task_info = self.submit_data_loading_tasks( | |
| text_parts=text_parts, | |
| multimodal_tokens=multimodal_tokens, | |
| data_iterators=data_iterators, | |
| discard_alpha_channel=discard_alpha_channel, | |
| audio_sample_rate=audio_sample_rate, | |
| ) | |
| task_info_iter = iter(task_info) | |
| futures_iter = iter(futures) | |
| # Process results | |
| images, videos, audios = [], [], [] | |
| new_text_parts = [] | |
| for text_part in text_parts: | |
| try: | |
| if multimodal_tokens_pattern.match(text_part): | |
| modality, raw_data, frame_limit = next(task_info_iter) | |
| is_precomputed = isinstance(raw_data, dict) | |
| result = next(futures_iter).result() | |
| if modality == Modality.IMAGE: | |
| # If data is already processed it will be a | |
| # dictionary(precomputed). In this case we want to keep the | |
| # expanded tokens in text_part. Otherwise, we will | |
| # call the processor code, so keep only a single image | |
| # token. | |
| mm_tokens = ( | |
| text_part | |
| if is_precomputed | |
| else multimodal_tokens.image_token | |
| ) | |
| frames = [result] if not isinstance(result, list) else result | |
| if frames: | |
| # only for minicpmv | |
| images += frames | |
| new_text_parts += mm_tokens * len(frames) | |
| elif modality == Modality.VIDEO: | |
| # load as video | |
| mm_tokens = ( | |
| text_part | |
| if is_precomputed | |
| else multimodal_tokens.video_token | |
| ) | |
| videos += [result] | |
| new_text_parts += mm_tokens | |
| elif modality == Modality.AUDIO: | |
| # audio | |
| mm_tokens = ( | |
| text_part | |
| if is_precomputed | |
| else multimodal_tokens.audio_token | |
| ) | |
| audios += [result] | |
| new_text_parts += mm_tokens | |
| else: | |
| # normal text | |
| new_text_parts += [text_part] | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"An exception occurred while loading multimodal data: {e}" | |
| ) | |
| return BaseMultiModalProcessorOutput( | |
| images=images, | |
| audios=audios, | |
| videos=videos, | |
| input_text="".join(new_text_parts), | |
| ) | |
| def get_mm_items_offset( | |
| input_ids: torch.Tensor, mm_token_id: int | |
| ) -> List[Tuple[int, int]]: | |
| """ | |
| Get a set of range for mm_items from input_ids | |
| Example: | |
| input_ids = [1, 2, 3, 3, 3, 4, 3, 3] | |
| mm_token_id = 3 | |
| return result = [(2,4),(6,7)] | |
| """ | |
| mask = input_ids == mm_token_id | |
| start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0] | |
| end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0] | |
| return list(zip(start_positions.tolist(), end_positions.tolist())) | |
| def get_mm_items_offset_by_pair( | |
| input_ids: torch.Tensor, mm_start_id: int, mm_end_id: int | |
| ) -> List[Tuple[int, int]]: | |
| indices_start = (input_ids == mm_start_id).nonzero(as_tuple=True)[0] + 1 | |
| indices_end = (input_ids == mm_end_id).nonzero(as_tuple=True)[0] - 1 | |
| return list(zip(indices_start.tolist(), indices_end.tolist())) | |
| def collect_mm_items_from_processor_output( | |
| self, data_dict: dict | |
| ) -> List[MultimodalDataItem]: | |
| """Create mm_items directly from processor output.""" | |
| items: dict[Modality, MultimodalDataItem] = {} | |
| for attr_name, value in data_dict.items(): | |
| if attr_name == "input_ids": | |
| continue | |
| # Get modality for this attribute | |
| modality = self.ATTR_NAME_TO_MODALITY.get(attr_name) | |
| if attr_name == "precomputed_embeddings": | |
| modality_str = data_dict.get("modality") | |
| modality = Modality.IMAGE | |
| if modality_str: | |
| try: | |
| modality = Modality.from_str(modality_str) | |
| except ValueError: | |
| pass | |
| if modality: | |
| # Create item if needed | |
| if modality not in items: | |
| items[modality] = MultimodalDataItem(modality=modality) | |
| if attr_name in self.FEATURE_NAMES: | |
| attr_name = "feature" | |
| items[modality].set(attr_name, value) | |
| return list(items.values()) | |
| def _process_and_collect_mm_items( | |
| self, input_text: str, images=None, audios=None, videos=None, **kwargs | |
| ) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]: | |
| """ | |
| Helper method to process multimodal data and create mm_items in one step. | |
| Returns: | |
| Tuple of (created mm_items, input_ids) | |
| """ | |
| ret = self.process_mm_data( | |
| input_text=input_text, images=images, audios=audios, videos=videos, **kwargs | |
| ) | |
| input_ids = ret["input_ids"].flatten() | |
| collected_items = self.collect_mm_items_from_processor_output(ret) | |
| return collected_items, input_ids, ret | |
| def process_and_combine_mm_data( | |
| self, | |
| base_output: BaseMultiModalProcessorOutput, | |
| mm_tokens: MultimodalSpecialTokens, | |
| **kwargs, | |
| ) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]: | |
| """ | |
| Process multimodal data and return the combined multimodal items and input_ids. | |
| Supports mixed modalities (images and audio in the same request). | |
| Returns: | |
| Tuple of (list of mm_items, input_ids) | |
| """ | |
| # Collect all items and categorize them | |
| all_items = base_output.organize_results() | |
| # Handle text-only case | |
| if not all_items: | |
| input_ids = self._processor.tokenizer( | |
| base_output.input_text, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| ).input_ids.flatten() | |
| return [], input_ids, {} | |
| dict_items, raw_images, raw_audios, raw_videos = [], [], [], [] | |
| for modality, item in all_items: | |
| if isinstance(item, dict): | |
| dict_items.append(item) | |
| elif modality == Modality.IMAGE: | |
| raw_images.append(item) | |
| elif modality == Modality.AUDIO: | |
| raw_audios.append(item) | |
| elif modality == Modality.VIDEO: | |
| raw_videos.append(item) | |
| else: | |
| raise ValueError(f"Unknown multimodal item type: {type(item)}") | |
| # Process items and get input_ids | |
| all_collected_items: list[MultimodalDataItem] = [] | |
| input_ids = None | |
| # Handle raw items (need processing) | |
| if raw_images or raw_audios or raw_videos: | |
| collected_items, input_ids, ret = self._process_and_collect_mm_items( | |
| input_text=base_output.input_text, | |
| images=raw_images, | |
| audios=raw_audios, | |
| videos=raw_videos, | |
| **kwargs, | |
| ) | |
| all_collected_items = collected_items | |
| else: | |
| ret = None | |
| # Handle dict items (already processed) | |
| for dict_item in dict_items: | |
| all_collected_items.extend( | |
| self.collect_mm_items_from_processor_output(dict_item) | |
| ) | |
| # Fallback tokenization if no raw items were processed | |
| if input_ids is None: | |
| input_ids = self._processor.tokenizer( | |
| base_output.input_text, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| ).input_ids.flatten() | |
| # Add offsets to all items | |
| for mm_item in all_collected_items: | |
| mm_token_id = mm_tokens.get_token_id_by_modality(mm_item.modality) | |
| if mm_token_id is None: | |
| raise ValueError(f"No token id found for modality: {mm_item.modality}") | |
| mm_item.offsets = self.get_mm_items_offset( | |
| input_ids=input_ids, | |
| mm_token_id=mm_token_id, | |
| ) | |
| return all_collected_items, input_ids, ret | |
Xet Storage Details
- Size:
- 25.1 kB
- Xet hash:
- 22968374915fa29bec7eb0fa0068f65787c1e061f55dc8ead788311701f0e7ee
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.