File size: 11,312 Bytes
713bd37 f2661ac 713bd37 9beae14 713bd37 f2661ac 2a55418 713bd37 f2661ac 713bd37 f2661ac 2a55418 713bd37 2a55418 713bd37 f2661ac 713bd37 2a55418 713bd37 4f71f8f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 | """
Image processor for Gemma3Tiled that tiles images into grids.
Instead of resizing images to a fixed size or using pan-and-scan crops,
this processor tiles the image into a grid of 896x896 patches that
preserves the spatial layout.
"""
import math
from typing import Optional, Union
import numpy as np
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import (
ImageInput,
make_flat_list_of_images,
valid_images,
infer_channel_dimension_format,
to_numpy_array,
ChannelDimension,
)
from transformers.utils import TensorType
def calculate_tile_grid(
image_height: int,
image_width: int,
tile_size: int,
max_tiles_h: int,
max_tiles_w: int,
min_tiles: int = 1,
) -> tuple[int, int]:
"""
Calculate the optimal tile grid dimensions for an image.
The strategy is to:
1. Maximize effective resolution (pixels preserved from original image)
2. Minimize wasted canvas space as a tiebreaker
Upscaling is not credited - effective resolution is capped at original image size.
This means larger grids are only chosen if they preserve more original detail.
Example: For a 1344x912 image with 896x896 tiles:
| Canvas | Scale | Effective | Wasted |
|-------------|-------|-------------|-----------|
| 1×1 (896²) | 0.667 | 544,768 | 258,048 |
| 1×2 | 0.982 | 1,182,720 | 422,912 |
| 2×1 | 0.667 | 544,768 | 1,060,864 |
| 2×2 | 1.333 | 1,225,728 ✓ | 1,985,536 |
Winner: 2×2 (highest effective resolution = 100% of original pixels)
Args:
image_height: Original image height
image_width: Original image width
tile_size: Size of each tile (896)
max_tiles_h: Maximum tiles in height
max_tiles_w: Maximum tiles in width
min_tiles: Minimum total tiles
Returns:
(grid_h, grid_w): Number of tiles in height and width
"""
original_pixels = image_height * image_width
best_grid = (1, 1)
best_score = float('-inf')
# Search all valid grid configurations
for rows in range(1, max_tiles_h + 1):
for cols in range(1, max_tiles_w + 1):
total_tiles = rows * cols
# Skip if below minimum tiles
if total_tiles < min_tiles:
continue
# Calculate canvas size for this grid
canvas_h = rows * tile_size
canvas_w = cols * tile_size
# Scale factor to fit image in canvas (aspect-ratio preserving)
scale = min(canvas_w / image_width, canvas_h / image_height)
# Effective resolution: how many original pixels are preserved
# Don't credit upscaling - cap at original size
effective = min(image_height * image_width * scale * scale, original_pixels)
# Wasted pixels = canvas area - effective area
waste = (canvas_h * canvas_w) - effective
# Score: maximize effective resolution, minimize waste as tiebreaker
score = effective - 0.001 * waste
if score > best_score:
best_score = score
best_grid = (rows, cols)
return best_grid
def tile_image(
image: np.ndarray,
tile_size: int,
grid_h: int,
grid_w: int,
resample: Image.Resampling = Image.Resampling.BICUBIC,
) -> np.ndarray:
"""
Tile an image into a grid of fixed-size patches.
The image is first resized so that when divided into grid_h x grid_w tiles,
each tile is exactly tile_size x tile_size.
Args:
image: Input image as numpy array (H, W, C) or (C, H, W)
tile_size: Size of each tile
grid_h: Number of tiles in height
grid_w: Number of tiles in width
resample: PIL resampling method
Returns:
Tiled image array of shape (grid_h * grid_w, C, tile_size, tile_size)
"""
# Determine channel dimension
if image.ndim == 3:
if image.shape[0] in [1, 3, 4]: # Likely (C, H, W)
image = np.transpose(image, (1, 2, 0)) # -> (H, W, C)
# Convert to uint8 for PIL, handling both [0-255] and [0-1] ranges
if np.issubdtype(image.dtype, np.floating) and image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
pil_image = Image.fromarray(image)
# Calculate target size for the full grid
target_h = grid_h * tile_size
target_w = grid_w * tile_size
# Resize image to target size
pil_image = pil_image.resize((target_w, target_h), resample=resample)
# Convert back to numpy
image = np.array(pil_image)
# Split into tiles
# image shape: (target_h, target_w, C)
tiles = []
for i in range(grid_h):
for j in range(grid_w):
y_start = i * tile_size
x_start = j * tile_size
tile = image[y_start:y_start + tile_size, x_start:x_start + tile_size]
# Convert to (C, H, W)
tile = np.transpose(tile, (2, 0, 1))
tiles.append(tile)
return np.stack(tiles, axis=0) # (num_tiles, C, tile_size, tile_size)
class Gemma3TiledImageProcessor(BaseImageProcessor):
"""
Image processor for Gemma3Tiled that tiles images into grids.
This processor:
1. Calculates the optimal tile grid for each image
2. Resizes and tiles the image
3. Returns pixel_values and tile_grid_shape metadata
"""
model_input_names = ["pixel_values", "tile_grid_shape", "num_crops"]
_auto_class = "AutoImageProcessor" # Required for auto_map in preprocessor_config.json
def __init__(
self,
tile_size: int = 896,
max_tiles_h: int = 4,
max_tiles_w: int = 4,
min_tiles: int = 1,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[list[float]] = None,
image_std: Optional[list[float]] = None,
do_convert_rgb: bool = True,
resample: Image.Resampling = Image.Resampling.BICUBIC,
**kwargs,
):
super().__init__(**kwargs)
self.tile_size = tile_size
self.max_tiles_h = max_tiles_h
self.max_tiles_w = max_tiles_w
self.min_tiles = min_tiles
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5]
self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5]
self.do_convert_rgb = do_convert_rgb
self.resample = resample
def preprocess(
self,
images: ImageInput,
tile_size: Optional[int] = None,
max_tiles_h: Optional[int] = None,
max_tiles_w: Optional[int] = None,
min_tiles: Optional[int] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[list[float]] = None,
image_std: Optional[list[float]] = None,
do_convert_rgb: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature:
"""
Preprocess images by tiling them into grids.
Args:
images: Single image or batch of images
Returns:
BatchFeature with:
- pixel_values: List of [num_tiles, C, H, W] arrays (one per image)
- tile_grid_shape: List of (grid_h, grid_w) tuples
"""
tile_size = tile_size if tile_size is not None else self.tile_size
max_tiles_h = max_tiles_h if max_tiles_h is not None else self.max_tiles_h
max_tiles_w = max_tiles_w if max_tiles_w is not None else self.max_tiles_w
min_tiles = min_tiles if min_tiles is not None else self.min_tiles
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
images = make_flat_list_of_images(images)
if not valid_images(images):
raise ValueError("Invalid image input")
all_pixel_values = []
all_grid_shapes = []
for image in images:
# Convert to numpy
image = to_numpy_array(image)
# Convert to RGB if needed
if do_convert_rgb and image.shape[-1] == 4:
image = image[..., :3]
# Get image dimensions (assume H, W, C after to_numpy_array)
if image.ndim == 3:
if image.shape[0] in [1, 3, 4]: # (C, H, W)
h, w = image.shape[1], image.shape[2]
else: # (H, W, C)
h, w = image.shape[0], image.shape[1]
else:
raise ValueError(f"Expected 3D image, got shape {image.shape}")
# Calculate grid dimensions
grid_h, grid_w = calculate_tile_grid(
h, w, tile_size, max_tiles_h, max_tiles_w, min_tiles
)
# Tile the image
tiles = tile_image(
image, tile_size, grid_h, grid_w, resample=self.resample
)
# Rescale
if do_rescale:
tiles = tiles.astype(np.float32) * rescale_factor
# Normalize
if do_normalize:
mean = np.array(image_mean, dtype=np.float32).reshape(1, 3, 1, 1)
std = np.array(image_std, dtype=np.float32).reshape(1, 3, 1, 1)
tiles = (tiles - mean) / std
all_pixel_values.append(tiles)
all_grid_shapes.append((grid_h, grid_w))
# num_crops is 0 for each image since we use tiling, not pan-and-scan
num_crops = [0] * len(all_pixel_values)
# Concatenate all tiles into a single array for vLLM compatibility
# vLLM's flat_from_sizes expects a single tensor, not a list
if len(all_pixel_values) > 0:
concatenated_pixels = np.concatenate(all_pixel_values, axis=0)
else:
concatenated_pixels = np.array([])
data = {
"pixel_values": concatenated_pixels,
"tile_grid_shape": all_grid_shapes,
"num_crops": num_crops,
}
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["Gemma3TiledImageProcessor", "calculate_tile_grid", "tile_image"]
|