Update image_processing_gemma3_tiled.py
Browse files
image_processing_gemma3_tiled.py
CHANGED
|
@@ -15,7 +15,7 @@ from PIL import Image
|
|
| 15 |
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
| 16 |
from transformers.image_utils import (
|
| 17 |
ImageInput,
|
| 18 |
-
|
| 19 |
valid_images,
|
| 20 |
infer_channel_dimension_format,
|
| 21 |
to_numpy_array,
|
|
@@ -167,7 +167,7 @@ class Gemma3TiledImageProcessor(BaseImageProcessor):
|
|
| 167 |
3. Returns pixel_values and tile_grid_shape metadata
|
| 168 |
"""
|
| 169 |
|
| 170 |
-
model_input_names = ["pixel_values", "tile_grid_shape"]
|
| 171 |
|
| 172 |
def __init__(
|
| 173 |
self,
|
|
@@ -236,7 +236,7 @@ class Gemma3TiledImageProcessor(BaseImageProcessor):
|
|
| 236 |
image_std = image_std if image_std is not None else self.image_std
|
| 237 |
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 238 |
|
| 239 |
-
images =
|
| 240 |
|
| 241 |
if not valid_images(images):
|
| 242 |
raise ValueError("Invalid image input")
|
|
@@ -284,12 +284,18 @@ class Gemma3TiledImageProcessor(BaseImageProcessor):
|
|
| 284 |
all_pixel_values.append(tiles)
|
| 285 |
all_grid_shapes.append((grid_h, grid_w))
|
| 286 |
|
|
|
|
|
|
|
|
|
|
| 287 |
data = {
|
| 288 |
"pixel_values": all_pixel_values,
|
| 289 |
"tile_grid_shape": all_grid_shapes,
|
|
|
|
| 290 |
}
|
| 291 |
|
| 292 |
-
|
|
|
|
|
|
|
| 293 |
|
| 294 |
|
| 295 |
__all__ = ["Gemma3TiledImageProcessor", "calculate_tile_grid", "tile_image"]
|
|
|
|
| 15 |
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
| 16 |
from transformers.image_utils import (
|
| 17 |
ImageInput,
|
| 18 |
+
make_flat_list_of_images,
|
| 19 |
valid_images,
|
| 20 |
infer_channel_dimension_format,
|
| 21 |
to_numpy_array,
|
|
|
|
| 167 |
3. Returns pixel_values and tile_grid_shape metadata
|
| 168 |
"""
|
| 169 |
|
| 170 |
+
model_input_names = ["pixel_values", "tile_grid_shape", "num_crops"]
|
| 171 |
|
| 172 |
def __init__(
|
| 173 |
self,
|
|
|
|
| 236 |
image_std = image_std if image_std is not None else self.image_std
|
| 237 |
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 238 |
|
| 239 |
+
images = make_flat_list_of_images(images)
|
| 240 |
|
| 241 |
if not valid_images(images):
|
| 242 |
raise ValueError("Invalid image input")
|
|
|
|
| 284 |
all_pixel_values.append(tiles)
|
| 285 |
all_grid_shapes.append((grid_h, grid_w))
|
| 286 |
|
| 287 |
+
# num_crops is 0 for each image since we use tiling, not pan-and-scan
|
| 288 |
+
num_crops = [0] * len(all_pixel_values)
|
| 289 |
+
|
| 290 |
data = {
|
| 291 |
"pixel_values": all_pixel_values,
|
| 292 |
"tile_grid_shape": all_grid_shapes,
|
| 293 |
+
"num_crops": num_crops,
|
| 294 |
}
|
| 295 |
|
| 296 |
+
# Don't convert to tensors here - pixel_values have inhomogeneous shapes
|
| 297 |
+
# (different images have different tile counts). Let the model handle it.
|
| 298 |
+
return BatchFeature(data=data, tensor_type=None)
|
| 299 |
|
| 300 |
|
| 301 |
__all__ = ["Gemma3TiledImageProcessor", "calculate_tile_grid", "tile_image"]
|