gemma-3-tiled-4b-it / image_processing_gemma3_tiled.py
Fraser's picture
Update image_processing_gemma3_tiled.py
2a55418 verified
"""
Image processor for Gemma3Tiled that tiles images into grids.
Instead of resizing images to a fixed size or using pan-and-scan crops,
this processor tiles the image into a grid of 896x896 patches that
preserves the spatial layout.
"""
import math
from typing import Optional, Union
import numpy as np
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import (
ImageInput,
make_flat_list_of_images,
valid_images,
infer_channel_dimension_format,
to_numpy_array,
ChannelDimension,
)
from transformers.utils import TensorType
def calculate_tile_grid(
image_height: int,
image_width: int,
tile_size: int,
max_tiles_h: int,
max_tiles_w: int,
min_tiles: int = 1,
) -> tuple[int, int]:
"""
Calculate the optimal tile grid dimensions for an image.
The strategy is to:
1. Maximize effective resolution (pixels preserved from original image)
2. Minimize wasted canvas space as a tiebreaker
Upscaling is not credited - effective resolution is capped at original image size.
This means larger grids are only chosen if they preserve more original detail.
Example: For a 1344x912 image with 896x896 tiles:
| Canvas | Scale | Effective | Wasted |
|-------------|-------|-------------|-----------|
| 1×1 (896²) | 0.667 | 544,768 | 258,048 |
| 1×2 | 0.982 | 1,182,720 | 422,912 |
| 2×1 | 0.667 | 544,768 | 1,060,864 |
| 2×2 | 1.333 | 1,225,728 ✓ | 1,985,536 |
Winner: 2×2 (highest effective resolution = 100% of original pixels)
Args:
image_height: Original image height
image_width: Original image width
tile_size: Size of each tile (896)
max_tiles_h: Maximum tiles in height
max_tiles_w: Maximum tiles in width
min_tiles: Minimum total tiles
Returns:
(grid_h, grid_w): Number of tiles in height and width
"""
original_pixels = image_height * image_width
best_grid = (1, 1)
best_score = float('-inf')
# Search all valid grid configurations
for rows in range(1, max_tiles_h + 1):
for cols in range(1, max_tiles_w + 1):
total_tiles = rows * cols
# Skip if below minimum tiles
if total_tiles < min_tiles:
continue
# Calculate canvas size for this grid
canvas_h = rows * tile_size
canvas_w = cols * tile_size
# Scale factor to fit image in canvas (aspect-ratio preserving)
scale = min(canvas_w / image_width, canvas_h / image_height)
# Effective resolution: how many original pixels are preserved
# Don't credit upscaling - cap at original size
effective = min(image_height * image_width * scale * scale, original_pixels)
# Wasted pixels = canvas area - effective area
waste = (canvas_h * canvas_w) - effective
# Score: maximize effective resolution, minimize waste as tiebreaker
score = effective - 0.001 * waste
if score > best_score:
best_score = score
best_grid = (rows, cols)
return best_grid
def tile_image(
image: np.ndarray,
tile_size: int,
grid_h: int,
grid_w: int,
resample: Image.Resampling = Image.Resampling.BICUBIC,
) -> np.ndarray:
"""
Tile an image into a grid of fixed-size patches.
The image is first resized so that when divided into grid_h x grid_w tiles,
each tile is exactly tile_size x tile_size.
Args:
image: Input image as numpy array (H, W, C) or (C, H, W)
tile_size: Size of each tile
grid_h: Number of tiles in height
grid_w: Number of tiles in width
resample: PIL resampling method
Returns:
Tiled image array of shape (grid_h * grid_w, C, tile_size, tile_size)
"""
# Determine channel dimension
if image.ndim == 3:
if image.shape[0] in [1, 3, 4]: # Likely (C, H, W)
image = np.transpose(image, (1, 2, 0)) # -> (H, W, C)
# Convert to uint8 for PIL, handling both [0-255] and [0-1] ranges
if np.issubdtype(image.dtype, np.floating) and image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
pil_image = Image.fromarray(image)
# Calculate target size for the full grid
target_h = grid_h * tile_size
target_w = grid_w * tile_size
# Resize image to target size
pil_image = pil_image.resize((target_w, target_h), resample=resample)
# Convert back to numpy
image = np.array(pil_image)
# Split into tiles
# image shape: (target_h, target_w, C)
tiles = []
for i in range(grid_h):
for j in range(grid_w):
y_start = i * tile_size
x_start = j * tile_size
tile = image[y_start:y_start + tile_size, x_start:x_start + tile_size]
# Convert to (C, H, W)
tile = np.transpose(tile, (2, 0, 1))
tiles.append(tile)
return np.stack(tiles, axis=0) # (num_tiles, C, tile_size, tile_size)
class Gemma3TiledImageProcessor(BaseImageProcessor):
"""
Image processor for Gemma3Tiled that tiles images into grids.
This processor:
1. Calculates the optimal tile grid for each image
2. Resizes and tiles the image
3. Returns pixel_values and tile_grid_shape metadata
"""
model_input_names = ["pixel_values", "tile_grid_shape", "num_crops"]
_auto_class = "AutoImageProcessor" # Required for auto_map in preprocessor_config.json
def __init__(
self,
tile_size: int = 896,
max_tiles_h: int = 4,
max_tiles_w: int = 4,
min_tiles: int = 1,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[list[float]] = None,
image_std: Optional[list[float]] = None,
do_convert_rgb: bool = True,
resample: Image.Resampling = Image.Resampling.BICUBIC,
**kwargs,
):
super().__init__(**kwargs)
self.tile_size = tile_size
self.max_tiles_h = max_tiles_h
self.max_tiles_w = max_tiles_w
self.min_tiles = min_tiles
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
self.do_convert_rgb = do_convert_rgb
self.resample = resample
def preprocess(
self,
images: ImageInput,
tile_size: Optional[int] = None,
max_tiles_h: Optional[int] = None,
max_tiles_w: Optional[int] = None,
min_tiles: Optional[int] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[list[float]] = None,
image_std: Optional[list[float]] = None,
do_convert_rgb: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature:
"""
Preprocess images by tiling them into grids.
Args:
images: Single image or batch of images
Returns:
BatchFeature with:
- pixel_values: List of [num_tiles, C, H, W] arrays (one per image)
- tile_grid_shape: List of (grid_h, grid_w) tuples
"""
tile_size = tile_size if tile_size is not None else self.tile_size
max_tiles_h = max_tiles_h if max_tiles_h is not None else self.max_tiles_h
max_tiles_w = max_tiles_w if max_tiles_w is not None else self.max_tiles_w
min_tiles = min_tiles if min_tiles is not None else self.min_tiles
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
images = make_flat_list_of_images(images)
if not valid_images(images):
raise ValueError("Invalid image input")
all_pixel_values = []
all_grid_shapes = []
for image in images:
# Convert to numpy
image = to_numpy_array(image)
# Convert to RGB if needed
if do_convert_rgb and image.shape[-1] == 4:
image = image[..., :3]
# Get image dimensions (assume H, W, C after to_numpy_array)
if image.ndim == 3:
if image.shape[0] in [1, 3, 4]: # (C, H, W)
h, w = image.shape[1], image.shape[2]
else: # (H, W, C)
h, w = image.shape[0], image.shape[1]
else:
raise ValueError(f"Expected 3D image, got shape {image.shape}")
# Calculate grid dimensions
grid_h, grid_w = calculate_tile_grid(
h, w, tile_size, max_tiles_h, max_tiles_w, min_tiles
)
# Tile the image
tiles = tile_image(
image, tile_size, grid_h, grid_w, resample=self.resample
)
# Rescale
if do_rescale:
tiles = tiles.astype(np.float32) * rescale_factor
# Normalize
if do_normalize:
mean = np.array(image_mean, dtype=np.float32).reshape(1, 3, 1, 1)
std = np.array(image_std, dtype=np.float32).reshape(1, 3, 1, 1)
tiles = (tiles - mean) / std
all_pixel_values.append(tiles)
all_grid_shapes.append((grid_h, grid_w))
# num_crops is 0 for each image since we use tiling, not pan-and-scan
num_crops = [0] * len(all_pixel_values)
# Concatenate all tiles into a single array for vLLM compatibility
# vLLM's flat_from_sizes expects a single tensor, not a list
if len(all_pixel_values) > 0:
concatenated_pixels = np.concatenate(all_pixel_values, axis=0)
else:
concatenated_pixels = np.array([])
data = {
"pixel_values": concatenated_pixels,
"tile_grid_shape": all_grid_shapes,
"num_crops": num_crops,
}
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["Gemma3TiledImageProcessor", "calculate_tile_grid", "tile_image"]