File size: 3,889 Bytes
66b6789
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import re
import torch
from transformers import ProcessorMixin, BatchFeature, CLIPImageProcessorFast
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_utils import ImageInput
from typing import Any, Dict, List, Optional, Union
from PIL import Image

from .llava_qwen import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN

# Adapted from transformers.models.llava_next.image_processing_llava_next.expand_to_square
def expand_to_square(image: torch.Tensor, background_color=0) -> torch.Tensor:
    """
    Expands an image to a square by adding a background color.
    """
    c, height, width = image.shape
    if width == height:
        return image
    elif width > height:
        result = torch.ones((c, width, width), dtype=image.dtype) * background_color
        result[:, (width - height) // 2 : (width - height) // 2 + height, :] = image
        return result
    else:
        result = torch.ones((c, height, height), dtype=image.dtype) * background_color
        result[:, :, (height - width) // 2 : (height - width) // 2 + width] = image
        return result


class FastVLMImageProcessor(CLIPImageProcessorFast):
    def _preprocess(self, images, **kwargs):
        image_sizes = [image.shape[-2:][::-1] for image in images]
        images = [expand_to_square(image) for image in images]
        images = super()._preprocess(images, **kwargs)
        pixel_values = torch.stack(images.pixel_values, dim=0)
        return BatchFeature(data={"pixel_values": pixel_values, "image_sizes": image_sizes})

class FastVLMProcessor(ProcessorMixin):
    attributes = ["tokenizer", "image_processor"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(
        self,
        tokenizer,
        image_processor,
        chat_template=None,
        **kwargs
    ):
        super().__init__(tokenizer, image_processor, chat_template=chat_template, **kwargs)

    def __call__(
        self,
        images: ImageInput = None,
        text: Optional[Union[str, List[str]]] = None,
        return_tensors: Optional[str] = "pt",
        **kwargs,
    ) -> BatchFeature:
        if isinstance(text, str):
            text = [text]
        elif not isinstance(text, list) and not isinstance(text[0], str):
            raise TypeError("Invalid input text. Please provide a string, or a list of strings")

        image_inputs = {}
        if images is not None:
            image_inputs = self.image_processor(images=images)
            
            image_token = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=torch.int64)
            input_ids = torch.tensor([], dtype=torch.int64)
            attention_mask = torch.tensor([], dtype=torch.int64)
            for prompt in text:
                image_indexes = [m.start() for m in re.finditer(DEFAULT_IMAGE_TOKEN, prompt)]
                if len(image_indexes) > 1:
                    raise ValueError(
                        f"Expected up to 1 image tokens per prompt, got {len(image_indexes)} instead."
                    )

                # DEFAULT_IMAGE_TOKEN is -200, not in the vocab (so we can't tokenize the full string)
                pre, _, post = prompt.partition(DEFAULT_IMAGE_TOKEN)
                pre_ids  = self.tokenizer(pre, return_tensors="pt", add_special_tokens=False).input_ids
                post_ids = self.tokenizer(post, return_tensors="pt", add_special_tokens=False).input_ids

                sample_ids = torch.cat([pre_ids, image_token, post_ids], dim=1).to(dtype=torch.int64)
                sample_mask = torch.ones_like(sample_ids)

                input_ids = torch.cat([input_ids, sample_ids], dim=0)
                attention_mask = torch.cat([attention_mask, sample_mask], dim=0)

        return BatchFeature(data={"input_ids": input_ids, "attention_mask": attention_mask, **image_inputs}, tensor_type=return_tensors)