File size: 2,824 Bytes
8f993ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73c7365
8f993ed
 
 
 
 
 
 
 
 
 
 
 
73c7365
8f993ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from PIL import Image
import torch
import numpy as np
from math import e
from param import output
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin

class LlavaUHDV3Processor(ProcessorMixin):
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")

    def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
        self.image_token = "<image>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
        if getattr(tokenizer, "image_token_id", None):
            self.image_token_id = tokenizer.image_token_id
        else:
            tokenizer.add_tokens(["<image>"], special_tokens=True)
            self.image_token_id = -200
            
        if chat_template is None and hasattr(tokenizer, "chat_template"):
            chat_template = tokenizer.chat_template
        super().__init__(image_processor, tokenizer, chat_template=chat_template)
    
    def __call__(self, images=None, text=None, max_resolution=None, upscale_rate=1.4, **kwargs):
        if "padding" not in kwargs:
            kwargs["padding"] = True
        if "truncation" not in kwargs:
            kwargs["truncation"] = True
        image_inputs = {}
        pixel_values, grid_hws = [], []
        if images is not None:
            for per_images in images if isinstance(images, list) else [images]:
                if per_images is None:
                    dummy_image = Image.fromarray(np.random.randint(0, 256, (400, 400, 3), dtype=np.uint8))
                    image_info = self.image_processor(images=dummy_image)
                else:
                    image_info = self.image_processor(images=per_images, max_resolution=max_resolution, upscale_rate=upscale_rate)
                pixel_values.append(image_info.pixel_values)
                grid_hws.append(image_info.grid_hws)
            pixel_values = torch.concat(pixel_values, dim=0)
            grid_hws = torch.concat(grid_hws, dim=0)
            image_inputs.update({'pixel_values': pixel_values, 'grid_hws': grid_hws})
        if not isinstance(text, list):
            text = [text]
        text = text.copy()
        return_tensors = kwargs.pop("return_tensors", None)
        text_inputs = self.tokenizer(text, **kwargs, return_tensors=return_tensors)
        img_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
        for ids in text_inputs["input_ids"]:
            for i, token_id in enumerate(ids):
                if token_id == img_token_id:
                    ids[i] = self.image_token_id
        return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)

__all__ = ["LlavaUHDV3Processor"]