import torch from PIL import Image from typing import Tuple, List import numpy as np from transformers import GemmaTokenizerFast from .paligemma_processor import PaliGemmaProcessor from typing import Optional def process_imgs(imgs: List[Image.Image], img_size: Tuple[int, int], rescale: float, mean: Tuple[float, float, float], std: Tuple[float, float, float]): def normalize(img, mean, std): img = (img - np.array(mean, dtype=img.dtype)) / np.array(std, dtype=img.dtype) return img resized_imgs = [img.resize((img_size[0], img_size[1]), resample=Image.Resampling.BICUBIC) for img in imgs] rescaled_imgs = [np.array(img, dtype=np.float32) * rescale for img in resized_imgs] normalized_imgs = [normalize(img, mean, std) for img in rescaled_imgs] transposed_imgs = [img.transpose(2, 0, 1) for img in normalized_imgs] tensor_imgs = torch.tensor(np.stack(transposed_imgs, axis=0), dtype=torch.float32) return tensor_imgs def process_prompts(prompt, image_token, max_num_image_token, bos_token): return f"{image_token * max_num_image_token}{bos_token}{prompt}\n" class ColPaliProcessor(PaliGemmaProcessor): def __init__(self, tokenizer: GemmaTokenizerFast) -> None: super().__init__(tokenizer=tokenizer) self.mock_image = Image.new(mode='RGB', size=(16, 16), color='black') def process_images(self, images: List[Image.Image]): input_prompts = ["Describe the image."] * len(images) images = [image.convert("RGB") for image in images] return_data = self(images, input_prompts, padding="longest", truncation=False) return return_data def process_queries(self, queries: List[str], max_length: int = 50, suffix: Optional[str] = None): if suffix is None: suffix = "" * 10 texts_query: List[str] = [] for query in queries: query = f"Question: {query}" query += suffix texts_query.append(query) batch_query = self(imgs=[self.mock_image] * len(texts_query), prompts=texts_query, padding="longest", max_length=max_length + self.image_seq_length, truncation=True) del batch_query["pixel_values"] batch_query["input_ids"] = batch_query["input_ids"][..., self.image_seq_length:] batch_query["attention_mask"] = batch_query["attention_mask"][..., self.image_seq_length:] return batch_query