Fraser commited on
Commit
c9a2bf3
·
verified ·
1 Parent(s): 4f71f8f

Fix grid shape ordering for multi-image inputs

Browse files
Files changed (1) hide show
  1. processing_gemma3_tiled.py +33 -21
processing_gemma3_tiled.py CHANGED
@@ -8,6 +8,7 @@ based on the tile grid dimensions.
8
  import re
9
  from typing import Optional, Union
10
 
 
11
  import numpy as np
12
 
13
  from transformers.feature_extraction_utils import BatchFeature
@@ -51,7 +52,7 @@ class Gemma3TiledProcessor(ProcessorMixin):
51
  """
52
 
53
  attributes = ["image_processor", "tokenizer"]
54
- image_processor_class = "image_processing_gemma3_tiled.Gemma3TiledImageProcessor"
55
  tokenizer_class = "AutoTokenizer"
56
 
57
  def __init__(
@@ -144,8 +145,8 @@ class Gemma3TiledProcessor(ProcessorMixin):
144
  # Process images to get tiles
145
  image_inputs = self.image_processor(images_fetched, **output_kwargs["images_kwargs"])
146
 
147
- # Get grid shapes for each image
148
- tile_grid_shapes = image_inputs.get("tile_grid_shape", [])
149
 
150
  # Create empty text to be replaced with placeholders
151
  if not text:
@@ -158,11 +159,12 @@ class Gemma3TiledProcessor(ProcessorMixin):
158
 
159
  # Build flat list of grid shapes across all batches
160
  all_grid_shapes = []
 
161
  for imgs in batched_images:
162
  for _ in imgs:
163
- if tile_grid_shapes:
164
- all_grid_shapes.append(tile_grid_shapes.pop(0))
165
- else:
166
  # Fallback to 1x1 grid
167
  all_grid_shapes.append((1, 1))
168
 
@@ -170,36 +172,46 @@ class Gemma3TiledProcessor(ProcessorMixin):
170
  grid_shape_idx = 0
171
  for batch_idx, (prompt, imgs) in enumerate(zip(text, batched_images)):
172
  image_indexes = [m.start() for m in re.finditer(re.escape(self.boi_token), prompt)]
173
-
174
  if len(imgs) != len(image_indexes):
175
  raise ValueError(
176
  f"Prompt contained {len(image_indexes)} image tokens but received {len(imgs)} images."
177
  )
178
-
 
 
 
 
179
  # Replace each BOI token with the full image sequence
180
- for idx in reversed(image_indexes):
181
- grid_h, grid_w = all_grid_shapes[grid_shape_idx]
182
- grid_shape_idx += 1
183
-
184
  image_sequence = self.build_image_token_sequence(grid_h, grid_w)
185
  prompt = prompt[:idx] + image_sequence + prompt[idx + len(self.boi_token):]
186
-
187
  text[batch_idx] = prompt
188
 
189
  return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
190
  return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
191
 
192
- text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
 
193
 
194
  # Add token type ids (1 for image tokens, 0 for text)
195
  if return_mm_token_type_ids:
196
- array_ids = np.array(text_inputs["input_ids"])
197
- mm_token_type_ids = np.zeros_like(array_ids)
198
- mm_token_type_ids[array_ids == self.image_token_id] = 1
199
- text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
200
-
201
- # Combine outputs
202
- return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
 
 
 
 
 
 
 
203
 
204
  @property
205
  def model_input_names(self):
 
8
  import re
9
  from typing import Optional, Union
10
 
11
+ import torch
12
  import numpy as np
13
 
14
  from transformers.feature_extraction_utils import BatchFeature
 
52
  """
53
 
54
  attributes = ["image_processor", "tokenizer"]
55
+ image_processor_class = "AutoImageProcessor" # Use AutoImageProcessor for compatibility
56
  tokenizer_class = "AutoTokenizer"
57
 
58
  def __init__(
 
145
  # Process images to get tiles
146
  image_inputs = self.image_processor(images_fetched, **output_kwargs["images_kwargs"])
147
 
148
+ # Get grid shapes for each image (make a copy to avoid mutating)
149
+ tile_grid_shapes = list(image_inputs.get("tile_grid_shape", []))
150
 
151
  # Create empty text to be replaced with placeholders
152
  if not text:
 
159
 
160
  # Build flat list of grid shapes across all batches
161
  all_grid_shapes = []
162
+ grid_shape_iter = iter(tile_grid_shapes)
163
  for imgs in batched_images:
164
  for _ in imgs:
165
+ try:
166
+ all_grid_shapes.append(next(grid_shape_iter))
167
+ except StopIteration:
168
  # Fallback to 1x1 grid
169
  all_grid_shapes.append((1, 1))
170
 
 
172
  grid_shape_idx = 0
173
  for batch_idx, (prompt, imgs) in enumerate(zip(text, batched_images)):
174
  image_indexes = [m.start() for m in re.finditer(re.escape(self.boi_token), prompt)]
175
+
176
  if len(imgs) != len(image_indexes):
177
  raise ValueError(
178
  f"Prompt contained {len(image_indexes)} image tokens but received {len(imgs)} images."
179
  )
180
+
181
+ # Get grid shapes for this batch's images (in order)
182
+ batch_grid_shapes = all_grid_shapes[grid_shape_idx:grid_shape_idx + len(imgs)]
183
+ grid_shape_idx += len(imgs)
184
+
185
  # Replace each BOI token with the full image sequence
186
+ # Iterate in reverse to avoid shifting string indices, but also reverse grid shapes to match
187
+ for idx, (grid_h, grid_w) in zip(reversed(image_indexes), reversed(batch_grid_shapes)):
 
 
188
  image_sequence = self.build_image_token_sequence(grid_h, grid_w)
189
  prompt = prompt[:idx] + image_sequence + prompt[idx + len(self.boi_token):]
190
+
191
  text[batch_idx] = prompt
192
 
193
  return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
194
  return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
195
 
196
+ # Get text inputs - let tokenizer handle tensor conversion for text
197
+ text_inputs = self.tokenizer(text=text, return_tensors=return_tensors, **output_kwargs["text_kwargs"])
198
 
199
  # Add token type ids (1 for image tokens, 0 for text)
200
  if return_mm_token_type_ids:
201
+ if return_tensors == "pt":
202
+ input_ids = text_inputs["input_ids"]
203
+ mm_token_type_ids = torch.zeros_like(input_ids)
204
+ mm_token_type_ids[input_ids == self.image_token_id] = 1
205
+ text_inputs["token_type_ids"] = mm_token_type_ids
206
+ else:
207
+ array_ids = np.array(text_inputs["input_ids"])
208
+ mm_token_type_ids = np.zeros_like(array_ids)
209
+ mm_token_type_ids[array_ids == self.image_token_id] = 1
210
+ text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
211
+
212
+ # Combine outputs - DON'T pass tensor_type here because pixel_values
213
+ # has inhomogeneous shapes (different tile counts per image)
214
+ return BatchFeature(data={**text_inputs, **image_inputs})
215
 
216
  @property
217
  def model_input_names(self):