| | """Processor for Yasa2 that unifies text + media preprocessing.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import urllib.request |
| | from enum import Enum |
| | from typing import Any, Dict, List, Literal, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| | from transformers import AutoTokenizer, ProcessorMixin |
| | from transformers.processing_utils import MultiModalData |
| |
|
| | from .image_processing_yasa2 import ( |
| | Yasa2ImageProcessor, |
| | estimate_num_tiles_llava_next, |
| | estimate_num_tiles_llava_uhd, |
| | image_rgb_decoder_pil, |
| | image_rgb_decoder_pil_tiling, |
| | process_anyres_image, |
| | process_anyres_image_uhd, |
| | ) |
| | from .video_processing_yasa2 import ( |
| | Yasa2VideoProcessor, |
| | video_rgb_decoder_factory, |
| | ) |
| |
|
| |
|
| | class MediaType(str, Enum): |
| | IMAGE = "image" |
| | VIDEO = "video" |
| |
|
| |
|
| | REKA_IMG_TOKEN = "<REKA_IMG_TOKEN>" |
| | IMAGE_START = "<image>" |
| | IMAGE_END = "</image>" |
| | VIDEO_START = "<video>" |
| | VIDEO_END = "</video>" |
| | SEP_TOKEN = "<sep>" |
| |
|
| | PAD_ID = 100257 |
| |
|
| |
|
| | def _read_bytes_from_uri(uri: str) -> bytes: |
| | """Read bytes from a local path or HTTP(S) URL. |
| | |
| | Args: |
| | uri: Local file path or HTTP(S) URL. |
| | |
| | Returns: |
| | Raw bytes content. |
| | """ |
| | if uri.startswith("http://") or uri.startswith("https://"): |
| | with urllib.request.urlopen(uri) as response: |
| | return response.read() |
| | with open(uri, "rb") as f: |
| | return f.read() |
| |
|
| |
|
| | def _decode_image_payload( |
| | payload: Union[str, bytes], |
| | img_tiling: bool, |
| | tiling_method: str, |
| | tiling_size: int, |
| | grid_pinpoints: List[Tuple[int, int]], |
| | max_tiles_num: int, |
| | patch_size: int, |
| | ) -> Dict[str, Any]: |
| | """Decode image payload bytes or path into a normalized pixel dict. |
| | |
| | Args: |
| | payload: Image path/URL or raw bytes. |
| | img_tiling: Whether to enable tiling. |
| | tiling_method: Tiling method identifier. |
| | tiling_size: Base tile size. |
| | grid_pinpoints: Candidate grid pinpoints. |
| | max_tiles_num: Maximum tile count for UHD tiling. |
| | patch_size: Patch size for UHD tiling. |
| | |
| | Returns: |
| | Dict with decoded image data and tiling metadata. |
| | """ |
| | if isinstance(payload, str): |
| | payload = _read_bytes_from_uri(payload) |
| | if img_tiling: |
| | return image_rgb_decoder_pil_tiling( |
| | payload, |
| | size=tiling_size, |
| | grid_pinpoints=grid_pinpoints, |
| | max_tiles_num=max_tiles_num, |
| | patch_size=patch_size, |
| | tiling_method=tiling_method, |
| | ) |
| | return image_rgb_decoder_pil(payload) |
| |
|
| |
|
| | def _decode_video_payload( |
| | payload: Union[str, bytes], |
| | num_frames: int, |
| | sampling: str, |
| | ) -> Dict[str, Any]: |
| | """Decode video payload bytes or path into sampled frames. |
| | |
| | Args: |
| | payload: Video path/URL or raw bytes. |
| | num_frames: Number of frames to sample. |
| | sampling: Sampling strategy. |
| | |
| | Returns: |
| | Dict with sampled frames and metadata. |
| | """ |
| | if isinstance(payload, str): |
| | payload = _read_bytes_from_uri(payload) |
| | decoder = video_rgb_decoder_factory( |
| | num_frames=num_frames, sampling=sampling |
| | ) |
| | return decoder(payload) |
| |
|
| |
|
| | class Yasa2Processor(ProcessorMixin): |
| | """Processor that applies the Yasa2 dialog formatting and media decoding.""" |
| |
|
| | attributes = ["tokenizer", "image_processor", "video_processor"] |
| | tokenizer_class = "AutoTokenizer" |
| | image_processor_class = "AutoImageProcessor" |
| | video_processor_class = "AutoVideoProcessor" |
| |
|
| | def __init__( |
| | self, |
| | tokenizer: AutoTokenizer | None = None, |
| | image_processor: Yasa2ImageProcessor | None = None, |
| | video_processor: Yasa2VideoProcessor | None = None, |
| | num_img_tokens: int = 64, |
| | image_token_id: int = 100278, |
| | num_video_frames: int = 6, |
| | video_sampling: str = "chunk", |
| | max_tokens: int = 8192, |
| | **kwargs, |
| | ) -> None: |
| | """Initialize the processor with tokenizer and media processors. |
| | |
| | Args: |
| | tokenizer: Tokenizer for text encoding. |
| | image_processor: Image processor for ConvNeXt inputs. |
| | video_processor: Video processor for sampled frames. |
| | num_img_tokens: Number of image content tokens per image. |
| | image_token_id: Token ID for image content tokens. |
| | num_video_frames: Number of frames to sample per video. |
| | video_sampling: Video sampling strategy. |
| | max_tokens: Maximum text token budget. |
| | **kwargs: Passed to ProcessorMixin. |
| | """ |
| | if image_processor is None: |
| | image_processor = Yasa2ImageProcessor() |
| | if video_processor is None: |
| | video_processor = Yasa2VideoProcessor( |
| | num_frames=num_video_frames, |
| | frame_sample_mode=video_sampling, |
| | max_num_frames=num_video_frames, |
| | ) |
| | super().__init__( |
| | tokenizer=tokenizer, |
| | image_processor=image_processor, |
| | video_processor=video_processor, |
| | **kwargs, |
| | ) |
| | self.num_img_tokens = num_img_tokens |
| | self.num_video_frames = num_video_frames |
| | self.video_sampling = video_sampling |
| | self.max_tokens = max_tokens |
| | self.image_token_id = image_token_id |
| |
|
| | def _build_prompt_and_media( |
| | self, |
| | messages: List[Dict[str, Any]], |
| | num_img_tokens: int, |
| | num_video_frames: int, |
| | video_sampling: str, |
| | img_tiling: bool, |
| | tiling_method: str, |
| | tiling_size: int, |
| | grid_pinpoints: List[Tuple[int, int]], |
| | max_tiles_num: int, |
| | patch_size: int, |
| | add_generation_prompt: bool, |
| | tools: Optional[List[Dict[str, Any]]] = None, |
| | enable_thinking: Optional[bool] = None, |
| | ) -> Tuple[str, List[Tuple[MediaType, Dict[str, Any]]]]: |
| | """Build Yasa2 prompt text and decode media payloads in prompt order. |
| | |
| | Prompt formatting is delegated to the tokenizer's shared chat template. |
| | |
| | Args: |
| | messages: Conversation messages in HF format. |
| | num_img_tokens: Content tokens per image. |
| | num_video_frames: Frames to sample per video. |
| | video_sampling: Sampling strategy for videos. |
| | img_tiling: Whether to enable tiling. |
| | tiling_method: Tiling method identifier. |
| | tiling_size: Base tile size. |
| | grid_pinpoints: Candidate grid pinpoints. |
| | max_tiles_num: Maximum tile count for UHD tiling. |
| | patch_size: Patch size for UHD tiling. |
| | add_generation_prompt: Whether to append an assistant prefix. |
| | tools: Optional tool schema list for system prompt injection. |
| | enable_thinking: Unused compatibility flag. |
| | Returns: |
| | Tuple of prompt string and list of decoded media items. |
| | """ |
| | media_items: List[Tuple[MediaType, Dict[str, Any]]] = [] |
| |
|
| | def image_builder(item: Dict[str, Any]) -> List[str]: |
| | """Serialize an image placeholder sequence for the chat prompt. |
| | |
| | Args: |
| | item: Raw message dict with image metadata. |
| | |
| | Returns: |
| | List[str]: Tokens that represent the image placeholder. |
| | """ |
| | payload = item.get("image") or item.get("image_url") |
| | if payload is None: |
| | raise ValueError("Image content requires an 'image' field.") |
| | image_datum = _decode_image_payload( |
| | payload, |
| | img_tiling=img_tiling, |
| | tiling_method=tiling_method, |
| | tiling_size=tiling_size, |
| | grid_pinpoints=grid_pinpoints, |
| | max_tiles_num=max_tiles_num, |
| | patch_size=patch_size, |
| | ) |
| | num_tiles = image_datum.get("num_tiles", 1) |
| | repeat_tokens = num_img_tokens * num_tiles |
| | media_items.append((MediaType.IMAGE, image_datum)) |
| | return ( |
| | [IMAGE_START] + [REKA_IMG_TOKEN] * repeat_tokens + [IMAGE_END] |
| | ) |
| |
|
| | def video_builder(item: Dict[str, Any]) -> List[str]: |
| | """Serialize a video placeholder sequence for the chat prompt. |
| | |
| | Args: |
| | item: Raw message dict with video metadata. |
| | |
| | Returns: |
| | List[str]: Tokens that represent the video placeholder. |
| | """ |
| | payload = item.get("video") or item.get("video_url") |
| | if payload is None: |
| | raise ValueError("Video content requires a 'video' field.") |
| | video_datum = _decode_video_payload( |
| | payload, |
| | num_frames=num_video_frames, |
| | sampling=video_sampling, |
| | ) |
| | repeat_tokens = num_img_tokens * video_datum.get( |
| | "num_frames", num_video_frames |
| | ) |
| | media_items.append((MediaType.VIDEO, video_datum)) |
| | return ( |
| | [VIDEO_START] + [REKA_IMG_TOKEN] * repeat_tokens + [VIDEO_END] |
| | ) |
| |
|
| | if self.tokenizer is None: |
| | raise ValueError( |
| | "Yasa2Processor requires a tokenizer to build prompts." |
| | ) |
| | prompt = self.tokenizer.build_chat_prompt( |
| | messages, |
| | add_generation_prompt=add_generation_prompt, |
| | continue_final_message=False, |
| | tools=tools, |
| | image_token_builder=image_builder, |
| | video_token_builder=video_builder, |
| | enable_thinking=enable_thinking, |
| | ) |
| | return prompt, media_items |
| |
|
| | def apply_chat_template( |
| | self, |
| | messages: List[Dict[str, Any]], |
| | tokenize: bool = False, |
| | add_generation_prompt: bool = True, |
| | tools: Optional[List[Dict[str, Any]]] = None, |
| | return_tensors: Optional[str] = None, |
| | return_dict: bool = False, |
| | max_length: Optional[int] = None, |
| | padding: Union[bool, Literal["longest", "max_length"]] = False, |
| | num_img_tokens: Optional[int] = None, |
| | num_video_frames: Optional[int] = None, |
| | video_sampling: Optional[str] = None, |
| | enable_thinking: Optional[bool] = None, |
| | img_tiling: bool = True, |
| | tiling_method: str = "llava-uhd", |
| | tiling_size: int = 512, |
| | grid_pinpoints: Optional[List[Tuple[int, int]]] = None, |
| | max_tiles_num: int = 4, |
| | patch_size: int = 14, |
| | return_prompt: bool = False, |
| | **kwargs, |
| | ) -> Union[str, Dict[str, Any]]: |
| | """Apply the Yasa2 dialog template and optionally tokenize + decode media. |
| | |
| | The chat template is produced via the tokenizer for consistency with |
| | text-only formatting. |
| | |
| | Args: |
| | messages: Conversation messages in HF format. |
| | tokenize: Whether to tokenize and return tensors. |
| | add_generation_prompt: Whether to append an assistant prefix. |
| | tools: Optional tool schema list for system prompt injection. |
| | return_tensors: Tensor type for outputs (e.g., "pt"). |
| | return_dict: Whether to return a dict payload. |
| | max_length: Optional max token length. |
| | padding: Padding strategy (False/True/"longest"/"max_length"). |
| | num_img_tokens: Override for image content tokens. |
| | num_video_frames: Override for video frame count. |
| | video_sampling: Override for video sampling strategy. |
| | enable_thinking: Unused compatibility flag. |
| | img_tiling: Whether to enable tiling for images. |
| | tiling_method: Tiling method identifier. |
| | tiling_size: Base tile size. |
| | grid_pinpoints: Candidate grid pinpoints. |
| | max_tiles_num: Maximum tile count for UHD tiling. |
| | patch_size: Patch size for UHD tiling. |
| | return_prompt: Whether to include the prompt string in output. |
| | **kwargs: Unused extra arguments for compatibility. |
| | |
| | Returns: |
| | Prompt string if tokenize is False, otherwise a dict of tensors. |
| | """ |
| | if grid_pinpoints is None: |
| | grid_pinpoints = [ |
| | (2, 2), |
| | (1, 2), |
| | (2, 1), |
| | (1, 3), |
| | (3, 1), |
| | (1, 4), |
| | (4, 1), |
| | ] |
| | num_img_tokens = num_img_tokens or self.num_img_tokens |
| | num_video_frames = num_video_frames or self.num_video_frames |
| | video_sampling = video_sampling or self.video_sampling |
| | user_max_length = max_length |
| | max_tokens = user_max_length or self.max_tokens |
| |
|
| | prompt, media_items = self._build_prompt_and_media( |
| | messages=messages, |
| | num_img_tokens=num_img_tokens, |
| | num_video_frames=num_video_frames, |
| | video_sampling=video_sampling, |
| | img_tiling=img_tiling, |
| | tiling_method=tiling_method, |
| | tiling_size=tiling_size, |
| | grid_pinpoints=grid_pinpoints, |
| | max_tiles_num=max_tiles_num, |
| | patch_size=patch_size, |
| | add_generation_prompt=add_generation_prompt, |
| | tools=tools, |
| | enable_thinking=enable_thinking, |
| | ) |
| |
|
| | if not tokenize: |
| | return prompt |
| |
|
| | expected_img_tokens = 0 |
| | for media_type, media_datum in media_items: |
| | if media_type == MediaType.IMAGE: |
| | expected_img_tokens += num_img_tokens * media_datum.get( |
| | "num_tiles", 1 |
| | ) |
| | elif media_type == MediaType.VIDEO: |
| | expected_img_tokens += num_img_tokens * media_datum.get( |
| | "num_frames", num_video_frames |
| | ) |
| |
|
| | input_ids = self.tokenizer.tiktoken.encode( |
| | prompt, allowed_special="all" |
| | ) |
| | input_ids = input_ids[:max_tokens] |
| | if expected_img_tokens: |
| | actual_img_tokens = sum( |
| | 1 for token_id in input_ids if token_id == self.image_token_id |
| | ) |
| | |
| | if actual_img_tokens != expected_img_tokens: |
| | raise ValueError( |
| | "Prompt truncation dropped image placeholder tokens. " |
| | "Increase max_length/max_tokens or reduce media inputs." |
| | ) |
| |
|
| | attention_mask = [1] * len(input_ids) |
| | token_type_ids, mm_token_type_ids = self._build_mm_token_type_ids( |
| | input_ids |
| | ) |
| |
|
| | if padding not in (False, True, "longest", "max_length"): |
| | raise ValueError(f"Unsupported padding value: {padding}") |
| | if padding in (True, "longest", "max_length"): |
| | pad_to_length = ( |
| | max_tokens |
| | if (padding == "max_length" or user_max_length) |
| | else len(input_ids) |
| | ) |
| | pad_len = pad_to_length - len(input_ids) |
| | if pad_len > 0: |
| | |
| | |
| | input_ids = [PAD_ID] * pad_len + input_ids |
| | attention_mask = [0] * pad_len + attention_mask |
| | token_type_ids = [0] * pad_len + token_type_ids |
| | mm_token_type_ids = [0] * pad_len + mm_token_type_ids |
| |
|
| | pixel_values_list = [] |
| | patch_attention_list = [] |
| | for media_type, media_datum in media_items: |
| | if media_type == MediaType.IMAGE: |
| | image_outputs = self.image_processor( |
| | images=media_datum["pixel_values"], return_tensors="pt" |
| | ) |
| | pixel_values_list.append(image_outputs["pixel_values"]) |
| | if "patch_attention_mask" in image_outputs: |
| | patch_attention_list.append( |
| | image_outputs["patch_attention_mask"] |
| | ) |
| | elif media_type == MediaType.VIDEO: |
| | video_outputs = self.video_processor.preprocess( |
| | videos=media_datum["pixel_values"], return_tensors="pt" |
| | ) |
| | pixel_values_list.append(video_outputs["pixel_values"]) |
| | patch_attention_list.append( |
| | video_outputs["patch_attention_mask"] |
| | ) |
| | else: |
| | raise ValueError(f"Unsupported media type: {media_type}") |
| |
|
| | if pixel_values_list: |
| | pixel_values = torch.cat(pixel_values_list, dim=0) |
| | else: |
| | pixel_values = torch.tensor([]) |
| | if patch_attention_list: |
| | patch_attention_mask = torch.cat(patch_attention_list, dim=0) |
| | else: |
| | patch_attention_mask = torch.tensor([]) |
| |
|
| | if return_tensors == "pt": |
| | input_ids = torch.tensor(input_ids, dtype=torch.long) |
| | attention_mask = torch.tensor(attention_mask, dtype=torch.long) |
| | token_type_ids = torch.tensor(token_type_ids, dtype=torch.long) |
| | mm_token_type_ids = torch.tensor( |
| | mm_token_type_ids, dtype=torch.long |
| | ) |
| | if input_ids.dim() == 1: |
| | input_ids = input_ids.unsqueeze(0) |
| | if attention_mask.dim() == 1: |
| | attention_mask = attention_mask.unsqueeze(0) |
| | if token_type_ids.dim() == 1: |
| | token_type_ids = token_type_ids.unsqueeze(0) |
| | if mm_token_type_ids.dim() == 1: |
| | mm_token_type_ids = mm_token_type_ids.unsqueeze(0) |
| |
|
| | output = { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "token_type_ids": token_type_ids, |
| | "mm_token_type_ids": mm_token_type_ids, |
| | "pixel_values": pixel_values, |
| | "patch_attention_mask": patch_attention_mask, |
| | } |
| | if return_prompt: |
| | output["prompt"] = prompt |
| |
|
| | return output if return_dict else output |
| |
|
| | def __call__( |
| | self, |
| | images: Optional[Any] = None, |
| | text: Optional[Union[str, List[str]]] = None, |
| | videos: Optional[Any] = None, |
| | audio: Optional[Any] = None, |
| | **kwargs: Any, |
| | ) -> Any: |
| | """Run the processor and ensure multimodal token identifiers are present. |
| | |
| | Args: |
| | images: Optional image inputs. |
| | text: Optional textual inputs. |
| | videos: Optional video inputs. |
| | audio: Optional audio inputs. |
| | **kwargs: Additional keyword arguments forwarded to the base processor. |
| | |
| | Returns: |
| | Any: Processor outputs augmented with token type ids when needed. |
| | """ |
| | kwargs.pop("return_mm_token_type_ids", None) |
| | image_processor = getattr(self, "image_processor", None) |
| | img_tiling = kwargs.get("img_tiling", True) |
| | tiling_method = kwargs.get( |
| | "tiling_method", |
| | getattr(image_processor, "tiling_method", "llava-uhd"), |
| | ) |
| | tiling_size = kwargs.get("tiling_size") |
| | if tiling_size is None and image_processor is not None: |
| | size = getattr(image_processor, "size", None) |
| | if isinstance(size, dict) and "shortest_edge" in size: |
| | tiling_size = int(size["shortest_edge"]) |
| | elif isinstance(size, int): |
| | tiling_size = size |
| | tiling_size = tiling_size or 512 |
| | grid_pinpoints = kwargs.get("grid_pinpoints") |
| | if grid_pinpoints is None: |
| | grid_pinpoints = [ |
| | (2, 2), |
| | (1, 2), |
| | (2, 1), |
| | (1, 3), |
| | (3, 1), |
| | (1, 4), |
| | (4, 1), |
| | ] |
| | max_tiles_num = kwargs.get( |
| | "max_tiles_num", getattr(image_processor, "max_tiles_num", 4) |
| | ) |
| | patch_size = kwargs.get( |
| | "patch_size", getattr(image_processor, "patch_size", 14) |
| | ) |
| |
|
| | |
| | |
| | if isinstance(text, str) and ( |
| | images is not None or videos is not None |
| | ): |
| | if ( |
| | REKA_IMG_TOKEN not in text |
| | and IMAGE_START not in text |
| | and VIDEO_START not in text |
| | ): |
| | text = self._prepend_mm_placeholders( |
| | text=text, images=images, videos=videos, **kwargs |
| | ) |
| | else: |
| | text = self._expand_image_placeholders( |
| | text=text, images=images, **kwargs |
| | ) |
| |
|
| | |
| | if images is not None and img_tiling: |
| | images = self._tile_images( |
| | images=images, |
| | tiling_method=tiling_method, |
| | tiling_size=tiling_size, |
| | grid_pinpoints=grid_pinpoints, |
| | max_tiles_num=max_tiles_num, |
| | patch_size=patch_size, |
| | ) |
| |
|
| | |
| | if isinstance(text, str) and isinstance(images, list): |
| | text = [text] |
| | images = [images] |
| | outputs = super().__call__( |
| | images=images, text=text, videos=videos, audio=audio, **kwargs |
| | ) |
| | if "input_ids" in outputs and "token_type_ids" not in outputs: |
| | token_type_ids, mm_token_type_ids = self._build_mm_token_type_ids( |
| | outputs["input_ids"] |
| | ) |
| | outputs["token_type_ids"] = token_type_ids |
| | outputs["mm_token_type_ids"] = mm_token_type_ids |
| | return outputs |
| |
|
| | def _expand_image_placeholders( |
| | self, |
| | text: str, |
| | images: Optional[Any], |
| | **kwargs: Any, |
| | ) -> str: |
| | if images is None or IMAGE_START not in text or IMAGE_END not in text: |
| | return text |
| | image_list = ( |
| | list(images) if isinstance(images, (list, tuple)) else [images] |
| | ) |
| | image_processor = getattr(self, "image_processor", None) |
| | img_tiling = kwargs.get("img_tiling", True) |
| | tiling_method = kwargs.get( |
| | "tiling_method", |
| | getattr(image_processor, "tiling_method", "llava-uhd"), |
| | ) |
| | tiling_size = kwargs.get("tiling_size") |
| | if tiling_size is None and image_processor is not None: |
| | size = getattr(image_processor, "size", None) |
| | if isinstance(size, dict) and "shortest_edge" in size: |
| | tiling_size = int(size["shortest_edge"]) |
| | elif isinstance(size, int): |
| | tiling_size = size |
| | tiling_size = tiling_size or 512 |
| | grid_pinpoints = kwargs.get("grid_pinpoints") |
| | if grid_pinpoints is None: |
| | grid_pinpoints = [ |
| | (2, 2), |
| | (1, 2), |
| | (2, 1), |
| | (1, 3), |
| | (3, 1), |
| | (1, 4), |
| | (4, 1), |
| | ] |
| | max_tiles_num = kwargs.get( |
| | "max_tiles_num", getattr(image_processor, "max_tiles_num", 4) |
| | ) |
| | patch_size = kwargs.get( |
| | "patch_size", getattr(image_processor, "patch_size", 14) |
| | ) |
| |
|
| | expected_tokens = [] |
| | for image in image_list: |
| | width = height = 0 |
| | if hasattr(image, "size"): |
| | width, height = image.size |
| | elif isinstance(image, (list, tuple)) and len(image) >= 2: |
| | height, width = int(image[0]), int(image[1]) |
| | if img_tiling and width > 0 and height > 0: |
| | if str(tiling_method).lower() == "llava-next": |
| | tiles = estimate_num_tiles_llava_next( |
| | (width, height), |
| | size=tiling_size, |
| | grid_pinpoints=grid_pinpoints, |
| | ) |
| | else: |
| | tiles = estimate_num_tiles_llava_uhd( |
| | (width, height), |
| | max_tiles_num=max_tiles_num, |
| | scale_resolution=tiling_size, |
| | patch_size=patch_size, |
| | never_split=False, |
| | ) |
| | else: |
| | tiles = 1 |
| | expected_tokens.append(self.num_img_tokens * tiles) |
| |
|
| | parts = [] |
| | remaining = text |
| | for tokens in expected_tokens: |
| | start = remaining.find(IMAGE_START) |
| | end = remaining.find(IMAGE_END, start + len(IMAGE_START)) |
| | if start == -1 or end == -1: |
| | return text |
| | parts.append(remaining[:start]) |
| | parts.append(IMAGE_START + (REKA_IMG_TOKEN * tokens) + IMAGE_END) |
| | remaining = remaining[end + len(IMAGE_END) :] |
| | parts.append(remaining) |
| | new_text = "".join(parts) |
| | return new_text |
| |
|
| | def _tile_images( |
| | self, |
| | images: Any, |
| | tiling_method: str, |
| | tiling_size: int, |
| | grid_pinpoints: List[Tuple[int, int]], |
| | max_tiles_num: int, |
| | patch_size: int, |
| | ) -> Any: |
| | |
| | image_list = ( |
| | list(images) if isinstance(images, (list, tuple)) else [images] |
| | ) |
| | tiled_images: List[Any] = [] |
| | for image in image_list: |
| | if image is None: |
| | continue |
| | if isinstance(image, torch.Tensor): |
| | tiled_images.append(image) |
| | continue |
| | if isinstance(image, np.ndarray): |
| | image = Image.fromarray(image) |
| | if isinstance(image, Image.Image): |
| | |
| | if str(tiling_method).lower() == "llava-next": |
| | tiles = process_anyres_image( |
| | image, size=tiling_size, grid_pinpoints=grid_pinpoints |
| | ) |
| | else: |
| | tiles = process_anyres_image_uhd( |
| | image, |
| | max_tiles_num=max_tiles_num, |
| | scale_resolution=tiling_size, |
| | patch_size=patch_size, |
| | never_split=False, |
| | ) |
| | tiled_images.extend(tiles) |
| | continue |
| | tiled_images.append(image) |
| | return ( |
| | tiled_images |
| | if isinstance(images, (list, tuple)) |
| | else tiled_images[0] |
| | ) |
| |
|
| | def _prepend_mm_placeholders( |
| | self, |
| | text: str, |
| | images: Optional[Any], |
| | videos: Optional[Any], |
| | **kwargs: Any, |
| | ) -> str: |
| | """Prepend placeholder tokens when media is provided without markers.""" |
| | |
| | image_list = ( |
| | list(images) |
| | if isinstance(images, (list, tuple)) |
| | else ([images] if images is not None else []) |
| | ) |
| | num_images = len(image_list) |
| | num_videos = self._count_media_items(videos) |
| | if num_images == 0 and num_videos == 0: |
| | return text |
| |
|
| | image_processor = getattr(self, "image_processor", None) |
| | img_tiling = kwargs.get("img_tiling", True) |
| | tiling_method = kwargs.get( |
| | "tiling_method", |
| | getattr(image_processor, "tiling_method", "llava-uhd"), |
| | ) |
| | tiling_size = kwargs.get("tiling_size") |
| | if tiling_size is None and image_processor is not None: |
| | size = getattr(image_processor, "size", None) |
| | if isinstance(size, dict) and "shortest_edge" in size: |
| | tiling_size = int(size["shortest_edge"]) |
| | elif isinstance(size, int): |
| | tiling_size = size |
| | tiling_size = tiling_size or 512 |
| | grid_pinpoints = kwargs.get("grid_pinpoints") |
| | if grid_pinpoints is None: |
| | grid_pinpoints = [ |
| | (2, 2), |
| | (1, 2), |
| | (2, 1), |
| | (1, 3), |
| | (3, 1), |
| | (1, 4), |
| | (4, 1), |
| | ] |
| | max_tiles_num = kwargs.get( |
| | "max_tiles_num", getattr(image_processor, "max_tiles_num", 4) |
| | ) |
| | patch_size = kwargs.get( |
| | "patch_size", getattr(image_processor, "patch_size", 14) |
| | ) |
| |
|
| | def _get_image_size(image: Any) -> Tuple[int, int]: |
| | if hasattr(image, "size"): |
| | size = image.size |
| | if isinstance(size, (list, tuple)) and len(size) >= 2: |
| | return int(size[0]), int(size[1]) |
| | if hasattr(image, "shape"): |
| | shape = image.shape |
| | if isinstance(shape, (list, tuple)) and len(shape) >= 2: |
| | return int(shape[1]), int(shape[0]) |
| | if isinstance(image, (list, tuple)) and len(image) >= 2: |
| | return int(image[1]), int(image[0]) |
| | return 0, 0 |
| |
|
| | placeholder = "" |
| | for image in image_list: |
| | tiles = 1 |
| | if img_tiling: |
| | width, height = _get_image_size(image) |
| | if width > 0 and height > 0: |
| | if str(tiling_method).lower() == "llava-next": |
| | tiles = estimate_num_tiles_llava_next( |
| | (width, height), |
| | size=tiling_size, |
| | grid_pinpoints=grid_pinpoints, |
| | ) |
| | else: |
| | tiles = estimate_num_tiles_llava_uhd( |
| | (width, height), |
| | max_tiles_num=max_tiles_num, |
| | scale_resolution=tiling_size, |
| | patch_size=patch_size, |
| | never_split=False, |
| | ) |
| | placeholder += IMAGE_START |
| | placeholder += REKA_IMG_TOKEN * (self.num_img_tokens * tiles) |
| | placeholder += IMAGE_END |
| | for _ in range(num_videos): |
| | placeholder += VIDEO_START |
| | placeholder += REKA_IMG_TOKEN * ( |
| | self.num_img_tokens * self.num_video_frames |
| | ) |
| | placeholder += VIDEO_END |
| | return f"{placeholder}{text}" |
| |
|
| | @staticmethod |
| | def _count_media_items(payload: Optional[Any]) -> int: |
| | """Best-effort count of media items for placeholder insertion.""" |
| | if payload is None: |
| | return 0 |
| | if isinstance(payload, (list, tuple)): |
| | return len(payload) |
| | return 1 |
| |
|
| | def _build_mm_token_type_ids(self, input_ids: Any) -> Tuple[Any, Any]: |
| | """Compute token_type_ids that mark multimodal placeholders. |
| | |
| | Args: |
| | input_ids: Input IDs or sequences containing tokenizer ids. |
| | |
| | Returns: |
| | Tuple[Any, Any]: Regular and multimodal token type ids detected from placeholders. |
| | """ |
| | if self.tokenizer is None: |
| | return input_ids, input_ids |
| | img_token_id = self.image_token_id |
| |
|
| | if isinstance(input_ids, torch.Tensor): |
| | mm_token_type_ids = (input_ids == img_token_id).long() |
| | token_type_ids = mm_token_type_ids.clone() |
| | return token_type_ids, mm_token_type_ids |
| |
|
| | if isinstance(input_ids, (list, tuple)): |
| | if input_ids and isinstance(input_ids[0], (list, tuple)): |
| | mm_token_type_ids = [ |
| | [1 if token_id == img_token_id else 0 for token_id in seq] |
| | for seq in input_ids |
| | ] |
| | else: |
| | mm_token_type_ids = [ |
| | 1 if token_id == img_token_id else 0 |
| | for token_id in input_ids |
| | ] |
| | token_type_ids = list(mm_token_type_ids) |
| | return token_type_ids, mm_token_type_ids |
| |
|
| | if hasattr(input_ids, "tolist"): |
| | ids = input_ids.tolist() |
| | token_type_ids, mm_token_type_ids = self._build_mm_token_type_ids( |
| | ids |
| | ) |
| | return token_type_ids, mm_token_type_ids |
| |
|
| | return input_ids, input_ids |
| |
|
| | def _get_num_multimodal_tokens( |
| | self, |
| | image_sizes: Optional[List[List[int]]] = None, |
| | video_sizes: Optional[List[List[int]]] = None, |
| | **kwargs: Any, |
| | ) -> MultiModalData: |
| | """Estimate the count of multimodal tokens from provided media sizes. |
| | |
| | Args: |
| | image_sizes: Per-image sizes as (height, width) tuples. |
| | video_sizes: Per-video sizes as (num_frames, height, width) tuples. |
| | **kwargs: Ignored compatibility arguments accepted by parent helpers. |
| | |
| | Returns: |
| | MultiModalData: Token counts for the vision modalities. |
| | """ |
| | vision_data: Dict[str, List[int]] = {} |
| | if image_sizes is not None: |
| | image_processor = getattr(self, "image_processor", None) |
| | img_tiling = kwargs.get("img_tiling", True) |
| | tiling_method = kwargs.get( |
| | "tiling_method", |
| | getattr(image_processor, "tiling_method", "llava-uhd"), |
| | ) |
| | tiling_size = kwargs.get("tiling_size") |
| | if tiling_size is None and image_processor is not None: |
| | size = getattr(image_processor, "size", None) |
| | if isinstance(size, dict) and "shortest_edge" in size: |
| | tiling_size = int(size["shortest_edge"]) |
| | elif isinstance(size, int): |
| | tiling_size = size |
| | tiling_size = tiling_size or 512 |
| | grid_pinpoints = kwargs.get("grid_pinpoints") |
| | if grid_pinpoints is None: |
| | grid_pinpoints = [ |
| | (2, 2), |
| | (1, 2), |
| | (2, 1), |
| | (1, 3), |
| | (3, 1), |
| | (1, 4), |
| | (4, 1), |
| | ] |
| | max_tiles_num = kwargs.get( |
| | "max_tiles_num", getattr(image_processor, "max_tiles_num", 4) |
| | ) |
| | patch_size = kwargs.get( |
| | "patch_size", getattr(image_processor, "patch_size", 14) |
| | ) |
| |
|
| | |
| | num_image_tokens: List[int] = [] |
| | num_image_patches: List[int] = [] |
| | for image_size in image_sizes: |
| | height = width = 0 |
| | if image_size and len(image_size) >= 2: |
| | height, width = int(image_size[0]), int(image_size[1]) |
| | tiles = 1 |
| | if img_tiling and width > 0 and height > 0: |
| | if str(tiling_method).lower() == "llava-next": |
| | tiles = estimate_num_tiles_llava_next( |
| | (width, height), |
| | size=tiling_size, |
| | grid_pinpoints=grid_pinpoints, |
| | ) |
| | else: |
| | tiles = estimate_num_tiles_llava_uhd( |
| | (width, height), |
| | max_tiles_num=max_tiles_num, |
| | scale_resolution=tiling_size, |
| | patch_size=patch_size, |
| | never_split=False, |
| | ) |
| | num_image_tokens.append(self.num_img_tokens * tiles) |
| | num_image_patches.append(tiles) |
| |
|
| | vision_data["num_image_tokens"] = num_image_tokens |
| | vision_data["num_image_patches"] = num_image_patches |
| | else: |
| | vision_data["num_image_tokens"] = [] |
| | vision_data["num_image_patches"] = [] |
| | if video_sizes is not None: |
| | video_tokens: List[int] = [] |
| | for video_size in video_sizes: |
| | num_frames = video_size[0] if video_size else 0 |
| | num_frames = min( |
| | num_frames or self.num_video_frames, self.num_video_frames |
| | ) |
| | video_tokens.append(self.num_img_tokens * num_frames) |
| | vision_data["num_video_tokens"] = video_tokens |
| |
|
| | return MultiModalData(**vision_data) |
| |
|
| |
|
| | Yasa2Processor.register_for_auto_class() |
| |
|