reka-edge-2603 / image_processing_yasa2.py
donovanOng92's picture
upload
7d24555 verified
"""Image processor and decoding helpers for Yasa2."""
from __future__ import annotations
import io
import math
from typing import List, Tuple
import numpy as np
from PIL import Image
from transformers import ConvNextImageProcessor
class Yasa2ImageProcessor(ConvNextImageProcessor):
"""ConvNeXt image processor for Yasa2."""
model_input_names = ["pixel_values"]
def __init__(self, *args, **kwargs):
"""Initialize the image processor with optional tiling metadata.
Args:
*args: Positional args forwarded to ConvNextImageProcessor.
**kwargs: Keyword args forwarded to ConvNextImageProcessor.
"""
kwargs.setdefault("size", {"shortest_edge": 512})
# Do not force crop_size; ConvNextImageProcessor uses crop_pct by default.
kwargs.setdefault("do_resize", True)
kwargs.setdefault("do_center_crop", False)
kwargs.setdefault("do_normalize", True)
# TODO: Non-square inputs can break square-grid assumptions; consider enforcing square outputs or returning spatial dims.
super().__init__(*args, **kwargs)
self.use_navit = kwargs.get("use_navit", False)
self.max_tiles_num = kwargs.get("max_tiles_num", 4)
self.patch_size = kwargs.get("patch_size", 14)
self.tiling_method = kwargs.get("tiling_method", "llava-uhd")
def image_rgb_decoder_pil(
image_bytes: bytes, skip_errors: bool = False
) -> dict:
"""Decode image bytes into a numpy RGB array.
Args:
image_bytes: Raw image bytes.
skip_errors: Whether to return error info instead of raising.
Returns:
Dict with pixel values or an error message.
"""
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
pixel_values = np.array(image)
if pixel_values.ndim == 4:
raise ValueError(
"Image has 4 dimensions, expected 3 (possible GIF with jpg/png extension)."
)
if pixel_values.shape[2] != 3:
raise ValueError(
f"Image has {pixel_values.shape[2]} channels, expected 3."
)
return {"pixel_values": pixel_values}
except Exception as exc:
if not skip_errors:
raise
return {"error": str(exc)}
def image_rgb_decoder_pil_tiling(
image_bytes: bytes,
skip_errors: bool = False,
size: int = 1024,
grid_pinpoints: List[Tuple[int, int]] = None,
max_tiles_num: int = 9,
patch_size: int = 4,
tiling_method: str = "llava-next",
) -> dict:
"""Decode image bytes into tiled numpy arrays.
Args:
image_bytes: Raw image bytes.
skip_errors: Whether to return error info instead of raising.
size: Base tile size.
grid_pinpoints: Candidate grid pinpoints.
max_tiles_num: Maximum number of tiles for UHD tiling.
patch_size: Patch size for UHD tiling.
tiling_method: Tiling method name.
Returns:
Dict with tiled pixel values or an error message.
"""
if grid_pinpoints is None:
grid_pinpoints = [
(2, 2),
(1, 2),
(2, 1),
(1, 3),
(3, 1),
(1, 4),
(4, 1),
]
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
if tiling_method.lower() == "llava-next":
images = process_anyres_image(image, size, grid_pinpoints)
pixel_values = np.array([np.array(img) for img in images])
elif tiling_method.lower() == "llava-uhd":
images = process_anyres_image_uhd(
image,
max_tiles_num=max_tiles_num,
scale_resolution=size,
patch_size=patch_size,
never_split=False,
)
pixel_values = [np.array(img) for img in images]
else:
raise ValueError(f"Unknown tiling method: {tiling_method}")
if tiling_method.lower() == "llava-next" and pixel_values.ndim != 4:
raise ValueError(
"Tiled image has unexpected dimensions (expected 4D)."
)
if (
tiling_method.lower() == "llava-next"
and pixel_values.shape[3] != 3
):
raise ValueError(
f"Tiled image has {pixel_values.shape[3]} channels, expected 3."
)
if tiling_method.lower() == "llava-uhd" and pixel_values[-1].ndim != 3:
raise ValueError(
"UHD tiled image has unexpected dimensions (expected 3D)."
)
if (
tiling_method.lower() == "llava-uhd"
and pixel_values[-1].shape[2] != 3
):
raise ValueError(
f"UHD tiled image has {pixel_values[-1].shape[2]} channels, expected 3."
)
return {
"pixel_values": pixel_values,
"num_tiles": len(pixel_values),
"img_tiling": True,
}
except Exception as exc:
if not skip_errors:
raise
return {"error": str(exc)}
def resize_and_pad_image(
image: Image.Image, target_resolution: Tuple[int, int]
) -> Image.Image:
"""Resize and pad an image to target resolution while preserving aspect ratio.
Args:
image: Input PIL image.
target_resolution: Target (width, height).
Returns:
Resized and padded PIL image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
resized_image = image.resize((new_width, new_height))
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def select_best_resolution(
original_size: Tuple[int, int],
possible_resolutions: List[Tuple[int, int]],
) -> Tuple[int, int]:
"""Select the best resolution based on aspect ratio and minimal waste.
Args:
original_size: Original image size (width, height).
possible_resolutions: Candidate resolutions.
Returns:
Best resolution (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float("inf")
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
scaled_width, scaled_height = (
int(original_width * scale),
int(original_height * scale),
)
effective_resolution = min(
scaled_width * scaled_height, original_width * original_height
)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def divide_to_patches(
image: Image.Image, patch_size: int
) -> List[Image.Image]:
"""Divide an image into square patches.
Args:
image: Input PIL image.
patch_size: Patch size in pixels.
Returns:
List of patch images.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patches.append(image.crop(box))
return patches
def process_anyres_image(
image: Image.Image,
size: int = 512,
grid_pinpoints: List[Tuple[int, int]] = None,
) -> List[Image.Image]:
"""Process an image into a list of tiles for LLaVA-Next style tiling.
Args:
image: Input PIL image.
size: Base tile size.
grid_pinpoints: Candidate grid pinpoints.
Returns:
List of tiled images (original resize + tiles).
"""
if grid_pinpoints is None:
grid_pinpoints = [(2, 2), (1, 2), (2, 1), (1, 3), (3, 1)]
possible_resolutions = [(x * size, y * size) for x, y in grid_pinpoints]
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, size)
image_original_resize = image.resize((size, size))
return [image_original_resize] + patches
def estimate_num_tiles_llava_next(
image_size: Tuple[int, int],
size: int = 512,
grid_pinpoints: List[Tuple[int, int]] = None,
) -> int:
"""Estimate tile count for LLaVA-Next tiling without decoding images."""
if grid_pinpoints is None:
grid_pinpoints = [(2, 2), (1, 2), (2, 1), (1, 3), (3, 1)]
possible_resolutions = [(x * size, y * size) for x, y in grid_pinpoints]
best_resolution = select_best_resolution(image_size, possible_resolutions)
grid_x = int(best_resolution[0] / size)
grid_y = int(best_resolution[1] / size)
return 1 + (grid_x * grid_y)
def split_to_patches(
image: Image.Image, grid: Tuple[int, int]
) -> List[Image.Image]:
"""Divide an image into patches using a fixed grid.
Args:
image: Input PIL image.
grid: Grid dimensions (grid_x, grid_y).
Returns:
List of patch images.
"""
patches = []
width, height = image.size
grid_x = int(width / grid[0])
grid_y = int(height / grid[1])
for i in range(0, height, grid_y):
for j in range(0, width, grid_x):
box = (j, i, j + grid_x, i + grid_y)
patches.append(image.crop(box))
return patches
def ensure_divide(length: float, patch_size: int) -> int:
"""Round length up to a multiple of patch_size.
Args:
length: Raw length to align.
patch_size: Patch size to align to.
Returns:
Length aligned to patch_size.
"""
return max(round(length / patch_size) * patch_size, patch_size)
def find_best_resize(
original_size: Tuple[int, int],
scale_resolution: int,
patch_size: int,
allow_upscale: bool = False,
) -> Tuple[int, int]:
"""Find the best resize dimensions for UHD tiling.
Args:
original_size: Original image size (width, height).
scale_resolution: Target scale resolution.
patch_size: Patch size for alignment.
allow_upscale: Whether to allow upscaling.
Returns:
Best resized (width, height).
"""
width, height = original_size
if (width * height > scale_resolution * scale_resolution) or allow_upscale:
aspect_ratio = width / height
height = int(scale_resolution / math.sqrt(aspect_ratio))
width = int(height * aspect_ratio)
best_width = ensure_divide(width, patch_size)
best_height = ensure_divide(height, patch_size)
return (best_width, best_height)
def get_refine_size(
original_size: Tuple[int, int],
grid: Tuple[int, int],
scale_resolution: int,
patch_size: int,
allow_upscale: bool = False,
) -> Tuple[int, int]:
"""Compute the refined resize based on a tile grid.
Args:
original_size: Original image size (width, height).
grid: Tile grid (grid_x, grid_y).
scale_resolution: Target scale resolution.
patch_size: Patch size for alignment.
allow_upscale: Whether to allow upscaling.
Returns:
Refined resize (width, height).
"""
width, height = original_size
grid_x, grid_y = grid
refine_width = ensure_divide(width, grid_x)
refine_height = ensure_divide(height, grid_y)
grid_width = refine_width / grid_x
grid_height = refine_height / grid_y
best_grid_size = find_best_resize(
(grid_width, grid_height),
scale_resolution,
patch_size,
allow_upscale=allow_upscale,
)
return (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
def process_anyres_image_uhd(
image: Image.Image,
max_tiles_num: int = 9,
scale_resolution: int = 448,
patch_size: int = 4,
never_split: bool = False,
) -> List[Image.Image]:
"""Process an image into tiles for LLaVA-UHD style tiling.
Args:
image: Input PIL image.
max_tiles_num: Maximum number of tiles to generate.
scale_resolution: Target resolution for scaling.
patch_size: Patch size for alignment.
never_split: Whether to avoid splitting into tiles.
Returns:
List of tiles (patches + resized source image).
"""
original_width, original_height = image.size
log_ratio = math.log(original_width / original_height)
ratio = (original_width * original_height) / (
scale_resolution * scale_resolution
)
multiple = min(math.ceil(ratio), max_tiles_num)
patches = []
if multiple <= 1 or never_split:
best_size = find_best_resize(
image.size, scale_resolution, patch_size, allow_upscale=True
)
source_image = image.resize(best_size, Image.Resampling.BICUBIC)
return [source_image]
candidate_split_grids_nums = []
for i in [multiple - 1, multiple, multiple + 1]:
if i == 1 or i > max_tiles_num:
continue
candidate_split_grids_nums.append(i)
best_resize = find_best_resize(image.size, scale_resolution, patch_size)
source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
candidate_grids = []
for split_grids_nums in candidate_split_grids_nums:
m = 1
while m <= split_grids_nums:
if split_grids_nums % m == 0:
candidate_grids.append([m, split_grids_nums // m])
m += 1
best_grid = [1, 1]
min_error = float("inf")
for grid in candidate_grids:
error = abs(log_ratio - math.log(grid[0] / grid[1]))
if error < min_error:
best_grid = grid
min_error = error
refine_size = get_refine_size(
image.size,
(best_grid[0], best_grid[1]),
scale_resolution,
patch_size,
allow_upscale=True,
)
refine_image = image.resize(refine_size, Image.Resampling.BICUBIC)
patches = split_to_patches(refine_image, (best_grid[0], best_grid[1]))
return patches + [source_image]
def estimate_num_tiles_llava_uhd(
image_size: Tuple[int, int],
max_tiles_num: int = 9,
scale_resolution: int = 448,
patch_size: int = 4,
never_split: bool = False,
) -> int:
"""Estimate tile count for LLaVA-UHD tiling without decoding images."""
original_width, original_height = image_size
log_ratio = math.log(original_width / original_height)
ratio = (original_width * original_height) / (
scale_resolution * scale_resolution
)
multiple = min(math.ceil(ratio), max_tiles_num)
if multiple <= 1 or never_split:
return 1
candidate_split_grids_nums = []
for i in [multiple - 1, multiple, multiple + 1]:
if i == 1 or i > max_tiles_num:
continue
candidate_split_grids_nums.append(i)
candidate_grids = []
for split_grids_nums in candidate_split_grids_nums:
m = 1
while m <= split_grids_nums:
if split_grids_nums % m == 0:
candidate_grids.append([m, split_grids_nums // m])
m += 1
best_grid = [1, 1]
min_error = float("inf")
for grid in candidate_grids:
error = abs(log_ratio - math.log(grid[0] / grid[1]))
if error < min_error:
best_grid = grid
min_error = error
return (best_grid[0] * best_grid[1]) + 1
Yasa2ImageProcessor.register_for_auto_class()