File size: 4,352 Bytes
fd6509b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from __future__ import annotations

from typing import Any

import torch
from transformers import AutoConfig
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_processing_utils import BaseImageProcessor
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase


class CapriProcessor(ProcessorMixin):
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "SiglipImageProcessor"
    tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")

    def __init__(
        self,
        image_processor: BaseImageProcessor,
        tokenizer: PreTrainedTokenizerBase,
        prompt_prefix: str = "<image> Caption:",
        image_token: str = "<image>",
        pooled_embedding_dim: int = 768,
    ):
        self.prompt_prefix = prompt_prefix
        self.image_token = image_token
        self.pooled_embedding_dim = pooled_embedding_dim
        super().__init__(image_processor, tokenizer)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        kwargs.setdefault("use_fast", False)
        processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
        config = AutoConfig.from_pretrained(
            pretrained_model_name_or_path,
            trust_remote_code=kwargs.get("trust_remote_code", False),
        )
        processor.prompt_prefix = getattr(config, "prompt_prefix", processor.prompt_prefix)
        processor.image_token = getattr(config, "image_token", processor.image_token)
        processor.pooled_embedding_dim = getattr(config, "projector_in_dim", processor.pooled_embedding_dim)
        return processor

    def normalize_images(self, images) -> list[Any]:
        if isinstance(images, torch.Tensor):
            if images.ndim == 4:
                return [images[i] for i in range(images.shape[0])]
            return [images]
        if isinstance(images, (list, tuple)):
            return list(images)
        return [images]

    def normalize_pooled_embeddings(self, pooled_embeddings) -> torch.Tensor:
        pooled = torch.as_tensor(pooled_embeddings)
        if pooled.ndim == 1:
            pooled = pooled.unsqueeze(0)
        if pooled.ndim != 2:
            raise ValueError("`pooled_embeddings` must be a 1D embedding or a 2D batch of embeddings.")
        if pooled.shape[-1] != self.pooled_embedding_dim:
            raise ValueError(
                f"Expected pooled embedding dim {self.pooled_embedding_dim}, got {pooled.shape[-1]}."
            )
        return pooled

    def __call__(
        self,
        images=None,
        pooled_embeddings=None,
        text=None,
        return_tensors: str | None = "pt",
        padding: bool | str = True,
        truncation: bool = False,
        max_length: int | None = None,
        **kwargs: Any,
    ) -> BatchFeature:
        if images is None and pooled_embeddings is None and text is None:
            raise ValueError("Provide `images`, `pooled_embeddings`, or `text`.")

        batch = {}
        batch_size = None

        if images is not None:
            image_features = self.image_processor(images=images, return_tensors=return_tensors, **kwargs)
            batch.update(dict(image_features))
            batch_size = batch["pixel_values"].shape[0]

        if pooled_embeddings is not None:
            pooled = self.normalize_pooled_embeddings(pooled_embeddings)
            batch["pooled_embeddings"] = pooled
            batch_size = pooled.shape[0]

        if text is None and batch_size is not None:
            text = [self.prompt_prefix] * batch_size

        if text is not None:
            if isinstance(text, str):
                text = [text]
            tokenized = self.tokenizer(
                text,
                add_special_tokens=False,
                padding=padding,
                truncation=truncation,
                max_length=max_length,
                return_tensors=return_tensors,
            )
            batch.update(dict(tokenized))

        return BatchFeature(data=batch, tensor_type=return_tensors)

    def batch_decode(self, *args, **kwargs):
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        return self.tokenizer.decode(*args, **kwargs)