File size: 19,360 Bytes
52197b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 |
# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
# This file was automatically generated from src/transformers/models/colpali/modular_colpali.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_colpali.py file directly. One of our CI enforces this.
# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, make_flat_list_of_images
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
from ...utils import is_torch_available
if is_torch_available():
import torch
class ColPaliProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": "longest",
},
"images_kwargs": {
"data_format": "channels_first",
"do_convert_rgb": True,
},
"common_kwargs": {"return_tensors": "pt"},
}
IMAGE_TOKEN = "<image>"
EXTRA_TOKENS = [f"<loc{i:0>4}>" for i in range(1024)] + [f"<seg{i:0>3}>" for i in range(128)]
def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images):
"""
Builds a string from the input prompt and image tokens.
For example, for the call:
build_string_from_input(
prompt="Prefix str"
bos_token="<s>",
image_seq_len=3,
image_token="<im>",
)
The output will be:
"<im><im><im><s>Initial str"
Args:
prompt (`list[Union[str, ImageInput]]`): The input prompt.
bos_token (`str`): The beginning of sentence token.
image_seq_len (`int`): The length of the image sequence.
image_token (`str`): The image token.
num_images (`int`): Number of images in the prompt.
"""
return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"
class ColPaliProcessor(ProcessorMixin):
r"""
Constructs a ColPali processor which wraps a PaliGemmaProcessor and special methods to process images and queries, as
well as to compute the late-interaction retrieval score.
[`ColPaliProcessor`] offers all the functionalities of [`PaliGemmaProcessor`]. See the [`~PaliGemmaProcessor.__call__`]
for more information.
Args:
image_processor ([`SiglipImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
visual_prompt_prefix (`str`, *optional*, defaults to `"Describe the image."`):
A string that gets tokenized and prepended to the image tokens.
query_prefix (`str`, *optional*, defaults to `"Question: "`):
A prefix to be used for the query.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast")
tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast")
def __init__(
self,
image_processor=None,
tokenizer=None,
chat_template=None,
visual_prompt_prefix: str = "Describe the image.",
query_prefix: str = "Question: ",
):
super().__init__(image_processor, tokenizer, chat_template=chat_template)
if not hasattr(image_processor, "image_seq_length"):
raise ValueError("Image processor is missing an `image_seq_length` attribute.")
self.image_seq_length = image_processor.image_seq_length
if not hasattr(tokenizer, "image_token"):
image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
tokens_to_add = {"additional_special_tokens": [image_token]}
tokenizer.add_special_tokens(tokens_to_add)
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
self.image_token = IMAGE_TOKEN
else:
self.image_token_id = tokenizer.image_token_id
self.image_token = tokenizer.image_token
tokenizer.add_tokens(EXTRA_TOKENS)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
self.visual_prompt_prefix = visual_prompt_prefix
self.query_prefix = query_prefix
def __call__(
self,
images: Optional[ImageInput] = None,
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[ColPaliProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process
both text and images at the same time.
When preparing the text(s), this method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's
[`~LlamaTokenizerFast.__call__`].
When preparing the image(s), this method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's
[`~SiglipImageProcessor.__call__`].
Please refer to the docstring of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
text (`str`, `list[str]`, `list[list[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
ColPaliProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
suffix = output_kwargs["text_kwargs"].pop("suffix", None)
return_token_type_ids = suffix is not None
if text is None and images is None:
raise ValueError("Either text or images must be provided")
if text is not None and images is not None:
raise ValueError("Only one of text or images can be processed at a time")
if images is not None:
images = self.image_processor.fetch_images(images)
images = make_flat_list_of_images(images)
texts_doc = [self.visual_prompt_prefix] * len(images)
images = [image.convert("RGB") for image in images]
input_strings = [
build_string_from_input(
prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length,
image_token=IMAGE_TOKEN,
num_images=len(image_list) if isinstance(image_list, list) else 1,
)
for prompt, image_list in zip(texts_doc, images)
]
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
# max_length has to account for the image tokens
if output_kwargs["text_kwargs"].get("max_length", None) is not None:
output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
inputs = self.tokenizer(
input_strings,
return_token_type_ids=False,
**output_kwargs["text_kwargs"],
)
return_data = {**inputs, "pixel_values": pixel_values}
if return_token_type_ids:
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
return_data.update({"labels": labels})
return BatchFeature(data=return_data)
elif text is not None:
if isinstance(text, str):
text = [text]
elif not (isinstance(text, list) and isinstance(text[0], str)):
raise ValueError("Text must be a string or a list of strings")
if suffix is None:
suffix = self.query_augmentation_token * 10
texts_query: list[str] = []
for query in text:
query = self.tokenizer.bos_token + self.query_prefix + query + suffix + "\n"
texts_query.append(query)
output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50)
batch_query = self.tokenizer(
texts_query,
return_token_type_ids=False,
**output_kwargs["text_kwargs"],
)
return batch_query
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (list[list[str]], *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
num_image_tokens = [self.image_seq_length] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
@property
def query_augmentation_token(self) -> str:
"""
Return the query augmentation token.
Query augmentation buffers are used as reasoning buffers during inference.
"""
return self.tokenizer.pad_token
def process_images(
self,
images: Optional[ImageInput] = None,
**kwargs: Unpack[ColPaliProcessorKwargs],
) -> BatchFeature:
"""
Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's
[`ColPaliProcessor.__call__`].
This method forwards the `images` and `kwargs` arguments to the image processor.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
return self.__call__(images=images, **kwargs)
def process_queries(
self,
text: Union[TextInput, list[TextInput]],
**kwargs: Unpack[ColPaliProcessorKwargs],
) -> BatchFeature:
"""
Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's
[`ColPaliProcessor.__call__`].
This method forwards the `text` and `kwargs` arguments to the tokenizer.
Args:
text (`str`, `list[str]`, `list[list[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
"""
return self.__call__(text=text, **kwargs)
def score_retrieval(
self,
query_embeddings: Union["torch.Tensor", list["torch.Tensor"]],
passage_embeddings: Union["torch.Tensor", list["torch.Tensor"]],
batch_size: int = 128,
output_dtype: Optional["torch.dtype"] = None,
output_device: Union["torch.device", str] = "cpu",
) -> "torch.Tensor":
"""
Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
image of a document page.
Because the embedding tensors are multi-vector and can thus have different shapes, they
should be fed as:
(1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
(2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
obtained by padding the list of tensors.
Args:
query_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Query embeddings.
passage_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Passage embeddings.
batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
If `None`, the dtype of the input embeddings is used.
output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.
Returns:
`torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
tensor is saved on the "cpu" device.
"""
if len(query_embeddings) == 0:
raise ValueError("No queries provided")
if len(passage_embeddings) == 0:
raise ValueError("No passages provided")
if query_embeddings[0].device != passage_embeddings[0].device:
raise ValueError("Queries and passages must be on the same device")
if query_embeddings[0].dtype != passage_embeddings[0].dtype:
raise ValueError("Queries and passages must have the same dtype")
if output_dtype is None:
output_dtype = query_embeddings[0].dtype
scores: list[torch.Tensor] = []
for i in range(0, len(query_embeddings), batch_size):
batch_scores: list[torch.Tensor] = []
batch_queries = torch.nn.utils.rnn.pad_sequence(
query_embeddings[i : i + batch_size], batch_first=True, padding_value=0
)
for j in range(0, len(passage_embeddings), batch_size):
batch_passages = torch.nn.utils.rnn.pad_sequence(
passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0
)
batch_scores.append(
torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2)
)
scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device))
return torch.cat(scores, dim=0)
__all__ = ["ColPaliProcessor"]
|