File size: 4,214 Bytes
146a630
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch

from typing import List, Union
from PIL import Image

from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput

from .modeling_vora import VoRAForCausalLM


class VoRAProcessorKwargs(ProcessingKwargs, total=False):
    _defaults = {
        "text_kwargs": {
            "padding": False,
        },
        "images_kwargs": {},
    }


class VoRAProcesser(ProcessorMixin):
    attributes = ["image_processor", "tokenizer"]
    valid_kwargs = [
        "chat_template",
        "image_token",
    ]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(
        self,
        image_processor=None,
        tokenizer=None,
        chat_template=None,
        image_token="<image>",  # set the default and let users change if they have peculiar special tokens in rare cases
        image_token_index = -200,
        **kwargs,
    ):
        self.image_token = image_token
        self.image_token_index = image_token_index
        super().__init__(image_processor, tokenizer, chat_template=chat_template)

    def __call__(
        self,
        images: ImageInput = None,
        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
        **kwargs: Unpack[VoRAProcessorKwargs],
    ):
        if images is None and text is None:
            raise ValueError("You have to specify at least one of `images` or `text`.")

        images, text = _validate_images_text_input_order(images, text)
        output_kwargs = self._merge_kwargs(
            VoRAProcessorKwargs,
            tokenizer_init_kwargs=self.tokenizer.init_kwargs,
            **kwargs,
        )

        if images is not None:
            images = [[self.expand2square(image[0])] for image in images]
            image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
        else:
            image_inputs = {}
        
        if isinstance(text, str):
            text = [text]
        elif not isinstance(text, list) and not isinstance(text[0], str):
            raise ValueError("Invalid input text. Please provide a string, or a list of strings")
        
        input_ids = [self.tokenizer_vision_placeholder(t) for t in text]
        attention_mask = [
            [1] * len(input_ids[i]) for i in range(len(input_ids))
        ]
        text_inputs = dict(
            input_ids=torch.as_tensor(input_ids, dtype=torch.int64),
            attention_mask=torch.as_tensor(attention_mask, dtype=torch.int64),
        )
        image_inputs['frames'] = image_inputs.pop('pixel_values')
        image_inputs['n_frames'] = [len(_images) for _images in images]
        image_inputs['vision_placeholder_index'] = self.image_token_index
        return BatchFeature(data={**text_inputs, **image_inputs})

    def expand2square(self, pil_img: Image.Image):
        background_color = (0, 0, 0)
        width, height = pil_img.size
        if width == height:
            return pil_img
        elif width > height:
            result = Image.new(pil_img.mode, (width, width), background_color)
            result.paste(pil_img, (0, (width - height) // 2))
            return result
        else:
            result = Image.new(pil_img.mode, (height, height), background_color)
            result.paste(pil_img, ((height - width) // 2, 0))
            return result

    def tokenizer_vision_placeholder(self, prompt, add_bos=False):
        def join_lists(*lists, sep):
            result = []
            for i, lst in enumerate(lists):
                if i > 0 and sep:
                    result.extend([sep])
                result.extend(lst)
            return result

        prompt_chunks = [self.tokenizer.encode(
            chunk) for chunk in prompt.split(self.image_token)]
        input_ids = join_lists(*prompt_chunks, sep=self.image_token_index)
        if add_bos:
            input_ids = [self.tokenizer.bos_token_id] + input_ids

        return input_ids