|
|
""" |
|
|
Image processor for Gemma3Tiled that tiles images into grids. |
|
|
|
|
|
Instead of resizing images to a fixed size or using pan-and-scan crops, |
|
|
this processor tiles the image into a grid of 896x896 patches that |
|
|
preserves the spatial layout. |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
|
|
from transformers.image_utils import ( |
|
|
ImageInput, |
|
|
make_flat_list_of_images, |
|
|
valid_images, |
|
|
infer_channel_dimension_format, |
|
|
to_numpy_array, |
|
|
ChannelDimension, |
|
|
) |
|
|
from transformers.utils import TensorType |
|
|
|
|
|
|
|
|
def calculate_tile_grid( |
|
|
image_height: int, |
|
|
image_width: int, |
|
|
tile_size: int, |
|
|
max_tiles_h: int, |
|
|
max_tiles_w: int, |
|
|
min_tiles: int = 1, |
|
|
) -> tuple[int, int]: |
|
|
""" |
|
|
Calculate the optimal tile grid dimensions for an image. |
|
|
|
|
|
The strategy is to: |
|
|
1. Maximize effective resolution (pixels preserved from original image) |
|
|
2. Minimize wasted canvas space as a tiebreaker |
|
|
|
|
|
Upscaling is not credited - effective resolution is capped at original image size. |
|
|
This means larger grids are only chosen if they preserve more original detail. |
|
|
|
|
|
Example: For a 1344x912 image with 896x896 tiles: |
|
|
| Canvas | Scale | Effective | Wasted | |
|
|
|-------------|-------|-------------|-----------| |
|
|
| 1×1 (896²) | 0.667 | 544,768 | 258,048 | |
|
|
| 1×2 | 0.982 | 1,182,720 | 422,912 | |
|
|
| 2×1 | 0.667 | 544,768 | 1,060,864 | |
|
|
| 2×2 | 1.333 | 1,225,728 ✓ | 1,985,536 | |
|
|
|
|
|
Winner: 2×2 (highest effective resolution = 100% of original pixels) |
|
|
|
|
|
Args: |
|
|
image_height: Original image height |
|
|
image_width: Original image width |
|
|
tile_size: Size of each tile (896) |
|
|
max_tiles_h: Maximum tiles in height |
|
|
max_tiles_w: Maximum tiles in width |
|
|
min_tiles: Minimum total tiles |
|
|
|
|
|
Returns: |
|
|
(grid_h, grid_w): Number of tiles in height and width |
|
|
""" |
|
|
original_pixels = image_height * image_width |
|
|
|
|
|
best_grid = (1, 1) |
|
|
best_score = float('-inf') |
|
|
|
|
|
|
|
|
for rows in range(1, max_tiles_h + 1): |
|
|
for cols in range(1, max_tiles_w + 1): |
|
|
total_tiles = rows * cols |
|
|
|
|
|
|
|
|
if total_tiles < min_tiles: |
|
|
continue |
|
|
|
|
|
|
|
|
canvas_h = rows * tile_size |
|
|
canvas_w = cols * tile_size |
|
|
|
|
|
|
|
|
scale = min(canvas_w / image_width, canvas_h / image_height) |
|
|
|
|
|
|
|
|
|
|
|
effective = min(image_height * image_width * scale * scale, original_pixels) |
|
|
|
|
|
|
|
|
waste = (canvas_h * canvas_w) - effective |
|
|
|
|
|
|
|
|
score = effective - 0.001 * waste |
|
|
|
|
|
if score > best_score: |
|
|
best_score = score |
|
|
best_grid = (rows, cols) |
|
|
|
|
|
return best_grid |
|
|
|
|
|
|
|
|
def tile_image( |
|
|
image: np.ndarray, |
|
|
tile_size: int, |
|
|
grid_h: int, |
|
|
grid_w: int, |
|
|
resample: Image.Resampling = Image.Resampling.BICUBIC, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Tile an image into a grid of fixed-size patches. |
|
|
|
|
|
The image is first resized so that when divided into grid_h x grid_w tiles, |
|
|
each tile is exactly tile_size x tile_size. |
|
|
|
|
|
Args: |
|
|
image: Input image as numpy array (H, W, C) or (C, H, W) |
|
|
tile_size: Size of each tile |
|
|
grid_h: Number of tiles in height |
|
|
grid_w: Number of tiles in width |
|
|
resample: PIL resampling method |
|
|
|
|
|
Returns: |
|
|
Tiled image array of shape (grid_h * grid_w, C, tile_size, tile_size) |
|
|
""" |
|
|
|
|
|
if image.ndim == 3: |
|
|
if image.shape[0] in [1, 3, 4]: |
|
|
image = np.transpose(image, (1, 2, 0)) |
|
|
|
|
|
|
|
|
if np.issubdtype(image.dtype, np.floating) and image.max() <= 1.0: |
|
|
image = (image * 255).astype(np.uint8) |
|
|
else: |
|
|
image = image.astype(np.uint8) |
|
|
pil_image = Image.fromarray(image) |
|
|
|
|
|
|
|
|
target_h = grid_h * tile_size |
|
|
target_w = grid_w * tile_size |
|
|
|
|
|
|
|
|
pil_image = pil_image.resize((target_w, target_h), resample=resample) |
|
|
|
|
|
|
|
|
image = np.array(pil_image) |
|
|
|
|
|
|
|
|
|
|
|
tiles = [] |
|
|
for i in range(grid_h): |
|
|
for j in range(grid_w): |
|
|
y_start = i * tile_size |
|
|
x_start = j * tile_size |
|
|
tile = image[y_start:y_start + tile_size, x_start:x_start + tile_size] |
|
|
|
|
|
tile = np.transpose(tile, (2, 0, 1)) |
|
|
tiles.append(tile) |
|
|
|
|
|
return np.stack(tiles, axis=0) |
|
|
|
|
|
|
|
|
class Gemma3TiledImageProcessor(BaseImageProcessor): |
|
|
""" |
|
|
Image processor for Gemma3Tiled that tiles images into grids. |
|
|
|
|
|
This processor: |
|
|
1. Calculates the optimal tile grid for each image |
|
|
2. Resizes and tiles the image |
|
|
3. Returns pixel_values and tile_grid_shape metadata |
|
|
""" |
|
|
|
|
|
model_input_names = ["pixel_values", "tile_grid_shape", "num_crops"] |
|
|
_auto_class = "AutoImageProcessor" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tile_size: int = 896, |
|
|
max_tiles_h: int = 4, |
|
|
max_tiles_w: int = 4, |
|
|
min_tiles: int = 1, |
|
|
do_rescale: bool = True, |
|
|
rescale_factor: float = 1 / 255, |
|
|
do_normalize: bool = True, |
|
|
image_mean: Optional[list[float]] = None, |
|
|
image_std: Optional[list[float]] = None, |
|
|
do_convert_rgb: bool = True, |
|
|
resample: Image.Resampling = Image.Resampling.BICUBIC, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.tile_size = tile_size |
|
|
self.max_tiles_h = max_tiles_h |
|
|
self.max_tiles_w = max_tiles_w |
|
|
self.min_tiles = min_tiles |
|
|
self.do_rescale = do_rescale |
|
|
self.rescale_factor = rescale_factor |
|
|
self.do_normalize = do_normalize |
|
|
self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5] |
|
|
self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5] |
|
|
self.do_convert_rgb = do_convert_rgb |
|
|
self.resample = resample |
|
|
|
|
|
def preprocess( |
|
|
self, |
|
|
images: ImageInput, |
|
|
tile_size: Optional[int] = None, |
|
|
max_tiles_h: Optional[int] = None, |
|
|
max_tiles_w: Optional[int] = None, |
|
|
min_tiles: Optional[int] = None, |
|
|
do_rescale: Optional[bool] = None, |
|
|
rescale_factor: Optional[float] = None, |
|
|
do_normalize: Optional[bool] = None, |
|
|
image_mean: Optional[list[float]] = None, |
|
|
image_std: Optional[list[float]] = None, |
|
|
do_convert_rgb: Optional[bool] = None, |
|
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
**kwargs, |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Preprocess images by tiling them into grids. |
|
|
|
|
|
Args: |
|
|
images: Single image or batch of images |
|
|
|
|
|
Returns: |
|
|
BatchFeature with: |
|
|
- pixel_values: List of [num_tiles, C, H, W] arrays (one per image) |
|
|
- tile_grid_shape: List of (grid_h, grid_w) tuples |
|
|
""" |
|
|
tile_size = tile_size if tile_size is not None else self.tile_size |
|
|
max_tiles_h = max_tiles_h if max_tiles_h is not None else self.max_tiles_h |
|
|
max_tiles_w = max_tiles_w if max_tiles_w is not None else self.max_tiles_w |
|
|
min_tiles = min_tiles if min_tiles is not None else self.min_tiles |
|
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale |
|
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor |
|
|
do_normalize = do_normalize if do_normalize is not None else self.do_normalize |
|
|
image_mean = image_mean if image_mean is not None else self.image_mean |
|
|
image_std = image_std if image_std is not None else self.image_std |
|
|
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb |
|
|
|
|
|
images = make_flat_list_of_images(images) |
|
|
|
|
|
if not valid_images(images): |
|
|
raise ValueError("Invalid image input") |
|
|
|
|
|
all_pixel_values = [] |
|
|
all_grid_shapes = [] |
|
|
|
|
|
for image in images: |
|
|
|
|
|
image = to_numpy_array(image) |
|
|
|
|
|
|
|
|
if do_convert_rgb and image.shape[-1] == 4: |
|
|
image = image[..., :3] |
|
|
|
|
|
|
|
|
if image.ndim == 3: |
|
|
if image.shape[0] in [1, 3, 4]: |
|
|
h, w = image.shape[1], image.shape[2] |
|
|
else: |
|
|
h, w = image.shape[0], image.shape[1] |
|
|
else: |
|
|
raise ValueError(f"Expected 3D image, got shape {image.shape}") |
|
|
|
|
|
|
|
|
grid_h, grid_w = calculate_tile_grid( |
|
|
h, w, tile_size, max_tiles_h, max_tiles_w, min_tiles |
|
|
) |
|
|
|
|
|
|
|
|
tiles = tile_image( |
|
|
image, tile_size, grid_h, grid_w, resample=self.resample |
|
|
) |
|
|
|
|
|
|
|
|
if do_rescale: |
|
|
tiles = tiles.astype(np.float32) * rescale_factor |
|
|
|
|
|
|
|
|
if do_normalize: |
|
|
mean = np.array(image_mean, dtype=np.float32).reshape(1, 3, 1, 1) |
|
|
std = np.array(image_std, dtype=np.float32).reshape(1, 3, 1, 1) |
|
|
tiles = (tiles - mean) / std |
|
|
|
|
|
all_pixel_values.append(tiles) |
|
|
all_grid_shapes.append((grid_h, grid_w)) |
|
|
|
|
|
|
|
|
num_crops = [0] * len(all_pixel_values) |
|
|
|
|
|
|
|
|
|
|
|
if len(all_pixel_values) > 0: |
|
|
concatenated_pixels = np.concatenate(all_pixel_values, axis=0) |
|
|
else: |
|
|
concatenated_pixels = np.array([]) |
|
|
|
|
|
data = { |
|
|
"pixel_values": concatenated_pixels, |
|
|
"tile_grid_shape": all_grid_shapes, |
|
|
"num_crops": num_crops, |
|
|
} |
|
|
|
|
|
return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
|
|
|
|
|
|
__all__ = ["Gemma3TiledImageProcessor", "calculate_tile_grid", "tile_image"] |
|
|
|