|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
IMAGE_MEAN = [0.5, 0.5, 0.5] |
|
|
IMAGE_STD = [0.5, 0.5, 0.5] |
|
|
|
|
|
|
|
|
def smart_resize( |
|
|
height: int, |
|
|
width: int, |
|
|
factor: int = 16, |
|
|
min_pixels: int = 128 * 128, |
|
|
max_pixels: int = 256 * 256, |
|
|
) -> tuple[int, int]: |
|
|
"""Resize dimensions to be divisible by factor while respecting pixel bounds.""" |
|
|
if height < factor or width < factor: |
|
|
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") |
|
|
if max(height, width) / min(height, width) > 200: |
|
|
raise ValueError(f"absolute aspect ratio must be smaller than 200") |
|
|
|
|
|
h_bar = round(height / factor) * factor |
|
|
w_bar = round(width / factor) * factor |
|
|
|
|
|
if h_bar * w_bar > max_pixels: |
|
|
beta = np.sqrt((height * width) / max_pixels) |
|
|
h_bar = math.floor(height / beta / factor) * factor |
|
|
w_bar = math.floor(width / beta / factor) * factor |
|
|
elif h_bar * w_bar < min_pixels: |
|
|
beta = np.sqrt(min_pixels / (height * width)) |
|
|
h_bar = math.ceil(height * beta / factor) * factor |
|
|
w_bar = math.ceil(width * beta / factor) * factor |
|
|
|
|
|
return h_bar, w_bar |
|
|
|
|
|
|
|
|
def convert_image_to_patches(image: torch.Tensor, patch_size: int) -> torch.Tensor: |
|
|
"""Convert image (H, W, C) to patches (num_patches, patch_size^2 * C).""" |
|
|
image_height, image_width, num_channels = image.shape |
|
|
num_patches_height = image_height // patch_size |
|
|
num_patches_width = image_width // patch_size |
|
|
|
|
|
patched_image = image.reshape(num_patches_height, patch_size, num_patches_width, patch_size, num_channels) |
|
|
patched_image = patched_image.permute(0, 2, 1, 3, 4) |
|
|
patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1) |
|
|
return patched_image |
|
|
|
|
|
|
|
|
def pad_along_first_dim( |
|
|
array: torch.Tensor, |
|
|
target_length: int, |
|
|
pad_value: float = 0.0, |
|
|
mask_dtype: torch.dtype = torch.float32, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Pad the array along the first dimension and return mask.""" |
|
|
current_length = array.shape[0] |
|
|
padding_length = target_length - current_length |
|
|
mask = torch.ones(target_length, dtype=mask_dtype, device=array.device) |
|
|
|
|
|
if padding_length > 0: |
|
|
paddings = (0, 0, 0, padding_length) |
|
|
array = torch.nn.functional.pad(array, paddings, mode="constant", value=pad_value) |
|
|
mask[-padding_length:] = 0 |
|
|
|
|
|
return array, mask |
|
|
|
|
|
|
|
|
class AMOEImageProcessor: |
|
|
"""Image processor for AMOE model. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
patch_size: int = 16, |
|
|
min_pixels: int = 128 * 128, |
|
|
max_pixels: int = 256 * 256, |
|
|
image_mean: list[float] | None = None, |
|
|
image_std: list[float] | None = None, |
|
|
do_resize: bool = True, |
|
|
do_rescale: bool = True, |
|
|
do_normalize: bool = True, |
|
|
): |
|
|
self.patch_size = patch_size |
|
|
self.min_pixels = min_pixels |
|
|
self.max_pixels = max_pixels |
|
|
self.image_mean = image_mean or IMAGE_MEAN |
|
|
self.image_std = image_std or IMAGE_STD |
|
|
self.do_resize = do_resize |
|
|
self.do_rescale = do_rescale |
|
|
self.do_normalize = do_normalize |
|
|
|
|
|
def preprocess_single(self, image: Image.Image | np.ndarray) -> tuple[np.ndarray, tuple[int, int]]: |
|
|
"""Preprocess a single image.""" |
|
|
if isinstance(image, Image.Image): |
|
|
image = image.convert("RGB") |
|
|
image = np.array(image) |
|
|
|
|
|
|
|
|
if image.ndim == 2: |
|
|
image = np.stack([image] * 3, axis=-1) |
|
|
elif image.shape[0] == 3: |
|
|
image = np.transpose(image, (1, 2, 0)) |
|
|
|
|
|
height, width = image.shape[:2] |
|
|
|
|
|
|
|
|
if self.do_resize: |
|
|
resized_height, resized_width = smart_resize( |
|
|
height, width, |
|
|
factor=self.patch_size, |
|
|
min_pixels=self.min_pixels, |
|
|
max_pixels=self.max_pixels, |
|
|
) |
|
|
pil_image = Image.fromarray(image.astype(np.uint8)) |
|
|
pil_image = pil_image.resize((resized_width, resized_height), Image.BICUBIC) |
|
|
image = np.array(pil_image) |
|
|
else: |
|
|
resized_height, resized_width = height, width |
|
|
|
|
|
|
|
|
if self.do_rescale: |
|
|
image = image.astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
if self.do_normalize: |
|
|
mean = np.array(self.image_mean, dtype=np.float32) |
|
|
std = np.array(self.image_std, dtype=np.float32) |
|
|
image = (image - mean) / std |
|
|
|
|
|
spatial_shape = (resized_height // self.patch_size, resized_width // self.patch_size) |
|
|
return image, spatial_shape |
|
|
|
|
|
def preprocess( |
|
|
self, |
|
|
images: list[Image.Image] | list[np.ndarray], |
|
|
) -> tuple[list[np.ndarray], list[tuple[int, int]]]: |
|
|
"""Preprocess a list of images.""" |
|
|
pixel_values = [] |
|
|
spatial_shapes = [] |
|
|
|
|
|
for image in images: |
|
|
processed_image, spatial_shape = self.preprocess_single(image) |
|
|
pixel_values.append(processed_image) |
|
|
spatial_shapes.append(spatial_shape) |
|
|
|
|
|
return pixel_values, spatial_shapes |
|
|
|
|
|
def batch_images_with_mask( |
|
|
self, |
|
|
pixel_values: list[np.ndarray], |
|
|
spatial_shapes: list[tuple[int, int]], |
|
|
max_num_patches: int = 256, |
|
|
pad: bool = True, |
|
|
output_dtype: torch.dtype = torch.float32, |
|
|
mask_dtype: torch.dtype | None = None, |
|
|
) -> dict[str, torch.Tensor]: |
|
|
"""Batch images into padded tensors with masks. |
|
|
""" |
|
|
if not pixel_values: |
|
|
return None |
|
|
|
|
|
if mask_dtype is None: |
|
|
mask_dtype = output_dtype |
|
|
|
|
|
batched_pixels = [] |
|
|
batched_masks = [] |
|
|
batched_shapes = [] |
|
|
|
|
|
for img, shape in zip(pixel_values, spatial_shapes): |
|
|
img_tensor = torch.from_numpy(img).to(dtype=output_dtype) |
|
|
patches = convert_image_to_patches(img_tensor, self.patch_size) |
|
|
|
|
|
if pad: |
|
|
patches, mask = pad_along_first_dim( |
|
|
patches, |
|
|
max_num_patches, |
|
|
mask_dtype=mask_dtype, |
|
|
) |
|
|
else: |
|
|
mask = torch.ones(patches.shape[0], dtype=mask_dtype, device=patches.device) |
|
|
|
|
|
batched_pixels.append(patches) |
|
|
batched_masks.append(mask) |
|
|
batched_shapes.append(list(shape)) |
|
|
|
|
|
return { |
|
|
"pixel_values": torch.stack(batched_pixels), |
|
|
"padding_mask": torch.stack(batched_masks), |
|
|
"spatial_shape": torch.tensor(batched_shapes), |
|
|
} |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: list[Image.Image] | Image.Image, |
|
|
max_num_patches: int = 256, |
|
|
n_storage_tokens: int = 4, |
|
|
return_tensors: str = "pt", |
|
|
pad: bool = True, |
|
|
output_dtype: torch.dtype = torch.float32, |
|
|
mask_dtype: torch.dtype | None = None, |
|
|
) -> dict[str, torch.Tensor]: |
|
|
"""Process images and return batched tensors.""" |
|
|
if isinstance(images, Image.Image): |
|
|
images = [images] |
|
|
|
|
|
pixel_values, spatial_shapes = self.preprocess(images) |
|
|
|
|
|
return self.batch_images_with_mask( |
|
|
pixel_values, |
|
|
spatial_shapes, |
|
|
max_num_patches=max_num_patches, |
|
|
pad=pad, |
|
|
output_dtype=output_dtype, |
|
|
mask_dtype=mask_dtype, |
|
|
) |
|
|
|