amoe / image_processor.py
SofianChay's picture
Initial upload of AMOE MoE model with custom code
fe365dd verified
# Image preprocessing for Falcon Vision
# Handles resizing, normalization, and patchification
import math
import numpy as np
import torch
from PIL import Image
IMAGE_MEAN = [0.5, 0.5, 0.5]
IMAGE_STD = [0.5, 0.5, 0.5]
def smart_resize(
height: int,
width: int,
factor: int = 16,
min_pixels: int = 128 * 128,
max_pixels: int = 256 * 256,
) -> tuple[int, int]:
"""Resize dimensions to be divisible by factor while respecting pixel bounds."""
if height < factor or width < factor:
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
if max(height, width) / min(height, width) > 200:
raise ValueError(f"absolute aspect ratio must be smaller than 200")
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = np.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = np.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
def convert_image_to_patches(image: torch.Tensor, patch_size: int) -> torch.Tensor:
"""Convert image (H, W, C) to patches (num_patches, patch_size^2 * C)."""
image_height, image_width, num_channels = image.shape
num_patches_height = image_height // patch_size
num_patches_width = image_width // patch_size
patched_image = image.reshape(num_patches_height, patch_size, num_patches_width, patch_size, num_channels)
patched_image = patched_image.permute(0, 2, 1, 3, 4)
patched_image = patched_image.reshape(num_patches_height * num_patches_width, -1)
return patched_image
def pad_along_first_dim(
array: torch.Tensor,
target_length: int,
pad_value: float = 0.0,
mask_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Pad the array along the first dimension and return mask."""
current_length = array.shape[0]
padding_length = target_length - current_length
mask = torch.ones(target_length, dtype=mask_dtype, device=array.device)
if padding_length > 0:
paddings = (0, 0, 0, padding_length)
array = torch.nn.functional.pad(array, paddings, mode="constant", value=pad_value)
mask[-padding_length:] = 0
return array, mask
class AMOEImageProcessor:
"""Image processor for AMOE model.
"""
def __init__(
self,
patch_size: int = 16,
min_pixels: int = 128 * 128,
max_pixels: int = 256 * 256,
image_mean: list[float] | None = None,
image_std: list[float] | None = None,
do_resize: bool = True,
do_rescale: bool = True,
do_normalize: bool = True,
):
self.patch_size = patch_size
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.image_mean = image_mean or IMAGE_MEAN
self.image_std = image_std or IMAGE_STD
self.do_resize = do_resize
self.do_rescale = do_rescale
self.do_normalize = do_normalize
def preprocess_single(self, image: Image.Image | np.ndarray) -> tuple[np.ndarray, tuple[int, int]]:
"""Preprocess a single image."""
if isinstance(image, Image.Image):
image = image.convert("RGB")
image = np.array(image)
# Ensure HWC format
if image.ndim == 2:
image = np.stack([image] * 3, axis=-1)
elif image.shape[0] == 3: # CHW -> HWC
image = np.transpose(image, (1, 2, 0))
height, width = image.shape[:2]
# Smart resize
if self.do_resize:
resized_height, resized_width = smart_resize(
height, width,
factor=self.patch_size,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
pil_image = Image.fromarray(image.astype(np.uint8))
pil_image = pil_image.resize((resized_width, resized_height), Image.BICUBIC)
image = np.array(pil_image)
else:
resized_height, resized_width = height, width
# Rescale to [0, 1]
if self.do_rescale:
image = image.astype(np.float32) / 255.0
# Normalize
if self.do_normalize:
mean = np.array(self.image_mean, dtype=np.float32)
std = np.array(self.image_std, dtype=np.float32)
image = (image - mean) / std
spatial_shape = (resized_height // self.patch_size, resized_width // self.patch_size)
return image, spatial_shape
def preprocess(
self,
images: list[Image.Image] | list[np.ndarray],
) -> tuple[list[np.ndarray], list[tuple[int, int]]]:
"""Preprocess a list of images."""
pixel_values = []
spatial_shapes = []
for image in images:
processed_image, spatial_shape = self.preprocess_single(image)
pixel_values.append(processed_image)
spatial_shapes.append(spatial_shape)
return pixel_values, spatial_shapes
def batch_images_with_mask(
self,
pixel_values: list[np.ndarray],
spatial_shapes: list[tuple[int, int]],
max_num_patches: int = 256,
pad: bool = True,
output_dtype: torch.dtype = torch.float32,
mask_dtype: torch.dtype | None = None,
) -> dict[str, torch.Tensor]:
"""Batch images into padded tensors with masks.
"""
if not pixel_values:
return None
if mask_dtype is None:
mask_dtype = output_dtype
batched_pixels = []
batched_masks = []
batched_shapes = []
for img, shape in zip(pixel_values, spatial_shapes):
img_tensor = torch.from_numpy(img).to(dtype=output_dtype)
patches = convert_image_to_patches(img_tensor, self.patch_size)
if pad:
patches, mask = pad_along_first_dim(
patches,
max_num_patches,
mask_dtype=mask_dtype,
)
else:
mask = torch.ones(patches.shape[0], dtype=mask_dtype, device=patches.device)
batched_pixels.append(patches)
batched_masks.append(mask)
batched_shapes.append(list(shape))
return {
"pixel_values": torch.stack(batched_pixels),
"padding_mask": torch.stack(batched_masks),
"spatial_shape": torch.tensor(batched_shapes),
}
def __call__(
self,
images: list[Image.Image] | Image.Image,
max_num_patches: int = 256,
n_storage_tokens: int = 4, # kept for API compat, not used here
return_tensors: str = "pt",
pad: bool = True,
output_dtype: torch.dtype = torch.float32,
mask_dtype: torch.dtype | None = None,
) -> dict[str, torch.Tensor]:
"""Process images and return batched tensors."""
if isinstance(images, Image.Image):
images = [images]
pixel_values, spatial_shapes = self.preprocess(images)
return self.batch_images_with_mask(
pixel_values,
spatial_shapes,
max_num_patches=max_num_patches,
pad=pad,
output_dtype=output_dtype,
mask_dtype=mask_dtype,
)