"""Processor for Yasa2 that unifies text + media preprocessing.""" from __future__ import annotations import urllib.request from enum import Enum from typing import Any, Dict, List, Literal, Optional, Tuple, Union import numpy as np import torch from PIL import Image from transformers import AutoTokenizer, ProcessorMixin from transformers.processing_utils import MultiModalData from .image_processing_yasa2 import ( Yasa2ImageProcessor, estimate_num_tiles_llava_next, estimate_num_tiles_llava_uhd, image_rgb_decoder_pil, image_rgb_decoder_pil_tiling, process_anyres_image, process_anyres_image_uhd, ) from .video_processing_yasa2 import ( Yasa2VideoProcessor, video_rgb_decoder_factory, ) class MediaType(str, Enum): IMAGE = "image" VIDEO = "video" REKA_IMG_TOKEN = "" IMAGE_START = "" IMAGE_END = "" VIDEO_START = "" SEP_TOKEN = "" PAD_ID = 100257 # <|endoftext|> def _read_bytes_from_uri(uri: str) -> bytes: """Read bytes from a local path or HTTP(S) URL. Args: uri: Local file path or HTTP(S) URL. Returns: Raw bytes content. """ if uri.startswith("http://") or uri.startswith("https://"): with urllib.request.urlopen(uri) as response: return response.read() with open(uri, "rb") as f: return f.read() def _decode_image_payload( payload: Union[str, bytes], img_tiling: bool, tiling_method: str, tiling_size: int, grid_pinpoints: List[Tuple[int, int]], max_tiles_num: int, patch_size: int, ) -> Dict[str, Any]: """Decode image payload bytes or path into a normalized pixel dict. Args: payload: Image path/URL or raw bytes. img_tiling: Whether to enable tiling. tiling_method: Tiling method identifier. tiling_size: Base tile size. grid_pinpoints: Candidate grid pinpoints. max_tiles_num: Maximum tile count for UHD tiling. patch_size: Patch size for UHD tiling. Returns: Dict with decoded image data and tiling metadata. """ if isinstance(payload, str): payload = _read_bytes_from_uri(payload) if img_tiling: return image_rgb_decoder_pil_tiling( payload, size=tiling_size, grid_pinpoints=grid_pinpoints, max_tiles_num=max_tiles_num, patch_size=patch_size, tiling_method=tiling_method, ) return image_rgb_decoder_pil(payload) def _decode_video_payload( payload: Union[str, bytes], num_frames: int, sampling: str, ) -> Dict[str, Any]: """Decode video payload bytes or path into sampled frames. Args: payload: Video path/URL or raw bytes. num_frames: Number of frames to sample. sampling: Sampling strategy. Returns: Dict with sampled frames and metadata. """ if isinstance(payload, str): payload = _read_bytes_from_uri(payload) decoder = video_rgb_decoder_factory( num_frames=num_frames, sampling=sampling ) return decoder(payload) class Yasa2Processor(ProcessorMixin): """Processor that applies the Yasa2 dialog formatting and media decoding.""" attributes = ["tokenizer", "image_processor", "video_processor"] tokenizer_class = "AutoTokenizer" image_processor_class = "AutoImageProcessor" video_processor_class = "AutoVideoProcessor" def __init__( self, tokenizer: AutoTokenizer | None = None, image_processor: Yasa2ImageProcessor | None = None, video_processor: Yasa2VideoProcessor | None = None, num_img_tokens: int = 64, image_token_id: int = 100278, num_video_frames: int = 6, video_sampling: str = "chunk", max_tokens: int = 8192, **kwargs, ) -> None: """Initialize the processor with tokenizer and media processors. Args: tokenizer: Tokenizer for text encoding. image_processor: Image processor for ConvNeXt inputs. video_processor: Video processor for sampled frames. num_img_tokens: Number of image content tokens per image. image_token_id: Token ID for image content tokens. num_video_frames: Number of frames to sample per video. video_sampling: Video sampling strategy. max_tokens: Maximum text token budget. **kwargs: Passed to ProcessorMixin. """ if image_processor is None: image_processor = Yasa2ImageProcessor() if video_processor is None: video_processor = Yasa2VideoProcessor( num_frames=num_video_frames, frame_sample_mode=video_sampling, max_num_frames=num_video_frames, ) super().__init__( tokenizer=tokenizer, image_processor=image_processor, video_processor=video_processor, **kwargs, ) self.num_img_tokens = num_img_tokens self.num_video_frames = num_video_frames self.video_sampling = video_sampling self.max_tokens = max_tokens self.image_token_id = image_token_id def _build_prompt_and_media( self, messages: List[Dict[str, Any]], num_img_tokens: int, num_video_frames: int, video_sampling: str, img_tiling: bool, tiling_method: str, tiling_size: int, grid_pinpoints: List[Tuple[int, int]], max_tiles_num: int, patch_size: int, add_generation_prompt: bool, tools: Optional[List[Dict[str, Any]]] = None, enable_thinking: Optional[bool] = None, ) -> Tuple[str, List[Tuple[MediaType, Dict[str, Any]]]]: """Build Yasa2 prompt text and decode media payloads in prompt order. Prompt formatting is delegated to the tokenizer's shared chat template. Args: messages: Conversation messages in HF format. num_img_tokens: Content tokens per image. num_video_frames: Frames to sample per video. video_sampling: Sampling strategy for videos. img_tiling: Whether to enable tiling. tiling_method: Tiling method identifier. tiling_size: Base tile size. grid_pinpoints: Candidate grid pinpoints. max_tiles_num: Maximum tile count for UHD tiling. patch_size: Patch size for UHD tiling. add_generation_prompt: Whether to append an assistant prefix. tools: Optional tool schema list for system prompt injection. enable_thinking: Unused compatibility flag. Returns: Tuple of prompt string and list of decoded media items. """ media_items: List[Tuple[MediaType, Dict[str, Any]]] = [] def image_builder(item: Dict[str, Any]) -> List[str]: """Serialize an image placeholder sequence for the chat prompt. Args: item: Raw message dict with image metadata. Returns: List[str]: Tokens that represent the image placeholder. """ payload = item.get("image") or item.get("image_url") if payload is None: raise ValueError("Image content requires an 'image' field.") image_datum = _decode_image_payload( payload, img_tiling=img_tiling, tiling_method=tiling_method, tiling_size=tiling_size, grid_pinpoints=grid_pinpoints, max_tiles_num=max_tiles_num, patch_size=patch_size, ) num_tiles = image_datum.get("num_tiles", 1) repeat_tokens = num_img_tokens * num_tiles media_items.append((MediaType.IMAGE, image_datum)) return ( [IMAGE_START] + [REKA_IMG_TOKEN] * repeat_tokens + [IMAGE_END] ) def video_builder(item: Dict[str, Any]) -> List[str]: """Serialize a video placeholder sequence for the chat prompt. Args: item: Raw message dict with video metadata. Returns: List[str]: Tokens that represent the video placeholder. """ payload = item.get("video") or item.get("video_url") if payload is None: raise ValueError("Video content requires a 'video' field.") video_datum = _decode_video_payload( payload, num_frames=num_video_frames, sampling=video_sampling, ) repeat_tokens = num_img_tokens * video_datum.get( "num_frames", num_video_frames ) media_items.append((MediaType.VIDEO, video_datum)) return ( [VIDEO_START] + [REKA_IMG_TOKEN] * repeat_tokens + [VIDEO_END] ) if self.tokenizer is None: raise ValueError( "Yasa2Processor requires a tokenizer to build prompts." ) prompt = self.tokenizer.build_chat_prompt( messages, add_generation_prompt=add_generation_prompt, continue_final_message=False, tools=tools, image_token_builder=image_builder, video_token_builder=video_builder, enable_thinking=enable_thinking, ) return prompt, media_items def apply_chat_template( self, messages: List[Dict[str, Any]], tokenize: bool = False, add_generation_prompt: bool = True, tools: Optional[List[Dict[str, Any]]] = None, return_tensors: Optional[str] = None, return_dict: bool = False, max_length: Optional[int] = None, padding: Union[bool, Literal["longest", "max_length"]] = False, num_img_tokens: Optional[int] = None, num_video_frames: Optional[int] = None, video_sampling: Optional[str] = None, enable_thinking: Optional[bool] = None, img_tiling: bool = True, tiling_method: str = "llava-uhd", tiling_size: int = 512, grid_pinpoints: Optional[List[Tuple[int, int]]] = None, max_tiles_num: int = 4, patch_size: int = 14, return_prompt: bool = False, **kwargs, ) -> Union[str, Dict[str, Any]]: """Apply the Yasa2 dialog template and optionally tokenize + decode media. The chat template is produced via the tokenizer for consistency with text-only formatting. Args: messages: Conversation messages in HF format. tokenize: Whether to tokenize and return tensors. add_generation_prompt: Whether to append an assistant prefix. tools: Optional tool schema list for system prompt injection. return_tensors: Tensor type for outputs (e.g., "pt"). return_dict: Whether to return a dict payload. max_length: Optional max token length. padding: Padding strategy (False/True/"longest"/"max_length"). num_img_tokens: Override for image content tokens. num_video_frames: Override for video frame count. video_sampling: Override for video sampling strategy. enable_thinking: Unused compatibility flag. img_tiling: Whether to enable tiling for images. tiling_method: Tiling method identifier. tiling_size: Base tile size. grid_pinpoints: Candidate grid pinpoints. max_tiles_num: Maximum tile count for UHD tiling. patch_size: Patch size for UHD tiling. return_prompt: Whether to include the prompt string in output. **kwargs: Unused extra arguments for compatibility. Returns: Prompt string if tokenize is False, otherwise a dict of tensors. """ if grid_pinpoints is None: grid_pinpoints = [ (2, 2), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (4, 1), ] num_img_tokens = num_img_tokens or self.num_img_tokens num_video_frames = num_video_frames or self.num_video_frames video_sampling = video_sampling or self.video_sampling user_max_length = max_length max_tokens = user_max_length or self.max_tokens prompt, media_items = self._build_prompt_and_media( messages=messages, num_img_tokens=num_img_tokens, num_video_frames=num_video_frames, video_sampling=video_sampling, img_tiling=img_tiling, tiling_method=tiling_method, tiling_size=tiling_size, grid_pinpoints=grid_pinpoints, max_tiles_num=max_tiles_num, patch_size=patch_size, add_generation_prompt=add_generation_prompt, tools=tools, enable_thinking=enable_thinking, ) if not tokenize: return prompt expected_img_tokens = 0 for media_type, media_datum in media_items: if media_type == MediaType.IMAGE: expected_img_tokens += num_img_tokens * media_datum.get( "num_tiles", 1 ) elif media_type == MediaType.VIDEO: expected_img_tokens += num_img_tokens * media_datum.get( "num_frames", num_video_frames ) input_ids = self.tokenizer.tiktoken.encode( prompt, allowed_special="all" ) input_ids = input_ids[:max_tokens] if expected_img_tokens: actual_img_tokens = sum( 1 for token_id in input_ids if token_id == self.image_token_id ) # Ensure truncation did not drop any media placeholder tokens. if actual_img_tokens != expected_img_tokens: raise ValueError( "Prompt truncation dropped image placeholder tokens. " "Increase max_length/max_tokens or reduce media inputs." ) attention_mask = [1] * len(input_ids) token_type_ids, mm_token_type_ids = self._build_mm_token_type_ids( input_ids ) if padding not in (False, True, "longest", "max_length"): raise ValueError(f"Unsupported padding value: {padding}") if padding in (True, "longest", "max_length"): pad_to_length = ( max_tokens if (padding == "max_length" or user_max_length) else len(input_ids) ) pad_len = pad_to_length - len(input_ids) if pad_len > 0: # GPT-style decoder-only LMs use absolute positions, so left-pad to # keep real tokens aligned at the end and avoid position offsets. input_ids = [PAD_ID] * pad_len + input_ids attention_mask = [0] * pad_len + attention_mask token_type_ids = [0] * pad_len + token_type_ids mm_token_type_ids = [0] * pad_len + mm_token_type_ids pixel_values_list = [] patch_attention_list = [] for media_type, media_datum in media_items: if media_type == MediaType.IMAGE: image_outputs = self.image_processor( images=media_datum["pixel_values"], return_tensors="pt" ) pixel_values_list.append(image_outputs["pixel_values"]) if "patch_attention_mask" in image_outputs: patch_attention_list.append( image_outputs["patch_attention_mask"] ) elif media_type == MediaType.VIDEO: video_outputs = self.video_processor.preprocess( videos=media_datum["pixel_values"], return_tensors="pt" ) pixel_values_list.append(video_outputs["pixel_values"]) patch_attention_list.append( video_outputs["patch_attention_mask"] ) else: raise ValueError(f"Unsupported media type: {media_type}") if pixel_values_list: pixel_values = torch.cat(pixel_values_list, dim=0) else: pixel_values = torch.tensor([]) if patch_attention_list: patch_attention_mask = torch.cat(patch_attention_list, dim=0) else: patch_attention_mask = torch.tensor([]) if return_tensors == "pt": input_ids = torch.tensor(input_ids, dtype=torch.long) attention_mask = torch.tensor(attention_mask, dtype=torch.long) token_type_ids = torch.tensor(token_type_ids, dtype=torch.long) mm_token_type_ids = torch.tensor( mm_token_type_ids, dtype=torch.long ) if input_ids.dim() == 1: input_ids = input_ids.unsqueeze(0) if attention_mask.dim() == 1: attention_mask = attention_mask.unsqueeze(0) if token_type_ids.dim() == 1: token_type_ids = token_type_ids.unsqueeze(0) if mm_token_type_ids.dim() == 1: mm_token_type_ids = mm_token_type_ids.unsqueeze(0) output = { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "mm_token_type_ids": mm_token_type_ids, "pixel_values": pixel_values, "patch_attention_mask": patch_attention_mask, } if return_prompt: output["prompt"] = prompt return output if return_dict else output def __call__( self, images: Optional[Any] = None, text: Optional[Union[str, List[str]]] = None, videos: Optional[Any] = None, audio: Optional[Any] = None, **kwargs: Any, ) -> Any: """Run the processor and ensure multimodal token identifiers are present. Args: images: Optional image inputs. text: Optional textual inputs. videos: Optional video inputs. audio: Optional audio inputs. **kwargs: Additional keyword arguments forwarded to the base processor. Returns: Any: Processor outputs augmented with token type ids when needed. """ kwargs.pop("return_mm_token_type_ids", None) image_processor = getattr(self, "image_processor", None) img_tiling = kwargs.get("img_tiling", True) tiling_method = kwargs.get( "tiling_method", getattr(image_processor, "tiling_method", "llava-uhd"), ) tiling_size = kwargs.get("tiling_size") if tiling_size is None and image_processor is not None: size = getattr(image_processor, "size", None) if isinstance(size, dict) and "shortest_edge" in size: tiling_size = int(size["shortest_edge"]) elif isinstance(size, int): tiling_size = size tiling_size = tiling_size or 512 grid_pinpoints = kwargs.get("grid_pinpoints") if grid_pinpoints is None: grid_pinpoints = [ (2, 2), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (4, 1), ] max_tiles_num = kwargs.get( "max_tiles_num", getattr(image_processor, "max_tiles_num", 4) ) patch_size = kwargs.get( "patch_size", getattr(image_processor, "patch_size", 14) ) # vLLM infers placeholder splits from the tokenized prompt, so expand # image placeholders to the exact tile count before tokenization. if isinstance(text, str) and ( images is not None or videos is not None ): if ( REKA_IMG_TOKEN not in text and IMAGE_START not in text and VIDEO_START not in text ): text = self._prepend_mm_placeholders( text=text, images=images, videos=videos, **kwargs ) else: text = self._expand_image_placeholders( text=text, images=images, **kwargs ) # vLLM derives placeholder lengths from processor outputs; tile before tokenization. if images is not None and img_tiling: images = self._tile_images( images=images, tiling_method=tiling_method, tiling_size=tiling_size, grid_pinpoints=grid_pinpoints, max_tiles_num=max_tiles_num, patch_size=patch_size, ) # vLLM should treat tiled images as one prompt with multiple images. if isinstance(text, str) and isinstance(images, list): text = [text] images = [images] outputs = super().__call__( images=images, text=text, videos=videos, audio=audio, **kwargs ) if "input_ids" in outputs and "token_type_ids" not in outputs: token_type_ids, mm_token_type_ids = self._build_mm_token_type_ids( outputs["input_ids"] ) outputs["token_type_ids"] = token_type_ids outputs["mm_token_type_ids"] = mm_token_type_ids return outputs def _expand_image_placeholders( self, text: str, images: Optional[Any], **kwargs: Any, ) -> str: if images is None or IMAGE_START not in text or IMAGE_END not in text: return text image_list = ( list(images) if isinstance(images, (list, tuple)) else [images] ) image_processor = getattr(self, "image_processor", None) img_tiling = kwargs.get("img_tiling", True) tiling_method = kwargs.get( "tiling_method", getattr(image_processor, "tiling_method", "llava-uhd"), ) tiling_size = kwargs.get("tiling_size") if tiling_size is None and image_processor is not None: size = getattr(image_processor, "size", None) if isinstance(size, dict) and "shortest_edge" in size: tiling_size = int(size["shortest_edge"]) elif isinstance(size, int): tiling_size = size tiling_size = tiling_size or 512 grid_pinpoints = kwargs.get("grid_pinpoints") if grid_pinpoints is None: grid_pinpoints = [ (2, 2), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (4, 1), ] max_tiles_num = kwargs.get( "max_tiles_num", getattr(image_processor, "max_tiles_num", 4) ) patch_size = kwargs.get( "patch_size", getattr(image_processor, "patch_size", 14) ) expected_tokens = [] for image in image_list: width = height = 0 if hasattr(image, "size"): width, height = image.size elif isinstance(image, (list, tuple)) and len(image) >= 2: height, width = int(image[0]), int(image[1]) if img_tiling and width > 0 and height > 0: if str(tiling_method).lower() == "llava-next": tiles = estimate_num_tiles_llava_next( (width, height), size=tiling_size, grid_pinpoints=grid_pinpoints, ) else: tiles = estimate_num_tiles_llava_uhd( (width, height), max_tiles_num=max_tiles_num, scale_resolution=tiling_size, patch_size=patch_size, never_split=False, ) else: tiles = 1 expected_tokens.append(self.num_img_tokens * tiles) parts = [] remaining = text for tokens in expected_tokens: start = remaining.find(IMAGE_START) end = remaining.find(IMAGE_END, start + len(IMAGE_START)) if start == -1 or end == -1: return text parts.append(remaining[:start]) parts.append(IMAGE_START + (REKA_IMG_TOKEN * tokens) + IMAGE_END) remaining = remaining[end + len(IMAGE_END) :] parts.append(remaining) new_text = "".join(parts) return new_text def _tile_images( self, images: Any, tiling_method: str, tiling_size: int, grid_pinpoints: List[Tuple[int, int]], max_tiles_num: int, patch_size: int, ) -> Any: # vLLM expects one image entry per tile so it can emit per-tile embeddings. image_list = ( list(images) if isinstance(images, (list, tuple)) else [images] ) tiled_images: List[Any] = [] for image in image_list: if image is None: continue if isinstance(image, torch.Tensor): tiled_images.append(image) continue if isinstance(image, np.ndarray): image = Image.fromarray(image) if isinstance(image, Image.Image): # Match the tiling logic used for placeholder expansion. if str(tiling_method).lower() == "llava-next": tiles = process_anyres_image( image, size=tiling_size, grid_pinpoints=grid_pinpoints ) else: tiles = process_anyres_image_uhd( image, max_tiles_num=max_tiles_num, scale_resolution=tiling_size, patch_size=patch_size, never_split=False, ) tiled_images.extend(tiles) continue tiled_images.append(image) return ( tiled_images if isinstance(images, (list, tuple)) else tiled_images[0] ) def _prepend_mm_placeholders( self, text: str, images: Optional[Any], videos: Optional[Any], **kwargs: Any, ) -> str: """Prepend placeholder tokens when media is provided without markers.""" # Keep placeholders aligned with tiling so vLLM doesn't under/over-allocate. image_list = ( list(images) if isinstance(images, (list, tuple)) else ([images] if images is not None else []) ) num_images = len(image_list) num_videos = self._count_media_items(videos) if num_images == 0 and num_videos == 0: return text image_processor = getattr(self, "image_processor", None) img_tiling = kwargs.get("img_tiling", True) tiling_method = kwargs.get( "tiling_method", getattr(image_processor, "tiling_method", "llava-uhd"), ) tiling_size = kwargs.get("tiling_size") if tiling_size is None and image_processor is not None: size = getattr(image_processor, "size", None) if isinstance(size, dict) and "shortest_edge" in size: tiling_size = int(size["shortest_edge"]) elif isinstance(size, int): tiling_size = size tiling_size = tiling_size or 512 grid_pinpoints = kwargs.get("grid_pinpoints") if grid_pinpoints is None: grid_pinpoints = [ (2, 2), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (4, 1), ] max_tiles_num = kwargs.get( "max_tiles_num", getattr(image_processor, "max_tiles_num", 4) ) patch_size = kwargs.get( "patch_size", getattr(image_processor, "patch_size", 14) ) def _get_image_size(image: Any) -> Tuple[int, int]: if hasattr(image, "size"): size = image.size if isinstance(size, (list, tuple)) and len(size) >= 2: return int(size[0]), int(size[1]) if hasattr(image, "shape"): shape = image.shape if isinstance(shape, (list, tuple)) and len(shape) >= 2: return int(shape[1]), int(shape[0]) if isinstance(image, (list, tuple)) and len(image) >= 2: return int(image[1]), int(image[0]) return 0, 0 placeholder = "" for image in image_list: tiles = 1 if img_tiling: width, height = _get_image_size(image) if width > 0 and height > 0: if str(tiling_method).lower() == "llava-next": tiles = estimate_num_tiles_llava_next( (width, height), size=tiling_size, grid_pinpoints=grid_pinpoints, ) else: tiles = estimate_num_tiles_llava_uhd( (width, height), max_tiles_num=max_tiles_num, scale_resolution=tiling_size, patch_size=patch_size, never_split=False, ) placeholder += IMAGE_START placeholder += REKA_IMG_TOKEN * (self.num_img_tokens * tiles) placeholder += IMAGE_END for _ in range(num_videos): placeholder += VIDEO_START placeholder += REKA_IMG_TOKEN * ( self.num_img_tokens * self.num_video_frames ) placeholder += VIDEO_END return f"{placeholder}{text}" @staticmethod def _count_media_items(payload: Optional[Any]) -> int: """Best-effort count of media items for placeholder insertion.""" if payload is None: return 0 if isinstance(payload, (list, tuple)): return len(payload) return 1 def _build_mm_token_type_ids(self, input_ids: Any) -> Tuple[Any, Any]: """Compute token_type_ids that mark multimodal placeholders. Args: input_ids: Input IDs or sequences containing tokenizer ids. Returns: Tuple[Any, Any]: Regular and multimodal token type ids detected from placeholders. """ if self.tokenizer is None: return input_ids, input_ids img_token_id = self.image_token_id if isinstance(input_ids, torch.Tensor): mm_token_type_ids = (input_ids == img_token_id).long() token_type_ids = mm_token_type_ids.clone() return token_type_ids, mm_token_type_ids if isinstance(input_ids, (list, tuple)): if input_ids and isinstance(input_ids[0], (list, tuple)): mm_token_type_ids = [ [1 if token_id == img_token_id else 0 for token_id in seq] for seq in input_ids ] else: mm_token_type_ids = [ 1 if token_id == img_token_id else 0 for token_id in input_ids ] token_type_ids = list(mm_token_type_ids) return token_type_ids, mm_token_type_ids if hasattr(input_ids, "tolist"): ids = input_ids.tolist() token_type_ids, mm_token_type_ids = self._build_mm_token_type_ids( ids ) return token_type_ids, mm_token_type_ids return input_ids, input_ids def _get_num_multimodal_tokens( self, image_sizes: Optional[List[List[int]]] = None, video_sizes: Optional[List[List[int]]] = None, **kwargs: Any, ) -> MultiModalData: """Estimate the count of multimodal tokens from provided media sizes. Args: image_sizes: Per-image sizes as (height, width) tuples. video_sizes: Per-video sizes as (num_frames, height, width) tuples. **kwargs: Ignored compatibility arguments accepted by parent helpers. Returns: MultiModalData: Token counts for the vision modalities. """ vision_data: Dict[str, List[int]] = {} if image_sizes is not None: image_processor = getattr(self, "image_processor", None) img_tiling = kwargs.get("img_tiling", True) tiling_method = kwargs.get( "tiling_method", getattr(image_processor, "tiling_method", "llava-uhd"), ) tiling_size = kwargs.get("tiling_size") if tiling_size is None and image_processor is not None: size = getattr(image_processor, "size", None) if isinstance(size, dict) and "shortest_edge" in size: tiling_size = int(size["shortest_edge"]) elif isinstance(size, int): tiling_size = size tiling_size = tiling_size or 512 grid_pinpoints = kwargs.get("grid_pinpoints") if grid_pinpoints is None: grid_pinpoints = [ (2, 2), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (4, 1), ] max_tiles_num = kwargs.get( "max_tiles_num", getattr(image_processor, "max_tiles_num", 4) ) patch_size = kwargs.get( "patch_size", getattr(image_processor, "patch_size", 14) ) # vLLM splits placeholder positions using per-image token/patch counts. num_image_tokens: List[int] = [] num_image_patches: List[int] = [] for image_size in image_sizes: height = width = 0 if image_size and len(image_size) >= 2: height, width = int(image_size[0]), int(image_size[1]) tiles = 1 if img_tiling and width > 0 and height > 0: if str(tiling_method).lower() == "llava-next": tiles = estimate_num_tiles_llava_next( (width, height), size=tiling_size, grid_pinpoints=grid_pinpoints, ) else: tiles = estimate_num_tiles_llava_uhd( (width, height), max_tiles_num=max_tiles_num, scale_resolution=tiling_size, patch_size=patch_size, never_split=False, ) num_image_tokens.append(self.num_img_tokens * tiles) num_image_patches.append(tiles) vision_data["num_image_tokens"] = num_image_tokens vision_data["num_image_patches"] = num_image_patches else: vision_data["num_image_tokens"] = [] vision_data["num_image_patches"] = [] if video_sizes is not None: video_tokens: List[int] = [] for video_size in video_sizes: num_frames = video_size[0] if video_size else 0 num_frames = min( num_frames or self.num_video_frames, self.num_video_frames ) video_tokens.append(self.num_img_tokens * num_frames) vision_data["num_video_tokens"] = video_tokens return MultiModalData(**vision_data) Yasa2Processor.register_for_auto_class()