|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Union
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from transformers.feature_extraction_utils import BatchFeature
|
|
|
from transformers.processing_utils import (
|
|
|
ProcessingKwargs,
|
|
|
ProcessorMixin,
|
|
|
Unpack,
|
|
|
VideosKwargs,
|
|
|
)
|
|
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
|
|
|
|
|
|
|
|
ImageInput = Union[
|
|
|
"PIL.Image.Image",
|
|
|
np.ndarray,
|
|
|
"torch.Tensor",
|
|
|
List["PIL.Image.Image"],
|
|
|
List[np.ndarray],
|
|
|
List["torch.Tensor"],
|
|
|
]
|
|
|
|
|
|
|
|
|
VideoInput = Union[
|
|
|
List["PIL.Image.Image"],
|
|
|
"np.ndarray",
|
|
|
"torch.Tensor",
|
|
|
List["np.ndarray"],
|
|
|
List["torch.Tensor"],
|
|
|
List[List["PIL.Image.Image"]],
|
|
|
List[List["np.ndarrray"]],
|
|
|
List[List["torch.Tensor"]],
|
|
|
]
|
|
|
|
|
|
|
|
|
class PaddleOCRVLVideosProcessorKwargs(VideosKwargs, total=False):
|
|
|
fps: Union[List[float], float]
|
|
|
|
|
|
|
|
|
class PaddleOCRVLProcessorKwargs(ProcessingKwargs, total=False):
|
|
|
videos_kwargs: PaddleOCRVLVideosProcessorKwargs
|
|
|
_defaults = {
|
|
|
"text_kwargs": {
|
|
|
"padding": False,
|
|
|
},
|
|
|
"videos_kwargs": {"fps": 2.0},
|
|
|
}
|
|
|
|
|
|
|
|
|
class PaddleOCRVLProcessor(ProcessorMixin):
|
|
|
r"""
|
|
|
[`PaddleOCRVLProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`Qwen2TokenizerFast`]. See the
|
|
|
[`~PaddleOCRVLProcessor.__call__`] and [`~PaddleOCRVLProcessor.decode`] for more information.
|
|
|
Args:
|
|
|
image_processor ([`SiglipImageProcessor`], *optional*):
|
|
|
The image processor is a required input.
|
|
|
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
|
|
The tokenizer is a required input.
|
|
|
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
|
|
in a chat into a tokenizable string.
|
|
|
"""
|
|
|
|
|
|
attributes = ["image_processor", "tokenizer"]
|
|
|
valid_kwargs = [
|
|
|
"chat_template",
|
|
|
"image_std",
|
|
|
"min_pixels",
|
|
|
"image_mean",
|
|
|
"merge_size",
|
|
|
"image_processor_type",
|
|
|
"temporal_patch_size",
|
|
|
"patch_size",
|
|
|
"max_pixels",
|
|
|
]
|
|
|
|
|
|
image_processor_class = "AutoImageProcessor"
|
|
|
tokenizer_class = "AutoTokenizer"
|
|
|
|
|
|
def __init__(
|
|
|
self, image_processor=None, tokenizer=None, chat_template=None, **kwargs
|
|
|
):
|
|
|
self.image_token = (
|
|
|
"<|IMAGE_PLACEHOLDER|>"
|
|
|
if not hasattr(tokenizer, "image_token")
|
|
|
else tokenizer.image_token
|
|
|
)
|
|
|
self.video_token = (
|
|
|
"<|video_pad|>"
|
|
|
if not hasattr(tokenizer, "video_token")
|
|
|
else tokenizer.video_token
|
|
|
)
|
|
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
images: ImageInput = None,
|
|
|
text: Union[
|
|
|
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
|
|
|
] = None,
|
|
|
videos: VideoInput = None,
|
|
|
**kwargs: Unpack[PaddleOCRVLProcessorKwargs],
|
|
|
) -> BatchFeature:
|
|
|
"""
|
|
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
|
|
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
|
|
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
|
|
SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `vision_infos` is not `None`.
|
|
|
|
|
|
Args:
|
|
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
|
tensor. Both channels-first and channels-last formats are supported.
|
|
|
text (`str`, `List[str]`, `List[List[str]]`):
|
|
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
|
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
|
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
|
|
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
|
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
|
|
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
|
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
|
|
If set, will return tensors of a particular framework. Acceptable values are:
|
|
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
|
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
|
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
|
|
|
|
|
Returns:
|
|
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
|
|
|
|
|
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
|
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
|
|
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
|
|
`None`).
|
|
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
|
|
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
|
|
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
|
|
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
|
|
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
|
|
|
"""
|
|
|
output_kwargs = self._merge_kwargs(
|
|
|
PaddleOCRVLProcessorKwargs,
|
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
if images is not None:
|
|
|
image_inputs = self.image_processor(images=images, return_tensors="pt")
|
|
|
image_inputs["pixel_values"] = image_inputs["pixel_values"]
|
|
|
image_grid_thw = image_inputs["image_grid_thw"]
|
|
|
|
|
|
else:
|
|
|
image_inputs = {}
|
|
|
image_grid_thw = None
|
|
|
|
|
|
if videos is not None:
|
|
|
|
|
|
videos_inputs = self.image_processor(
|
|
|
images=None, videos=videos, **output_kwargs["images_kwargs"]
|
|
|
)
|
|
|
video_grid_thw = videos_inputs["video_grid_thw"]
|
|
|
|
|
|
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
|
|
if isinstance(fps, (int, float)):
|
|
|
second_per_grid_ts = [
|
|
|
self.image_processor.temporal_patch_size / fps
|
|
|
] * len(video_grid_thw)
|
|
|
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
|
|
|
second_per_grid_ts = [
|
|
|
self.image_processor.temporal_patch_size / tmp for tmp in fps
|
|
|
]
|
|
|
else:
|
|
|
raise ValueError(
|
|
|
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
|
|
|
)
|
|
|
videos_inputs.update(
|
|
|
{"second_per_grid_ts": torch.tensor(second_per_grid_ts)}
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
videos_inputs = {}
|
|
|
video_grid_thw = None
|
|
|
|
|
|
if not isinstance(text, list):
|
|
|
text = [text]
|
|
|
|
|
|
if image_grid_thw is not None:
|
|
|
index = 0
|
|
|
for i in range(len(text)):
|
|
|
while self.image_token in text[i]:
|
|
|
text[i] = text[i].replace(
|
|
|
self.image_token,
|
|
|
"<|placeholder|>"
|
|
|
* (
|
|
|
image_grid_thw[index].prod()
|
|
|
// self.image_processor.merge_size
|
|
|
// self.image_processor.merge_size
|
|
|
),
|
|
|
1,
|
|
|
)
|
|
|
index += 1
|
|
|
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
|
|
|
|
|
if video_grid_thw is not None:
|
|
|
index = 0
|
|
|
for i in range(len(text)):
|
|
|
while self.video_token in text[i]:
|
|
|
text[i] = text[i].replace(
|
|
|
self.video_token,
|
|
|
"<|placeholder|>"
|
|
|
* (
|
|
|
video_grid_thw[index].prod()
|
|
|
// self.image_processor.merge_size
|
|
|
// self.image_processor.merge_size
|
|
|
),
|
|
|
1,
|
|
|
)
|
|
|
index += 1
|
|
|
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
|
|
|
|
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
|
|
|
|
|
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
|
|
|
|
|
def batch_decode(self, *args, **kwargs):
|
|
|
"""
|
|
|
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
|
|
refer to the docstring of this method for more information.
|
|
|
"""
|
|
|
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
|
|
|
|
def decode(self, *args, **kwargs):
|
|
|
"""
|
|
|
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
|
|
the docstring of this method for more information.
|
|
|
"""
|
|
|
return self.tokenizer.decode(*args, **kwargs)
|
|
|
|
|
|
def post_process_image_text_to_text(
|
|
|
self,
|
|
|
generated_outputs,
|
|
|
skip_special_tokens=True,
|
|
|
clean_up_tokenization_spaces=False,
|
|
|
**kwargs,
|
|
|
):
|
|
|
"""
|
|
|
Post-process the output of the model to decode the text.
|
|
|
|
|
|
Args:
|
|
|
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
|
|
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
|
|
or `(sequence_length,)`.
|
|
|
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
|
|
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
|
|
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
|
|
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
|
|
**kwargs:
|
|
|
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
|
|
|
|
|
Returns:
|
|
|
`List[str]`: The decoded text.
|
|
|
"""
|
|
|
return self.tokenizer.batch_decode(
|
|
|
generated_outputs,
|
|
|
skip_special_tokens=skip_special_tokens,
|
|
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
@property
|
|
|
def model_input_names(self):
|
|
|
tokenizer_input_names = self.tokenizer.model_input_names
|
|
|
image_processor_input_names = self.image_processor.model_input_names
|
|
|
names_from_processor = list(
|
|
|
dict.fromkeys(tokenizer_input_names + image_processor_input_names)
|
|
|
)
|
|
|
return names_from_processor + ["second_per_grid_ts"]
|
|
|
|
|
|
|
|
|
__all__ = ["PaddleOCRVLProcessor", "PaddleOCRVLProcessor"]
|
|
|
|