| """ |
| NaFlexProcessor β image + text processor for NaFlexCrossForConditionalGeneration. |
| |
| Image pipeline: |
| 1. Convert to RGB |
| 2. Resize so the patch count stays within `max_num_patches`, snapping H and W |
| to multiples of patch_h / patch_w respectively (preserves aspect ratio). |
| 3. Normalise with per-channel mean/std. |
| 4. Extract patches row-major; record (row, col) integer position of each patch. |
| 5. Pad across the batch to the longest sequence; return patch_attention_mask. |
| |
| Text pipeline: |
| Standard HuggingFace tokenizer with padding / truncation. |
| |
| Usage: |
| processor = NaFlexProcessor.from_pretrained("checkpoints/qwen2vl-24m", |
| patch_size=(16, 16)) |
| batch = processor( |
| text=texts, |
| images=image_inputs, |
| return_tensors="pt", |
| padding="longest", |
| truncation=True, |
| max_length=512, |
| ) |
| # batch keys: input_ids, attention_mask, |
| # pixel_values, patch_positions, patch_attention_mask |
| """ |
|
|
| import math |
| from typing import List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from transformers import AutoTokenizer, BatchFeature, ProcessorMixin |
| from transformers.tokenization_utils_base import PaddingStrategy, TruncationStrategy |
|
|
|
|
| _DEFAULT_MEAN = (0.7931, 0.7931, 0.7931) |
| _DEFAULT_STD = (0.1738, 0.1738, 0.1738) |
|
|
| |
| |
| _DEFAULT_CHAT_TEMPLATE = ( |
| "{%- for message in messages %}\n" |
| "{{- '<|im_start|>' }\n" |
| "{%- if message['content'] is string %}\n" |
| "{{- message['content'] }}\n" |
| "{%- else %}\n" |
| "{%- for content in message['content'] %}\n" |
| "{%- if content['type'] == 'text' %}\n" |
| "{{- content['text'] }}\n" |
| "{%- endif %}\n" |
| "{%- endfor %}\n" |
| "{%- endif %}\n" |
| "{{- '<|im_end|>\\n' }}\n" |
| "{%- endfor %}\n" |
| "{%- if add_generation_prompt %}\n" |
| "{{- '<|im_start|>' }}\n" |
| "{%- endif %}\n" |
| ) |
|
|
|
|
| class NaFlexProcessor(ProcessorMixin): |
| """ |
| Processor for NaFlexCrossForConditionalGeneration. |
| |
| Args: |
| tokenizer: Any HuggingFace tokenizer. |
| patch_size: (patch_h, patch_w) or single int for square patches. |
| max_num_patches: Maximum patches per image (controls resolution budget). |
| image_mean: Per-channel normalisation mean (C,). |
| image_std: Per-channel normalisation std (C,). |
| """ |
|
|
| attributes = ["tokenizer"] |
| tokenizer_class = "AutoTokenizer" |
| |
| |
| image_token_id: int = -1 |
|
|
| def __init__( |
| self, |
| tokenizer, |
| patch_size: Union[int, tuple] = 16, |
| max_num_patches: int = 1024, |
| image_mean: tuple = _DEFAULT_MEAN, |
| image_std: tuple = _DEFAULT_STD, |
| chat_template: Optional[str] = None, |
| ): |
| if isinstance(patch_size, int): |
| patch_size = (patch_size, patch_size) |
| self.patch_size = list(patch_size) |
| self.patch_h, self.patch_w = patch_size |
| self.max_num_patches = max_num_patches |
| self.image_mean = np.array(image_mean, dtype=np.float32) |
| self.image_std = np.array(image_std, dtype=np.float32) |
| |
| super().__init__(tokenizer, chat_template=chat_template or _DEFAULT_CHAT_TEMPLATE) |
|
|
| |
|
|
| def save_pretrained(self, save_directory: str, **kwargs): |
| """Save processor config, tokenizer, chat template, and a copy of this module.""" |
| import json, shutil, os |
|
|
| |
| |
| if self.tokenizer.chat_template != _DEFAULT_CHAT_TEMPLATE: |
| self.tokenizer.chat_template = _DEFAULT_CHAT_TEMPLATE |
|
|
| super().save_pretrained(save_directory, **kwargs) |
|
|
| |
| |
| |
| with open(os.path.join(save_directory, "chat_template.jinja"), "w") as f: |
| f.write(self.tokenizer.chat_template) |
| tok_cfg_path = os.path.join(save_directory, "tokenizer_config.json") |
| if os.path.isfile(tok_cfg_path): |
| with open(tok_cfg_path) as f: |
| tok_cfg = json.load(f) |
| tok_cfg["chat_template"] = self.tokenizer.chat_template |
| with open(tok_cfg_path, "w") as f: |
| json.dump(tok_cfg, f, indent=2) |
|
|
| |
| cfg_path = os.path.join(save_directory, "processor_config.json") |
| with open(cfg_path) as f: |
| cfg = json.load(f) |
| cfg["patch_size"] = self.patch_size |
| cfg["max_num_patches"] = self.max_num_patches |
| cfg["image_mean"] = self.image_mean.tolist() |
| cfg["image_std"] = self.image_std.tolist() |
| cfg["processor_class"] = "NaFlexProcessor" |
| with open(cfg_path, "w") as f: |
| json.dump(cfg, f, indent=2) |
|
|
| |
| src = os.path.join(os.path.dirname(__file__), "naflex_processor.py") |
| shutil.copy(src, os.path.join(save_directory, "naflex_processor.py")) |
|
|
| |
| main_cfg_path = os.path.join(save_directory, "config.json") |
| if os.path.exists(main_cfg_path): |
| with open(main_cfg_path) as f: |
| main_cfg = json.load(f) |
| main_cfg.setdefault("auto_map", {}) |
| main_cfg["auto_map"]["AutoProcessor"] = "naflex_processor.NaFlexProcessor" |
| with open(main_cfg_path, "w") as f: |
| json.dump(main_cfg, f, indent=2) |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: str, |
| patch_size: Union[int, tuple, None] = None, |
| max_num_patches: Optional[int] = None, |
| image_mean: Optional[tuple] = None, |
| image_std: Optional[tuple] = None, |
| **kwargs, |
| ) -> "NaFlexProcessor": |
| import json, os |
| |
| kwargs.pop("use_fast", None) |
| trust_remote_code = kwargs.pop("trust_remote_code", None) |
|
|
| |
| saved = {} |
| cfg_path = os.path.join(pretrained_model_name_or_path, "processor_config.json") |
| if os.path.isfile(cfg_path): |
| with open(cfg_path) as f: |
| saved = json.load(f) |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs |
| ) |
|
|
| |
| |
| template_path = os.path.join(pretrained_model_name_or_path, "chat_template.jinja") |
| if os.path.isfile(template_path): |
| with open(template_path) as f: |
| chat_template = f.read() |
| else: |
| chat_template = _DEFAULT_CHAT_TEMPLATE |
|
|
| |
| if patch_size is None and "patch_size" not in saved: |
| model_cfg_path = os.path.join(pretrained_model_name_or_path, "config.json") |
| if os.path.isfile(model_cfg_path): |
| with open(model_cfg_path) as f: |
| model_cfg = json.load(f) |
| if "vit_patch_size" in model_cfg: |
| saved["patch_size"] = model_cfg["vit_patch_size"] |
|
|
| return cls( |
| tokenizer=tokenizer, |
| chat_template=chat_template, |
| patch_size=patch_size or saved.get("patch_size", [16, 16]), |
| max_num_patches=max_num_patches or saved.get("max_num_patches", 1024), |
| image_mean=image_mean or saved.get("image_mean", _DEFAULT_MEAN), |
| image_std=image_std or saved.get("image_std", _DEFAULT_STD), |
| ) |
|
|
| |
|
|
| @property |
| def image_processor(self): |
| """ |
| Shim for code that expects a Qwen2VL-style image_processor (e.g. TokenBudgetDataset). |
| NaFlex encodes images via cross-attention ViT β they do NOT add tokens to the text |
| sequence, so max_pixels / min_pixels are set to signal zero image-token overhead. |
| """ |
| patch_pixels = self.patch_h * self.patch_w |
| max_pixels = self.max_num_patches * patch_pixels |
|
|
| class _Shim: |
| pass |
|
|
| shim = _Shim() |
| shim.max_pixels = max_pixels |
| shim.min_pixels = patch_pixels |
| |
| shim.image_tokens_in_text = False |
| return shim |
|
|
| |
|
|
| def _resize(self, img: Image.Image, max_num_patches: Optional[int] = None) -> Image.Image: |
| """ |
| Resize image so that: |
| - H is a multiple of patch_h, W is a multiple of patch_w |
| - patch count <= max_num_patches (or self.max_num_patches if not given) |
| - aspect ratio is preserved as closely as possible |
| """ |
| max_n = max_num_patches if max_num_patches is not None else self.max_num_patches |
| W_orig, H_orig = img.size |
| |
| |
| area_per_patch = self.patch_h * self.patch_w |
| scale = math.sqrt(max_n * area_per_patch / (H_orig * W_orig)) |
| |
| H_new = max(self.patch_h, round(H_orig * scale / self.patch_h) * self.patch_h) |
| W_new = max(self.patch_w, round(W_orig * scale / self.patch_w) * self.patch_w) |
| |
| if H_orig <= H_new and W_orig <= W_new: |
| H_new = max(self.patch_h, round(H_orig / self.patch_h) * self.patch_h) |
| W_new = max(self.patch_w, round(W_orig / self.patch_w) * self.patch_w) |
| return img.resize((W_new, H_new), Image.BICUBIC) |
|
|
| def _patchify( |
| self, img: Image.Image, max_num_patches: Optional[int] = None |
| ) -> tuple[np.ndarray, np.ndarray]: |
| """ |
| Returns: |
| patches: [N, 3 * patch_h * patch_w] float32 |
| positions: [N, 2] int32 (row, col) |
| """ |
| img = img.convert("RGB") |
| img = self._resize(img, max_num_patches=max_num_patches) |
| arr = np.array(img, dtype=np.float32) / 255.0 |
| arr = (arr - self.image_mean) / self.image_std |
|
|
| H, W, _ = arr.shape |
| n_rows = H // self.patch_h |
| n_cols = W // self.patch_w |
| |
| arr = arr.reshape(n_rows, self.patch_h, n_cols, self.patch_w, 3) |
| arr = arr.transpose(0, 2, 1, 3, 4) |
| patches = arr.reshape(n_rows * n_cols, self.patch_h * self.patch_w * 3) |
|
|
| rows, cols = np.meshgrid(np.arange(n_rows), np.arange(n_cols), indexing="ij") |
| positions = np.stack([rows.ravel(), cols.ravel()], axis=-1).astype(np.int32) |
|
|
| return patches, positions |
|
|
| |
|
|
| def __call__( |
| self, |
| text: Union[str, List[str], None] = None, |
| images: Union[Image.Image, List[Image.Image], None] = None, |
| return_tensors: Optional[str] = "pt", |
| padding: Union[bool, str, PaddingStrategy] = False, |
| truncation: Union[bool, str, TruncationStrategy] = False, |
| max_length: Optional[int] = None, |
| max_num_patches: Optional[int] = None, |
| **kwargs, |
| ) -> BatchFeature: |
| if text is None and images is None: |
| raise ValueError("At least one of `text` or `images` must be provided.") |
|
|
| |
| encoding = {} |
| if text is not None: |
| text_enc = self.tokenizer( |
| text, |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| return_tensors=return_tensors, |
| **kwargs, |
| ) |
| encoding.update(text_enc) |
|
|
| |
| if images is not None: |
| if isinstance(images, Image.Image): |
| images = [images] |
|
|
| all_patches, all_positions = [], [] |
| for img in images: |
| patches, positions = self._patchify(img, max_num_patches=max_num_patches) |
| all_patches.append(patches) |
| all_positions.append(positions) |
|
|
| |
| max_n = max(p.shape[0] for p in all_patches) |
| patch_dim = all_patches[0].shape[1] |
|
|
| padded_patches = np.zeros((len(images), max_n, patch_dim), dtype=np.float32) |
| padded_positions = np.zeros((len(images), max_n, 2), dtype=np.int32) |
| patch_attn_mask = np.zeros((len(images), max_n), dtype=np.bool_) |
|
|
| for i, (p, pos) in enumerate(zip(all_patches, all_positions)): |
| n = p.shape[0] |
| padded_patches[i, :n] = p |
| padded_positions[i, :n] = pos |
| patch_attn_mask[i, :n] = True |
|
|
| if return_tensors == "pt": |
| encoding["pixel_values"] = torch.from_numpy(padded_patches) |
| encoding["patch_positions"] = torch.from_numpy(padded_positions).long() |
| encoding["patch_attention_mask"] = torch.from_numpy(patch_attn_mask) |
| else: |
| encoding["pixel_values"] = padded_patches |
| encoding["patch_positions"] = padded_positions |
| encoding["patch_attention_mask"] = patch_attn_mask |
|
|
| return BatchFeature(data=encoding, tensor_type=None) |
|
|