"""Image processor and decoding helpers for Yasa2.""" from __future__ import annotations import io import math from typing import List, Tuple import numpy as np from PIL import Image from transformers import ConvNextImageProcessor class Yasa2ImageProcessor(ConvNextImageProcessor): """ConvNeXt image processor for Yasa2.""" model_input_names = ["pixel_values"] def __init__(self, *args, **kwargs): """Initialize the image processor with optional tiling metadata. Args: *args: Positional args forwarded to ConvNextImageProcessor. **kwargs: Keyword args forwarded to ConvNextImageProcessor. """ kwargs.setdefault("size", {"shortest_edge": 512}) # Do not force crop_size; ConvNextImageProcessor uses crop_pct by default. kwargs.setdefault("do_resize", True) kwargs.setdefault("do_center_crop", False) kwargs.setdefault("do_normalize", True) # TODO: Non-square inputs can break square-grid assumptions; consider enforcing square outputs or returning spatial dims. super().__init__(*args, **kwargs) self.use_navit = kwargs.get("use_navit", False) self.max_tiles_num = kwargs.get("max_tiles_num", 4) self.patch_size = kwargs.get("patch_size", 14) self.tiling_method = kwargs.get("tiling_method", "llava-uhd") def image_rgb_decoder_pil( image_bytes: bytes, skip_errors: bool = False ) -> dict: """Decode image bytes into a numpy RGB array. Args: image_bytes: Raw image bytes. skip_errors: Whether to return error info instead of raising. Returns: Dict with pixel values or an error message. """ try: image = Image.open(io.BytesIO(image_bytes)).convert("RGB") pixel_values = np.array(image) if pixel_values.ndim == 4: raise ValueError( "Image has 4 dimensions, expected 3 (possible GIF with jpg/png extension)." ) if pixel_values.shape[2] != 3: raise ValueError( f"Image has {pixel_values.shape[2]} channels, expected 3." ) return {"pixel_values": pixel_values} except Exception as exc: if not skip_errors: raise return {"error": str(exc)} def image_rgb_decoder_pil_tiling( image_bytes: bytes, skip_errors: bool = False, size: int = 1024, grid_pinpoints: List[Tuple[int, int]] = None, max_tiles_num: int = 9, patch_size: int = 4, tiling_method: str = "llava-next", ) -> dict: """Decode image bytes into tiled numpy arrays. Args: image_bytes: Raw image bytes. skip_errors: Whether to return error info instead of raising. size: Base tile size. grid_pinpoints: Candidate grid pinpoints. max_tiles_num: Maximum number of tiles for UHD tiling. patch_size: Patch size for UHD tiling. tiling_method: Tiling method name. Returns: Dict with tiled pixel values or an error message. """ if grid_pinpoints is None: grid_pinpoints = [ (2, 2), (1, 2), (2, 1), (1, 3), (3, 1), (1, 4), (4, 1), ] try: image = Image.open(io.BytesIO(image_bytes)).convert("RGB") if tiling_method.lower() == "llava-next": images = process_anyres_image(image, size, grid_pinpoints) pixel_values = np.array([np.array(img) for img in images]) elif tiling_method.lower() == "llava-uhd": images = process_anyres_image_uhd( image, max_tiles_num=max_tiles_num, scale_resolution=size, patch_size=patch_size, never_split=False, ) pixel_values = [np.array(img) for img in images] else: raise ValueError(f"Unknown tiling method: {tiling_method}") if tiling_method.lower() == "llava-next" and pixel_values.ndim != 4: raise ValueError( "Tiled image has unexpected dimensions (expected 4D)." ) if ( tiling_method.lower() == "llava-next" and pixel_values.shape[3] != 3 ): raise ValueError( f"Tiled image has {pixel_values.shape[3]} channels, expected 3." ) if tiling_method.lower() == "llava-uhd" and pixel_values[-1].ndim != 3: raise ValueError( "UHD tiled image has unexpected dimensions (expected 3D)." ) if ( tiling_method.lower() == "llava-uhd" and pixel_values[-1].shape[2] != 3 ): raise ValueError( f"UHD tiled image has {pixel_values[-1].shape[2]} channels, expected 3." ) return { "pixel_values": pixel_values, "num_tiles": len(pixel_values), "img_tiling": True, } except Exception as exc: if not skip_errors: raise return {"error": str(exc)} def resize_and_pad_image( image: Image.Image, target_resolution: Tuple[int, int] ) -> Image.Image: """Resize and pad an image to target resolution while preserving aspect ratio. Args: image: Input PIL image. target_resolution: Target (width, height). Returns: Resized and padded PIL image. """ original_width, original_height = image.size target_width, target_height = target_resolution scale_w = target_width / original_width scale_h = target_height / original_height if scale_w < scale_h: new_width = target_width new_height = min(math.ceil(original_height * scale_w), target_height) else: new_height = target_height new_width = min(math.ceil(original_width * scale_h), target_width) resized_image = image.resize((new_width, new_height)) new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) paste_x = (target_width - new_width) // 2 paste_y = (target_height - new_height) // 2 new_image.paste(resized_image, (paste_x, paste_y)) return new_image def select_best_resolution( original_size: Tuple[int, int], possible_resolutions: List[Tuple[int, int]], ) -> Tuple[int, int]: """Select the best resolution based on aspect ratio and minimal waste. Args: original_size: Original image size (width, height). possible_resolutions: Candidate resolutions. Returns: Best resolution (width, height). """ original_width, original_height = original_size best_fit = None max_effective_resolution = 0 min_wasted_resolution = float("inf") for width, height in possible_resolutions: scale = min(width / original_width, height / original_height) scaled_width, scaled_height = ( int(original_width * scale), int(original_height * scale), ) effective_resolution = min( scaled_width * scaled_height, original_width * original_height ) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or ( effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution ): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (width, height) return best_fit def divide_to_patches( image: Image.Image, patch_size: int ) -> List[Image.Image]: """Divide an image into square patches. Args: image: Input PIL image. patch_size: Patch size in pixels. Returns: List of patch images. """ patches = [] width, height = image.size for i in range(0, height, patch_size): for j in range(0, width, patch_size): box = (j, i, j + patch_size, i + patch_size) patches.append(image.crop(box)) return patches def process_anyres_image( image: Image.Image, size: int = 512, grid_pinpoints: List[Tuple[int, int]] = None, ) -> List[Image.Image]: """Process an image into a list of tiles for LLaVA-Next style tiling. Args: image: Input PIL image. size: Base tile size. grid_pinpoints: Candidate grid pinpoints. Returns: List of tiled images (original resize + tiles). """ if grid_pinpoints is None: grid_pinpoints = [(2, 2), (1, 2), (2, 1), (1, 3), (3, 1)] possible_resolutions = [(x * size, y * size) for x, y in grid_pinpoints] best_resolution = select_best_resolution(image.size, possible_resolutions) image_padded = resize_and_pad_image(image, best_resolution) patches = divide_to_patches(image_padded, size) image_original_resize = image.resize((size, size)) return [image_original_resize] + patches def estimate_num_tiles_llava_next( image_size: Tuple[int, int], size: int = 512, grid_pinpoints: List[Tuple[int, int]] = None, ) -> int: """Estimate tile count for LLaVA-Next tiling without decoding images.""" if grid_pinpoints is None: grid_pinpoints = [(2, 2), (1, 2), (2, 1), (1, 3), (3, 1)] possible_resolutions = [(x * size, y * size) for x, y in grid_pinpoints] best_resolution = select_best_resolution(image_size, possible_resolutions) grid_x = int(best_resolution[0] / size) grid_y = int(best_resolution[1] / size) return 1 + (grid_x * grid_y) def split_to_patches( image: Image.Image, grid: Tuple[int, int] ) -> List[Image.Image]: """Divide an image into patches using a fixed grid. Args: image: Input PIL image. grid: Grid dimensions (grid_x, grid_y). Returns: List of patch images. """ patches = [] width, height = image.size grid_x = int(width / grid[0]) grid_y = int(height / grid[1]) for i in range(0, height, grid_y): for j in range(0, width, grid_x): box = (j, i, j + grid_x, i + grid_y) patches.append(image.crop(box)) return patches def ensure_divide(length: float, patch_size: int) -> int: """Round length up to a multiple of patch_size. Args: length: Raw length to align. patch_size: Patch size to align to. Returns: Length aligned to patch_size. """ return max(round(length / patch_size) * patch_size, patch_size) def find_best_resize( original_size: Tuple[int, int], scale_resolution: int, patch_size: int, allow_upscale: bool = False, ) -> Tuple[int, int]: """Find the best resize dimensions for UHD tiling. Args: original_size: Original image size (width, height). scale_resolution: Target scale resolution. patch_size: Patch size for alignment. allow_upscale: Whether to allow upscaling. Returns: Best resized (width, height). """ width, height = original_size if (width * height > scale_resolution * scale_resolution) or allow_upscale: aspect_ratio = width / height height = int(scale_resolution / math.sqrt(aspect_ratio)) width = int(height * aspect_ratio) best_width = ensure_divide(width, patch_size) best_height = ensure_divide(height, patch_size) return (best_width, best_height) def get_refine_size( original_size: Tuple[int, int], grid: Tuple[int, int], scale_resolution: int, patch_size: int, allow_upscale: bool = False, ) -> Tuple[int, int]: """Compute the refined resize based on a tile grid. Args: original_size: Original image size (width, height). grid: Tile grid (grid_x, grid_y). scale_resolution: Target scale resolution. patch_size: Patch size for alignment. allow_upscale: Whether to allow upscaling. Returns: Refined resize (width, height). """ width, height = original_size grid_x, grid_y = grid refine_width = ensure_divide(width, grid_x) refine_height = ensure_divide(height, grid_y) grid_width = refine_width / grid_x grid_height = refine_height / grid_y best_grid_size = find_best_resize( (grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale, ) return (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) def process_anyres_image_uhd( image: Image.Image, max_tiles_num: int = 9, scale_resolution: int = 448, patch_size: int = 4, never_split: bool = False, ) -> List[Image.Image]: """Process an image into tiles for LLaVA-UHD style tiling. Args: image: Input PIL image. max_tiles_num: Maximum number of tiles to generate. scale_resolution: Target resolution for scaling. patch_size: Patch size for alignment. never_split: Whether to avoid splitting into tiles. Returns: List of tiles (patches + resized source image). """ original_width, original_height = image.size log_ratio = math.log(original_width / original_height) ratio = (original_width * original_height) / ( scale_resolution * scale_resolution ) multiple = min(math.ceil(ratio), max_tiles_num) patches = [] if multiple <= 1 or never_split: best_size = find_best_resize( image.size, scale_resolution, patch_size, allow_upscale=True ) source_image = image.resize(best_size, Image.Resampling.BICUBIC) return [source_image] candidate_split_grids_nums = [] for i in [multiple - 1, multiple, multiple + 1]: if i == 1 or i > max_tiles_num: continue candidate_split_grids_nums.append(i) best_resize = find_best_resize(image.size, scale_resolution, patch_size) source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) candidate_grids = [] for split_grids_nums in candidate_split_grids_nums: m = 1 while m <= split_grids_nums: if split_grids_nums % m == 0: candidate_grids.append([m, split_grids_nums // m]) m += 1 best_grid = [1, 1] min_error = float("inf") for grid in candidate_grids: error = abs(log_ratio - math.log(grid[0] / grid[1])) if error < min_error: best_grid = grid min_error = error refine_size = get_refine_size( image.size, (best_grid[0], best_grid[1]), scale_resolution, patch_size, allow_upscale=True, ) refine_image = image.resize(refine_size, Image.Resampling.BICUBIC) patches = split_to_patches(refine_image, (best_grid[0], best_grid[1])) return patches + [source_image] def estimate_num_tiles_llava_uhd( image_size: Tuple[int, int], max_tiles_num: int = 9, scale_resolution: int = 448, patch_size: int = 4, never_split: bool = False, ) -> int: """Estimate tile count for LLaVA-UHD tiling without decoding images.""" original_width, original_height = image_size log_ratio = math.log(original_width / original_height) ratio = (original_width * original_height) / ( scale_resolution * scale_resolution ) multiple = min(math.ceil(ratio), max_tiles_num) if multiple <= 1 or never_split: return 1 candidate_split_grids_nums = [] for i in [multiple - 1, multiple, multiple + 1]: if i == 1 or i > max_tiles_num: continue candidate_split_grids_nums.append(i) candidate_grids = [] for split_grids_nums in candidate_split_grids_nums: m = 1 while m <= split_grids_nums: if split_grids_nums % m == 0: candidate_grids.append([m, split_grids_nums // m]) m += 1 best_grid = [1, 1] min_error = float("inf") for grid in candidate_grids: error = abs(log_ratio - math.log(grid[0] / grid[1])) if error < min_error: best_grid = grid min_error = error return (best_grid[0] * best_grid[1]) + 1 Yasa2ImageProcessor.register_for_auto_class()