Update processing_gemma3_tiled.py
Browse files- processing_gemma3_tiled.py +106 -7
processing_gemma3_tiled.py
CHANGED
|
@@ -13,10 +13,52 @@ import numpy as np
|
|
| 13 |
|
| 14 |
from transformers.feature_extraction_utils import BatchFeature
|
| 15 |
from transformers.image_utils import ImageInput, make_nested_list_of_images
|
| 16 |
-
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, ImagesKwargs
|
| 17 |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
class Gemma3TiledImagesKwargs(ImagesKwargs):
|
| 21 |
tile_size: Optional[int]
|
| 22 |
max_tiles_h: Optional[int]
|
|
@@ -54,6 +96,7 @@ class Gemma3TiledProcessor(ProcessorMixin):
|
|
| 54 |
attributes = ["image_processor", "tokenizer"]
|
| 55 |
image_processor_class = "AutoImageProcessor" # Use AutoImageProcessor for compatibility
|
| 56 |
tokenizer_class = "AutoTokenizer"
|
|
|
|
| 57 |
|
| 58 |
def __init__(
|
| 59 |
self,
|
|
@@ -99,20 +142,24 @@ class Gemma3TiledProcessor(ProcessorMixin):
|
|
| 99 |
def build_image_token_sequence(self, grid_h: int, grid_w: int) -> str:
|
| 100 |
"""
|
| 101 |
Build the image token sequence for a tiled image.
|
| 102 |
-
|
| 103 |
Returns a string like:
|
| 104 |
-
\n\n<boi><img>×(16*grid_w)<img>×(16*grid_w)...(×16*grid_h rows)...<eoi
|
| 105 |
-
|
| 106 |
Note: We use <img> tokens for BOTH actual image positions AND linebreak positions.
|
| 107 |
The model will replace them with the appropriate embeddings.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
"""
|
| 109 |
rows = grid_h * self.tokens_per_tile_side
|
| 110 |
cols = grid_w * self.tokens_per_tile_side
|
| 111 |
-
|
| 112 |
total_tokens = self.get_num_image_tokens(grid_h, grid_w)
|
| 113 |
image_tokens = self.image_token * total_tokens
|
| 114 |
-
|
| 115 |
-
return f"\n\n{self.boi_token}{image_tokens}{self.eoi_token}
|
| 116 |
|
| 117 |
def __call__(
|
| 118 |
self,
|
|
@@ -218,6 +265,58 @@ class Gemma3TiledProcessor(ProcessorMixin):
|
|
| 218 |
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
|
| 219 |
image_processor_input_names = self.image_processor.model_input_names
|
| 220 |
return list(set(tokenizer_input_names + image_processor_input_names))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
|
| 223 |
__all__ = ["Gemma3TiledProcessor", "Gemma3TiledProcessorKwargs"]
|
|
|
|
| 13 |
|
| 14 |
from transformers.feature_extraction_utils import BatchFeature
|
| 15 |
from transformers.image_utils import ImageInput, make_nested_list_of_images
|
| 16 |
+
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, ImagesKwargs, MultiModalData
|
| 17 |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 18 |
|
| 19 |
|
| 20 |
+
def calculate_tile_grid(
|
| 21 |
+
image_height: int,
|
| 22 |
+
image_width: int,
|
| 23 |
+
tile_size: int,
|
| 24 |
+
max_tiles_h: int,
|
| 25 |
+
max_tiles_w: int,
|
| 26 |
+
min_tiles: int = 1,
|
| 27 |
+
) -> tuple[int, int]:
|
| 28 |
+
"""
|
| 29 |
+
Calculate the optimal tile grid dimensions for an image.
|
| 30 |
+
|
| 31 |
+
The strategy is to:
|
| 32 |
+
1. Maximize effective resolution (pixels preserved from original image)
|
| 33 |
+
2. Minimize wasted canvas space as a tiebreaker
|
| 34 |
+
"""
|
| 35 |
+
original_pixels = image_height * image_width
|
| 36 |
+
|
| 37 |
+
best_grid = (1, 1)
|
| 38 |
+
best_score = float('-inf')
|
| 39 |
+
|
| 40 |
+
for rows in range(1, max_tiles_h + 1):
|
| 41 |
+
for cols in range(1, max_tiles_w + 1):
|
| 42 |
+
total_tiles = rows * cols
|
| 43 |
+
|
| 44 |
+
if total_tiles < min_tiles:
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
canvas_h = rows * tile_size
|
| 48 |
+
canvas_w = cols * tile_size
|
| 49 |
+
|
| 50 |
+
scale = min(canvas_w / image_width, canvas_h / image_height)
|
| 51 |
+
effective = min(image_height * image_width * scale * scale, original_pixels)
|
| 52 |
+
waste = (canvas_h * canvas_w) - effective
|
| 53 |
+
score = effective - 0.001 * waste
|
| 54 |
+
|
| 55 |
+
if score > best_score:
|
| 56 |
+
best_score = score
|
| 57 |
+
best_grid = (rows, cols)
|
| 58 |
+
|
| 59 |
+
return best_grid
|
| 60 |
+
|
| 61 |
+
|
| 62 |
class Gemma3TiledImagesKwargs(ImagesKwargs):
|
| 63 |
tile_size: Optional[int]
|
| 64 |
max_tiles_h: Optional[int]
|
|
|
|
| 96 |
attributes = ["image_processor", "tokenizer"]
|
| 97 |
image_processor_class = "AutoImageProcessor" # Use AutoImageProcessor for compatibility
|
| 98 |
tokenizer_class = "AutoTokenizer"
|
| 99 |
+
_auto_class = "AutoProcessor" # Required for auto_map in processor_config.json
|
| 100 |
|
| 101 |
def __init__(
|
| 102 |
self,
|
|
|
|
| 142 |
def build_image_token_sequence(self, grid_h: int, grid_w: int) -> str:
|
| 143 |
"""
|
| 144 |
Build the image token sequence for a tiled image.
|
| 145 |
+
|
| 146 |
Returns a string like:
|
| 147 |
+
\n\n<boi><img>×(16*grid_w)<img>×(16*grid_w)...(×16*grid_h rows)...<eoi>
|
| 148 |
+
|
| 149 |
Note: We use <img> tokens for BOTH actual image positions AND linebreak positions.
|
| 150 |
The model will replace them with the appropriate embeddings.
|
| 151 |
+
|
| 152 |
+
IMPORTANT: We do NOT add trailing \n\n because when followed by text content
|
| 153 |
+
that starts with \n, it would create \n\n\n which tokenizes differently and
|
| 154 |
+
breaks vLLM's placeholder pattern matching.
|
| 155 |
"""
|
| 156 |
rows = grid_h * self.tokens_per_tile_side
|
| 157 |
cols = grid_w * self.tokens_per_tile_side
|
| 158 |
+
|
| 159 |
total_tokens = self.get_num_image_tokens(grid_h, grid_w)
|
| 160 |
image_tokens = self.image_token * total_tokens
|
| 161 |
+
|
| 162 |
+
return f"\n\n{self.boi_token}{image_tokens}{self.eoi_token}"
|
| 163 |
|
| 164 |
def __call__(
|
| 165 |
self,
|
|
|
|
| 265 |
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
|
| 266 |
image_processor_input_names = self.image_processor.model_input_names
|
| 267 |
return list(set(tokenizer_input_names + image_processor_input_names))
|
| 268 |
+
|
| 269 |
+
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
| 270 |
+
"""
|
| 271 |
+
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
| 272 |
+
|
| 273 |
+
This is required by vLLM for memory profiling and scheduling.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
image_sizes (`list[list[int]]`, *optional*):
|
| 277 |
+
The input sizes formatted as (height, width) per each image.
|
| 278 |
+
**kwargs: Additional arguments (tile_size, max_tiles_h, max_tiles_w, min_tiles)
|
| 279 |
+
that override image processor defaults.
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
| 283 |
+
input modalities, along with other useful data.
|
| 284 |
+
"""
|
| 285 |
+
vision_data = {}
|
| 286 |
+
if image_sizes is not None:
|
| 287 |
+
# Get tiling parameters from kwargs or fall back to image processor settings
|
| 288 |
+
tile_size = kwargs.get("tile_size", getattr(self.image_processor, "tile_size", 896))
|
| 289 |
+
max_tiles_h = kwargs.get("max_tiles_h", getattr(self.image_processor, "max_tiles_h", 4))
|
| 290 |
+
max_tiles_w = kwargs.get("max_tiles_w", getattr(self.image_processor, "max_tiles_w", 4))
|
| 291 |
+
min_tiles = kwargs.get("min_tiles", getattr(self.image_processor, "min_tiles", 1))
|
| 292 |
+
|
| 293 |
+
num_image_tokens = []
|
| 294 |
+
num_image_patches = []
|
| 295 |
+
|
| 296 |
+
for height, width in image_sizes:
|
| 297 |
+
# Calculate optimal tile grid for this image
|
| 298 |
+
grid_h, grid_w = calculate_tile_grid(
|
| 299 |
+
image_height=height,
|
| 300 |
+
image_width=width,
|
| 301 |
+
tile_size=tile_size,
|
| 302 |
+
max_tiles_h=max_tiles_h,
|
| 303 |
+
max_tiles_w=max_tiles_w,
|
| 304 |
+
min_tiles=min_tiles,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Calculate token count for this grid
|
| 308 |
+
tokens = self.get_num_image_tokens(grid_h, grid_w)
|
| 309 |
+
num_image_tokens.append(tokens)
|
| 310 |
+
|
| 311 |
+
# Number of patches = number of tiles
|
| 312 |
+
num_image_patches.append(grid_h * grid_w)
|
| 313 |
+
|
| 314 |
+
vision_data.update({
|
| 315 |
+
"num_image_tokens": num_image_tokens,
|
| 316 |
+
"num_image_patches": num_image_patches,
|
| 317 |
+
})
|
| 318 |
+
|
| 319 |
+
return MultiModalData(**vision_data)
|
| 320 |
|
| 321 |
|
| 322 |
__all__ = ["Gemma3TiledProcessor", "Gemma3TiledProcessorKwargs"]
|