Fraser commited on
Commit
ff23816
·
verified ·
1 Parent(s): 50deb8f

Update processing_gemma3_tiled.py

Browse files
Files changed (1) hide show
  1. processing_gemma3_tiled.py +106 -7
processing_gemma3_tiled.py CHANGED
@@ -13,10 +13,52 @@ import numpy as np
13
 
14
  from transformers.feature_extraction_utils import BatchFeature
15
  from transformers.image_utils import ImageInput, make_nested_list_of_images
16
- from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, ImagesKwargs
17
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class Gemma3TiledImagesKwargs(ImagesKwargs):
21
  tile_size: Optional[int]
22
  max_tiles_h: Optional[int]
@@ -54,6 +96,7 @@ class Gemma3TiledProcessor(ProcessorMixin):
54
  attributes = ["image_processor", "tokenizer"]
55
  image_processor_class = "AutoImageProcessor" # Use AutoImageProcessor for compatibility
56
  tokenizer_class = "AutoTokenizer"
 
57
 
58
  def __init__(
59
  self,
@@ -99,20 +142,24 @@ class Gemma3TiledProcessor(ProcessorMixin):
99
  def build_image_token_sequence(self, grid_h: int, grid_w: int) -> str:
100
  """
101
  Build the image token sequence for a tiled image.
102
-
103
  Returns a string like:
104
- \n\n<boi><img>×(16*grid_w)<img>×(16*grid_w)...(×16*grid_h rows)...<eoi>\n\n
105
-
106
  Note: We use <img> tokens for BOTH actual image positions AND linebreak positions.
107
  The model will replace them with the appropriate embeddings.
 
 
 
 
108
  """
109
  rows = grid_h * self.tokens_per_tile_side
110
  cols = grid_w * self.tokens_per_tile_side
111
-
112
  total_tokens = self.get_num_image_tokens(grid_h, grid_w)
113
  image_tokens = self.image_token * total_tokens
114
-
115
- return f"\n\n{self.boi_token}{image_tokens}{self.eoi_token}\n\n"
116
 
117
  def __call__(
118
  self,
@@ -218,6 +265,58 @@ class Gemma3TiledProcessor(ProcessorMixin):
218
  tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
219
  image_processor_input_names = self.image_processor.model_input_names
220
  return list(set(tokenizer_input_names + image_processor_input_names))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
 
223
  __all__ = ["Gemma3TiledProcessor", "Gemma3TiledProcessorKwargs"]
 
13
 
14
  from transformers.feature_extraction_utils import BatchFeature
15
  from transformers.image_utils import ImageInput, make_nested_list_of_images
16
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, ImagesKwargs, MultiModalData
17
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
18
 
19
 
20
+ def calculate_tile_grid(
21
+ image_height: int,
22
+ image_width: int,
23
+ tile_size: int,
24
+ max_tiles_h: int,
25
+ max_tiles_w: int,
26
+ min_tiles: int = 1,
27
+ ) -> tuple[int, int]:
28
+ """
29
+ Calculate the optimal tile grid dimensions for an image.
30
+
31
+ The strategy is to:
32
+ 1. Maximize effective resolution (pixels preserved from original image)
33
+ 2. Minimize wasted canvas space as a tiebreaker
34
+ """
35
+ original_pixels = image_height * image_width
36
+
37
+ best_grid = (1, 1)
38
+ best_score = float('-inf')
39
+
40
+ for rows in range(1, max_tiles_h + 1):
41
+ for cols in range(1, max_tiles_w + 1):
42
+ total_tiles = rows * cols
43
+
44
+ if total_tiles < min_tiles:
45
+ continue
46
+
47
+ canvas_h = rows * tile_size
48
+ canvas_w = cols * tile_size
49
+
50
+ scale = min(canvas_w / image_width, canvas_h / image_height)
51
+ effective = min(image_height * image_width * scale * scale, original_pixels)
52
+ waste = (canvas_h * canvas_w) - effective
53
+ score = effective - 0.001 * waste
54
+
55
+ if score > best_score:
56
+ best_score = score
57
+ best_grid = (rows, cols)
58
+
59
+ return best_grid
60
+
61
+
62
  class Gemma3TiledImagesKwargs(ImagesKwargs):
63
  tile_size: Optional[int]
64
  max_tiles_h: Optional[int]
 
96
  attributes = ["image_processor", "tokenizer"]
97
  image_processor_class = "AutoImageProcessor" # Use AutoImageProcessor for compatibility
98
  tokenizer_class = "AutoTokenizer"
99
+ _auto_class = "AutoProcessor" # Required for auto_map in processor_config.json
100
 
101
  def __init__(
102
  self,
 
142
  def build_image_token_sequence(self, grid_h: int, grid_w: int) -> str:
143
  """
144
  Build the image token sequence for a tiled image.
145
+
146
  Returns a string like:
147
+ \n\n<boi><img>×(16*grid_w)<img>×(16*grid_w)...(×16*grid_h rows)...<eoi>
148
+
149
  Note: We use <img> tokens for BOTH actual image positions AND linebreak positions.
150
  The model will replace them with the appropriate embeddings.
151
+
152
+ IMPORTANT: We do NOT add trailing \n\n because when followed by text content
153
+ that starts with \n, it would create \n\n\n which tokenizes differently and
154
+ breaks vLLM's placeholder pattern matching.
155
  """
156
  rows = grid_h * self.tokens_per_tile_side
157
  cols = grid_w * self.tokens_per_tile_side
158
+
159
  total_tokens = self.get_num_image_tokens(grid_h, grid_w)
160
  image_tokens = self.image_token * total_tokens
161
+
162
+ return f"\n\n{self.boi_token}{image_tokens}{self.eoi_token}"
163
 
164
  def __call__(
165
  self,
 
265
  tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
266
  image_processor_input_names = self.image_processor.model_input_names
267
  return list(set(tokenizer_input_names + image_processor_input_names))
268
+
269
+ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
270
+ """
271
+ Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
272
+
273
+ This is required by vLLM for memory profiling and scheduling.
274
+
275
+ Args:
276
+ image_sizes (`list[list[int]]`, *optional*):
277
+ The input sizes formatted as (height, width) per each image.
278
+ **kwargs: Additional arguments (tile_size, max_tiles_h, max_tiles_w, min_tiles)
279
+ that override image processor defaults.
280
+
281
+ Returns:
282
+ `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
283
+ input modalities, along with other useful data.
284
+ """
285
+ vision_data = {}
286
+ if image_sizes is not None:
287
+ # Get tiling parameters from kwargs or fall back to image processor settings
288
+ tile_size = kwargs.get("tile_size", getattr(self.image_processor, "tile_size", 896))
289
+ max_tiles_h = kwargs.get("max_tiles_h", getattr(self.image_processor, "max_tiles_h", 4))
290
+ max_tiles_w = kwargs.get("max_tiles_w", getattr(self.image_processor, "max_tiles_w", 4))
291
+ min_tiles = kwargs.get("min_tiles", getattr(self.image_processor, "min_tiles", 1))
292
+
293
+ num_image_tokens = []
294
+ num_image_patches = []
295
+
296
+ for height, width in image_sizes:
297
+ # Calculate optimal tile grid for this image
298
+ grid_h, grid_w = calculate_tile_grid(
299
+ image_height=height,
300
+ image_width=width,
301
+ tile_size=tile_size,
302
+ max_tiles_h=max_tiles_h,
303
+ max_tiles_w=max_tiles_w,
304
+ min_tiles=min_tiles,
305
+ )
306
+
307
+ # Calculate token count for this grid
308
+ tokens = self.get_num_image_tokens(grid_h, grid_w)
309
+ num_image_tokens.append(tokens)
310
+
311
+ # Number of patches = number of tiles
312
+ num_image_patches.append(grid_h * grid_w)
313
+
314
+ vision_data.update({
315
+ "num_image_tokens": num_image_tokens,
316
+ "num_image_patches": num_image_patches,
317
+ })
318
+
319
+ return MultiModalData(**vision_data)
320
 
321
 
322
  __all__ = ["Gemma3TiledProcessor", "Gemma3TiledProcessorKwargs"]