| | """ |
| | Processor class for Molmo. |
| | """ |
| |
|
| | from typing import Optional |
| |
|
| | import PIL |
| | from PIL import ImageOps |
| | from PIL.Image import Image |
| |
|
| | try: |
| | from typing import Unpack |
| | except ImportError: |
| | from typing_extensions import Unpack |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from transformers.image_utils import ImageInput |
| | from transformers.processing_utils import ( |
| | TextKwargs, |
| | ProcessingKwargs, |
| | ProcessorMixin, |
| | ) |
| |
|
| | from transformers.tokenization_utils_base import TextInput, PreTokenizedInput |
| | from transformers.utils import logging |
| |
|
| | from transformers import AutoTokenizer |
| | from .image_preprocessing_molmo import MolmoImagesKwargs, MolmoImageProcessor |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | DEFAULT_IMAGE_PATCH_TOKEN = f"<im_patch>" |
| | DEFAULT_IM_START_TOKEN = f"<im_start>" |
| | DEFAULT_IM_END_TOKEN = f"<im_end>" |
| | DEFAULT_IM_COL_TOKEN = f"<im_col>" |
| | IMAGE_PROMPT = "<|image|>" |
| |
|
| | EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT) |
| |
|
| |
|
| | def get_special_token_ids(tokenizer): |
| | ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False) |
| | assert len(ids) == len(EXTRA_TOKENS) |
| | return {k: i for k, i in zip(EXTRA_TOKENS, ids)} |
| |
|
| |
|
| | class MolmoTextKwargs(TextKwargs, total=False): |
| | style: Optional[str] |
| | system_prompt: Optional[str] |
| | message_format: Optional[str] |
| | always_start_with_space: Optional[bool] |
| | sequence_length: Optional[int] |
| |
|
| |
|
| | class MolmoProcessorKwargs(ProcessingKwargs, total=False): |
| | text_kwargs: MolmoTextKwargs |
| | images_kwargs: MolmoImagesKwargs |
| | _defaults = { |
| | "images_kwargs": { |
| | "max_crops": 12, |
| | "overlap_margins": [4, 4], |
| | "base_image_input_size": [336, 336], |
| | "image_token_length_w": 12, |
| | "image_token_length_h": 12, |
| | "image_patch_size": 14, |
| | "image_padding_mask": True, |
| | }, |
| | "text_kwargs": { |
| | "style": "long_caption", |
| | "system_prompt": "none", |
| | "message_format": "role", |
| | "always_start_with_space": True, |
| | "sequence_length": 1536, |
| | "padding": False, |
| | }, |
| | } |
| |
|
| |
|
| | class MolmoProcessor(ProcessorMixin): |
| | attributes = ["image_processor", "tokenizer"] |
| | image_processor_class = "AutoImageProcessor" |
| | tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast") |
| |
|
| | def __init__(self, image_processor: MolmoImageProcessor = None, tokenizer : AutoTokenizer = None, **kwargs): |
| | |
| | |
| | super().__init__(image_processor, tokenizer) |
| | self._special_tokens = None |
| |
|
| | @property |
| | def special_token_ids(self): |
| | if self._special_tokens is None: |
| | self._special_tokens = get_special_token_ids(self.tokenizer) |
| | return self._special_tokens |
| |
|
| | def get_tokens_input(self, prompt, message_format, always_start_with_space): |
| | if message_format == "none" or message_format is None: |
| | pass |
| | elif message_format == "role": |
| | prompt = "User: " + prompt + " Assistant:" |
| | else: |
| | raise NotImplementedError(f"Message format {message_format} not implemented") |
| |
|
| | if always_start_with_space: |
| | prompt = " " + prompt |
| |
|
| | tokens = self.tokenizer.encode(prompt, add_special_tokens=False) |
| |
|
| | return tokens |
| |
|
| | def process( |
| | self, |
| | text: TextInput = None, |
| | images: ImageInput = None, |
| | *, |
| | tokens: Optional[PreTokenizedInput] = None, |
| | **kwargs: Unpack[MolmoProcessorKwargs], |
| | ): |
| | output_kwargs = self._merge_kwargs( |
| | MolmoProcessorKwargs, |
| | tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| | **kwargs, |
| | ) |
| |
|
| | if tokens is None: |
| | tokens = self.get_tokens_input( |
| | text, |
| | output_kwargs["text_kwargs"]["message_format"], |
| | output_kwargs["text_kwargs"]["always_start_with_space"], |
| | ) |
| |
|
| | image_token_id = self.special_token_ids[IMAGE_PROMPT] |
| |
|
| | if images is not None: |
| | if not isinstance(images, (list, tuple)): |
| | images = [images] |
| | image_arrays = [] |
| | for image in images: |
| | if isinstance(image, Image): |
| | image = image.convert("RGB") |
| | |
| | |
| | img = ImageOps.exif_transpose(image) |
| | image_arrays.append(np.array(image)) |
| | else: |
| | assert len(image.shape) == 3 and image.shape[-1] == 3 |
| | image_arrays.append(image.astype(np.uint8)) |
| | images = image_arrays |
| | |
| | image_idx = [-1]*len(images) |
| | else: |
| | image_idx = None |
| |
|
| | sequence_length = output_kwargs["text_kwargs"]["sequence_length"] |
| |
|
| | image_patch_token_id = self.special_token_ids[DEFAULT_IMAGE_PATCH_TOKEN] |
| | image_col_token_id = self.special_token_ids[DEFAULT_IM_COL_TOKEN] |
| | image_start_token_id = self.special_token_ids[DEFAULT_IM_START_TOKEN] |
| | image_end_token_id = self.special_token_ids[DEFAULT_IM_END_TOKEN] |
| | out = self.image_processor.multimodal_preprocess( |
| | images=images, |
| | image_idx=image_idx, |
| | tokens=np.asarray(tokens).astype(np.int32), |
| | sequence_length=sequence_length, |
| | image_patch_token_id=image_patch_token_id, |
| | image_col_token_id=image_col_token_id, |
| | image_start_token_id=image_start_token_id, |
| | image_end_token_id=image_end_token_id, |
| | **output_kwargs["images_kwargs"] |
| | ) |
| |
|
| | |
| | |
| | bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id |
| | decoder_input_tokens = np.pad(out["input_ids"], [[1, 0]], constant_values=bos) |
| | out["input_ids"] = decoder_input_tokens |
| | if "image_input_idx" in out: |
| | |
| | image_input_idx = out["image_input_idx"] |
| | out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1) |
| |
|
| | for k, v in out.items(): |
| | out[k] = torch.from_numpy(v) |
| |
|
| | return out |
| |
|
| |
|
| | MolmoProcessor.register_for_auto_class() |
| |
|