LocateAnything-3B / image_processing_locateanything.py
exiawsh's picture
Initial snapshot
d979fab
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Image processor class for KimiVL."""
import math
import numpy as np
from PIL import Image
from typing import Optional, Union
import torch
from torchvision.transforms import functional as TF
from transformers.image_utils import ImageInput, make_list_of_images, valid_images
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.utils import TensorType
from transformers import AutoImageProcessor
MEAN = (0.5, 0.5, 0.5)
STD = (0.5, 0.5, 0.5)
class LocateAnythingImageProcessor(BaseImageProcessor):
model_type = "locateanything"
def __init__(
self,
patch_size: int = 14,
image_mean: tuple[float, float, float] = MEAN,
image_std: tuple[float, float, float] = STD,
in_token_limit: int = 4096,
merge_kernel_size: list[int, int] = [2, 2],
**kwargs,
):
super().__init__(**kwargs)
self.in_token_limit = in_token_limit
self.patch_size = patch_size
self.image_mean = image_mean
self.image_std = image_std
self.merge_kernel_size = merge_kernel_size
def rescale(
self, image: Image.Image, merge_kernel_size: list[int, int] = [2, 2]
) -> Image.Image:
w, h = image.size
patch_size = self.patch_size
if (w // patch_size) * (h // patch_size) > self.in_token_limit:
scale = math.sqrt(self.in_token_limit / ((w // patch_size) * (h // patch_size)))
new_w, new_h = int(w * scale), int(h * scale)
image = image.resize((new_w, new_h), Image.Resampling.BICUBIC)
new_w, new_h = image.size
pad_size_h = merge_kernel_size[0] * patch_size
pad_size_w = merge_kernel_size[1] * patch_size
target_w = math.ceil(new_w / pad_size_w) * pad_size_w
target_h = math.ceil(new_h / pad_size_h) * pad_size_h
if target_w != new_w or target_h != new_h:
image = image.resize((target_w, target_h), Image.Resampling.BICUBIC)
w, h = image.size
if w // patch_size >= 512 or h // patch_size >= 512:
raise ValueError("Exceed pos emb")
return image
def to_tensor(self, image: Image.Image) -> torch.Tensor:
return TF.to_tensor(image.convert("RGB"))
def normalize(self, image: torch.Tensor) -> torch.Tensor:
return TF.normalize(image, self.image_mean, self.image_std)
def patchify(self, image: torch.Tensor) -> tuple[torch.Tensor, list[int, int]]:
patch_size = self.patch_size
C, H, W = image.shape
patches = image.reshape(C, H // patch_size, patch_size, W // patch_size, patch_size)
patches = patches.permute(1, 3, 0, 2, 4)
patches = patches.contiguous().view(-1, C, patch_size, patch_size)
grid_hw = (H // patch_size, W // patch_size)
return patches, grid_hw
def _preprocess(self, image: ImageInput) -> tuple[torch.Tensor, list[int, int]]:
"""
Preprocess image and patchify it.
Args:
image (`ImageInput`):
Image to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
Returns:
patches: torch.Tensor
grid_hw: list[int, int]
"""
image = self.rescale(image, self.merge_kernel_size)
image = self.to_tensor(image)
image = self.normalize(image)
patches, grid_hw = self.patchify(image)
return patches, grid_hw
def preprocess(
self,
images: ImageInput,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
pixel_values, image_grid_hws = [], []
for image in images:
patches, image_grid_hw = self._preprocess(image)
pixel_values.append(patches)
image_grid_hws.append(image_grid_hw)
pixel_values = torch.concat(pixel_values, dim=0)
image_grid_hws = np.array(image_grid_hws)
data = {"pixel_values": pixel_values, "image_grid_hws": image_grid_hws}
return BatchFeature(data=data, tensor_type=return_tensors)
AutoImageProcessor.register("LocateAnythingImageProcessor", LocateAnythingImageProcessor)