File size: 9,171 Bytes
76a07b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# coding=utf-8
"""Processor for the Molmo-v1 (CLIP vision) VLM.

Reproduces the Molmo preprocessor token layout exactly for this VLM's config
(crop_mode=resize, max_crops=1, image_pooling_2d=none, include_cls_token=true):

  per image block (213 tokens; 197 <im_patch>):
    [<im_start>] [<im_patch>(CLS)] then 14x([<im_patch>*14][<im_col>]) [<im_end>]

  full sequence: [BOS] + <pre-image text> + image_block + <post-image text>
  image_input_idx: the 197 <im_patch> positions (CLS first, then 196 row-major),
                   each +1 for the prepended BOS.
"""

from typing import List, Optional, Union

import numpy as np
import torch

from transformers.processing_utils import ProcessorMixin
from transformers.feature_extraction_utils import BatchFeature


class MolmoOlmo3Processor(ProcessorMixin):
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "AutoTokenizer"

    # token-id constants (dolma2 base 100278; specials appended at 100278..100282)
    IMAGE_PROMPT_TOKEN_ID = 100282   # <|image|>
    IMAGE_START_TOKEN_ID = 100278    # <im_start>
    IMAGE_END_TOKEN_ID = 100279      # <im_end>
    IMAGE_PATCH_TOKEN_ID = 100280    # <im_patch>
    IMAGE_COL_TOKEN_ID = 100281      # <im_col>
    BOS_TOKEN_ID = 100257

    # The only styles these models were trained on (system_prompt_kind='demo_or_style').
    # long_caption/user_qa/synthetic_qa saw the "{style}:" prefix only ~10% of the time
    # (no prefix the other ~90%); transcript was always prefixed.
    KNOWN_STYLES = ("long_caption", "transcript", "user_qa", "synthetic_qa")

    def __init__(
        self,
        image_processor=None,
        tokenizer=None,
        image_token_length_w: int = 14,
        image_token_length_h: int = 14,
        include_cls_token: bool = True,
        use_col_tokens: bool = True,
        always_start_with_space: bool = True,
        **kwargs,
    ):
        self.image_token_length_w = image_token_length_w
        self.image_token_length_h = image_token_length_h
        self.include_cls_token = include_cls_token
        self.use_col_tokens = use_col_tokens
        self.always_start_with_space = always_start_with_space
        super().__init__(image_processor, tokenizer, **kwargs)

    def format_prompt(self, question: str, style=None) -> str:
        """Reproduce Molmo's DataFormatter (system_prompt='demo_or_style', message_format='none').

        Usage:
          - VQA / instruction (most common): `text="your question"`, `style=None`
            -> " your question". This matches ~90% of training (no prefix), so leaving
            style unset is usually best.
          - Captioning: `text=""`, `style=None` -> a bare " " prompt; or
            `text="", style="long_caption"` / `style="transcript"` to request that mode
            explicitly. (Training produced captions/transcripts from an empty user turn.)
          - Steer output mode: pass `style` in {long_caption, transcript, user_qa,
            synthetic_qa} -> "{style}: ...". Note long_caption/user_qa/synthetic_qa only
            saw the prefix ~10% of the time in training; transcript was always prefixed.

        always_start_with_space -> a single leading space is always prepended.
        """
        if style is not None and style not in self.KNOWN_STYLES:
            import warnings
            warnings.warn(
                f"style={style!r} was not used to train these models; the model may ignore "
                f"or mishandle it. Known styles: {self.KNOWN_STYLES}. Use style=None for the "
                f"default (no-prefix) behavior the model saw ~90% of the time."
            )
        prefix = "" if not style else f"{style}:"
        if prefix and question:
            text = prefix + " " + question
        elif prefix:
            text = prefix
        else:
            text = question
        if self.always_start_with_space:
            text = " " + text
        return text

    def _image_block(self) -> np.ndarray:
        """The 213-token image block for a single resized crop."""
        per_row = np.full((self.image_token_length_w,), self.IMAGE_PATCH_TOKEN_ID, dtype=np.int32)
        if self.use_col_tokens:
            per_row = np.concatenate([per_row, [self.IMAGE_COL_TOKEN_ID]], 0)
        extra = np.tile(per_row, [self.image_token_length_h])
        joint = [[self.IMAGE_START_TOKEN_ID]]
        if self.include_cls_token:
            joint.append([self.IMAGE_PATCH_TOKEN_ID])
        joint += [extra, [self.IMAGE_END_TOKEN_ID]]
        return np.concatenate(joint, 0).astype(np.int32)

    def _image_input_idx(self, image_block: np.ndarray) -> np.ndarray:
        """Positions of <im_patch> within the block, (1, features_per_image)."""
        tokens_per_image = self.image_token_length_w * self.image_token_length_h
        features_per_image = tokens_per_image + (1 if self.include_cls_token else 0)
        idx = np.nonzero(image_block == self.IMAGE_PATCH_TOKEN_ID)[0].astype(np.int32)
        return idx.reshape(1, features_per_image)

    def __call__(
        self,
        text: Union[str, List[str]],
        images=None,
        style=None,
        apply_prompt_format: bool = True,
        return_tensors: Optional[str] = "pt",
        **kwargs,
    ) -> BatchFeature:
        """Tokenize text + splice image features.

        By default (apply_prompt_format=True) the text is wrapped with the training-time
        formatting (leading space + optional "{style}: " prefix) and the image is placed
        first (Molmo inserts the image at the start when no <|image|> marker is present).
        Pass apply_prompt_format=False to feed pre-formatted text, or include an explicit
        <|image|> marker to control image placement.
        """
        if isinstance(text, (list, tuple)):
            if len(text) != 1:
                raise NotImplementedError("MolmoOlmo3Processor supports a single prompt at a time.")
            text = text[0]
        if images is not None and not isinstance(images, (list, tuple)):
            images = [images]

        if apply_prompt_format and self.IMAGE_PROMPT_TOKEN_ID not in \
                self.tokenizer.encode(text, add_special_tokens=False):
            text = self.format_prompt(text, style=style)

        tokens = np.array(self.tokenizer.encode(text, add_special_tokens=False), dtype=np.int32)

        if not images:
            input_ids = np.pad(tokens, [[1, 0]], constant_values=self.BOS_TOKEN_ID)
            return self._finalize({"input_tokens": input_ids}, None, None, return_tensors)

        marker_pos = np.argwhere(tokens == self.IMAGE_PROMPT_TOKEN_ID)
        # No marker -> image first (token_ix=-1, matching Molmo's no-marker behavior).
        image_idx = marker_pos[:, 0] if len(marker_pos) else np.array([-1] * len(images))
        assert len(image_idx) == len(images), "number of <|image|> markers must match images"

        block = self._image_block()
        patch_idx = self._image_input_idx(block)
        all_pixel = self.image_processor(images, return_tensors=None)["pixel_values"]  # (n,3,H,W)

        out_tokens, all_image_idx = [], []
        for ix in range(len(images)):
            token_ix = image_idx[ix]
            if token_ix == -1:
                start, token_ix = 0, 0
            else:
                start = 0 if ix == 0 else image_idx[ix - 1] + 1
            all_image_idx.append(patch_idx + token_ix)
            out_tokens.append(tokens[start:token_ix])
            out_tokens.append(block)
        end = (image_idx[-1] + 1) if image_idx[-1] != -1 else 0
        out_tokens.append(tokens[end:])

        input_ids = np.concatenate(out_tokens, 0)
        image_input_idx = np.concatenate(all_image_idx, 0)

        # prepend BOS; shift image_input_idx by +1 (matches Molmo inference path)
        input_ids = np.pad(input_ids, [[1, 0]], constant_values=self.BOS_TOKEN_ID)
        image_input_idx = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)

        return self._finalize(
            {"input_tokens": input_ids, "image_input_idx": image_input_idx[None]},
            all_pixel, image_input_idx[None], return_tensors,
        )

    def _finalize(self, out, pixel_values, image_input_idx, return_tensors):
        input_ids = out["input_tokens"].astype(np.int64)[None]  # (1, seq)
        attention_mask = np.ones_like(input_ids)
        data = {"input_ids": input_ids, "attention_mask": attention_mask}
        if pixel_values is not None:
            data["pixel_values"] = pixel_values[None]  # (1, n_images, 3, H, W)
            data["image_input_idx"] = image_input_idx  # (1, n_images, features_per_image)
        if return_tensors == "pt":
            data = {k: torch.as_tensor(v) for k, v in data.items()}
            if "pixel_values" in data:
                data["pixel_values"] = data["pixel_values"].to(torch.float32)
        return BatchFeature(data=data, tensor_type=None)

    def batch_decode(self, *args, **kwargs):
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        return self.tokenizer.decode(*args, **kwargs)


__all__ = ["MolmoOlmo3Processor"]