|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Union |
|
|
|
|
|
from ...feature_extraction_utils import BatchFeature |
|
|
from ...image_utils import ImageInput, make_flat_list_of_images |
|
|
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack |
|
|
from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput |
|
|
from ...utils import is_torch_available |
|
|
|
|
|
|
|
|
if is_torch_available(): |
|
|
import torch |
|
|
|
|
|
|
|
|
class ColPaliProcessorKwargs(ProcessingKwargs, total=False): |
|
|
_defaults = { |
|
|
"text_kwargs": { |
|
|
"padding": "longest", |
|
|
}, |
|
|
"images_kwargs": { |
|
|
"data_format": "channels_first", |
|
|
"do_convert_rgb": True, |
|
|
}, |
|
|
"common_kwargs": {"return_tensors": "pt"}, |
|
|
} |
|
|
|
|
|
|
|
|
IMAGE_TOKEN = "<image>" |
|
|
EXTRA_TOKENS = [f"<loc{i:0>4}>" for i in range(1024)] + [f"<seg{i:0>3}>" for i in range(128)] |
|
|
|
|
|
|
|
|
def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images): |
|
|
""" |
|
|
Builds a string from the input prompt and image tokens. |
|
|
For example, for the call: |
|
|
build_string_from_input( |
|
|
prompt="Prefix str" |
|
|
bos_token="<s>", |
|
|
image_seq_len=3, |
|
|
image_token="<im>", |
|
|
) |
|
|
The output will be: |
|
|
"<im><im><im><s>Initial str" |
|
|
Args: |
|
|
prompt (`list[Union[str, ImageInput]]`): The input prompt. |
|
|
bos_token (`str`): The beginning of sentence token. |
|
|
image_seq_len (`int`): The length of the image sequence. |
|
|
image_token (`str`): The image token. |
|
|
num_images (`int`): Number of images in the prompt. |
|
|
""" |
|
|
return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n" |
|
|
|
|
|
|
|
|
class ColPaliProcessor(ProcessorMixin): |
|
|
r""" |
|
|
Constructs a ColPali processor which wraps a PaliGemmaProcessor and special methods to process images and queries, as |
|
|
well as to compute the late-interaction retrieval score. |
|
|
|
|
|
[`ColPaliProcessor`] offers all the functionalities of [`PaliGemmaProcessor`]. See the [`~PaliGemmaProcessor.__call__`] |
|
|
for more information. |
|
|
|
|
|
Args: |
|
|
image_processor ([`SiglipImageProcessor`], *optional*): |
|
|
The image processor is a required input. |
|
|
tokenizer ([`LlamaTokenizerFast`], *optional*): |
|
|
The tokenizer is a required input. |
|
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages |
|
|
in a chat into a tokenizable string. |
|
|
visual_prompt_prefix (`str`, *optional*, defaults to `"Describe the image."`): |
|
|
A string that gets tokenized and prepended to the image tokens. |
|
|
query_prefix (`str`, *optional*, defaults to `"Question: "`): |
|
|
A prefix to be used for the query. |
|
|
""" |
|
|
|
|
|
attributes = ["image_processor", "tokenizer"] |
|
|
image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast") |
|
|
tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast") |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image_processor=None, |
|
|
tokenizer=None, |
|
|
chat_template=None, |
|
|
visual_prompt_prefix: str = "Describe the image.", |
|
|
query_prefix: str = "Question: ", |
|
|
): |
|
|
super().__init__(image_processor, tokenizer, chat_template=chat_template) |
|
|
if not hasattr(image_processor, "image_seq_length"): |
|
|
raise ValueError("Image processor is missing an `image_seq_length` attribute.") |
|
|
|
|
|
self.image_seq_length = image_processor.image_seq_length |
|
|
|
|
|
if not hasattr(tokenizer, "image_token"): |
|
|
image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True) |
|
|
tokens_to_add = {"additional_special_tokens": [image_token]} |
|
|
tokenizer.add_special_tokens(tokens_to_add) |
|
|
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) |
|
|
self.image_token = IMAGE_TOKEN |
|
|
else: |
|
|
self.image_token_id = tokenizer.image_token_id |
|
|
self.image_token = tokenizer.image_token |
|
|
|
|
|
tokenizer.add_tokens(EXTRA_TOKENS) |
|
|
tokenizer.add_bos_token = False |
|
|
tokenizer.add_eos_token = False |
|
|
self.visual_prompt_prefix = visual_prompt_prefix |
|
|
self.query_prefix = query_prefix |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: Optional[ImageInput] = None, |
|
|
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, |
|
|
audio=None, |
|
|
videos=None, |
|
|
**kwargs: Unpack[ColPaliProcessorKwargs], |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom |
|
|
wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process |
|
|
both text and images at the same time. |
|
|
|
|
|
When preparing the text(s), this method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's |
|
|
[`~LlamaTokenizerFast.__call__`]. |
|
|
When preparing the image(s), this method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's |
|
|
[`~SiglipImageProcessor.__call__`]. |
|
|
Please refer to the docstring of the above two methods for more information. |
|
|
|
|
|
Args: |
|
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): |
|
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch |
|
|
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a |
|
|
number of channels, H and W are image height and width. |
|
|
text (`str`, `list[str]`, `list[list[str]]`): |
|
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings |
|
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set |
|
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). |
|
|
return_tensors (`str` or [`~utils.TensorType`], *optional*): |
|
|
If set, will return tensors of a particular framework. Acceptable values are: |
|
|
|
|
|
- `'tf'`: Return TensorFlow `tf.constant` objects. |
|
|
- `'pt'`: Return PyTorch `torch.Tensor` objects. |
|
|
- `'np'`: Return NumPy `np.ndarray` objects. |
|
|
- `'jax'`: Return JAX `jnp.ndarray` objects. |
|
|
|
|
|
Returns: |
|
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields: |
|
|
|
|
|
- **input_ids** -- List of token ids to be fed to a model. |
|
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when |
|
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not |
|
|
`None`). |
|
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. |
|
|
""" |
|
|
output_kwargs = self._merge_kwargs( |
|
|
ColPaliProcessorKwargs, |
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
|
|
**kwargs, |
|
|
) |
|
|
suffix = output_kwargs["text_kwargs"].pop("suffix", None) |
|
|
|
|
|
return_token_type_ids = suffix is not None |
|
|
|
|
|
if text is None and images is None: |
|
|
raise ValueError("Either text or images must be provided") |
|
|
if text is not None and images is not None: |
|
|
raise ValueError("Only one of text or images can be processed at a time") |
|
|
|
|
|
if images is not None: |
|
|
images = self.image_processor.fetch_images(images) |
|
|
images = make_flat_list_of_images(images) |
|
|
texts_doc = [self.visual_prompt_prefix] * len(images) |
|
|
images = [image.convert("RGB") for image in images] |
|
|
|
|
|
input_strings = [ |
|
|
build_string_from_input( |
|
|
prompt=prompt, |
|
|
bos_token=self.tokenizer.bos_token, |
|
|
image_seq_len=self.image_seq_length, |
|
|
image_token=IMAGE_TOKEN, |
|
|
num_images=len(image_list) if isinstance(image_list, list) else 1, |
|
|
) |
|
|
for prompt, image_list in zip(texts_doc, images) |
|
|
] |
|
|
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] |
|
|
|
|
|
|
|
|
if output_kwargs["text_kwargs"].get("max_length", None) is not None: |
|
|
output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
input_strings, |
|
|
return_token_type_ids=False, |
|
|
**output_kwargs["text_kwargs"], |
|
|
) |
|
|
|
|
|
return_data = {**inputs, "pixel_values": pixel_values} |
|
|
|
|
|
if return_token_type_ids: |
|
|
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100) |
|
|
return_data.update({"labels": labels}) |
|
|
|
|
|
return BatchFeature(data=return_data) |
|
|
|
|
|
elif text is not None: |
|
|
if isinstance(text, str): |
|
|
text = [text] |
|
|
elif not (isinstance(text, list) and isinstance(text[0], str)): |
|
|
raise ValueError("Text must be a string or a list of strings") |
|
|
|
|
|
if suffix is None: |
|
|
suffix = self.query_augmentation_token * 10 |
|
|
|
|
|
texts_query: list[str] = [] |
|
|
for query in text: |
|
|
query = self.tokenizer.bos_token + self.query_prefix + query + suffix + "\n" |
|
|
texts_query.append(query) |
|
|
|
|
|
output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50) |
|
|
|
|
|
batch_query = self.tokenizer( |
|
|
texts_query, |
|
|
return_token_type_ids=False, |
|
|
**output_kwargs["text_kwargs"], |
|
|
) |
|
|
|
|
|
return batch_query |
|
|
|
|
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): |
|
|
""" |
|
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. |
|
|
|
|
|
Args: |
|
|
image_sizes (list[list[str]], *optional*): |
|
|
The input sizes formatted as (height, width) per each image. |
|
|
Returns: |
|
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided |
|
|
input modalities, along with other useful data. |
|
|
""" |
|
|
vision_data = {} |
|
|
if image_sizes is not None: |
|
|
num_image_tokens = [self.image_seq_length] * len(image_sizes) |
|
|
num_image_patches = [1] * len(image_sizes) |
|
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) |
|
|
return MultiModalData(**vision_data) |
|
|
|
|
|
@property |
|
|
def query_augmentation_token(self) -> str: |
|
|
""" |
|
|
Return the query augmentation token. |
|
|
|
|
|
Query augmentation buffers are used as reasoning buffers during inference. |
|
|
""" |
|
|
return self.tokenizer.pad_token |
|
|
|
|
|
def process_images( |
|
|
self, |
|
|
images: Optional[ImageInput] = None, |
|
|
**kwargs: Unpack[ColPaliProcessorKwargs], |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's |
|
|
[`ColPaliProcessor.__call__`]. |
|
|
|
|
|
This method forwards the `images` and `kwargs` arguments to the image processor. |
|
|
|
|
|
Args: |
|
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): |
|
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch |
|
|
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a |
|
|
number of channels, H and W are image height and width. |
|
|
return_tensors (`str` or [`~utils.TensorType`], *optional*): |
|
|
If set, will return tensors of a particular framework. Acceptable values are: |
|
|
|
|
|
- `'tf'`: Return TensorFlow `tf.constant` objects. |
|
|
- `'pt'`: Return PyTorch `torch.Tensor` objects. |
|
|
- `'np'`: Return NumPy `np.ndarray` objects. |
|
|
- `'jax'`: Return JAX `jnp.ndarray` objects. |
|
|
|
|
|
Returns: |
|
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields: |
|
|
|
|
|
- **input_ids** -- List of token ids to be fed to a model. |
|
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when |
|
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not |
|
|
`None`). |
|
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. |
|
|
""" |
|
|
return self.__call__(images=images, **kwargs) |
|
|
|
|
|
def process_queries( |
|
|
self, |
|
|
text: Union[TextInput, list[TextInput]], |
|
|
**kwargs: Unpack[ColPaliProcessorKwargs], |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's |
|
|
[`ColPaliProcessor.__call__`]. |
|
|
|
|
|
This method forwards the `text` and `kwargs` arguments to the tokenizer. |
|
|
|
|
|
Args: |
|
|
text (`str`, `list[str]`, `list[list[str]]`): |
|
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings |
|
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set |
|
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). |
|
|
return_tensors (`str` or [`~utils.TensorType`], *optional*): |
|
|
If set, will return tensors of a particular framework. Acceptable values are: |
|
|
|
|
|
- `'tf'`: Return TensorFlow `tf.constant` objects. |
|
|
- `'pt'`: Return PyTorch `torch.Tensor` objects. |
|
|
- `'np'`: Return NumPy `np.ndarray` objects. |
|
|
- `'jax'`: Return JAX `jnp.ndarray` objects. |
|
|
|
|
|
Returns: |
|
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields: |
|
|
|
|
|
- **input_ids** -- List of token ids to be fed to a model. |
|
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when |
|
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not |
|
|
`None`). |
|
|
""" |
|
|
return self.__call__(text=text, **kwargs) |
|
|
|
|
|
def score_retrieval( |
|
|
self, |
|
|
query_embeddings: Union["torch.Tensor", list["torch.Tensor"]], |
|
|
passage_embeddings: Union["torch.Tensor", list["torch.Tensor"]], |
|
|
batch_size: int = 128, |
|
|
output_dtype: Optional["torch.dtype"] = None, |
|
|
output_device: Union["torch.device", str] = "cpu", |
|
|
) -> "torch.Tensor": |
|
|
""" |
|
|
Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector |
|
|
query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the |
|
|
image of a document page. |
|
|
|
|
|
Because the embedding tensors are multi-vector and can thus have different shapes, they |
|
|
should be fed as: |
|
|
(1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim) |
|
|
(2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually |
|
|
obtained by padding the list of tensors. |
|
|
|
|
|
Args: |
|
|
query_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Query embeddings. |
|
|
passage_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Passage embeddings. |
|
|
batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. |
|
|
output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor. |
|
|
If `None`, the dtype of the input embeddings is used. |
|
|
output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score |
|
|
tensor is saved on the "cpu" device. |
|
|
""" |
|
|
|
|
|
if len(query_embeddings) == 0: |
|
|
raise ValueError("No queries provided") |
|
|
if len(passage_embeddings) == 0: |
|
|
raise ValueError("No passages provided") |
|
|
|
|
|
if query_embeddings[0].device != passage_embeddings[0].device: |
|
|
raise ValueError("Queries and passages must be on the same device") |
|
|
|
|
|
if query_embeddings[0].dtype != passage_embeddings[0].dtype: |
|
|
raise ValueError("Queries and passages must have the same dtype") |
|
|
|
|
|
if output_dtype is None: |
|
|
output_dtype = query_embeddings[0].dtype |
|
|
|
|
|
scores: list[torch.Tensor] = [] |
|
|
|
|
|
for i in range(0, len(query_embeddings), batch_size): |
|
|
batch_scores: list[torch.Tensor] = [] |
|
|
batch_queries = torch.nn.utils.rnn.pad_sequence( |
|
|
query_embeddings[i : i + batch_size], batch_first=True, padding_value=0 |
|
|
) |
|
|
for j in range(0, len(passage_embeddings), batch_size): |
|
|
batch_passages = torch.nn.utils.rnn.pad_sequence( |
|
|
passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0 |
|
|
) |
|
|
batch_scores.append( |
|
|
torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2) |
|
|
) |
|
|
scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device)) |
|
|
|
|
|
return torch.cat(scores, dim=0) |
|
|
|
|
|
|
|
|
__all__ = ["ColPaliProcessor"] |
|
|
|