Fix grid shape ordering for multi-image inputs
Browse files- processing_gemma3_tiled.py +33 -21
processing_gemma3_tiled.py
CHANGED
|
@@ -8,6 +8,7 @@ based on the tile grid dimensions.
|
|
| 8 |
import re
|
| 9 |
from typing import Optional, Union
|
| 10 |
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
|
| 13 |
from transformers.feature_extraction_utils import BatchFeature
|
|
@@ -51,7 +52,7 @@ class Gemma3TiledProcessor(ProcessorMixin):
|
|
| 51 |
"""
|
| 52 |
|
| 53 |
attributes = ["image_processor", "tokenizer"]
|
| 54 |
-
image_processor_class = "
|
| 55 |
tokenizer_class = "AutoTokenizer"
|
| 56 |
|
| 57 |
def __init__(
|
|
@@ -144,8 +145,8 @@ class Gemma3TiledProcessor(ProcessorMixin):
|
|
| 144 |
# Process images to get tiles
|
| 145 |
image_inputs = self.image_processor(images_fetched, **output_kwargs["images_kwargs"])
|
| 146 |
|
| 147 |
-
# Get grid shapes for each image
|
| 148 |
-
tile_grid_shapes = image_inputs.get("tile_grid_shape", [])
|
| 149 |
|
| 150 |
# Create empty text to be replaced with placeholders
|
| 151 |
if not text:
|
|
@@ -158,11 +159,12 @@ class Gemma3TiledProcessor(ProcessorMixin):
|
|
| 158 |
|
| 159 |
# Build flat list of grid shapes across all batches
|
| 160 |
all_grid_shapes = []
|
|
|
|
| 161 |
for imgs in batched_images:
|
| 162 |
for _ in imgs:
|
| 163 |
-
|
| 164 |
-
all_grid_shapes.append(
|
| 165 |
-
|
| 166 |
# Fallback to 1x1 grid
|
| 167 |
all_grid_shapes.append((1, 1))
|
| 168 |
|
|
@@ -170,36 +172,46 @@ class Gemma3TiledProcessor(ProcessorMixin):
|
|
| 170 |
grid_shape_idx = 0
|
| 171 |
for batch_idx, (prompt, imgs) in enumerate(zip(text, batched_images)):
|
| 172 |
image_indexes = [m.start() for m in re.finditer(re.escape(self.boi_token), prompt)]
|
| 173 |
-
|
| 174 |
if len(imgs) != len(image_indexes):
|
| 175 |
raise ValueError(
|
| 176 |
f"Prompt contained {len(image_indexes)} image tokens but received {len(imgs)} images."
|
| 177 |
)
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
# Replace each BOI token with the full image sequence
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
grid_shape_idx += 1
|
| 183 |
-
|
| 184 |
image_sequence = self.build_image_token_sequence(grid_h, grid_w)
|
| 185 |
prompt = prompt[:idx] + image_sequence + prompt[idx + len(self.boi_token):]
|
| 186 |
-
|
| 187 |
text[batch_idx] = prompt
|
| 188 |
|
| 189 |
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 190 |
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
| 191 |
|
| 192 |
-
|
|
|
|
| 193 |
|
| 194 |
# Add token type ids (1 for image tokens, 0 for text)
|
| 195 |
if return_mm_token_type_ids:
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
@property
|
| 205 |
def model_input_names(self):
|
|
|
|
| 8 |
import re
|
| 9 |
from typing import Optional, Union
|
| 10 |
|
| 11 |
+
import torch
|
| 12 |
import numpy as np
|
| 13 |
|
| 14 |
from transformers.feature_extraction_utils import BatchFeature
|
|
|
|
| 52 |
"""
|
| 53 |
|
| 54 |
attributes = ["image_processor", "tokenizer"]
|
| 55 |
+
image_processor_class = "AutoImageProcessor" # Use AutoImageProcessor for compatibility
|
| 56 |
tokenizer_class = "AutoTokenizer"
|
| 57 |
|
| 58 |
def __init__(
|
|
|
|
| 145 |
# Process images to get tiles
|
| 146 |
image_inputs = self.image_processor(images_fetched, **output_kwargs["images_kwargs"])
|
| 147 |
|
| 148 |
+
# Get grid shapes for each image (make a copy to avoid mutating)
|
| 149 |
+
tile_grid_shapes = list(image_inputs.get("tile_grid_shape", []))
|
| 150 |
|
| 151 |
# Create empty text to be replaced with placeholders
|
| 152 |
if not text:
|
|
|
|
| 159 |
|
| 160 |
# Build flat list of grid shapes across all batches
|
| 161 |
all_grid_shapes = []
|
| 162 |
+
grid_shape_iter = iter(tile_grid_shapes)
|
| 163 |
for imgs in batched_images:
|
| 164 |
for _ in imgs:
|
| 165 |
+
try:
|
| 166 |
+
all_grid_shapes.append(next(grid_shape_iter))
|
| 167 |
+
except StopIteration:
|
| 168 |
# Fallback to 1x1 grid
|
| 169 |
all_grid_shapes.append((1, 1))
|
| 170 |
|
|
|
|
| 172 |
grid_shape_idx = 0
|
| 173 |
for batch_idx, (prompt, imgs) in enumerate(zip(text, batched_images)):
|
| 174 |
image_indexes = [m.start() for m in re.finditer(re.escape(self.boi_token), prompt)]
|
| 175 |
+
|
| 176 |
if len(imgs) != len(image_indexes):
|
| 177 |
raise ValueError(
|
| 178 |
f"Prompt contained {len(image_indexes)} image tokens but received {len(imgs)} images."
|
| 179 |
)
|
| 180 |
+
|
| 181 |
+
# Get grid shapes for this batch's images (in order)
|
| 182 |
+
batch_grid_shapes = all_grid_shapes[grid_shape_idx:grid_shape_idx + len(imgs)]
|
| 183 |
+
grid_shape_idx += len(imgs)
|
| 184 |
+
|
| 185 |
# Replace each BOI token with the full image sequence
|
| 186 |
+
# Iterate in reverse to avoid shifting string indices, but also reverse grid shapes to match
|
| 187 |
+
for idx, (grid_h, grid_w) in zip(reversed(image_indexes), reversed(batch_grid_shapes)):
|
|
|
|
|
|
|
| 188 |
image_sequence = self.build_image_token_sequence(grid_h, grid_w)
|
| 189 |
prompt = prompt[:idx] + image_sequence + prompt[idx + len(self.boi_token):]
|
| 190 |
+
|
| 191 |
text[batch_idx] = prompt
|
| 192 |
|
| 193 |
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
| 194 |
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
| 195 |
|
| 196 |
+
# Get text inputs - let tokenizer handle tensor conversion for text
|
| 197 |
+
text_inputs = self.tokenizer(text=text, return_tensors=return_tensors, **output_kwargs["text_kwargs"])
|
| 198 |
|
| 199 |
# Add token type ids (1 for image tokens, 0 for text)
|
| 200 |
if return_mm_token_type_ids:
|
| 201 |
+
if return_tensors == "pt":
|
| 202 |
+
input_ids = text_inputs["input_ids"]
|
| 203 |
+
mm_token_type_ids = torch.zeros_like(input_ids)
|
| 204 |
+
mm_token_type_ids[input_ids == self.image_token_id] = 1
|
| 205 |
+
text_inputs["token_type_ids"] = mm_token_type_ids
|
| 206 |
+
else:
|
| 207 |
+
array_ids = np.array(text_inputs["input_ids"])
|
| 208 |
+
mm_token_type_ids = np.zeros_like(array_ids)
|
| 209 |
+
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
| 210 |
+
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
|
| 211 |
+
|
| 212 |
+
# Combine outputs - DON'T pass tensor_type here because pixel_values
|
| 213 |
+
# has inhomogeneous shapes (different tile counts per image)
|
| 214 |
+
return BatchFeature(data={**text_inputs, **image_inputs})
|
| 215 |
|
| 216 |
@property
|
| 217 |
def model_input_names(self):
|