DiffusionVL-Qwen2.5-7B / processing_diffusionvl_qwen2_5.py
xiazhi's picture
Upload folder using huggingface_hub
eb188ee verified
# coding=utf-8
# Copyright 2025 The HustVL Team and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on Qwen2.5 and SigLIP. It has been modified to create DiffusionVL.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DiffusionVL-Qwen2.5 Processor - Combines image processor and tokenizer."""
import ast
import math
import re
from typing import List, Optional, Tuple, Union
import torch
import numpy as np
from PIL import Image
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers import SiglipImageProcessor
# Image token for LLaVA format
DEFAULT_IMAGE_TOKEN = "<image>"
IMAGE_TOKEN_INDEX = -200
def select_best_resolution(original_size: Tuple[int, int], possible_resolutions: List[Tuple[int, int]]) -> Tuple[int, int]:
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Matching training code: llava/mm_utils.py::select_best_resolution
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def resize_and_pad_image(image: Image.Image, target_resolution: Tuple[int, int]) -> Image.Image:
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Matching training code: llava/mm_utils.py::resize_and_pad_image
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
resized_image = image.resize((new_width, new_height))
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image: Image.Image, patch_size: int) -> List[Image.Image]:
"""
Divides an image into patches of a specified size.
Matching training code: llava/mm_utils.py::divide_to_patches
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def expand2square(pil_img: Image.Image, background_color: Tuple[int, int, int]) -> Image.Image:
"""
Expand image to square by padding.
Matching training code: llava/mm_utils.py::expand2square
"""
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def get_anyres_image_grid_shape(image_size: Tuple[int, int], grid_pinpoints, patch_size: int) -> Tuple[int, int]:
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Matching training code: llava/mm_utils.py::get_anyres_image_grid_shape
"""
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
if isinstance(grid_pinpoints, list):
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
def process_anyres_image(image: Image.Image, processor: SiglipImageProcessor, grid_pinpoints: str) -> torch.Tensor:
"""
Process an image with variable resolutions (anyres).
Matching training code: llava/mm_utils.py::process_anyres_image
Returns: torch.Tensor of shape (num_patches, C, H, W) where num_patches = 1 + grid_patches
"""
# Get patch size from processor
if isinstance(processor.size, dict):
patch_size = processor.size.get("shortest_edge", processor.size.get("height", 384))
else:
patch_size = processor.size[0] if hasattr(processor.size, '__getitem__') else 384
crop_size = processor.crop_size.get("height", patch_size) if hasattr(processor, 'crop_size') else patch_size
# Parse grid pinpoints
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
assert patch_size in [224, 336, 384, 448, 512], f"patch_size {patch_size} should be in [224, 336, 384, 448, 512]"
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
range_start = tuple(map(int, matches[0]))
range_end = tuple(map(int, matches[-1]))
grid_pinpoints_list = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
possible_resolutions = [[dim * patch_size for dim in pair] for pair in grid_pinpoints_list]
elif isinstance(grid_pinpoints, list):
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, crop_size)
# Base image (resized to patch size) - matching training code behavior
if isinstance(processor.size, dict):
shortest_edge = processor.size.get("shortest_edge", processor.size.get("height", 384))
else:
shortest_edge = min(processor.size) if hasattr(processor.size, '__iter__') else 384
image_original_resize = image.resize((shortest_edge, shortest_edge))
# Combine: base image + grid patches (same order as training code)
image_patches = [image_original_resize] + patches
# Preprocess all patches using the HF processor
processed_patches = [processor.preprocess(patch, return_tensors="pt")["pixel_values"][0] for patch in image_patches]
return torch.stack(processed_patches, dim=0)
def process_images(images: List[Image.Image], image_processor: SiglipImageProcessor, model_cfg) -> torch.Tensor:
"""
Process images matching the training code pipeline.
Matching training code: llava/mm_utils.py::process_images
Args:
images: List of PIL Images
image_processor: SiglipImageProcessor instance
model_cfg: Model config with image_aspect_ratio and image_grid_pinpoints
Returns:
torch.Tensor or List[torch.Tensor] of processed image patches
"""
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == "anyres" or (image_aspect_ratio and "anyres" in image_aspect_ratio):
grid_pinpoints = getattr(model_cfg, "image_grid_pinpoints", "(1x1),...,(2x2)")
for image in images:
processed = process_anyres_image(image, image_processor, grid_pinpoints)
new_images.append(processed)
elif image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
processed = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
new_images.append(processed)
else:
# Default: simple preprocessing
return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
# Stack if all same shape, otherwise return list
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
def tokenizer_image_token(prompt: str, tokenizer, image_token_index: int = IMAGE_TOKEN_INDEX, return_tensors: str = None):
"""
Tokenize prompt with proper handling of <image> tokens.
Matching training code: llava/mm_utils.py::tokenizer_image_token
Args:
prompt: Text prompt containing <image> placeholders
tokenizer: Tokenizer instance
image_token_index: Index to use for image tokens (default: -200)
return_tensors: If "pt", return PyTorch tensor
Returns:
List of token IDs or torch.Tensor
"""
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
class Conversation:
"""Simple conversation class matching LLaVA's conv_templates."""
def __init__(self, system: str, roles: Tuple[str, str], sep: str, sep2: str = None):
self.system = system
self.roles = roles
self.sep = sep
self.sep2 = sep2
self.messages = []
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
sep=self.sep,
sep2=self.sep2,
)
def append_message(self, role: str, message: str):
self.messages.append([role, message])
def get_prompt(self) -> str:
"""Build the prompt string."""
ret = ""
if self.system:
ret = f"<|im_start|>system\n{self.system}<|im_end|>\n"
for role, message in self.messages:
if message:
ret += f"<|im_start|>{role}\n{message}<|im_end|>\n"
else:
ret += f"<|im_start|>{role}\n"
return ret
# Pre-defined conversation template for Qwen2.5
CONV_QWEN_2_5 = Conversation(
system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
roles=("user", "assistant"),
sep="<|im_end|>",
sep2=None,
)
class DiffusionVL_Qwen2_5_Processor(ProcessorMixin):
"""
Processor for DiffusionVL-Qwen2.5 model.
Self-contained implementation matching the training code pipeline:
- Uses SiglipImageProcessor for image preprocessing
- Implements process_images with anyres support
- Implements tokenizer_image_token for proper <image> token handling
The processor stores model config for anyres parameters. Config can be:
1. Passed during __init__ via config parameter
2. Set after loading via set_config() method
3. Passed per-call via model_cfg parameter in __call__
"""
attributes = ["tokenizer"]
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self,
tokenizer=None,
image_processor=None,
config=None,
**kwargs
):
# Use provided image_processor or create default SiglipImageProcessor
if image_processor is None:
self.image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
else:
self.image_processor = image_processor
# Store config for anyres processing
self._config = config
super().__init__(tokenizer)
def set_config(self, config):
"""Set model config for anyres image processing."""
self._config = config
def __call__(
self,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
images: Optional[Union[Image.Image, List[Image.Image]]] = None,
model_cfg=None,
return_tensors: Optional[str] = "pt",
**kwargs,
) -> BatchFeature:
"""
Process text and images for model input.
Args:
text: Input text or list of texts with <image> placeholder.
images: PIL Image or list of PIL Images.
model_cfg: Model config (needed for anyres parameters).
return_tensors: Return type ("pt" for PyTorch).
Returns:
BatchFeature with input_ids and pixel_values.
"""
if text is None and images is None:
raise ValueError("You must provide either text or images.")
# Process text using tokenizer_image_token
if text is not None:
if isinstance(text, str):
text = [text]
all_input_ids = []
for t in text:
input_ids = tokenizer_image_token(t, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
all_input_ids.append(input_ids)
# Pad sequences if multiple
if len(all_input_ids) > 1:
max_len = max(ids.shape[0] for ids in all_input_ids)
padded_input_ids = []
for ids in all_input_ids:
if ids.shape[0] < max_len:
padding = torch.full((max_len - ids.shape[0],), self.tokenizer.pad_token_id, dtype=torch.long)
ids = torch.cat([ids, padding])
padded_input_ids.append(ids)
input_ids = torch.stack(padded_input_ids)
else:
input_ids = all_input_ids[0].unsqueeze(0)
text_inputs = {"input_ids": input_ids}
else:
text_inputs = {}
# Process images using process_images
if images is not None:
if isinstance(images, Image.Image):
images = [images]
# Get image sizes before processing
image_sizes = [img.size for img in images]
# Use model_cfg if provided, otherwise use stored config
cfg = model_cfg if model_cfg is not None else self._config
if cfg is not None:
pixel_values = process_images(images, self.image_processor, cfg)
# Calculate num_patches_per_image for anyres
if isinstance(pixel_values, list):
num_patches_per_image = [t.shape[0] for t in pixel_values]
# Concatenate all patches into single tensor
pixel_values = torch.cat(pixel_values, dim=0)
elif pixel_values.dim() == 5:
# Shape: (num_images, num_patches, C, H, W)
num_patches_per_image = [pixel_values.shape[1]] * pixel_values.shape[0]
pixel_values = pixel_values.view(-1, *pixel_values.shape[2:])
else:
# Shape: (total_patches, C, H, W) - 1 patch per image
num_patches_per_image = [1] * len(images)
else:
# Fallback to simple preprocessing if no config
pixel_values = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
num_patches_per_image = [1] * len(images)
image_inputs = {
"pixel_values": pixel_values,
"image_sizes": image_sizes,
}
else:
image_inputs = {}
num_patches_per_image = None
# Create BatchFeature first
result = BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
# Add num_patches_per_image as plain Python list (not converted to tensor)
# This is needed for prepare_inputs_labels_for_multimodal
if num_patches_per_image is not None:
result["num_patches_per_image"] = num_patches_per_image
return result
def batch_decode(self, *args, **kwargs):
"""Decode token IDs to text."""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""Decode token IDs to text."""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = ["pixel_values", "image_sizes", "num_patches_per_image"]
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
__all__ = [
"DiffusionVL_Qwen2_5_Processor",
"process_images",
"tokenizer_image_token",
"get_anyres_image_grid_shape",
"Conversation",
"CONV_QWEN_2_5",
"DEFAULT_IMAGE_TOKEN",
"IMAGE_TOKEN_INDEX",
]