| | """Image processor and decoding helpers for Yasa2.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import io |
| | import math |
| | from typing import List, Tuple |
| |
|
| | import numpy as np |
| | from PIL import Image |
| | from transformers import ConvNextImageProcessor |
| |
|
| |
|
| | class Yasa2ImageProcessor(ConvNextImageProcessor): |
| | """ConvNeXt image processor for Yasa2.""" |
| |
|
| | model_input_names = ["pixel_values"] |
| |
|
| | def __init__(self, *args, **kwargs): |
| | """Initialize the image processor with optional tiling metadata. |
| | |
| | Args: |
| | *args: Positional args forwarded to ConvNextImageProcessor. |
| | **kwargs: Keyword args forwarded to ConvNextImageProcessor. |
| | """ |
| | kwargs.setdefault("size", {"shortest_edge": 512}) |
| | |
| | kwargs.setdefault("do_resize", True) |
| | kwargs.setdefault("do_center_crop", False) |
| | kwargs.setdefault("do_normalize", True) |
| | |
| | super().__init__(*args, **kwargs) |
| | self.use_navit = kwargs.get("use_navit", False) |
| | self.max_tiles_num = kwargs.get("max_tiles_num", 4) |
| | self.patch_size = kwargs.get("patch_size", 14) |
| | self.tiling_method = kwargs.get("tiling_method", "llava-uhd") |
| |
|
| |
|
| | def image_rgb_decoder_pil( |
| | image_bytes: bytes, skip_errors: bool = False |
| | ) -> dict: |
| | """Decode image bytes into a numpy RGB array. |
| | |
| | Args: |
| | image_bytes: Raw image bytes. |
| | skip_errors: Whether to return error info instead of raising. |
| | |
| | Returns: |
| | Dict with pixel values or an error message. |
| | """ |
| | try: |
| | image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| | pixel_values = np.array(image) |
| | if pixel_values.ndim == 4: |
| | raise ValueError( |
| | "Image has 4 dimensions, expected 3 (possible GIF with jpg/png extension)." |
| | ) |
| | if pixel_values.shape[2] != 3: |
| | raise ValueError( |
| | f"Image has {pixel_values.shape[2]} channels, expected 3." |
| | ) |
| | return {"pixel_values": pixel_values} |
| | except Exception as exc: |
| | if not skip_errors: |
| | raise |
| | return {"error": str(exc)} |
| |
|
| |
|
| | def image_rgb_decoder_pil_tiling( |
| | image_bytes: bytes, |
| | skip_errors: bool = False, |
| | size: int = 1024, |
| | grid_pinpoints: List[Tuple[int, int]] = None, |
| | max_tiles_num: int = 9, |
| | patch_size: int = 4, |
| | tiling_method: str = "llava-next", |
| | ) -> dict: |
| | """Decode image bytes into tiled numpy arrays. |
| | |
| | Args: |
| | image_bytes: Raw image bytes. |
| | skip_errors: Whether to return error info instead of raising. |
| | size: Base tile size. |
| | grid_pinpoints: Candidate grid pinpoints. |
| | max_tiles_num: Maximum number of tiles for UHD tiling. |
| | patch_size: Patch size for UHD tiling. |
| | tiling_method: Tiling method name. |
| | |
| | Returns: |
| | Dict with tiled pixel values or an error message. |
| | """ |
| | if grid_pinpoints is None: |
| | grid_pinpoints = [ |
| | (2, 2), |
| | (1, 2), |
| | (2, 1), |
| | (1, 3), |
| | (3, 1), |
| | (1, 4), |
| | (4, 1), |
| | ] |
| | try: |
| | image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| | if tiling_method.lower() == "llava-next": |
| | images = process_anyres_image(image, size, grid_pinpoints) |
| | pixel_values = np.array([np.array(img) for img in images]) |
| | elif tiling_method.lower() == "llava-uhd": |
| | images = process_anyres_image_uhd( |
| | image, |
| | max_tiles_num=max_tiles_num, |
| | scale_resolution=size, |
| | patch_size=patch_size, |
| | never_split=False, |
| | ) |
| | pixel_values = [np.array(img) for img in images] |
| | else: |
| | raise ValueError(f"Unknown tiling method: {tiling_method}") |
| |
|
| | if tiling_method.lower() == "llava-next" and pixel_values.ndim != 4: |
| | raise ValueError( |
| | "Tiled image has unexpected dimensions (expected 4D)." |
| | ) |
| | if ( |
| | tiling_method.lower() == "llava-next" |
| | and pixel_values.shape[3] != 3 |
| | ): |
| | raise ValueError( |
| | f"Tiled image has {pixel_values.shape[3]} channels, expected 3." |
| | ) |
| | if tiling_method.lower() == "llava-uhd" and pixel_values[-1].ndim != 3: |
| | raise ValueError( |
| | "UHD tiled image has unexpected dimensions (expected 3D)." |
| | ) |
| | if ( |
| | tiling_method.lower() == "llava-uhd" |
| | and pixel_values[-1].shape[2] != 3 |
| | ): |
| | raise ValueError( |
| | f"UHD tiled image has {pixel_values[-1].shape[2]} channels, expected 3." |
| | ) |
| |
|
| | return { |
| | "pixel_values": pixel_values, |
| | "num_tiles": len(pixel_values), |
| | "img_tiling": True, |
| | } |
| | except Exception as exc: |
| | if not skip_errors: |
| | raise |
| | return {"error": str(exc)} |
| |
|
| |
|
| | def resize_and_pad_image( |
| | image: Image.Image, target_resolution: Tuple[int, int] |
| | ) -> Image.Image: |
| | """Resize and pad an image to target resolution while preserving aspect ratio. |
| | |
| | Args: |
| | image: Input PIL image. |
| | target_resolution: Target (width, height). |
| | |
| | Returns: |
| | Resized and padded PIL image. |
| | """ |
| | original_width, original_height = image.size |
| | target_width, target_height = target_resolution |
| | scale_w = target_width / original_width |
| | scale_h = target_height / original_height |
| |
|
| | if scale_w < scale_h: |
| | new_width = target_width |
| | new_height = min(math.ceil(original_height * scale_w), target_height) |
| | else: |
| | new_height = target_height |
| | new_width = min(math.ceil(original_width * scale_h), target_width) |
| |
|
| | resized_image = image.resize((new_width, new_height)) |
| | new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) |
| | paste_x = (target_width - new_width) // 2 |
| | paste_y = (target_height - new_height) // 2 |
| | new_image.paste(resized_image, (paste_x, paste_y)) |
| | return new_image |
| |
|
| |
|
| | def select_best_resolution( |
| | original_size: Tuple[int, int], |
| | possible_resolutions: List[Tuple[int, int]], |
| | ) -> Tuple[int, int]: |
| | """Select the best resolution based on aspect ratio and minimal waste. |
| | |
| | Args: |
| | original_size: Original image size (width, height). |
| | possible_resolutions: Candidate resolutions. |
| | |
| | Returns: |
| | Best resolution (width, height). |
| | """ |
| | original_width, original_height = original_size |
| | best_fit = None |
| | max_effective_resolution = 0 |
| | min_wasted_resolution = float("inf") |
| |
|
| | for width, height in possible_resolutions: |
| | scale = min(width / original_width, height / original_height) |
| | scaled_width, scaled_height = ( |
| | int(original_width * scale), |
| | int(original_height * scale), |
| | ) |
| | effective_resolution = min( |
| | scaled_width * scaled_height, original_width * original_height |
| | ) |
| | wasted_resolution = (width * height) - effective_resolution |
| | if effective_resolution > max_effective_resolution or ( |
| | effective_resolution == max_effective_resolution |
| | and wasted_resolution < min_wasted_resolution |
| | ): |
| | max_effective_resolution = effective_resolution |
| | min_wasted_resolution = wasted_resolution |
| | best_fit = (width, height) |
| | return best_fit |
| |
|
| |
|
| | def divide_to_patches( |
| | image: Image.Image, patch_size: int |
| | ) -> List[Image.Image]: |
| | """Divide an image into square patches. |
| | |
| | Args: |
| | image: Input PIL image. |
| | patch_size: Patch size in pixels. |
| | |
| | Returns: |
| | List of patch images. |
| | """ |
| | patches = [] |
| | width, height = image.size |
| | for i in range(0, height, patch_size): |
| | for j in range(0, width, patch_size): |
| | box = (j, i, j + patch_size, i + patch_size) |
| | patches.append(image.crop(box)) |
| | return patches |
| |
|
| |
|
| | def process_anyres_image( |
| | image: Image.Image, |
| | size: int = 512, |
| | grid_pinpoints: List[Tuple[int, int]] = None, |
| | ) -> List[Image.Image]: |
| | """Process an image into a list of tiles for LLaVA-Next style tiling. |
| | |
| | Args: |
| | image: Input PIL image. |
| | size: Base tile size. |
| | grid_pinpoints: Candidate grid pinpoints. |
| | |
| | Returns: |
| | List of tiled images (original resize + tiles). |
| | """ |
| | if grid_pinpoints is None: |
| | grid_pinpoints = [(2, 2), (1, 2), (2, 1), (1, 3), (3, 1)] |
| | possible_resolutions = [(x * size, y * size) for x, y in grid_pinpoints] |
| | best_resolution = select_best_resolution(image.size, possible_resolutions) |
| | image_padded = resize_and_pad_image(image, best_resolution) |
| | patches = divide_to_patches(image_padded, size) |
| | image_original_resize = image.resize((size, size)) |
| | return [image_original_resize] + patches |
| |
|
| |
|
| | def estimate_num_tiles_llava_next( |
| | image_size: Tuple[int, int], |
| | size: int = 512, |
| | grid_pinpoints: List[Tuple[int, int]] = None, |
| | ) -> int: |
| | """Estimate tile count for LLaVA-Next tiling without decoding images.""" |
| | if grid_pinpoints is None: |
| | grid_pinpoints = [(2, 2), (1, 2), (2, 1), (1, 3), (3, 1)] |
| | possible_resolutions = [(x * size, y * size) for x, y in grid_pinpoints] |
| | best_resolution = select_best_resolution(image_size, possible_resolutions) |
| | grid_x = int(best_resolution[0] / size) |
| | grid_y = int(best_resolution[1] / size) |
| | return 1 + (grid_x * grid_y) |
| |
|
| |
|
| | def split_to_patches( |
| | image: Image.Image, grid: Tuple[int, int] |
| | ) -> List[Image.Image]: |
| | """Divide an image into patches using a fixed grid. |
| | |
| | Args: |
| | image: Input PIL image. |
| | grid: Grid dimensions (grid_x, grid_y). |
| | |
| | Returns: |
| | List of patch images. |
| | """ |
| | patches = [] |
| | width, height = image.size |
| | grid_x = int(width / grid[0]) |
| | grid_y = int(height / grid[1]) |
| | for i in range(0, height, grid_y): |
| | for j in range(0, width, grid_x): |
| | box = (j, i, j + grid_x, i + grid_y) |
| | patches.append(image.crop(box)) |
| | return patches |
| |
|
| |
|
| | def ensure_divide(length: float, patch_size: int) -> int: |
| | """Round length up to a multiple of patch_size. |
| | |
| | Args: |
| | length: Raw length to align. |
| | patch_size: Patch size to align to. |
| | |
| | Returns: |
| | Length aligned to patch_size. |
| | """ |
| | return max(round(length / patch_size) * patch_size, patch_size) |
| |
|
| |
|
| | def find_best_resize( |
| | original_size: Tuple[int, int], |
| | scale_resolution: int, |
| | patch_size: int, |
| | allow_upscale: bool = False, |
| | ) -> Tuple[int, int]: |
| | """Find the best resize dimensions for UHD tiling. |
| | |
| | Args: |
| | original_size: Original image size (width, height). |
| | scale_resolution: Target scale resolution. |
| | patch_size: Patch size for alignment. |
| | allow_upscale: Whether to allow upscaling. |
| | |
| | Returns: |
| | Best resized (width, height). |
| | """ |
| | width, height = original_size |
| | if (width * height > scale_resolution * scale_resolution) or allow_upscale: |
| | aspect_ratio = width / height |
| | height = int(scale_resolution / math.sqrt(aspect_ratio)) |
| | width = int(height * aspect_ratio) |
| | best_width = ensure_divide(width, patch_size) |
| | best_height = ensure_divide(height, patch_size) |
| | return (best_width, best_height) |
| |
|
| |
|
| | def get_refine_size( |
| | original_size: Tuple[int, int], |
| | grid: Tuple[int, int], |
| | scale_resolution: int, |
| | patch_size: int, |
| | allow_upscale: bool = False, |
| | ) -> Tuple[int, int]: |
| | """Compute the refined resize based on a tile grid. |
| | |
| | Args: |
| | original_size: Original image size (width, height). |
| | grid: Tile grid (grid_x, grid_y). |
| | scale_resolution: Target scale resolution. |
| | patch_size: Patch size for alignment. |
| | allow_upscale: Whether to allow upscaling. |
| | |
| | Returns: |
| | Refined resize (width, height). |
| | """ |
| | width, height = original_size |
| | grid_x, grid_y = grid |
| | refine_width = ensure_divide(width, grid_x) |
| | refine_height = ensure_divide(height, grid_y) |
| | grid_width = refine_width / grid_x |
| | grid_height = refine_height / grid_y |
| | best_grid_size = find_best_resize( |
| | (grid_width, grid_height), |
| | scale_resolution, |
| | patch_size, |
| | allow_upscale=allow_upscale, |
| | ) |
| | return (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) |
| |
|
| |
|
| | def process_anyres_image_uhd( |
| | image: Image.Image, |
| | max_tiles_num: int = 9, |
| | scale_resolution: int = 448, |
| | patch_size: int = 4, |
| | never_split: bool = False, |
| | ) -> List[Image.Image]: |
| | """Process an image into tiles for LLaVA-UHD style tiling. |
| | |
| | Args: |
| | image: Input PIL image. |
| | max_tiles_num: Maximum number of tiles to generate. |
| | scale_resolution: Target resolution for scaling. |
| | patch_size: Patch size for alignment. |
| | never_split: Whether to avoid splitting into tiles. |
| | |
| | Returns: |
| | List of tiles (patches + resized source image). |
| | """ |
| | original_width, original_height = image.size |
| | log_ratio = math.log(original_width / original_height) |
| | ratio = (original_width * original_height) / ( |
| | scale_resolution * scale_resolution |
| | ) |
| | multiple = min(math.ceil(ratio), max_tiles_num) |
| | patches = [] |
| |
|
| | if multiple <= 1 or never_split: |
| | best_size = find_best_resize( |
| | image.size, scale_resolution, patch_size, allow_upscale=True |
| | ) |
| | source_image = image.resize(best_size, Image.Resampling.BICUBIC) |
| | return [source_image] |
| |
|
| | candidate_split_grids_nums = [] |
| | for i in [multiple - 1, multiple, multiple + 1]: |
| | if i == 1 or i > max_tiles_num: |
| | continue |
| | candidate_split_grids_nums.append(i) |
| |
|
| | best_resize = find_best_resize(image.size, scale_resolution, patch_size) |
| | source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) |
| | candidate_grids = [] |
| | for split_grids_nums in candidate_split_grids_nums: |
| | m = 1 |
| | while m <= split_grids_nums: |
| | if split_grids_nums % m == 0: |
| | candidate_grids.append([m, split_grids_nums // m]) |
| | m += 1 |
| |
|
| | best_grid = [1, 1] |
| | min_error = float("inf") |
| | for grid in candidate_grids: |
| | error = abs(log_ratio - math.log(grid[0] / grid[1])) |
| | if error < min_error: |
| | best_grid = grid |
| | min_error = error |
| |
|
| | refine_size = get_refine_size( |
| | image.size, |
| | (best_grid[0], best_grid[1]), |
| | scale_resolution, |
| | patch_size, |
| | allow_upscale=True, |
| | ) |
| | refine_image = image.resize(refine_size, Image.Resampling.BICUBIC) |
| | patches = split_to_patches(refine_image, (best_grid[0], best_grid[1])) |
| | return patches + [source_image] |
| |
|
| |
|
| | def estimate_num_tiles_llava_uhd( |
| | image_size: Tuple[int, int], |
| | max_tiles_num: int = 9, |
| | scale_resolution: int = 448, |
| | patch_size: int = 4, |
| | never_split: bool = False, |
| | ) -> int: |
| | """Estimate tile count for LLaVA-UHD tiling without decoding images.""" |
| | original_width, original_height = image_size |
| | log_ratio = math.log(original_width / original_height) |
| | ratio = (original_width * original_height) / ( |
| | scale_resolution * scale_resolution |
| | ) |
| | multiple = min(math.ceil(ratio), max_tiles_num) |
| | if multiple <= 1 or never_split: |
| | return 1 |
| |
|
| | candidate_split_grids_nums = [] |
| | for i in [multiple - 1, multiple, multiple + 1]: |
| | if i == 1 or i > max_tiles_num: |
| | continue |
| | candidate_split_grids_nums.append(i) |
| |
|
| | candidate_grids = [] |
| | for split_grids_nums in candidate_split_grids_nums: |
| | m = 1 |
| | while m <= split_grids_nums: |
| | if split_grids_nums % m == 0: |
| | candidate_grids.append([m, split_grids_nums // m]) |
| | m += 1 |
| |
|
| | best_grid = [1, 1] |
| | min_error = float("inf") |
| | for grid in candidate_grids: |
| | error = abs(log_ratio - math.log(grid[0] / grid[1])) |
| | if error < min_error: |
| | best_grid = grid |
| | min_error = error |
| |
|
| | return (best_grid[0] * best_grid[1]) + 1 |
| |
|
| |
|
| | Yasa2ImageProcessor.register_for_auto_class() |
| |
|