LLaVA-UHD-v3 / processing_llava_uhd_v3.py
Sishxo's picture
Upload processing_llava_uhd_v3.py
73c7365 verified
from PIL import Image
import torch
import numpy as np
from math import e
from param import output
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
class LlavaUHDV3Processor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
self.image_token = "<image>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
if getattr(tokenizer, "image_token_id", None):
self.image_token_id = tokenizer.image_token_id
else:
tokenizer.add_tokens(["<image>"], special_tokens=True)
self.image_token_id = -200
if chat_template is None and hasattr(tokenizer, "chat_template"):
chat_template = tokenizer.chat_template
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(self, images=None, text=None, max_resolution=None, upscale_rate=1.4, **kwargs):
if "padding" not in kwargs:
kwargs["padding"] = True
if "truncation" not in kwargs:
kwargs["truncation"] = True
image_inputs = {}
pixel_values, grid_hws = [], []
if images is not None:
for per_images in images if isinstance(images, list) else [images]:
if per_images is None:
dummy_image = Image.fromarray(np.random.randint(0, 256, (400, 400, 3), dtype=np.uint8))
image_info = self.image_processor(images=dummy_image)
else:
image_info = self.image_processor(images=per_images, max_resolution=max_resolution, upscale_rate=upscale_rate)
pixel_values.append(image_info.pixel_values)
grid_hws.append(image_info.grid_hws)
pixel_values = torch.concat(pixel_values, dim=0)
grid_hws = torch.concat(grid_hws, dim=0)
image_inputs.update({'pixel_values': pixel_values, 'grid_hws': grid_hws})
if not isinstance(text, list):
text = [text]
text = text.copy()
return_tensors = kwargs.pop("return_tensors", None)
text_inputs = self.tokenizer(text, **kwargs, return_tensors=return_tensors)
img_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
for ids in text_inputs["input_ids"]:
for i, token_id in enumerate(ids):
if token_id == img_token_id:
ids[i] = self.image_token_id
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
__all__ = ["LlavaUHDV3Processor"]