|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""DiffusionVL-Qwen2.5 Processor - Combines image processor and tokenizer.""" |
|
|
|
|
|
import ast |
|
|
import math |
|
|
import re |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
from transformers.processing_utils import ProcessorMixin |
|
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
|
|
from transformers import SiglipImageProcessor |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
|
IMAGE_TOKEN_INDEX = -200 |
|
|
|
|
|
def select_best_resolution(original_size: Tuple[int, int], possible_resolutions: List[Tuple[int, int]]) -> Tuple[int, int]: |
|
|
""" |
|
|
Selects the best resolution from a list of possible resolutions based on the original size. |
|
|
Matching training code: llava/mm_utils.py::select_best_resolution |
|
|
""" |
|
|
original_width, original_height = original_size |
|
|
best_fit = None |
|
|
max_effective_resolution = 0 |
|
|
min_wasted_resolution = float("inf") |
|
|
|
|
|
for width, height in possible_resolutions: |
|
|
scale = min(width / original_width, height / original_height) |
|
|
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) |
|
|
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) |
|
|
wasted_resolution = (width * height) - effective_resolution |
|
|
|
|
|
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): |
|
|
max_effective_resolution = effective_resolution |
|
|
min_wasted_resolution = wasted_resolution |
|
|
best_fit = (width, height) |
|
|
|
|
|
return best_fit |
|
|
|
|
|
|
|
|
def resize_and_pad_image(image: Image.Image, target_resolution: Tuple[int, int]) -> Image.Image: |
|
|
""" |
|
|
Resize and pad an image to a target resolution while maintaining aspect ratio. |
|
|
Matching training code: llava/mm_utils.py::resize_and_pad_image |
|
|
""" |
|
|
original_width, original_height = image.size |
|
|
target_width, target_height = target_resolution |
|
|
|
|
|
scale_w = target_width / original_width |
|
|
scale_h = target_height / original_height |
|
|
|
|
|
if scale_w < scale_h: |
|
|
new_width = target_width |
|
|
new_height = min(math.ceil(original_height * scale_w), target_height) |
|
|
else: |
|
|
new_height = target_height |
|
|
new_width = min(math.ceil(original_width * scale_h), target_width) |
|
|
|
|
|
resized_image = image.resize((new_width, new_height)) |
|
|
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) |
|
|
paste_x = (target_width - new_width) // 2 |
|
|
paste_y = (target_height - new_height) // 2 |
|
|
new_image.paste(resized_image, (paste_x, paste_y)) |
|
|
|
|
|
return new_image |
|
|
|
|
|
|
|
|
def divide_to_patches(image: Image.Image, patch_size: int) -> List[Image.Image]: |
|
|
""" |
|
|
Divides an image into patches of a specified size. |
|
|
Matching training code: llava/mm_utils.py::divide_to_patches |
|
|
""" |
|
|
patches = [] |
|
|
width, height = image.size |
|
|
for i in range(0, height, patch_size): |
|
|
for j in range(0, width, patch_size): |
|
|
box = (j, i, j + patch_size, i + patch_size) |
|
|
patch = image.crop(box) |
|
|
patches.append(patch) |
|
|
return patches |
|
|
|
|
|
|
|
|
def expand2square(pil_img: Image.Image, background_color: Tuple[int, int, int]) -> Image.Image: |
|
|
""" |
|
|
Expand image to square by padding. |
|
|
Matching training code: llava/mm_utils.py::expand2square |
|
|
""" |
|
|
width, height = pil_img.size |
|
|
if width == height: |
|
|
return pil_img |
|
|
elif width > height: |
|
|
result = Image.new(pil_img.mode, (width, width), background_color) |
|
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
|
return result |
|
|
else: |
|
|
result = Image.new(pil_img.mode, (height, height), background_color) |
|
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
|
return result |
|
|
|
|
|
|
|
|
def get_anyres_image_grid_shape(image_size: Tuple[int, int], grid_pinpoints, patch_size: int) -> Tuple[int, int]: |
|
|
""" |
|
|
Calculate the shape of the image patch grid after the preprocessing for images of any resolution. |
|
|
Matching training code: llava/mm_utils.py::get_anyres_image_grid_shape |
|
|
""" |
|
|
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: |
|
|
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" |
|
|
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) |
|
|
range_start = tuple(map(int, matches[0])) |
|
|
range_end = tuple(map(int, matches[-1])) |
|
|
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] |
|
|
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] |
|
|
if isinstance(grid_pinpoints, list): |
|
|
possible_resolutions = grid_pinpoints |
|
|
else: |
|
|
possible_resolutions = ast.literal_eval(grid_pinpoints) |
|
|
width, height = select_best_resolution(image_size, possible_resolutions) |
|
|
return width // patch_size, height // patch_size |
|
|
|
|
|
|
|
|
def process_anyres_image(image: Image.Image, processor: SiglipImageProcessor, grid_pinpoints: str) -> torch.Tensor: |
|
|
""" |
|
|
Process an image with variable resolutions (anyres). |
|
|
Matching training code: llava/mm_utils.py::process_anyres_image |
|
|
|
|
|
Returns: torch.Tensor of shape (num_patches, C, H, W) where num_patches = 1 + grid_patches |
|
|
""" |
|
|
|
|
|
if isinstance(processor.size, dict): |
|
|
patch_size = processor.size.get("shortest_edge", processor.size.get("height", 384)) |
|
|
else: |
|
|
patch_size = processor.size[0] if hasattr(processor.size, '__getitem__') else 384 |
|
|
|
|
|
crop_size = processor.crop_size.get("height", patch_size) if hasattr(processor, 'crop_size') else patch_size |
|
|
|
|
|
|
|
|
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: |
|
|
assert patch_size in [224, 336, 384, 448, 512], f"patch_size {patch_size} should be in [224, 336, 384, 448, 512]" |
|
|
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) |
|
|
range_start = tuple(map(int, matches[0])) |
|
|
range_end = tuple(map(int, matches[-1])) |
|
|
grid_pinpoints_list = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] |
|
|
possible_resolutions = [[dim * patch_size for dim in pair] for pair in grid_pinpoints_list] |
|
|
elif isinstance(grid_pinpoints, list): |
|
|
possible_resolutions = grid_pinpoints |
|
|
else: |
|
|
possible_resolutions = ast.literal_eval(grid_pinpoints) |
|
|
|
|
|
best_resolution = select_best_resolution(image.size, possible_resolutions) |
|
|
image_padded = resize_and_pad_image(image, best_resolution) |
|
|
patches = divide_to_patches(image_padded, crop_size) |
|
|
|
|
|
|
|
|
if isinstance(processor.size, dict): |
|
|
shortest_edge = processor.size.get("shortest_edge", processor.size.get("height", 384)) |
|
|
else: |
|
|
shortest_edge = min(processor.size) if hasattr(processor.size, '__iter__') else 384 |
|
|
image_original_resize = image.resize((shortest_edge, shortest_edge)) |
|
|
|
|
|
|
|
|
image_patches = [image_original_resize] + patches |
|
|
|
|
|
|
|
|
processed_patches = [processor.preprocess(patch, return_tensors="pt")["pixel_values"][0] for patch in image_patches] |
|
|
|
|
|
return torch.stack(processed_patches, dim=0) |
|
|
|
|
|
|
|
|
def process_images(images: List[Image.Image], image_processor: SiglipImageProcessor, model_cfg) -> torch.Tensor: |
|
|
""" |
|
|
Process images matching the training code pipeline. |
|
|
Matching training code: llava/mm_utils.py::process_images |
|
|
|
|
|
Args: |
|
|
images: List of PIL Images |
|
|
image_processor: SiglipImageProcessor instance |
|
|
model_cfg: Model config with image_aspect_ratio and image_grid_pinpoints |
|
|
|
|
|
Returns: |
|
|
torch.Tensor or List[torch.Tensor] of processed image patches |
|
|
""" |
|
|
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) |
|
|
new_images = [] |
|
|
|
|
|
if image_aspect_ratio == "anyres" or (image_aspect_ratio and "anyres" in image_aspect_ratio): |
|
|
grid_pinpoints = getattr(model_cfg, "image_grid_pinpoints", "(1x1),...,(2x2)") |
|
|
for image in images: |
|
|
processed = process_anyres_image(image, image_processor, grid_pinpoints) |
|
|
new_images.append(processed) |
|
|
elif image_aspect_ratio == "pad": |
|
|
for image in images: |
|
|
image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean)) |
|
|
processed = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] |
|
|
new_images.append(processed) |
|
|
else: |
|
|
|
|
|
return image_processor.preprocess(images, return_tensors="pt")["pixel_values"] |
|
|
|
|
|
|
|
|
if all(x.shape == new_images[0].shape for x in new_images): |
|
|
new_images = torch.stack(new_images, dim=0) |
|
|
return new_images |
|
|
|
|
|
|
|
|
def tokenizer_image_token(prompt: str, tokenizer, image_token_index: int = IMAGE_TOKEN_INDEX, return_tensors: str = None): |
|
|
""" |
|
|
Tokenize prompt with proper handling of <image> tokens. |
|
|
Matching training code: llava/mm_utils.py::tokenizer_image_token |
|
|
|
|
|
Args: |
|
|
prompt: Text prompt containing <image> placeholders |
|
|
tokenizer: Tokenizer instance |
|
|
image_token_index: Index to use for image tokens (default: -200) |
|
|
return_tensors: If "pt", return PyTorch tensor |
|
|
|
|
|
Returns: |
|
|
List of token IDs or torch.Tensor |
|
|
""" |
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")] |
|
|
|
|
|
def insert_separator(X, sep): |
|
|
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] |
|
|
|
|
|
input_ids = [] |
|
|
offset = 0 |
|
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
|
|
offset = 1 |
|
|
input_ids.append(prompt_chunks[0][0]) |
|
|
|
|
|
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
|
|
input_ids.extend(x[offset:]) |
|
|
|
|
|
if return_tensors is not None: |
|
|
if return_tensors == "pt": |
|
|
return torch.tensor(input_ids, dtype=torch.long) |
|
|
raise ValueError(f"Unsupported tensor type: {return_tensors}") |
|
|
return input_ids |
|
|
|
|
|
|
|
|
class Conversation: |
|
|
"""Simple conversation class matching LLaVA's conv_templates.""" |
|
|
|
|
|
def __init__(self, system: str, roles: Tuple[str, str], sep: str, sep2: str = None): |
|
|
self.system = system |
|
|
self.roles = roles |
|
|
self.sep = sep |
|
|
self.sep2 = sep2 |
|
|
self.messages = [] |
|
|
|
|
|
def copy(self): |
|
|
return Conversation( |
|
|
system=self.system, |
|
|
roles=self.roles, |
|
|
sep=self.sep, |
|
|
sep2=self.sep2, |
|
|
) |
|
|
|
|
|
def append_message(self, role: str, message: str): |
|
|
self.messages.append([role, message]) |
|
|
|
|
|
def get_prompt(self) -> str: |
|
|
"""Build the prompt string.""" |
|
|
ret = "" |
|
|
if self.system: |
|
|
ret = f"<|im_start|>system\n{self.system}<|im_end|>\n" |
|
|
|
|
|
for role, message in self.messages: |
|
|
if message: |
|
|
ret += f"<|im_start|>{role}\n{message}<|im_end|>\n" |
|
|
else: |
|
|
ret += f"<|im_start|>{role}\n" |
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
CONV_QWEN_2_5 = Conversation( |
|
|
system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", |
|
|
roles=("user", "assistant"), |
|
|
sep="<|im_end|>", |
|
|
sep2=None, |
|
|
) |
|
|
|
|
|
|
|
|
class DiffusionVL_Qwen2_5_Processor(ProcessorMixin): |
|
|
""" |
|
|
Processor for DiffusionVL-Qwen2.5 model. |
|
|
|
|
|
Self-contained implementation matching the training code pipeline: |
|
|
- Uses SiglipImageProcessor for image preprocessing |
|
|
- Implements process_images with anyres support |
|
|
- Implements tokenizer_image_token for proper <image> token handling |
|
|
|
|
|
The processor stores model config for anyres parameters. Config can be: |
|
|
1. Passed during __init__ via config parameter |
|
|
2. Set after loading via set_config() method |
|
|
3. Passed per-call via model_cfg parameter in __call__ |
|
|
""" |
|
|
|
|
|
attributes = ["tokenizer"] |
|
|
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tokenizer=None, |
|
|
image_processor=None, |
|
|
config=None, |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
if image_processor is None: |
|
|
self.image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384") |
|
|
else: |
|
|
self.image_processor = image_processor |
|
|
|
|
|
|
|
|
self._config = config |
|
|
|
|
|
super().__init__(tokenizer) |
|
|
|
|
|
def set_config(self, config): |
|
|
"""Set model config for anyres image processing.""" |
|
|
self._config = config |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, |
|
|
images: Optional[Union[Image.Image, List[Image.Image]]] = None, |
|
|
model_cfg=None, |
|
|
return_tensors: Optional[str] = "pt", |
|
|
**kwargs, |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Process text and images for model input. |
|
|
|
|
|
Args: |
|
|
text: Input text or list of texts with <image> placeholder. |
|
|
images: PIL Image or list of PIL Images. |
|
|
model_cfg: Model config (needed for anyres parameters). |
|
|
return_tensors: Return type ("pt" for PyTorch). |
|
|
|
|
|
Returns: |
|
|
BatchFeature with input_ids and pixel_values. |
|
|
""" |
|
|
if text is None and images is None: |
|
|
raise ValueError("You must provide either text or images.") |
|
|
|
|
|
|
|
|
if text is not None: |
|
|
if isinstance(text, str): |
|
|
text = [text] |
|
|
|
|
|
all_input_ids = [] |
|
|
for t in text: |
|
|
input_ids = tokenizer_image_token(t, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") |
|
|
all_input_ids.append(input_ids) |
|
|
|
|
|
|
|
|
if len(all_input_ids) > 1: |
|
|
max_len = max(ids.shape[0] for ids in all_input_ids) |
|
|
padded_input_ids = [] |
|
|
for ids in all_input_ids: |
|
|
if ids.shape[0] < max_len: |
|
|
padding = torch.full((max_len - ids.shape[0],), self.tokenizer.pad_token_id, dtype=torch.long) |
|
|
ids = torch.cat([ids, padding]) |
|
|
padded_input_ids.append(ids) |
|
|
input_ids = torch.stack(padded_input_ids) |
|
|
else: |
|
|
input_ids = all_input_ids[0].unsqueeze(0) |
|
|
|
|
|
text_inputs = {"input_ids": input_ids} |
|
|
else: |
|
|
text_inputs = {} |
|
|
|
|
|
|
|
|
if images is not None: |
|
|
if isinstance(images, Image.Image): |
|
|
images = [images] |
|
|
|
|
|
|
|
|
image_sizes = [img.size for img in images] |
|
|
|
|
|
|
|
|
cfg = model_cfg if model_cfg is not None else self._config |
|
|
|
|
|
if cfg is not None: |
|
|
pixel_values = process_images(images, self.image_processor, cfg) |
|
|
|
|
|
if isinstance(pixel_values, list): |
|
|
num_patches_per_image = [t.shape[0] for t in pixel_values] |
|
|
|
|
|
pixel_values = torch.cat(pixel_values, dim=0) |
|
|
elif pixel_values.dim() == 5: |
|
|
|
|
|
num_patches_per_image = [pixel_values.shape[1]] * pixel_values.shape[0] |
|
|
pixel_values = pixel_values.view(-1, *pixel_values.shape[2:]) |
|
|
else: |
|
|
|
|
|
num_patches_per_image = [1] * len(images) |
|
|
else: |
|
|
|
|
|
pixel_values = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"] |
|
|
num_patches_per_image = [1] * len(images) |
|
|
|
|
|
image_inputs = { |
|
|
"pixel_values": pixel_values, |
|
|
"image_sizes": image_sizes, |
|
|
} |
|
|
else: |
|
|
image_inputs = {} |
|
|
num_patches_per_image = None |
|
|
|
|
|
|
|
|
result = BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) |
|
|
|
|
|
|
|
|
|
|
|
if num_patches_per_image is not None: |
|
|
result["num_patches_per_image"] = num_patches_per_image |
|
|
|
|
|
return result |
|
|
|
|
|
def batch_decode(self, *args, **kwargs): |
|
|
"""Decode token IDs to text.""" |
|
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
|
|
def decode(self, *args, **kwargs): |
|
|
"""Decode token IDs to text.""" |
|
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
|
|
@property |
|
|
def model_input_names(self): |
|
|
tokenizer_input_names = self.tokenizer.model_input_names |
|
|
image_processor_input_names = ["pixel_values", "image_sizes", "num_patches_per_image"] |
|
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"DiffusionVL_Qwen2_5_Processor", |
|
|
"process_images", |
|
|
"tokenizer_image_token", |
|
|
"get_anyres_image_grid_shape", |
|
|
"Conversation", |
|
|
"CONV_QWEN_2_5", |
|
|
"DEFAULT_IMAGE_TOKEN", |
|
|
"IMAGE_TOKEN_INDEX", |
|
|
] |
|
|
|