File size: 15,546 Bytes
b689296 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 | """
NaFlexProcessor β image + text processor for NaFlexCrossForConditionalGeneration.
Image pipeline:
1. Convert to RGB
2. Resize so the patch count stays within `max_num_patches`, snapping H and W
to multiples of patch_h / patch_w respectively (preserves aspect ratio).
3. Normalise with per-channel mean/std.
4. Extract patches row-major; record (row, col) integer position of each patch.
5. Pad across the batch to the longest sequence; return patch_attention_mask.
Text pipeline:
Standard HuggingFace tokenizer with padding / truncation.
Usage:
processor = NaFlexProcessor.from_pretrained("checkpoints/qwen2vl-24m",
patch_size=(16, 16))
batch = processor(
text=texts,
images=image_inputs,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=512,
)
# batch keys: input_ids, attention_mask,
# pixel_values, patch_positions, patch_attention_mask
"""
import math
from typing import List, Optional, Union
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, BatchFeature, ProcessorMixin
from transformers.tokenization_utils_base import PaddingStrategy, TruncationStrategy
_DEFAULT_MEAN = (0.7931, 0.7931, 0.7931)
_DEFAULT_STD = (0.1738, 0.1738, 0.1738)
# Cross-attention template: images are encoded by the ViT, NOT injected as tokens.
# Image content blocks are silently dropped; only text reaches the tokenizer.
_DEFAULT_CHAT_TEMPLATE = (
"{%- for message in messages %}\n"
"{{- '<|im_start|>' }\n"
"{%- if message['content'] is string %}\n"
"{{- message['content'] }}\n"
"{%- else %}\n"
"{%- for content in message['content'] %}\n"
"{%- if content['type'] == 'text' %}\n"
"{{- content['text'] }}\n"
"{%- endif %}\n"
"{%- endfor %}\n"
"{%- endif %}\n"
"{{- '<|im_end|>\\n' }}\n"
"{%- endfor %}\n"
"{%- if add_generation_prompt %}\n"
"{{- '<|im_start|>' }}\n"
"{%- endif %}\n"
)
class NaFlexProcessor(ProcessorMixin):
"""
Processor for NaFlexCrossForConditionalGeneration.
Args:
tokenizer: Any HuggingFace tokenizer.
patch_size: (patch_h, patch_w) or single int for square patches.
max_num_patches: Maximum patches per image (controls resolution budget).
image_mean: Per-channel normalisation mean (C,).
image_std: Per-channel normalisation std (C,).
"""
attributes = ["tokenizer"]
tokenizer_class = "AutoTokenizer"
# Sentinel: NaFlex has no image tokens in the text sequence.
# train_sft.py uses this to mask image tokens in labels; -1 is never a real token ID.
image_token_id: int = -1
def __init__(
self,
tokenizer,
patch_size: Union[int, tuple] = 16,
max_num_patches: int = 1024,
image_mean: tuple = _DEFAULT_MEAN,
image_std: tuple = _DEFAULT_STD,
chat_template: Optional[str] = None,
):
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
self.patch_size = list(patch_size) # serialised as [patch_h, patch_w]
self.patch_h, self.patch_w = patch_size
self.max_num_patches = max_num_patches
self.image_mean = np.array(image_mean, dtype=np.float32)
self.image_std = np.array(image_std, dtype=np.float32)
# ProcessorMixin stores chat_template as its own attribute via kwargs
super().__init__(tokenizer, chat_template=chat_template or _DEFAULT_CHAT_TEMPLATE)
# ββ Save / load βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def save_pretrained(self, save_directory: str, **kwargs):
"""Save processor config, tokenizer, chat template, and a copy of this module."""
import json, shutil, os
# Always enforce the NaFlex cross-attention template (not the early-fusion one
# that may have been inherited from a Qwen2VL tokenizer source)
if self.tokenizer.chat_template != _DEFAULT_CHAT_TEMPLATE:
self.tokenizer.chat_template = _DEFAULT_CHAT_TEMPLATE
super().save_pretrained(save_directory, **kwargs)
# Write chat_template.jinja and also stamp it directly into tokenizer_config.json.
# tokenizer.save_pretrained() only writes chat_template there if the tokenizer
# was originally constructed with it, so we patch it ourselves.
with open(os.path.join(save_directory, "chat_template.jinja"), "w") as f:
f.write(self.tokenizer.chat_template)
tok_cfg_path = os.path.join(save_directory, "tokenizer_config.json")
if os.path.isfile(tok_cfg_path):
with open(tok_cfg_path) as f:
tok_cfg = json.load(f)
tok_cfg["chat_template"] = self.tokenizer.chat_template
with open(tok_cfg_path, "w") as f:
json.dump(tok_cfg, f, indent=2)
# Overwrite processor_config.json with all image config fields
cfg_path = os.path.join(save_directory, "processor_config.json")
with open(cfg_path) as f:
cfg = json.load(f)
cfg["patch_size"] = self.patch_size
cfg["max_num_patches"] = self.max_num_patches
cfg["image_mean"] = self.image_mean.tolist()
cfg["image_std"] = self.image_std.tolist()
cfg["processor_class"] = "NaFlexProcessor"
with open(cfg_path, "w") as f:
json.dump(cfg, f, indent=2)
# Copy module file so AutoProcessor can load with trust_remote_code=True
src = os.path.join(os.path.dirname(__file__), "naflex_processor.py")
shutil.copy(src, os.path.join(save_directory, "naflex_processor.py"))
# Add AutoProcessor entry to config.json auto_map
main_cfg_path = os.path.join(save_directory, "config.json")
if os.path.exists(main_cfg_path):
with open(main_cfg_path) as f:
main_cfg = json.load(f)
main_cfg.setdefault("auto_map", {})
main_cfg["auto_map"]["AutoProcessor"] = "naflex_processor.NaFlexProcessor"
with open(main_cfg_path, "w") as f:
json.dump(main_cfg, f, indent=2)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
patch_size: Union[int, tuple, None] = None,
max_num_patches: Optional[int] = None,
image_mean: Optional[tuple] = None,
image_std: Optional[tuple] = None,
**kwargs,
) -> "NaFlexProcessor":
import json, os
# Strip kwargs that belong to AutoProcessor/tokenizer infrastructure
kwargs.pop("use_fast", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
# Read saved image config (if loading from a checkpoint directory)
saved = {}
cfg_path = os.path.join(pretrained_model_name_or_path, "processor_config.json")
if os.path.isfile(cfg_path):
with open(cfg_path) as f:
saved = json.load(f)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
# Load chat template: checkpoint file takes priority, then built-in default.
# This overwrites any early-fusion template inherited from a Qwen2VL tokenizer.
template_path = os.path.join(pretrained_model_name_or_path, "chat_template.jinja")
if os.path.isfile(template_path):
with open(template_path) as f:
chat_template = f.read()
else:
chat_template = _DEFAULT_CHAT_TEMPLATE
# Fall back to vit_patch_size from model config.json if not in processor_config
if patch_size is None and "patch_size" not in saved:
model_cfg_path = os.path.join(pretrained_model_name_or_path, "config.json")
if os.path.isfile(model_cfg_path):
with open(model_cfg_path) as f:
model_cfg = json.load(f)
if "vit_patch_size" in model_cfg:
saved["patch_size"] = model_cfg["vit_patch_size"]
return cls(
tokenizer=tokenizer,
chat_template=chat_template,
patch_size=patch_size or saved.get("patch_size", [16, 16]),
max_num_patches=max_num_patches or saved.get("max_num_patches", 1024),
image_mean=image_mean or saved.get("image_mean", _DEFAULT_MEAN),
image_std=image_std or saved.get("image_std", _DEFAULT_STD),
)
# ββ Compatibility shim ββββββββββββββββββββββββββββββββββββββββββββββββββββ
@property
def image_processor(self):
"""
Shim for code that expects a Qwen2VL-style image_processor (e.g. TokenBudgetDataset).
NaFlex encodes images via cross-attention ViT β they do NOT add tokens to the text
sequence, so max_pixels / min_pixels are set to signal zero image-token overhead.
"""
patch_pixels = self.patch_h * self.patch_w
max_pixels = self.max_num_patches * patch_pixels
class _Shim:
pass
shim = _Shim()
shim.max_pixels = max_pixels
shim.min_pixels = patch_pixels
# Signal to any caller that images don't contribute to the text token count
shim.image_tokens_in_text = False
return shim
# ββ Image helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _resize(self, img: Image.Image, max_num_patches: Optional[int] = None) -> Image.Image:
"""
Resize image so that:
- H is a multiple of patch_h, W is a multiple of patch_w
- patch count <= max_num_patches (or self.max_num_patches if not given)
- aspect ratio is preserved as closely as possible
"""
max_n = max_num_patches if max_num_patches is not None else self.max_num_patches
W_orig, H_orig = img.size
# Target: scale uniformly so total patches == max_n
# patches = ceil(H/P_h) * ceil(W/P_w) β (H*W) / (P_h*P_w)
area_per_patch = self.patch_h * self.patch_w
scale = math.sqrt(max_n * area_per_patch / (H_orig * W_orig))
# Round to nearest patch grid boundary
H_new = max(self.patch_h, round(H_orig * scale / self.patch_h) * self.patch_h)
W_new = max(self.patch_w, round(W_orig * scale / self.patch_w) * self.patch_w)
# If already within budget, only snap to grid without upscaling
if H_orig <= H_new and W_orig <= W_new:
H_new = max(self.patch_h, round(H_orig / self.patch_h) * self.patch_h)
W_new = max(self.patch_w, round(W_orig / self.patch_w) * self.patch_w)
return img.resize((W_new, H_new), Image.BICUBIC)
def _patchify(
self, img: Image.Image, max_num_patches: Optional[int] = None
) -> tuple[np.ndarray, np.ndarray]:
"""
Returns:
patches: [N, 3 * patch_h * patch_w] float32
positions: [N, 2] int32 (row, col)
"""
img = img.convert("RGB")
img = self._resize(img, max_num_patches=max_num_patches)
arr = np.array(img, dtype=np.float32) / 255.0 # [H, W, 3]
arr = (arr - self.image_mean) / self.image_std # normalise
H, W, _ = arr.shape
n_rows = H // self.patch_h
n_cols = W // self.patch_w
# Reshape to [n_rows, n_cols, patch_h, patch_w, 3] then flatten per patch
arr = arr.reshape(n_rows, self.patch_h, n_cols, self.patch_w, 3)
arr = arr.transpose(0, 2, 1, 3, 4) # [n_rows, n_cols, P_h, P_w, 3]
patches = arr.reshape(n_rows * n_cols, self.patch_h * self.patch_w * 3) # channels-last order
rows, cols = np.meshgrid(np.arange(n_rows), np.arange(n_cols), indexing="ij")
positions = np.stack([rows.ravel(), cols.ravel()], axis=-1).astype(np.int32)
return patches, positions
# ββ Main __call__ βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def __call__(
self,
text: Union[str, List[str], None] = None,
images: Union[Image.Image, List[Image.Image], None] = None,
return_tensors: Optional[str] = "pt",
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
max_num_patches: Optional[int] = None,
**kwargs,
) -> BatchFeature:
if text is None and images is None:
raise ValueError("At least one of `text` or `images` must be provided.")
# ββ Text ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
encoding = {}
if text is not None:
text_enc = self.tokenizer(
text,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
**kwargs,
)
encoding.update(text_enc)
# ββ Images ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if images is not None:
if isinstance(images, Image.Image):
images = [images]
all_patches, all_positions = [], []
for img in images:
patches, positions = self._patchify(img, max_num_patches=max_num_patches)
all_patches.append(patches)
all_positions.append(positions)
# Pad to max N in batch
max_n = max(p.shape[0] for p in all_patches)
patch_dim = all_patches[0].shape[1]
padded_patches = np.zeros((len(images), max_n, patch_dim), dtype=np.float32)
padded_positions = np.zeros((len(images), max_n, 2), dtype=np.int32)
patch_attn_mask = np.zeros((len(images), max_n), dtype=np.bool_)
for i, (p, pos) in enumerate(zip(all_patches, all_positions)):
n = p.shape[0]
padded_patches[i, :n] = p
padded_positions[i, :n] = pos
patch_attn_mask[i, :n] = True
if return_tensors == "pt":
encoding["pixel_values"] = torch.from_numpy(padded_patches)
encoding["patch_positions"] = torch.from_numpy(padded_positions).long()
encoding["patch_attention_mask"] = torch.from_numpy(patch_attn_mask)
else:
encoding["pixel_values"] = padded_patches
encoding["patch_positions"] = padded_positions
encoding["patch_attention_mask"] = patch_attn_mask
return BatchFeature(data=encoding, tensor_type=None)
|