File size: 15,546 Bytes
b689296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
"""
NaFlexProcessor β€” image + text processor for NaFlexCrossForConditionalGeneration.

Image pipeline:
  1. Convert to RGB
  2. Resize so the patch count stays within `max_num_patches`, snapping H and W
     to multiples of patch_h / patch_w respectively (preserves aspect ratio).
  3. Normalise with per-channel mean/std.
  4. Extract patches row-major; record (row, col) integer position of each patch.
  5. Pad across the batch to the longest sequence; return patch_attention_mask.

Text pipeline:
  Standard HuggingFace tokenizer with padding / truncation.

Usage:
    processor = NaFlexProcessor.from_pretrained("checkpoints/qwen2vl-24m",
                                                 patch_size=(16, 16))
    batch = processor(
        text=texts,
        images=image_inputs,
        return_tensors="pt",
        padding="longest",
        truncation=True,
        max_length=512,
    )
    # batch keys: input_ids, attention_mask,
    #             pixel_values, patch_positions, patch_attention_mask
"""

import math
from typing import List, Optional, Union

import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, BatchFeature, ProcessorMixin
from transformers.tokenization_utils_base import PaddingStrategy, TruncationStrategy


_DEFAULT_MEAN = (0.7931, 0.7931, 0.7931)
_DEFAULT_STD  = (0.1738, 0.1738, 0.1738)

# Cross-attention template: images are encoded by the ViT, NOT injected as tokens.
# Image content blocks are silently dropped; only text reaches the tokenizer.
_DEFAULT_CHAT_TEMPLATE = (
    "{%- for message in messages %}\n"
    "{{- '<|im_start|>' }\n"
    "{%- if message['content'] is string %}\n"
    "{{- message['content'] }}\n"
    "{%- else %}\n"
    "{%- for content in message['content'] %}\n"
    "{%- if content['type'] == 'text' %}\n"
    "{{- content['text'] }}\n"
    "{%- endif %}\n"
    "{%- endfor %}\n"
    "{%- endif %}\n"
    "{{- '<|im_end|>\\n' }}\n"
    "{%- endfor %}\n"
    "{%- if add_generation_prompt %}\n"
    "{{- '<|im_start|>' }}\n"
    "{%- endif %}\n"
)


class NaFlexProcessor(ProcessorMixin):
    """
    Processor for NaFlexCrossForConditionalGeneration.

    Args:
        tokenizer:        Any HuggingFace tokenizer.
        patch_size:       (patch_h, patch_w) or single int for square patches.
        max_num_patches:  Maximum patches per image (controls resolution budget).
        image_mean:       Per-channel normalisation mean (C,).
        image_std:        Per-channel normalisation std  (C,).
    """

    attributes = ["tokenizer"]
    tokenizer_class = "AutoTokenizer"
    # Sentinel: NaFlex has no image tokens in the text sequence.
    # train_sft.py uses this to mask image tokens in labels; -1 is never a real token ID.
    image_token_id: int = -1

    def __init__(
        self,
        tokenizer,
        patch_size: Union[int, tuple] = 16,
        max_num_patches: int = 1024,
        image_mean: tuple = _DEFAULT_MEAN,
        image_std: tuple = _DEFAULT_STD,
        chat_template: Optional[str] = None,
    ):
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
        self.patch_size = list(patch_size)          # serialised as [patch_h, patch_w]
        self.patch_h, self.patch_w = patch_size
        self.max_num_patches = max_num_patches
        self.image_mean = np.array(image_mean, dtype=np.float32)
        self.image_std  = np.array(image_std,  dtype=np.float32)
        # ProcessorMixin stores chat_template as its own attribute via kwargs
        super().__init__(tokenizer, chat_template=chat_template or _DEFAULT_CHAT_TEMPLATE)

    # ── Save / load ───────────────────────────────────────────────────────────

    def save_pretrained(self, save_directory: str, **kwargs):
        """Save processor config, tokenizer, chat template, and a copy of this module."""
        import json, shutil, os

        # Always enforce the NaFlex cross-attention template (not the early-fusion one
        # that may have been inherited from a Qwen2VL tokenizer source)
        if self.tokenizer.chat_template != _DEFAULT_CHAT_TEMPLATE:
            self.tokenizer.chat_template = _DEFAULT_CHAT_TEMPLATE

        super().save_pretrained(save_directory, **kwargs)

        # Write chat_template.jinja and also stamp it directly into tokenizer_config.json.
        # tokenizer.save_pretrained() only writes chat_template there if the tokenizer
        # was originally constructed with it, so we patch it ourselves.
        with open(os.path.join(save_directory, "chat_template.jinja"), "w") as f:
            f.write(self.tokenizer.chat_template)
        tok_cfg_path = os.path.join(save_directory, "tokenizer_config.json")
        if os.path.isfile(tok_cfg_path):
            with open(tok_cfg_path) as f:
                tok_cfg = json.load(f)
            tok_cfg["chat_template"] = self.tokenizer.chat_template
            with open(tok_cfg_path, "w") as f:
                json.dump(tok_cfg, f, indent=2)

        # Overwrite processor_config.json with all image config fields
        cfg_path = os.path.join(save_directory, "processor_config.json")
        with open(cfg_path) as f:
            cfg = json.load(f)
        cfg["patch_size"]       = self.patch_size
        cfg["max_num_patches"]  = self.max_num_patches
        cfg["image_mean"]       = self.image_mean.tolist()
        cfg["image_std"]        = self.image_std.tolist()
        cfg["processor_class"]  = "NaFlexProcessor"
        with open(cfg_path, "w") as f:
            json.dump(cfg, f, indent=2)

        # Copy module file so AutoProcessor can load with trust_remote_code=True
        src = os.path.join(os.path.dirname(__file__), "naflex_processor.py")
        shutil.copy(src, os.path.join(save_directory, "naflex_processor.py"))

        # Add AutoProcessor entry to config.json auto_map
        main_cfg_path = os.path.join(save_directory, "config.json")
        if os.path.exists(main_cfg_path):
            with open(main_cfg_path) as f:
                main_cfg = json.load(f)
            main_cfg.setdefault("auto_map", {})
            main_cfg["auto_map"]["AutoProcessor"] = "naflex_processor.NaFlexProcessor"
            with open(main_cfg_path, "w") as f:
                json.dump(main_cfg, f, indent=2)

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        patch_size: Union[int, tuple, None] = None,
        max_num_patches: Optional[int] = None,
        image_mean: Optional[tuple] = None,
        image_std: Optional[tuple] = None,
        **kwargs,
    ) -> "NaFlexProcessor":
        import json, os
        # Strip kwargs that belong to AutoProcessor/tokenizer infrastructure
        kwargs.pop("use_fast", None)
        trust_remote_code = kwargs.pop("trust_remote_code", None)

        # Read saved image config (if loading from a checkpoint directory)
        saved = {}
        cfg_path = os.path.join(pretrained_model_name_or_path, "processor_config.json")
        if os.path.isfile(cfg_path):
            with open(cfg_path) as f:
                saved = json.load(f)

        tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
        )

        # Load chat template: checkpoint file takes priority, then built-in default.
        # This overwrites any early-fusion template inherited from a Qwen2VL tokenizer.
        template_path = os.path.join(pretrained_model_name_or_path, "chat_template.jinja")
        if os.path.isfile(template_path):
            with open(template_path) as f:
                chat_template = f.read()
        else:
            chat_template = _DEFAULT_CHAT_TEMPLATE

        # Fall back to vit_patch_size from model config.json if not in processor_config
        if patch_size is None and "patch_size" not in saved:
            model_cfg_path = os.path.join(pretrained_model_name_or_path, "config.json")
            if os.path.isfile(model_cfg_path):
                with open(model_cfg_path) as f:
                    model_cfg = json.load(f)
                if "vit_patch_size" in model_cfg:
                    saved["patch_size"] = model_cfg["vit_patch_size"]

        return cls(
            tokenizer=tokenizer,
            chat_template=chat_template,
            patch_size=patch_size   or saved.get("patch_size",      [16, 16]),
            max_num_patches=max_num_patches or saved.get("max_num_patches", 1024),
            image_mean=image_mean   or saved.get("image_mean",       _DEFAULT_MEAN),
            image_std=image_std     or saved.get("image_std",        _DEFAULT_STD),
        )

    # ── Compatibility shim ────────────────────────────────────────────────────

    @property
    def image_processor(self):
        """
        Shim for code that expects a Qwen2VL-style image_processor (e.g. TokenBudgetDataset).
        NaFlex encodes images via cross-attention ViT β€” they do NOT add tokens to the text
        sequence, so max_pixels / min_pixels are set to signal zero image-token overhead.
        """
        patch_pixels = self.patch_h * self.patch_w
        max_pixels = self.max_num_patches * patch_pixels

        class _Shim:
            pass

        shim = _Shim()
        shim.max_pixels = max_pixels
        shim.min_pixels = patch_pixels
        # Signal to any caller that images don't contribute to the text token count
        shim.image_tokens_in_text = False
        return shim

    # ── Image helpers ─────────────────────────────────────────────────────────

    def _resize(self, img: Image.Image, max_num_patches: Optional[int] = None) -> Image.Image:
        """
        Resize image so that:
          - H is a multiple of patch_h, W is a multiple of patch_w
          - patch count <= max_num_patches (or self.max_num_patches if not given)
          - aspect ratio is preserved as closely as possible
        """
        max_n = max_num_patches if max_num_patches is not None else self.max_num_patches
        W_orig, H_orig = img.size
        # Target: scale uniformly so total patches == max_n
        # patches = ceil(H/P_h) * ceil(W/P_w) β‰ˆ (H*W) / (P_h*P_w)
        area_per_patch = self.patch_h * self.patch_w
        scale = math.sqrt(max_n * area_per_patch / (H_orig * W_orig))
        # Round to nearest patch grid boundary
        H_new = max(self.patch_h, round(H_orig * scale / self.patch_h) * self.patch_h)
        W_new = max(self.patch_w, round(W_orig * scale / self.patch_w) * self.patch_w)
        # If already within budget, only snap to grid without upscaling
        if H_orig <= H_new and W_orig <= W_new:
            H_new = max(self.patch_h, round(H_orig / self.patch_h) * self.patch_h)
            W_new = max(self.patch_w, round(W_orig / self.patch_w) * self.patch_w)
        return img.resize((W_new, H_new), Image.BICUBIC)

    def _patchify(
        self, img: Image.Image, max_num_patches: Optional[int] = None
    ) -> tuple[np.ndarray, np.ndarray]:
        """
        Returns:
            patches:   [N, 3 * patch_h * patch_w]  float32
            positions: [N, 2]                       int32  (row, col)
        """
        img = img.convert("RGB")
        img = self._resize(img, max_num_patches=max_num_patches)
        arr = np.array(img, dtype=np.float32) / 255.0  # [H, W, 3]
        arr = (arr - self.image_mean) / self.image_std  # normalise

        H, W, _ = arr.shape
        n_rows = H // self.patch_h
        n_cols = W // self.patch_w
        # Reshape to [n_rows, n_cols, patch_h, patch_w, 3] then flatten per patch
        arr = arr.reshape(n_rows, self.patch_h, n_cols, self.patch_w, 3)
        arr = arr.transpose(0, 2, 1, 3, 4)                     # [n_rows, n_cols, P_h, P_w, 3]
        patches = arr.reshape(n_rows * n_cols, self.patch_h * self.patch_w * 3)  # channels-last order

        rows, cols = np.meshgrid(np.arange(n_rows), np.arange(n_cols), indexing="ij")
        positions = np.stack([rows.ravel(), cols.ravel()], axis=-1).astype(np.int32)

        return patches, positions

    # ── Main __call__ ─────────────────────────────────────────────────────────

    def __call__(
        self,
        text: Union[str, List[str], None] = None,
        images: Union[Image.Image, List[Image.Image], None] = None,
        return_tensors: Optional[str] = "pt",
        padding: Union[bool, str, PaddingStrategy] = False,
        truncation: Union[bool, str, TruncationStrategy] = False,
        max_length: Optional[int] = None,
        max_num_patches: Optional[int] = None,
        **kwargs,
    ) -> BatchFeature:
        if text is None and images is None:
            raise ValueError("At least one of `text` or `images` must be provided.")

        # ── Text ──────────────────────────────────────────────────────────────
        encoding = {}
        if text is not None:
            text_enc = self.tokenizer(
                text,
                padding=padding,
                truncation=truncation,
                max_length=max_length,
                return_tensors=return_tensors,
                **kwargs,
            )
            encoding.update(text_enc)

        # ── Images ────────────────────────────────────────────────────────────
        if images is not None:
            if isinstance(images, Image.Image):
                images = [images]

            all_patches, all_positions = [], []
            for img in images:
                patches, positions = self._patchify(img, max_num_patches=max_num_patches)
                all_patches.append(patches)
                all_positions.append(positions)

            # Pad to max N in batch
            max_n = max(p.shape[0] for p in all_patches)
            patch_dim = all_patches[0].shape[1]

            padded_patches    = np.zeros((len(images), max_n, patch_dim),  dtype=np.float32)
            padded_positions  = np.zeros((len(images), max_n, 2),          dtype=np.int32)
            patch_attn_mask   = np.zeros((len(images), max_n),             dtype=np.bool_)

            for i, (p, pos) in enumerate(zip(all_patches, all_positions)):
                n = p.shape[0]
                padded_patches[i, :n]   = p
                padded_positions[i, :n] = pos
                patch_attn_mask[i, :n]  = True

            if return_tensors == "pt":
                encoding["pixel_values"]         = torch.from_numpy(padded_patches)
                encoding["patch_positions"]       = torch.from_numpy(padded_positions).long()
                encoding["patch_attention_mask"]  = torch.from_numpy(patch_attn_mask)
            else:
                encoding["pixel_values"]         = padded_patches
                encoding["patch_positions"]       = padded_positions
                encoding["patch_attention_mask"]  = patch_attn_mask

        return BatchFeature(data=encoding, tensor_type=None)