""" NaFlexProcessor — image + text processor for NaFlexCrossForConditionalGeneration. Image pipeline: 1. Convert to RGB 2. Resize so the patch count stays within `max_num_patches`, snapping H and W to multiples of patch_h / patch_w respectively (preserves aspect ratio). 3. Normalise with per-channel mean/std. 4. Extract patches row-major; record (row, col) integer position of each patch. 5. Pad across the batch to the longest sequence; return patch_attention_mask. Text pipeline: Standard HuggingFace tokenizer with padding / truncation. Usage: processor = NaFlexProcessor.from_pretrained("checkpoints/qwen2vl-24m", patch_size=(16, 16)) batch = processor( text=texts, images=image_inputs, return_tensors="pt", padding="longest", truncation=True, max_length=512, ) # batch keys: input_ids, attention_mask, # pixel_values, patch_positions, patch_attention_mask """ import math from typing import List, Optional, Union import numpy as np import torch from PIL import Image from transformers import AutoTokenizer, BatchFeature, ProcessorMixin from transformers.tokenization_utils_base import PaddingStrategy, TruncationStrategy _DEFAULT_MEAN = (0.7931, 0.7931, 0.7931) _DEFAULT_STD = (0.1738, 0.1738, 0.1738) # Cross-attention template: images are encoded by the ViT, NOT injected as tokens. # Image content blocks are silently dropped; only text reaches the tokenizer. _DEFAULT_CHAT_TEMPLATE = ( "{%- for message in messages %}\n" "{{- '<|im_start|>' }\n" "{%- if message['content'] is string %}\n" "{{- message['content'] }}\n" "{%- else %}\n" "{%- for content in message['content'] %}\n" "{%- if content['type'] == 'text' %}\n" "{{- content['text'] }}\n" "{%- endif %}\n" "{%- endfor %}\n" "{%- endif %}\n" "{{- '<|im_end|>\\n' }}\n" "{%- endfor %}\n" "{%- if add_generation_prompt %}\n" "{{- '<|im_start|>' }}\n" "{%- endif %}\n" ) class NaFlexProcessor(ProcessorMixin): """ Processor for NaFlexCrossForConditionalGeneration. Args: tokenizer: Any HuggingFace tokenizer. patch_size: (patch_h, patch_w) or single int for square patches. max_num_patches: Maximum patches per image (controls resolution budget). image_mean: Per-channel normalisation mean (C,). image_std: Per-channel normalisation std (C,). """ attributes = ["tokenizer"] tokenizer_class = "AutoTokenizer" # Sentinel: NaFlex has no image tokens in the text sequence. # train_sft.py uses this to mask image tokens in labels; -1 is never a real token ID. image_token_id: int = -1 def __init__( self, tokenizer, patch_size: Union[int, tuple] = 16, max_num_patches: int = 1024, image_mean: tuple = _DEFAULT_MEAN, image_std: tuple = _DEFAULT_STD, chat_template: Optional[str] = None, ): if isinstance(patch_size, int): patch_size = (patch_size, patch_size) self.patch_size = list(patch_size) # serialised as [patch_h, patch_w] self.patch_h, self.patch_w = patch_size self.max_num_patches = max_num_patches self.image_mean = np.array(image_mean, dtype=np.float32) self.image_std = np.array(image_std, dtype=np.float32) # ProcessorMixin stores chat_template as its own attribute via kwargs super().__init__(tokenizer, chat_template=chat_template or _DEFAULT_CHAT_TEMPLATE) # ── Save / load ─────────────────────────────────────────────────────────── def save_pretrained(self, save_directory: str, **kwargs): """Save processor config, tokenizer, chat template, and a copy of this module.""" import json, shutil, os # Always enforce the NaFlex cross-attention template (not the early-fusion one # that may have been inherited from a Qwen2VL tokenizer source) if self.tokenizer.chat_template != _DEFAULT_CHAT_TEMPLATE: self.tokenizer.chat_template = _DEFAULT_CHAT_TEMPLATE super().save_pretrained(save_directory, **kwargs) # Write chat_template.jinja and also stamp it directly into tokenizer_config.json. # tokenizer.save_pretrained() only writes chat_template there if the tokenizer # was originally constructed with it, so we patch it ourselves. with open(os.path.join(save_directory, "chat_template.jinja"), "w") as f: f.write(self.tokenizer.chat_template) tok_cfg_path = os.path.join(save_directory, "tokenizer_config.json") if os.path.isfile(tok_cfg_path): with open(tok_cfg_path) as f: tok_cfg = json.load(f) tok_cfg["chat_template"] = self.tokenizer.chat_template with open(tok_cfg_path, "w") as f: json.dump(tok_cfg, f, indent=2) # Overwrite processor_config.json with all image config fields cfg_path = os.path.join(save_directory, "processor_config.json") with open(cfg_path) as f: cfg = json.load(f) cfg["patch_size"] = self.patch_size cfg["max_num_patches"] = self.max_num_patches cfg["image_mean"] = self.image_mean.tolist() cfg["image_std"] = self.image_std.tolist() cfg["processor_class"] = "NaFlexProcessor" with open(cfg_path, "w") as f: json.dump(cfg, f, indent=2) # Copy module file so AutoProcessor can load with trust_remote_code=True src = os.path.join(os.path.dirname(__file__), "naflex_processor.py") shutil.copy(src, os.path.join(save_directory, "naflex_processor.py")) # Add AutoProcessor entry to config.json auto_map main_cfg_path = os.path.join(save_directory, "config.json") if os.path.exists(main_cfg_path): with open(main_cfg_path) as f: main_cfg = json.load(f) main_cfg.setdefault("auto_map", {}) main_cfg["auto_map"]["AutoProcessor"] = "naflex_processor.NaFlexProcessor" with open(main_cfg_path, "w") as f: json.dump(main_cfg, f, indent=2) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, patch_size: Union[int, tuple, None] = None, max_num_patches: Optional[int] = None, image_mean: Optional[tuple] = None, image_std: Optional[tuple] = None, **kwargs, ) -> "NaFlexProcessor": import json, os # Strip kwargs that belong to AutoProcessor/tokenizer infrastructure kwargs.pop("use_fast", None) trust_remote_code = kwargs.pop("trust_remote_code", None) # Read saved image config (if loading from a checkpoint directory) saved = {} cfg_path = os.path.join(pretrained_model_name_or_path, "processor_config.json") if os.path.isfile(cfg_path): with open(cfg_path) as f: saved = json.load(f) tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs ) # Load chat template: checkpoint file takes priority, then built-in default. # This overwrites any early-fusion template inherited from a Qwen2VL tokenizer. template_path = os.path.join(pretrained_model_name_or_path, "chat_template.jinja") if os.path.isfile(template_path): with open(template_path) as f: chat_template = f.read() else: chat_template = _DEFAULT_CHAT_TEMPLATE # Fall back to vit_patch_size from model config.json if not in processor_config if patch_size is None and "patch_size" not in saved: model_cfg_path = os.path.join(pretrained_model_name_or_path, "config.json") if os.path.isfile(model_cfg_path): with open(model_cfg_path) as f: model_cfg = json.load(f) if "vit_patch_size" in model_cfg: saved["patch_size"] = model_cfg["vit_patch_size"] return cls( tokenizer=tokenizer, chat_template=chat_template, patch_size=patch_size or saved.get("patch_size", [16, 16]), max_num_patches=max_num_patches or saved.get("max_num_patches", 1024), image_mean=image_mean or saved.get("image_mean", _DEFAULT_MEAN), image_std=image_std or saved.get("image_std", _DEFAULT_STD), ) # ── Compatibility shim ──────────────────────────────────────────────────── @property def image_processor(self): """ Shim for code that expects a Qwen2VL-style image_processor (e.g. TokenBudgetDataset). NaFlex encodes images via cross-attention ViT — they do NOT add tokens to the text sequence, so max_pixels / min_pixels are set to signal zero image-token overhead. """ patch_pixels = self.patch_h * self.patch_w max_pixels = self.max_num_patches * patch_pixels class _Shim: pass shim = _Shim() shim.max_pixels = max_pixels shim.min_pixels = patch_pixels # Signal to any caller that images don't contribute to the text token count shim.image_tokens_in_text = False return shim # ── Image helpers ───────────────────────────────────────────────────────── def _resize(self, img: Image.Image, max_num_patches: Optional[int] = None) -> Image.Image: """ Resize image so that: - H is a multiple of patch_h, W is a multiple of patch_w - patch count <= max_num_patches (or self.max_num_patches if not given) - aspect ratio is preserved as closely as possible """ max_n = max_num_patches if max_num_patches is not None else self.max_num_patches W_orig, H_orig = img.size # Target: scale uniformly so total patches == max_n # patches = ceil(H/P_h) * ceil(W/P_w) ≈ (H*W) / (P_h*P_w) area_per_patch = self.patch_h * self.patch_w scale = math.sqrt(max_n * area_per_patch / (H_orig * W_orig)) # Round to nearest patch grid boundary H_new = max(self.patch_h, round(H_orig * scale / self.patch_h) * self.patch_h) W_new = max(self.patch_w, round(W_orig * scale / self.patch_w) * self.patch_w) # If already within budget, only snap to grid without upscaling if H_orig <= H_new and W_orig <= W_new: H_new = max(self.patch_h, round(H_orig / self.patch_h) * self.patch_h) W_new = max(self.patch_w, round(W_orig / self.patch_w) * self.patch_w) return img.resize((W_new, H_new), Image.BICUBIC) def _patchify( self, img: Image.Image, max_num_patches: Optional[int] = None ) -> tuple[np.ndarray, np.ndarray]: """ Returns: patches: [N, 3 * patch_h * patch_w] float32 positions: [N, 2] int32 (row, col) """ img = img.convert("RGB") img = self._resize(img, max_num_patches=max_num_patches) arr = np.array(img, dtype=np.float32) / 255.0 # [H, W, 3] arr = (arr - self.image_mean) / self.image_std # normalise H, W, _ = arr.shape n_rows = H // self.patch_h n_cols = W // self.patch_w # Reshape to [n_rows, n_cols, patch_h, patch_w, 3] then flatten per patch arr = arr.reshape(n_rows, self.patch_h, n_cols, self.patch_w, 3) arr = arr.transpose(0, 2, 1, 3, 4) # [n_rows, n_cols, P_h, P_w, 3] patches = arr.reshape(n_rows * n_cols, self.patch_h * self.patch_w * 3) # channels-last order rows, cols = np.meshgrid(np.arange(n_rows), np.arange(n_cols), indexing="ij") positions = np.stack([rows.ravel(), cols.ravel()], axis=-1).astype(np.int32) return patches, positions # ── Main __call__ ───────────────────────────────────────────────────────── def __call__( self, text: Union[str, List[str], None] = None, images: Union[Image.Image, List[Image.Image], None] = None, return_tensors: Optional[str] = "pt", padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, max_num_patches: Optional[int] = None, **kwargs, ) -> BatchFeature: if text is None and images is None: raise ValueError("At least one of `text` or `images` must be provided.") # ── Text ────────────────────────────────────────────────────────────── encoding = {} if text is not None: text_enc = self.tokenizer( text, padding=padding, truncation=truncation, max_length=max_length, return_tensors=return_tensors, **kwargs, ) encoding.update(text_enc) # ── Images ──────────────────────────────────────────────────────────── if images is not None: if isinstance(images, Image.Image): images = [images] all_patches, all_positions = [], [] for img in images: patches, positions = self._patchify(img, max_num_patches=max_num_patches) all_patches.append(patches) all_positions.append(positions) # Pad to max N in batch max_n = max(p.shape[0] for p in all_patches) patch_dim = all_patches[0].shape[1] padded_patches = np.zeros((len(images), max_n, patch_dim), dtype=np.float32) padded_positions = np.zeros((len(images), max_n, 2), dtype=np.int32) patch_attn_mask = np.zeros((len(images), max_n), dtype=np.bool_) for i, (p, pos) in enumerate(zip(all_patches, all_positions)): n = p.shape[0] padded_patches[i, :n] = p padded_positions[i, :n] = pos patch_attn_mask[i, :n] = True if return_tensors == "pt": encoding["pixel_values"] = torch.from_numpy(padded_patches) encoding["patch_positions"] = torch.from_numpy(padded_positions).long() encoding["patch_attention_mask"] = torch.from_numpy(patch_attn_mask) else: encoding["pixel_values"] = padded_patches encoding["patch_positions"] = padded_positions encoding["patch_attention_mask"] = patch_attn_mask return BatchFeature(data=encoding, tensor_type=None)