# pylint: disable=no-member # avoid weird pylint warnings from SentencePieceProcessor """Text and Image processor for CASA models using Qwen2.5_VL image encoder""" from math import ceil from typing import TYPE_CHECKING, Any, Literal, TypedDict, cast, overload from typing import cast as type_cast import torch import torchvision.transforms.v2 as T from einops import rearrange from PIL import Image from torchvision.transforms import InterpolationMode from torchvision.transforms.functional import to_tensor as pil_to_tensor from torchvision.transforms.v2 import functional as F from transformers.image_processing_utils import BaseImageProcessor from transformers.processing_utils import ProcessorMixin if TYPE_CHECKING: from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer from transformers.tokenization_utils_fast import PreTrainedTokenizerFast ImageMessage = TypedDict( "ImageMessage", { "type": Literal["image"], "image": str | Image.Image | None, }, ) TextMessage = TypedDict( "TextMessage", { "type": Literal["text"], "text": str, }, ) MessageContent = list[ImageMessage | TextMessage] Message = TypedDict( "Message", { "role": Literal["system", "user", "assistant"], "content": MessageContent, }, ) ProcessorInput = list[list[Message]] | list[Message] __INTERP_NAME_TO_MODE__ = { "nearest": InterpolationMode.NEAREST, "bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC, "lanczos": InterpolationMode.LANCZOS, } __INTERP_INT_TO_MODE__ = { 0: InterpolationMode.NEAREST, 2: InterpolationMode.BILINEAR, 3: InterpolationMode.BICUBIC, 4: InterpolationMode.BOX, 5: InterpolationMode.HAMMING, 1: InterpolationMode.LANCZOS, } @overload def universal_resize( img: Image.Image, size: tuple[int, int], interpolation: str | InterpolationMode | int = "bilinear", antialias: bool = True, ) -> Image.Image: ... @overload def universal_resize( img: torch.Tensor, size: tuple[int, int], interpolation: str | InterpolationMode | int = "bilinear", antialias: bool = True, ) -> torch.Tensor: ... def universal_resize( img: Image.Image | torch.Tensor, size: tuple[int, int], interpolation: str | InterpolationMode | int = "bilinear", antialias: bool = True, ) -> Image.Image | torch.Tensor: """Resize that works for PIL.Image, CHW tensor, or BCHW tensor""" if isinstance(interpolation, str): interpolation = __INTERP_NAME_TO_MODE__[interpolation] elif isinstance(interpolation, int): interpolation = __INTERP_INT_TO_MODE__[interpolation] return F.resize( img, size, interpolation=type_cast(InterpolationMode, interpolation), antialias=antialias ) @overload def convert_to_rgb(img: Image.Image) -> Image.Image: ... @overload def convert_to_rgb(img: torch.Tensor) -> torch.Tensor: ... def convert_to_rgb(img: Image.Image | torch.Tensor) -> Image.Image | torch.Tensor: """Convert any image to RGB in a way that does not throw PIL warning""" if isinstance(img, torch.Tensor): return img if img.mode == "RGB": # no changes return img if img.mode == "P": # palette images need to be converted to RGBA first return img.convert("RGBA").convert("RGB") return img.convert("RGB") class QwenImageProcessor(BaseImageProcessor): """Resizing for the Qwen2.5VL encoder. Note that the normalization is handled in the image_encoder in the model forward""" def __init__( self, img_size: int = 448, interpolation: Literal["bicubic", "bilinear", "nearest", "nearest_exact"] = "bicubic", max_ratio: int = 10, round_to_patch_size: int = 56, use_fast: bool = True, **kwargs: Any, ) -> None: # this will also be used in V2llms to determine whether to remove # the temporal conv self._num_target_channels = 588 self._merge_size = 2 self._patch_size = 14 super().__init__( use_fast=use_fast, do_normalize=False, **kwargs, ) self.img_size = img_size self.interpolation = interpolation self.max_ratio = max_ratio self.round_to_patch_size = round_to_patch_size def resize_transform( self, img: Image.Image | torch.Tensor, img_size: int | None = None ) -> Image.Image | torch.Tensor: if img_size is None: img_size = self.img_size max_area = img_size**2 if isinstance(img, Image.Image): img = convert_to_rgb(img) w_og, h_og = img.size else: h_og, w_og = img.shape[-2:] w, h = w_og, h_og # Qwen requires max ratio of 10 between max and min sizes if self.max_ratio > 0: w, h = max(w, h // self.max_ratio), max(h, w // self.max_ratio) # resize to max area current_area = w * h if current_area > max_area: scale = (max_area / current_area) ** 0.5 w, h = int(w * scale), int(h * scale) # resize to patch size if self.round_to_patch_size > 0: w = ceil(w / self.round_to_patch_size) * self.round_to_patch_size h = ceil((h / self.round_to_patch_size)) * self.round_to_patch_size # resize if w != w_og or h != h_og: img = universal_resize(img, (h, w), self.interpolation) if isinstance(img, torch.Tensor): img = T.ToDtype(torch.float32, scale=True)(T.ToImage()(img)) return img def __process_one__( self, video_or_img: Image.Image | torch.Tensor, img_size: int | None = None ) -> torch.Tensor: """Same operation as __process_one_with_processor__ but without going through numpy""" video_or_img = self.resize_transform(video_or_img, img_size) if isinstance(video_or_img, Image.Image): video_or_img = pil_to_tensor(video_or_img) assert isinstance(video_or_img, torch.Tensor) if video_or_img.ndim == 3: video_or_img = video_or_img[None] assert video_or_img.ndim == 4 and video_or_img.shape[1] == 3, ( f"Invalid shape {video_or_img.shape}." ) t, c, h, w = video_or_img.shape p = self._patch_size m = self._merge_size # Convert to RGB if c == 1: video_or_img = video_or_img.expand((-1, 3, -1, -1)) if c == 4: video_or_img = video_or_img[:, :3] c = video_or_img.shape[1] assert c == 3, "Expecting RGB image in QwenNormalize" # Reshape to t h w c' format h, w = video_or_img.shape[2] // p, video_or_img.shape[3] // p rearrange_dict = dict(p1=p, p2=p, m1=m, m2=m) video_or_img = rearrange( video_or_img, "t c (h m1 p1) (w m2 p2) -> (t h w m1 m2) (c p1 p2)", **rearrange_dict, ) assert video_or_img.shape[-1] == self._num_target_channels, ( f"{video_or_img.shape[-1]} != {self._num_target_channels}" ) video_or_img = video_or_img.view((-1, h, w, self._num_target_channels)) return video_or_img @overload def process_images( self, image: Image.Image | torch.Tensor, img_size: int | None = None ) -> torch.Tensor: ... @overload def process_images( self, image: list[Image.Image] | list[torch.Tensor], img_size: int | None = None ) -> list[torch.Tensor]: ... def process_images( self, image: Image.Image | torch.Tensor | list[Image.Image] | list[torch.Tensor], img_size: int | None = None, ) -> torch.Tensor | list[torch.Tensor]: if isinstance(image, list): return [self.__process_one__(_x, img_size) for _x in image] return self.__process_one__(image, img_size) class ProcessorOutput(dict): input_ids: torch.Tensor attention_mask: torch.Tensor image_embeds_insertion_points: list[torch.Tensor] | None pixel_values: torch.Tensor | list[torch.Tensor] | None def to( self, device: torch.device | str, dtype: torch.dtype = torch.bfloat16 ) -> "ProcessorOutput": return ProcessorOutput( { "input_ids": self["input_ids"].to(device), "attention_mask": self["attention_mask"].to(device), "image_embeds_insertion_points": self["image_embeds_insertion_points"], "pixel_values": ( self["pixel_values"].to(dtype).to(device) if isinstance(self["pixel_values"], torch.Tensor) else [x.to(dtype).to(device) for x in self["pixel_values"]] if self["pixel_values"] is not None else None ), } ) class BaseProcessor(ProcessorMixin): def __init__( self, tokenizer: "PreTrainedTokenizerFast | Qwen2Tokenizer", pre_image_tokens: tuple[int, ...] = (), post_image_tokens: tuple[int, ...] = (), system_start_tokens: tuple[int, ...] = (), system_end_tokens: tuple[int, ...] = (), user_start_tokens: tuple[int, ...] = (), user_end_tokens: tuple[int, ...] = (), asst_start_tokens: tuple[int, ...] = (), asst_end_tokens: tuple[int, ...] = (), allow_system_prompt: bool = True, pad_token: int = 0, bos_token: int | None = None, ) -> None: self.pre_image_tokens = list(pre_image_tokens) self.post_image_tokens = list(post_image_tokens) self.system_start_tokens = list(system_start_tokens) self.system_end_tokens = list(system_end_tokens) self.user_start_tokens = list(user_start_tokens) self.user_end_tokens = list(user_end_tokens) self.asst_start_tokens = list(asst_start_tokens) self.asst_end_tokens = list(asst_end_tokens) self._allow_system_prompt = allow_system_prompt self.tokenizer = tokenizer self._image_processor = None self._pad_token = pad_token self.bos_token = bos_token @property def image_processor(self) -> QwenImageProcessor: assert self._image_processor is not None return self._image_processor def _process_content( self, message_content: MessageContent, role: Literal["system", "user", "assistant"], tokenized_messages: list[torch.Tensor], insertion_points: list[int], image_list: list[torch.Tensor | None], token_count: int, img_size: int | None = None, **kwargs: Any, ) -> int: mapping = { "user": (self.user_start_tokens, self.user_end_tokens), "assistant": (self.asst_start_tokens, self.asst_end_tokens), "system": (self.system_start_tokens, self.system_end_tokens), } if role.lower() not in mapping: raise ValueError(f"Unknown role '{role}' encountered in messages.") start_tokens, end_tokens = mapping[role.lower()] # 1) Add the start tokens if start_tokens: tokenized_messages.append(torch.Tensor(start_tokens).flatten().to(torch.long)) token_count += len(start_tokens) # 2) Process the message content one by one (potentially interleaved image and text) for part in message_content: elt_type = part["type"] if elt_type == "image": part = cast(ImageMessage, part) self._process_image_message( part, tokenized_messages, image_list, img_size=img_size, ) token_count += len(self.pre_image_tokens) insertion_points.append(token_count) token_count += len(self.post_image_tokens) else: part = cast(TextMessage, part) self._process_text_message( part["text"], role=role, token_list=tokenized_messages, **kwargs, ) token_count += tokenized_messages[-1].size(0) # 3) Add the end tokens if end_tokens: tokenized_messages.append(torch.Tensor(end_tokens).flatten().to(torch.long)) token_count += len(end_tokens) return token_count def _process_text_message( self, message: str, role: Literal["system", "user", "assistant"], token_list: list[torch.Tensor], **kwargs: Any, ) -> None: if role.lower() == "system" and not self._allow_system_prompt: raise ValueError("System prompts are not allowed in this tokenizer configuration.") tokens = self.tokenizer.encode( message, add_special_tokens=False, return_tensors="pt", **kwargs ) tokens = cast(torch.Tensor, tokens) token_list.append(tokens.flatten().to(torch.long)) def _process_image_message( self, message: ImageMessage, token_list: list[torch.Tensor], image_list: list[torch.Tensor | None], img_size: int | None = None, ) -> None: img = message["image"] if img is None: image_list.append(None) else: image_list.append( self.image_processor.process_images( self._load_image(img), img_size=img_size ).squeeze(0) ) if self.pre_image_tokens: token_list.append(torch.Tensor(self.pre_image_tokens).flatten().to(torch.long)) if self.post_image_tokens: token_list.append(torch.Tensor(self.post_image_tokens).flatten().to(torch.long)) def _load_image(self, image_path_or_image: str | Image.Image) -> Image.Image: if isinstance(image_path_or_image, str): return Image.open(image_path_or_image).convert("RGB") return image_path_or_image def _maybe_pad(self, tokens: torch.Tensor, pad_len: int, pad_value: int) -> torch.Tensor: return torch.nn.functional.pad( tokens, (0, pad_len) if self.tokenizer.padding_side == "right" else (pad_len, 0), value=pad_value, ) def pad_tokenized_messages( self, tokenized_messages_batch: list[torch.Tensor], image_insertion_points_batch: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor] | None]: max_len = max(len(x) for x in tokenized_messages_batch) if image_insertion_points_batch is not None and self.tokenizer.padding_side == "left": image_insertion_points_batch = [ x + max_len - len(tokenized_messages_batch[idx]) for idx, x in enumerate(image_insertion_points_batch) ] input_ids = torch.stack( [ self._maybe_pad(s, max_len - s.size(0), self._pad_token) for s in tokenized_messages_batch ], dim=0, ) attention_mask = torch.stack( [ self._maybe_pad(torch.ones_like(s), max_len - s.size(0), 0) for s in tokenized_messages_batch ], dim=0, ) return input_ids, attention_mask, image_insertion_points_batch def tokenize_messages( self, messages: ProcessorInput, suppress_bos_token: bool = False, **kwargs: Any, ) -> ProcessorOutput | None: """Tokenize a batch of messages into token IDs suitable for Helium1 CASA model. Args: messages (list[list[dict[str, str]]] | list[dict[str, str]]): Batch of message lists (or single list of messages), where each message is a list of dictionaries with 'role' and 'content' keys. continue_final_message (bool, optional): If True, the final message in each list will not have an end token added. Defaults to False. suppress_bos_token (bool, optional): If True, the beginning-of-sequence token will not be added. Defaults to False. **kwargs: Additional keyword arguments passed to the underlying encode method. """ if not messages: return None if isinstance(messages[0], dict): messages = [messages] # type: ignore[assignment] messages = cast(list[list[Message]], messages) image_insertion_points_batch = [] tokenized_messages_batch = [] image_list: list[torch.Tensor | None] = [] for msgs in messages: # msgs.append({ # "role": "assistant", # "content": [{"type": "text", "text": ""}] # }) tokenized_messages = [] if not suppress_bos_token and self.bos_token is not None: tokenized_messages.append(torch.tensor([self.bos_token], dtype=torch.long)) insertion_points = [] token_count = 0 for msg in msgs: token_count = self._process_content( msg["content"], role=msg["role"], tokenized_messages=tokenized_messages, insertion_points=insertion_points, image_list=image_list, token_count=token_count, **kwargs, ) tokenized_messages_batch.append(torch.cat(tokenized_messages, dim=0).to(torch.long)) image_insertion_points_batch.append(torch.tensor(insertion_points, dtype=torch.long)) if msgs and self.asst_end_tokens and msgs[-1]["role"].lower() == "assistant": # Remove the assistant end tokens from the final message end_token_len = len(self.asst_end_tokens) tokenized_messages_batch[-1] = tokenized_messages_batch[-1][:-end_token_len] if msgs and self.asst_start_tokens and msgs[-1]["role"].lower() == "user": # Remove the assistant end tokens from the final message end_token_len = len(self.asst_end_tokens) tokenized_messages_batch[-1] = torch.cat( [ tokenized_messages_batch[-1], torch.Tensor(self.asst_start_tokens).to(torch.long), ] ) input_ids, attention_mask, image_embeds_insertion_points = self.pad_tokenized_messages( tokenized_messages_batch, image_insertion_points_batch ) if image_list: assert sum(img is None for img in image_list) % len(image_list) == 0, ( "Either all or no image must be None." ) pixel_values: None | torch.Tensor | list[torch.Tensor] if image_list[0] is None: pixel_values = None else: pixel_values = cast(list[torch.Tensor], image_list) return ProcessorOutput( input_ids=input_ids, image_embeds_insertion_points=image_embeds_insertion_points, attention_mask=attention_mask, pixel_values=pixel_values, )