""" Image processor for Gemma3Tiled that tiles images into grids. Instead of resizing images to a fixed size or using pan-and-scan crops, this processor tiles the image into a grid of 896x896 patches that preserves the spatial layout. """ import math from typing import Optional, Union import numpy as np from PIL import Image from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_utils import ( ImageInput, make_flat_list_of_images, valid_images, infer_channel_dimension_format, to_numpy_array, ChannelDimension, ) from transformers.utils import TensorType def calculate_tile_grid( image_height: int, image_width: int, tile_size: int, max_tiles_h: int, max_tiles_w: int, min_tiles: int = 1, ) -> tuple[int, int]: """ Calculate the optimal tile grid dimensions for an image. The strategy is to: 1. Maximize effective resolution (pixels preserved from original image) 2. Minimize wasted canvas space as a tiebreaker Upscaling is not credited - effective resolution is capped at original image size. This means larger grids are only chosen if they preserve more original detail. Example: For a 1344x912 image with 896x896 tiles: | Canvas | Scale | Effective | Wasted | |-------------|-------|-------------|-----------| | 1×1 (896²) | 0.667 | 544,768 | 258,048 | | 1×2 | 0.982 | 1,182,720 | 422,912 | | 2×1 | 0.667 | 544,768 | 1,060,864 | | 2×2 | 1.333 | 1,225,728 ✓ | 1,985,536 | Winner: 2×2 (highest effective resolution = 100% of original pixels) Args: image_height: Original image height image_width: Original image width tile_size: Size of each tile (896) max_tiles_h: Maximum tiles in height max_tiles_w: Maximum tiles in width min_tiles: Minimum total tiles Returns: (grid_h, grid_w): Number of tiles in height and width """ original_pixels = image_height * image_width best_grid = (1, 1) best_score = float('-inf') # Search all valid grid configurations for rows in range(1, max_tiles_h + 1): for cols in range(1, max_tiles_w + 1): total_tiles = rows * cols # Skip if below minimum tiles if total_tiles < min_tiles: continue # Calculate canvas size for this grid canvas_h = rows * tile_size canvas_w = cols * tile_size # Scale factor to fit image in canvas (aspect-ratio preserving) scale = min(canvas_w / image_width, canvas_h / image_height) # Effective resolution: how many original pixels are preserved # Don't credit upscaling - cap at original size effective = min(image_height * image_width * scale * scale, original_pixels) # Wasted pixels = canvas area - effective area waste = (canvas_h * canvas_w) - effective # Score: maximize effective resolution, minimize waste as tiebreaker score = effective - 0.001 * waste if score > best_score: best_score = score best_grid = (rows, cols) return best_grid def tile_image( image: np.ndarray, tile_size: int, grid_h: int, grid_w: int, resample: Image.Resampling = Image.Resampling.BICUBIC, ) -> np.ndarray: """ Tile an image into a grid of fixed-size patches. The image is first resized so that when divided into grid_h x grid_w tiles, each tile is exactly tile_size x tile_size. Args: image: Input image as numpy array (H, W, C) or (C, H, W) tile_size: Size of each tile grid_h: Number of tiles in height grid_w: Number of tiles in width resample: PIL resampling method Returns: Tiled image array of shape (grid_h * grid_w, C, tile_size, tile_size) """ # Determine channel dimension if image.ndim == 3: if image.shape[0] in [1, 3, 4]: # Likely (C, H, W) image = np.transpose(image, (1, 2, 0)) # -> (H, W, C) # Convert to uint8 for PIL, handling both [0-255] and [0-1] ranges if np.issubdtype(image.dtype, np.floating) and image.max() <= 1.0: image = (image * 255).astype(np.uint8) else: image = image.astype(np.uint8) pil_image = Image.fromarray(image) # Calculate target size for the full grid target_h = grid_h * tile_size target_w = grid_w * tile_size # Resize image to target size pil_image = pil_image.resize((target_w, target_h), resample=resample) # Convert back to numpy image = np.array(pil_image) # Split into tiles # image shape: (target_h, target_w, C) tiles = [] for i in range(grid_h): for j in range(grid_w): y_start = i * tile_size x_start = j * tile_size tile = image[y_start:y_start + tile_size, x_start:x_start + tile_size] # Convert to (C, H, W) tile = np.transpose(tile, (2, 0, 1)) tiles.append(tile) return np.stack(tiles, axis=0) # (num_tiles, C, tile_size, tile_size) class Gemma3TiledImageProcessor(BaseImageProcessor): """ Image processor for Gemma3Tiled that tiles images into grids. This processor: 1. Calculates the optimal tile grid for each image 2. Resizes and tiles the image 3. Returns pixel_values and tile_grid_shape metadata """ model_input_names = ["pixel_values", "tile_grid_shape", "num_crops"] _auto_class = "AutoImageProcessor" # Required for auto_map in preprocessor_config.json def __init__( self, tile_size: int = 896, max_tiles_h: int = 4, max_tiles_w: int = 4, min_tiles: int = 1, do_rescale: bool = True, rescale_factor: float = 1 / 255, do_normalize: bool = True, image_mean: Optional[list[float]] = None, image_std: Optional[list[float]] = None, do_convert_rgb: bool = True, resample: Image.Resampling = Image.Resampling.BICUBIC, **kwargs, ): super().__init__(**kwargs) self.tile_size = tile_size self.max_tiles_h = max_tiles_h self.max_tiles_w = max_tiles_w self.min_tiles = min_tiles self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5] self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5] self.do_convert_rgb = do_convert_rgb self.resample = resample def preprocess( self, images: ImageInput, tile_size: Optional[int] = None, max_tiles_h: Optional[int] = None, max_tiles_w: Optional[int] = None, min_tiles: Optional[int] = None, do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, do_normalize: Optional[bool] = None, image_mean: Optional[list[float]] = None, image_std: Optional[list[float]] = None, do_convert_rgb: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs, ) -> BatchFeature: """ Preprocess images by tiling them into grids. Args: images: Single image or batch of images Returns: BatchFeature with: - pixel_values: List of [num_tiles, C, H, W] arrays (one per image) - tile_grid_shape: List of (grid_h, grid_w) tuples """ tile_size = tile_size if tile_size is not None else self.tile_size max_tiles_h = max_tiles_h if max_tiles_h is not None else self.max_tiles_h max_tiles_w = max_tiles_w if max_tiles_w is not None else self.max_tiles_w min_tiles = min_tiles if min_tiles is not None else self.min_tiles do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor do_normalize = do_normalize if do_normalize is not None else self.do_normalize image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb images = make_flat_list_of_images(images) if not valid_images(images): raise ValueError("Invalid image input") all_pixel_values = [] all_grid_shapes = [] for image in images: # Convert to numpy image = to_numpy_array(image) # Convert to RGB if needed if do_convert_rgb and image.shape[-1] == 4: image = image[..., :3] # Get image dimensions (assume H, W, C after to_numpy_array) if image.ndim == 3: if image.shape[0] in [1, 3, 4]: # (C, H, W) h, w = image.shape[1], image.shape[2] else: # (H, W, C) h, w = image.shape[0], image.shape[1] else: raise ValueError(f"Expected 3D image, got shape {image.shape}") # Calculate grid dimensions grid_h, grid_w = calculate_tile_grid( h, w, tile_size, max_tiles_h, max_tiles_w, min_tiles ) # Tile the image tiles = tile_image( image, tile_size, grid_h, grid_w, resample=self.resample ) # Rescale if do_rescale: tiles = tiles.astype(np.float32) * rescale_factor # Normalize if do_normalize: mean = np.array(image_mean, dtype=np.float32).reshape(1, 3, 1, 1) std = np.array(image_std, dtype=np.float32).reshape(1, 3, 1, 1) tiles = (tiles - mean) / std all_pixel_values.append(tiles) all_grid_shapes.append((grid_h, grid_w)) # num_crops is 0 for each image since we use tiling, not pan-and-scan num_crops = [0] * len(all_pixel_values) # Concatenate all tiles into a single array for vLLM compatibility # vLLM's flat_from_sizes expects a single tensor, not a list if len(all_pixel_values) > 0: concatenated_pixels = np.concatenate(all_pixel_values, axis=0) else: concatenated_pixels = np.array([]) data = { "pixel_values": concatenated_pixels, "tile_grid_shape": all_grid_shapes, "num_crops": num_crops, } return BatchFeature(data=data, tensor_type=return_tensors) __all__ = ["Gemma3TiledImageProcessor", "calculate_tile_grid", "tile_image"]