| from PIL import Image | |
| import torch | |
| import numpy as np | |
| from math import e | |
| from param import output | |
| from transformers.feature_extraction_utils import BatchFeature | |
| from transformers.processing_utils import ProcessorMixin | |
| class LlavaUHDV3Processor(ProcessorMixin): | |
| attributes = ["image_processor", "tokenizer"] | |
| image_processor_class = "AutoImageProcessor" | |
| tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") | |
| def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): | |
| self.image_token = "<image>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token | |
| if getattr(tokenizer, "image_token_id", None): | |
| self.image_token_id = tokenizer.image_token_id | |
| else: | |
| tokenizer.add_tokens(["<image>"], special_tokens=True) | |
| self.image_token_id = -200 | |
| if chat_template is None and hasattr(tokenizer, "chat_template"): | |
| chat_template = tokenizer.chat_template | |
| super().__init__(image_processor, tokenizer, chat_template=chat_template) | |
| def __call__(self, images=None, text=None, max_resolution=None, upscale_rate=1.4, **kwargs): | |
| if "padding" not in kwargs: | |
| kwargs["padding"] = True | |
| if "truncation" not in kwargs: | |
| kwargs["truncation"] = True | |
| image_inputs = {} | |
| pixel_values, grid_hws = [], [] | |
| if images is not None: | |
| for per_images in images if isinstance(images, list) else [images]: | |
| if per_images is None: | |
| dummy_image = Image.fromarray(np.random.randint(0, 256, (400, 400, 3), dtype=np.uint8)) | |
| image_info = self.image_processor(images=dummy_image) | |
| else: | |
| image_info = self.image_processor(images=per_images, max_resolution=max_resolution, upscale_rate=upscale_rate) | |
| pixel_values.append(image_info.pixel_values) | |
| grid_hws.append(image_info.grid_hws) | |
| pixel_values = torch.concat(pixel_values, dim=0) | |
| grid_hws = torch.concat(grid_hws, dim=0) | |
| image_inputs.update({'pixel_values': pixel_values, 'grid_hws': grid_hws}) | |
| if not isinstance(text, list): | |
| text = [text] | |
| text = text.copy() | |
| return_tensors = kwargs.pop("return_tensors", None) | |
| text_inputs = self.tokenizer(text, **kwargs, return_tensors=return_tensors) | |
| img_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) | |
| for ids in text_inputs["input_ids"]: | |
| for i, token_id in enumerate(ids): | |
| if token_id == img_token_id: | |
| ids[i] = self.image_token_id | |
| return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) | |
| __all__ = ["LlavaUHDV3Processor"] |