|
|
from transformers import BaseImageProcessor, ImageProcessingMixin |
|
|
from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs |
|
|
import math |
|
|
from typing import Iterable, Optional, Tuple, List, TypedDict, Literal, Union, overload |
|
|
|
|
|
from PIL import Image |
|
|
import torch |
|
|
import numpy as np |
|
|
import torchvision |
|
|
from torch import nn |
|
|
from torch.nn import functional as F, LayerNorm |
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
from transformers.activations import ACT2FN |
|
|
from torchvision import transforms |
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
from transformers.feature_extraction_utils import BatchFeature, TensorType |
|
|
from transformers.image_utils import ImageInput |
|
|
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack |
|
|
from math import ceil |
|
|
from itertools import product |
|
|
|
|
|
|
|
|
|
|
|
MAX_IMAGE_SIZE: int = 3024 |
|
|
|
|
|
class Step3VLImagePixelInputs(TypedDict): |
|
|
type: Literal["pixel_values"] |
|
|
pixel_values: torch.Tensor |
|
|
patch_pixel_values: Optional[torch.Tensor] |
|
|
num_patches: list[int] |
|
|
|
|
|
|
|
|
class Step3VLImageEmbeddingInputs(TypedDict): |
|
|
type: Literal["image_embeds"] |
|
|
image_embeds: torch.Tensor |
|
|
|
|
|
|
|
|
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None] |
|
|
|
|
|
|
|
|
class GPUToTensor(torch.nn.Module): |
|
|
|
|
|
def forward(self, raw_image: Union[np.ndarray, |
|
|
Image.Image]) -> torch.Tensor: |
|
|
if isinstance(raw_image, Image.Image): |
|
|
return transforms.ToTensor()(raw_image) |
|
|
if raw_image.ndim == 2: |
|
|
raw_image = raw_image[:, :, None].repeat(3, -1) |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
image_tensor = torch.from_numpy(raw_image).to(device) |
|
|
image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous() |
|
|
if image_tensor.dtype == torch.uint8: |
|
|
image_tensor = image_tensor.to(torch.float32).div(255) |
|
|
return image_tensor |
|
|
|
|
|
class Step3VisionProcessor(BaseImageProcessor): |
|
|
|
|
|
def __init__(self, size, interpolation_mode="bicubic", patch_size=None): |
|
|
mean = [0.48145466, 0.4578275, 0.40821073] |
|
|
std = [0.26862954, 0.26130258, 0.27577711] |
|
|
patch_size = patch_size if patch_size is not None else size |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
GPUToTensor(), |
|
|
transforms.Normalize(mean, std), |
|
|
transforms.Resize( |
|
|
(size, size), |
|
|
interpolation=InterpolationMode.BICUBIC if interpolation_mode |
|
|
== "bicubic" else InterpolationMode.BILINEAR, |
|
|
antialias=True), |
|
|
]) |
|
|
|
|
|
self.patch_transform = transforms.Compose([ |
|
|
GPUToTensor(), |
|
|
transforms.Normalize(mean, std), |
|
|
transforms.Resize( |
|
|
(patch_size, patch_size), |
|
|
interpolation=InterpolationMode.BICUBIC if interpolation_mode |
|
|
== "bicubic" else InterpolationMode.BILINEAR, |
|
|
antialias=True), |
|
|
]) if patch_size is not None else None |
|
|
|
|
|
def __call__(self, image, is_patch=False): |
|
|
if is_patch: |
|
|
return {"pixel_values": self.patch_transform(image).unsqueeze(0)} |
|
|
else: |
|
|
return {"pixel_values": self.transform(image).unsqueeze(0)} |
|
|
|
|
|
class ImagePatcher: |
|
|
def determine_window_size(self, long: int, short: int) -> int: |
|
|
if long <= 728: |
|
|
return short if long / short > 1.5 else 0 |
|
|
return min(short, 504) if long / short > 4 else 504 |
|
|
def slide_window( |
|
|
self, |
|
|
width: int, |
|
|
height: int, |
|
|
sizes: list[tuple[int, int]], |
|
|
steps: list[tuple[int, int]], |
|
|
img_rate_thr: float = 0.6, |
|
|
) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]: |
|
|
assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1" |
|
|
windows = [] |
|
|
|
|
|
for size, step in zip(sizes, steps): |
|
|
size_w, size_h = size |
|
|
step_w, step_h = step |
|
|
|
|
|
x_num = 1 if width <= size_w else ceil((width - size_w) / step_w + |
|
|
1) |
|
|
x_start = [step_w * i for i in range(x_num)] |
|
|
if len(x_start) > 1 and x_start[-1] + size_w > width: |
|
|
x_start[-1] = width - size_w |
|
|
|
|
|
y_num = 1 if height <= size_h else ceil((height - size_h) / |
|
|
step_h + 1) |
|
|
y_start = [step_h * i for i in range(y_num)] |
|
|
if len(y_start) > 1 and y_start[-1] + size_h > height: |
|
|
y_start[-1] = height - size_h |
|
|
|
|
|
start = np.array(list(product(y_start, x_start)), dtype=int) |
|
|
start[:, [0, 1]] = start[:, [1, 0]] |
|
|
windows.append(np.concatenate([start, start + size], axis=1)) |
|
|
windows = np.concatenate(windows, axis=0) |
|
|
|
|
|
return [(int(box[0]), int(box[1]), int(box[2] - box[0]), |
|
|
int(box[3] - box[1])) for box in windows], (x_num, y_num) |
|
|
|
|
|
def square_pad(self, img: Image.Image) -> Image.Image: |
|
|
w, h = img.size |
|
|
if w == h: |
|
|
return img |
|
|
size = max(w, h) |
|
|
padded = Image.new(img.mode, (size, size), 0) |
|
|
padded.paste(img, (0, 0)) |
|
|
return padded |
|
|
|
|
|
def get_image_size_for_padding(self, img_width: int, |
|
|
img_height: int) -> tuple[int, int]: |
|
|
ratio = img_width / img_height |
|
|
if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4): |
|
|
new_size = max(img_height, img_width) |
|
|
return new_size, new_size |
|
|
return img_width, img_height |
|
|
|
|
|
def get_image_size_for_preprocess(self, img_width: int, |
|
|
img_height: int) -> tuple[int, int]: |
|
|
|
|
|
if max(img_height, img_width) > MAX_IMAGE_SIZE: |
|
|
scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width) |
|
|
img_width = int(img_width * scale_factor) |
|
|
img_height = int(img_height * scale_factor) |
|
|
return img_width, img_height |
|
|
|
|
|
def get_image_size_for_crop(self, img_width: int, img_height: int, |
|
|
window_size: int): |
|
|
w_ratio = img_width / window_size |
|
|
h_ratio = img_height / window_size |
|
|
|
|
|
if w_ratio < 1: |
|
|
width_new = img_width |
|
|
else: |
|
|
decimal_w = w_ratio - img_width // window_size |
|
|
w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio) |
|
|
width_new = window_size * w_ratio |
|
|
if h_ratio < 1: |
|
|
height_new = img_height |
|
|
else: |
|
|
decimal_h = h_ratio - img_height // window_size |
|
|
h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio) |
|
|
height_new = window_size * h_ratio |
|
|
return int(width_new), int(height_new) |
|
|
|
|
|
def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int): |
|
|
target = img.crop((j, i, j + tw, i + th)) |
|
|
return target |
|
|
|
|
|
def get_num_patches(self, img_width: int, |
|
|
img_height: int) -> tuple[int, int]: |
|
|
img_width, img_height = self.get_image_size_for_padding( |
|
|
img_width, img_height) |
|
|
img_width, img_height = self.get_image_size_for_preprocess( |
|
|
img_width, img_height) |
|
|
window_size = self.determine_window_size(max(img_height, img_width), |
|
|
min(img_height, img_width)) |
|
|
if window_size == 0: |
|
|
return 0, 0 |
|
|
else: |
|
|
img_width, img_height = self.get_image_size_for_crop( |
|
|
img_width, img_height, window_size) |
|
|
center_list, (x_num, y_num) = self.slide_window( |
|
|
img_width, img_height, [(window_size, window_size)], |
|
|
[(window_size, window_size)]) |
|
|
full_rows = (len(center_list) - 1) // x_num + 1 |
|
|
if len(center_list) > 0 and len(center_list) % x_num == 0: |
|
|
full_rows -= 1 |
|
|
return len(center_list), full_rows |
|
|
|
|
|
def __call__( |
|
|
self, img: Image.Image |
|
|
) -> tuple[Image.Image, list[Image.Image], list[bool] | None]: |
|
|
img_width, img_height = img.size |
|
|
new_img_width, new_img_height = self.get_image_size_for_padding( |
|
|
img_width, img_height) |
|
|
if new_img_width != img_width or new_img_height != img_height: |
|
|
img = self.square_pad(img) |
|
|
img_width, img_height = img.size |
|
|
|
|
|
new_img_width, new_img_height = self.get_image_size_for_preprocess( |
|
|
img_width, img_height) |
|
|
img = img.resize((new_img_width, new_img_height), |
|
|
Image.Resampling.BILINEAR) |
|
|
window_size = self.determine_window_size( |
|
|
max(new_img_height, new_img_width), |
|
|
min(new_img_height, new_img_width)) |
|
|
|
|
|
if window_size == 0: |
|
|
return img, [], None |
|
|
else: |
|
|
new_img_width, new_img_height = self.get_image_size_for_crop( |
|
|
new_img_width, new_img_height, window_size) |
|
|
if (new_img_width, new_img_height) != (img_width, img_height): |
|
|
img_for_crop = img.resize((new_img_width, new_img_height), |
|
|
Image.Resampling.BILINEAR) |
|
|
else: |
|
|
img_for_crop = img |
|
|
|
|
|
patches = [] |
|
|
newlines = [] |
|
|
center_list, (x_num, y_num) = self.slide_window( |
|
|
new_img_width, new_img_height, [(window_size, window_size)], |
|
|
[(window_size, window_size)]) |
|
|
for patch_id, center_lf_point in enumerate(center_list): |
|
|
x, y, patch_w, patch_h = center_lf_point |
|
|
big_patch = self.patch_crop(img_for_crop, y, x, patch_h, |
|
|
patch_w) |
|
|
patches.append(big_patch) |
|
|
if (patch_id + 1) % x_num == 0: |
|
|
newlines.append(patch_id) |
|
|
|
|
|
if newlines and newlines[-1] == len(patches) - 1: |
|
|
newlines.pop() |
|
|
|
|
|
return img, patches, [i in newlines for i in range(len(patches))] if len(patches) > 0 else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Step3VLProcessor(ProcessorMixin): |
|
|
|
|
|
|
|
|
attributes = ["tokenizer"] |
|
|
tokenizer_class = "AutoTokenizer" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tokenizer=None, |
|
|
chat_template=None, |
|
|
**kwargs |
|
|
) -> None: |
|
|
self.image_size = 728 |
|
|
self.patch_size = 504 |
|
|
|
|
|
self.image_preprocessor = Step3VisionProcessor(self.image_size, |
|
|
"bilinear", |
|
|
self.patch_size) |
|
|
|
|
|
self.num_image_feature_size = 169 |
|
|
self.num_patch_feature_size = 81 |
|
|
self.image_token = "<im_patch>" |
|
|
self.image_feature_placeholder = (self.image_token * |
|
|
self.num_image_feature_size) |
|
|
self.patch_feature_placeholder = (self.image_token * |
|
|
self.num_patch_feature_size) |
|
|
super().__init__(tokenizer=tokenizer, chat_template=chat_template, **kwargs) |
|
|
self.patcher = ImagePatcher() |
|
|
|
|
|
@property |
|
|
def image_token_id(self) -> int: |
|
|
return self.tokenizer.get_vocab()[self.image_token] |
|
|
|
|
|
def get_num_image_tokens(self, img_width: int, img_height: int) -> int: |
|
|
num_patches, num_newlines = self.patcher.get_num_patches( |
|
|
img_width, img_height) |
|
|
|
|
|
return num_patches * ( |
|
|
self.num_patch_feature_size + |
|
|
2) + self.num_image_feature_size + 2 + num_newlines |
|
|
|
|
|
def _split_images(self, |
|
|
images: list[Image.Image]) -> list[ImageWithPatches]: |
|
|
result = [] |
|
|
for img in images: |
|
|
result.append(self.patcher(img)) |
|
|
return result |
|
|
|
|
|
def _convert_images_to_pixel_values( |
|
|
self, |
|
|
images: list[Image.Image], |
|
|
is_patch: bool = False, |
|
|
) -> list[torch.Tensor]: |
|
|
return [ |
|
|
self.image_preprocessor(img, is_patch=is_patch)["pixel_values"] |
|
|
for img in images |
|
|
] |
|
|
|
|
|
def _get_patch_repl( |
|
|
self, |
|
|
num_patches: int, |
|
|
patch_newline_mask: list[bool] | None, |
|
|
) -> tuple[str, list[int]]: |
|
|
text = "" |
|
|
token_ids = [] |
|
|
for i in range(num_patches): |
|
|
assert len(patch_newline_mask) == num_patches |
|
|
text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>" |
|
|
token_ids.extend( |
|
|
[self.tokenizer.convert_tokens_to_ids("<patch_start>")] + |
|
|
[self.image_token_id] * self.num_patch_feature_size + |
|
|
[self.tokenizer.convert_tokens_to_ids("<patch_end>")]) |
|
|
if patch_newline_mask and patch_newline_mask[i]: |
|
|
text += "<patch_newline>" |
|
|
token_ids.append( |
|
|
self.tokenizer.convert_tokens_to_ids("<patch_newline>")) |
|
|
return text, token_ids |
|
|
|
|
|
def _get_image_repl( |
|
|
self, |
|
|
num_images: int, |
|
|
) -> tuple[str, list[int]]: |
|
|
text = f"<im_start>{self.image_feature_placeholder}<im_end>" |
|
|
token_ids = [ |
|
|
self.tokenizer.convert_tokens_to_ids("<im_start>") |
|
|
] + [self.image_token_id] * self.num_image_feature_size + [ |
|
|
self.tokenizer.convert_tokens_to_ids("<im_end>") |
|
|
] |
|
|
return text * num_images, token_ids * num_images |
|
|
|
|
|
def _get_image_repl_features( |
|
|
self, |
|
|
num_images: int, |
|
|
num_patches: int, |
|
|
patch_new_line_idx: Optional[list[bool]], |
|
|
) -> tuple[str, list[int]]: |
|
|
if num_patches > 0: |
|
|
patch_repl, patch_repl_ids = self._get_patch_repl( |
|
|
num_patches, patch_new_line_idx) |
|
|
else: |
|
|
patch_repl = "" |
|
|
patch_repl_ids = [] |
|
|
image_repl, image_repl_ids = self._get_image_repl(num_images) |
|
|
return patch_repl + image_repl, patch_repl_ids + image_repl_ids |
|
|
|
|
|
def replace_placeholder(self, text: str, placeholder: str, |
|
|
repls: list[str]) -> str: |
|
|
parts = text.split(placeholder) |
|
|
|
|
|
if len(parts) - 1 != len(repls): |
|
|
raise ValueError( |
|
|
"The number of placeholders does not match the number of replacements." |
|
|
) |
|
|
|
|
|
result = [parts[0]] |
|
|
for i, repl in enumerate(repls): |
|
|
result.append(repl) |
|
|
result.append(parts[i + 1]) |
|
|
|
|
|
return "".join(result) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
text: Optional[Union[str, list[str]]] = None, |
|
|
images: ImageInput | None = None, |
|
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
**kwargs, |
|
|
) -> BatchFeature: |
|
|
|
|
|
if images is not None: |
|
|
images = self.image_preprocessor.fetch_images(images) |
|
|
if text is None: |
|
|
text = [] |
|
|
if not isinstance(text, list): |
|
|
text = [text] |
|
|
if images is None: |
|
|
images = [] |
|
|
elif not isinstance(images, list): |
|
|
images = [images] |
|
|
elif isinstance(images[0], list): |
|
|
images = images[0] |
|
|
|
|
|
if len(images) == 0: |
|
|
image_inputs = {} |
|
|
text_inputs = self.tokenizer(text) |
|
|
else: |
|
|
splitted_images_data = self._split_images(images) |
|
|
pixel_values_lst = [] |
|
|
patch_pixel_values_lst = [] |
|
|
patch_newline_mask_lst = [] |
|
|
image_repl_str_lst = [] |
|
|
image_repl_ids_lst = [] |
|
|
num_patches = [] |
|
|
for raw_img, img_patches, patch_newline_mask in splitted_images_data: |
|
|
pixel_values_lst.extend( |
|
|
self._convert_images_to_pixel_values([raw_img])) |
|
|
|
|
|
if len(img_patches) > 0: |
|
|
patch_pixel_values_lst.extend( |
|
|
self._convert_images_to_pixel_values(img_patches, |
|
|
is_patch=True)) |
|
|
num_patches.append(len(img_patches)) |
|
|
|
|
|
image_repl_str, image_repl_ids = self._get_image_repl_features( |
|
|
1, len(img_patches), patch_newline_mask) |
|
|
image_repl_str_lst.append(image_repl_str) |
|
|
image_repl_ids_lst.extend(image_repl_ids) |
|
|
|
|
|
if patch_newline_mask is not None: |
|
|
patch_newline_mask_lst.extend(patch_newline_mask) |
|
|
|
|
|
image_inputs = { |
|
|
"pixel_values": torch.cat(pixel_values_lst), |
|
|
"num_patches": num_patches, |
|
|
} |
|
|
if patch_pixel_values_lst: |
|
|
image_inputs["patch_pixel_values"] = torch.cat( |
|
|
patch_pixel_values_lst) |
|
|
if patch_newline_mask_lst: |
|
|
image_inputs["patch_newline_mask"] = torch.tensor( |
|
|
patch_newline_mask_lst, dtype=torch.bool) |
|
|
|
|
|
text = [ |
|
|
self.replace_placeholder(t, self.image_token, |
|
|
image_repl_str_lst) for t in text |
|
|
] |
|
|
text_inputs = self.tokenizer(text) |
|
|
|
|
|
return BatchFeature( |
|
|
{ |
|
|
**text_inputs, |
|
|
**image_inputs, |
|
|
}, |
|
|
tensor_type=return_tensors, |
|
|
) |
|
|
|
|
|
|
|
|
def batch_decode(self, *args, **kwargs): |
|
|
""" |
|
|
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
|
refer to the docstring of this method for more information. |
|
|
""" |
|
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
|
|
|
|
|
def decode(self, *args, **kwargs): |
|
|
""" |
|
|
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to |
|
|
the docstring of this method for more information. |
|
|
""" |
|
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
|
|
__all__ = ["Step3VLProcessor"] |
|
|
|