naflex_cross_1 / naflex_processor.py
lukbl's picture
Upload folder using huggingface_hub
b689296 verified
"""
NaFlexProcessor β€” image + text processor for NaFlexCrossForConditionalGeneration.
Image pipeline:
1. Convert to RGB
2. Resize so the patch count stays within `max_num_patches`, snapping H and W
to multiples of patch_h / patch_w respectively (preserves aspect ratio).
3. Normalise with per-channel mean/std.
4. Extract patches row-major; record (row, col) integer position of each patch.
5. Pad across the batch to the longest sequence; return patch_attention_mask.
Text pipeline:
Standard HuggingFace tokenizer with padding / truncation.
Usage:
processor = NaFlexProcessor.from_pretrained("checkpoints/qwen2vl-24m",
patch_size=(16, 16))
batch = processor(
text=texts,
images=image_inputs,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=512,
)
# batch keys: input_ids, attention_mask,
# pixel_values, patch_positions, patch_attention_mask
"""
import math
from typing import List, Optional, Union
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, BatchFeature, ProcessorMixin
from transformers.tokenization_utils_base import PaddingStrategy, TruncationStrategy
_DEFAULT_MEAN = (0.7931, 0.7931, 0.7931)
_DEFAULT_STD = (0.1738, 0.1738, 0.1738)
# Cross-attention template: images are encoded by the ViT, NOT injected as tokens.
# Image content blocks are silently dropped; only text reaches the tokenizer.
_DEFAULT_CHAT_TEMPLATE = (
"{%- for message in messages %}\n"
"{{- '<|im_start|>' }\n"
"{%- if message['content'] is string %}\n"
"{{- message['content'] }}\n"
"{%- else %}\n"
"{%- for content in message['content'] %}\n"
"{%- if content['type'] == 'text' %}\n"
"{{- content['text'] }}\n"
"{%- endif %}\n"
"{%- endfor %}\n"
"{%- endif %}\n"
"{{- '<|im_end|>\\n' }}\n"
"{%- endfor %}\n"
"{%- if add_generation_prompt %}\n"
"{{- '<|im_start|>' }}\n"
"{%- endif %}\n"
)
class NaFlexProcessor(ProcessorMixin):
"""
Processor for NaFlexCrossForConditionalGeneration.
Args:
tokenizer: Any HuggingFace tokenizer.
patch_size: (patch_h, patch_w) or single int for square patches.
max_num_patches: Maximum patches per image (controls resolution budget).
image_mean: Per-channel normalisation mean (C,).
image_std: Per-channel normalisation std (C,).
"""
attributes = ["tokenizer"]
tokenizer_class = "AutoTokenizer"
# Sentinel: NaFlex has no image tokens in the text sequence.
# train_sft.py uses this to mask image tokens in labels; -1 is never a real token ID.
image_token_id: int = -1
def __init__(
self,
tokenizer,
patch_size: Union[int, tuple] = 16,
max_num_patches: int = 1024,
image_mean: tuple = _DEFAULT_MEAN,
image_std: tuple = _DEFAULT_STD,
chat_template: Optional[str] = None,
):
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
self.patch_size = list(patch_size) # serialised as [patch_h, patch_w]
self.patch_h, self.patch_w = patch_size
self.max_num_patches = max_num_patches
self.image_mean = np.array(image_mean, dtype=np.float32)
self.image_std = np.array(image_std, dtype=np.float32)
# ProcessorMixin stores chat_template as its own attribute via kwargs
super().__init__(tokenizer, chat_template=chat_template or _DEFAULT_CHAT_TEMPLATE)
# ── Save / load ───────────────────────────────────────────────────────────
def save_pretrained(self, save_directory: str, **kwargs):
"""Save processor config, tokenizer, chat template, and a copy of this module."""
import json, shutil, os
# Always enforce the NaFlex cross-attention template (not the early-fusion one
# that may have been inherited from a Qwen2VL tokenizer source)
if self.tokenizer.chat_template != _DEFAULT_CHAT_TEMPLATE:
self.tokenizer.chat_template = _DEFAULT_CHAT_TEMPLATE
super().save_pretrained(save_directory, **kwargs)
# Write chat_template.jinja and also stamp it directly into tokenizer_config.json.
# tokenizer.save_pretrained() only writes chat_template there if the tokenizer
# was originally constructed with it, so we patch it ourselves.
with open(os.path.join(save_directory, "chat_template.jinja"), "w") as f:
f.write(self.tokenizer.chat_template)
tok_cfg_path = os.path.join(save_directory, "tokenizer_config.json")
if os.path.isfile(tok_cfg_path):
with open(tok_cfg_path) as f:
tok_cfg = json.load(f)
tok_cfg["chat_template"] = self.tokenizer.chat_template
with open(tok_cfg_path, "w") as f:
json.dump(tok_cfg, f, indent=2)
# Overwrite processor_config.json with all image config fields
cfg_path = os.path.join(save_directory, "processor_config.json")
with open(cfg_path) as f:
cfg = json.load(f)
cfg["patch_size"] = self.patch_size
cfg["max_num_patches"] = self.max_num_patches
cfg["image_mean"] = self.image_mean.tolist()
cfg["image_std"] = self.image_std.tolist()
cfg["processor_class"] = "NaFlexProcessor"
with open(cfg_path, "w") as f:
json.dump(cfg, f, indent=2)
# Copy module file so AutoProcessor can load with trust_remote_code=True
src = os.path.join(os.path.dirname(__file__), "naflex_processor.py")
shutil.copy(src, os.path.join(save_directory, "naflex_processor.py"))
# Add AutoProcessor entry to config.json auto_map
main_cfg_path = os.path.join(save_directory, "config.json")
if os.path.exists(main_cfg_path):
with open(main_cfg_path) as f:
main_cfg = json.load(f)
main_cfg.setdefault("auto_map", {})
main_cfg["auto_map"]["AutoProcessor"] = "naflex_processor.NaFlexProcessor"
with open(main_cfg_path, "w") as f:
json.dump(main_cfg, f, indent=2)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
patch_size: Union[int, tuple, None] = None,
max_num_patches: Optional[int] = None,
image_mean: Optional[tuple] = None,
image_std: Optional[tuple] = None,
**kwargs,
) -> "NaFlexProcessor":
import json, os
# Strip kwargs that belong to AutoProcessor/tokenizer infrastructure
kwargs.pop("use_fast", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
# Read saved image config (if loading from a checkpoint directory)
saved = {}
cfg_path = os.path.join(pretrained_model_name_or_path, "processor_config.json")
if os.path.isfile(cfg_path):
with open(cfg_path) as f:
saved = json.load(f)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
# Load chat template: checkpoint file takes priority, then built-in default.
# This overwrites any early-fusion template inherited from a Qwen2VL tokenizer.
template_path = os.path.join(pretrained_model_name_or_path, "chat_template.jinja")
if os.path.isfile(template_path):
with open(template_path) as f:
chat_template = f.read()
else:
chat_template = _DEFAULT_CHAT_TEMPLATE
# Fall back to vit_patch_size from model config.json if not in processor_config
if patch_size is None and "patch_size" not in saved:
model_cfg_path = os.path.join(pretrained_model_name_or_path, "config.json")
if os.path.isfile(model_cfg_path):
with open(model_cfg_path) as f:
model_cfg = json.load(f)
if "vit_patch_size" in model_cfg:
saved["patch_size"] = model_cfg["vit_patch_size"]
return cls(
tokenizer=tokenizer,
chat_template=chat_template,
patch_size=patch_size or saved.get("patch_size", [16, 16]),
max_num_patches=max_num_patches or saved.get("max_num_patches", 1024),
image_mean=image_mean or saved.get("image_mean", _DEFAULT_MEAN),
image_std=image_std or saved.get("image_std", _DEFAULT_STD),
)
# ── Compatibility shim ────────────────────────────────────────────────────
@property
def image_processor(self):
"""
Shim for code that expects a Qwen2VL-style image_processor (e.g. TokenBudgetDataset).
NaFlex encodes images via cross-attention ViT β€” they do NOT add tokens to the text
sequence, so max_pixels / min_pixels are set to signal zero image-token overhead.
"""
patch_pixels = self.patch_h * self.patch_w
max_pixels = self.max_num_patches * patch_pixels
class _Shim:
pass
shim = _Shim()
shim.max_pixels = max_pixels
shim.min_pixels = patch_pixels
# Signal to any caller that images don't contribute to the text token count
shim.image_tokens_in_text = False
return shim
# ── Image helpers ─────────────────────────────────────────────────────────
def _resize(self, img: Image.Image, max_num_patches: Optional[int] = None) -> Image.Image:
"""
Resize image so that:
- H is a multiple of patch_h, W is a multiple of patch_w
- patch count <= max_num_patches (or self.max_num_patches if not given)
- aspect ratio is preserved as closely as possible
"""
max_n = max_num_patches if max_num_patches is not None else self.max_num_patches
W_orig, H_orig = img.size
# Target: scale uniformly so total patches == max_n
# patches = ceil(H/P_h) * ceil(W/P_w) β‰ˆ (H*W) / (P_h*P_w)
area_per_patch = self.patch_h * self.patch_w
scale = math.sqrt(max_n * area_per_patch / (H_orig * W_orig))
# Round to nearest patch grid boundary
H_new = max(self.patch_h, round(H_orig * scale / self.patch_h) * self.patch_h)
W_new = max(self.patch_w, round(W_orig * scale / self.patch_w) * self.patch_w)
# If already within budget, only snap to grid without upscaling
if H_orig <= H_new and W_orig <= W_new:
H_new = max(self.patch_h, round(H_orig / self.patch_h) * self.patch_h)
W_new = max(self.patch_w, round(W_orig / self.patch_w) * self.patch_w)
return img.resize((W_new, H_new), Image.BICUBIC)
def _patchify(
self, img: Image.Image, max_num_patches: Optional[int] = None
) -> tuple[np.ndarray, np.ndarray]:
"""
Returns:
patches: [N, 3 * patch_h * patch_w] float32
positions: [N, 2] int32 (row, col)
"""
img = img.convert("RGB")
img = self._resize(img, max_num_patches=max_num_patches)
arr = np.array(img, dtype=np.float32) / 255.0 # [H, W, 3]
arr = (arr - self.image_mean) / self.image_std # normalise
H, W, _ = arr.shape
n_rows = H // self.patch_h
n_cols = W // self.patch_w
# Reshape to [n_rows, n_cols, patch_h, patch_w, 3] then flatten per patch
arr = arr.reshape(n_rows, self.patch_h, n_cols, self.patch_w, 3)
arr = arr.transpose(0, 2, 1, 3, 4) # [n_rows, n_cols, P_h, P_w, 3]
patches = arr.reshape(n_rows * n_cols, self.patch_h * self.patch_w * 3) # channels-last order
rows, cols = np.meshgrid(np.arange(n_rows), np.arange(n_cols), indexing="ij")
positions = np.stack([rows.ravel(), cols.ravel()], axis=-1).astype(np.int32)
return patches, positions
# ── Main __call__ ─────────────────────────────────────────────────────────
def __call__(
self,
text: Union[str, List[str], None] = None,
images: Union[Image.Image, List[Image.Image], None] = None,
return_tensors: Optional[str] = "pt",
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
max_num_patches: Optional[int] = None,
**kwargs,
) -> BatchFeature:
if text is None and images is None:
raise ValueError("At least one of `text` or `images` must be provided.")
# ── Text ──────────────────────────────────────────────────────────────
encoding = {}
if text is not None:
text_enc = self.tokenizer(
text,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
**kwargs,
)
encoding.update(text_enc)
# ── Images ────────────────────────────────────────────────────────────
if images is not None:
if isinstance(images, Image.Image):
images = [images]
all_patches, all_positions = [], []
for img in images:
patches, positions = self._patchify(img, max_num_patches=max_num_patches)
all_patches.append(patches)
all_positions.append(positions)
# Pad to max N in batch
max_n = max(p.shape[0] for p in all_patches)
patch_dim = all_patches[0].shape[1]
padded_patches = np.zeros((len(images), max_n, patch_dim), dtype=np.float32)
padded_positions = np.zeros((len(images), max_n, 2), dtype=np.int32)
patch_attn_mask = np.zeros((len(images), max_n), dtype=np.bool_)
for i, (p, pos) in enumerate(zip(all_patches, all_positions)):
n = p.shape[0]
padded_patches[i, :n] = p
padded_positions[i, :n] = pos
patch_attn_mask[i, :n] = True
if return_tensors == "pt":
encoding["pixel_values"] = torch.from_numpy(padded_patches)
encoding["patch_positions"] = torch.from_numpy(padded_positions).long()
encoding["patch_attention_mask"] = torch.from_numpy(patch_attn_mask)
else:
encoding["pixel_values"] = padded_patches
encoding["patch_positions"] = padded_positions
encoding["patch_attention_mask"] = patch_attn_mask
return BatchFeature(data=encoding, tensor_type=None)