from PIL import Image import torch import numpy as np from math import e from param import output from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin class LlavaUHDV3Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): self.image_token = "" if not hasattr(tokenizer, "image_token") else tokenizer.image_token if getattr(tokenizer, "image_token_id", None): self.image_token_id = tokenizer.image_token_id else: tokenizer.add_tokens([""], special_tokens=True) self.image_token_id = -200 if chat_template is None and hasattr(tokenizer, "chat_template"): chat_template = tokenizer.chat_template super().__init__(image_processor, tokenizer, chat_template=chat_template) def __call__(self, images=None, text=None, max_resolution=None, upscale_rate=1.4, **kwargs): if "padding" not in kwargs: kwargs["padding"] = True if "truncation" not in kwargs: kwargs["truncation"] = True image_inputs = {} pixel_values, grid_hws = [], [] if images is not None: for per_images in images if isinstance(images, list) else [images]: if per_images is None: dummy_image = Image.fromarray(np.random.randint(0, 256, (400, 400, 3), dtype=np.uint8)) image_info = self.image_processor(images=dummy_image) else: image_info = self.image_processor(images=per_images, max_resolution=max_resolution, upscale_rate=upscale_rate) pixel_values.append(image_info.pixel_values) grid_hws.append(image_info.grid_hws) pixel_values = torch.concat(pixel_values, dim=0) grid_hws = torch.concat(grid_hws, dim=0) image_inputs.update({'pixel_values': pixel_values, 'grid_hws': grid_hws}) if not isinstance(text, list): text = [text] text = text.copy() return_tensors = kwargs.pop("return_tensors", None) text_inputs = self.tokenizer(text, **kwargs, return_tensors=return_tensors) img_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) for ids in text_inputs["input_ids"]: for i, token_id in enumerate(ids): if token_id == img_token_id: ids[i] = self.image_token_id return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) __all__ = ["LlavaUHDV3Processor"]