Image-to-Text
Transformers
Safetensors
Japanese
English
sarashina2_vision
text-generation
multimodal
ocr
document-understanding
vision-language
custom_code
Instructions to use subhash4face/sarashina2.2-ocr with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use subhash4face/sarashina2.2-ocr with Transformers:
# Use a pipeline as a high-level helper # Warning: Pipeline type "image-to-text" is no longer supported in transformers v5. # You must load the model directly (see below) or downgrade to v4.x with: # 'pip install "transformers<5.0.0' from transformers import pipeline pipe = pipeline("image-to-text", model="subhash4face/sarashina2.2-ocr", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("subhash4face/sarashina2.2-ocr", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # coding=utf-8 | |
| # Copyright 2026 the SB Intuitions. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Processor class for Sarashina2Vision. | |
| """ | |
| import math | |
| from typing import Dict, List, Optional, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import ( | |
| AutoImageProcessor, | |
| AutoVideoProcessor, | |
| BaseImageProcessor, | |
| BaseVideoProcessor, | |
| ) | |
| from transformers.feature_extraction_utils import BatchFeature | |
| from transformers.image_transforms import ( | |
| convert_to_rgb, | |
| to_channel_dimension_format, | |
| ) | |
| from transformers.image_utils import ( | |
| OPENAI_CLIP_MEAN, | |
| OPENAI_CLIP_STD, | |
| ChannelDimension, | |
| ImageInput, | |
| get_image_size, | |
| infer_channel_dimension_format, | |
| is_scaled_image, | |
| make_flat_list_of_images, | |
| make_list_of_images, | |
| to_numpy_array, | |
| valid_images, | |
| ) | |
| from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize | |
| from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack | |
| from transformers.tokenization_utils_base import PreTokenizedInput, TextInput | |
| from transformers.utils import TensorType, logging | |
| from transformers.video_utils import VideoInput, VideoMetadata, load_video | |
| logger = logging.get_logger(__name__) | |
| class Sarashina2VisionImageProcessor(BaseImageProcessor): | |
| r""" | |
| Constructs a Sarashina2Vision image processor that dynamically resizes images based on the original images. | |
| Args: | |
| do_resize (`bool`, *optional*, defaults to `True`): | |
| Whether to resize the image's (height, width) dimensions. | |
| do_rescale (`bool`, *optional*, defaults to `True`): | |
| Whether to rescale the image by the specified scale `rescale_factor`. | |
| rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): | |
| Scale factor to use if rescaling the image. | |
| do_normalize (`bool`, *optional*, defaults to `True`): | |
| Whether to normalize the image. | |
| image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): | |
| Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. | |
| image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): | |
| Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. | |
| do_convert_rgb (`bool`, *optional*, defaults to `True`): | |
| Whether to convert the image to RGB. | |
| min_pixels (`int`, *optional*, defaults to `56 * 56`): | |
| The min pixels of the image to resize the image. | |
| max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`): | |
| The max pixels of the image to resize the image. | |
| patch_size (`int`, *optional*, defaults to 14): | |
| The spacial patch size of the vision encoder. | |
| temporal_patch_size (`int`, *optional*, defaults to 2): | |
| The temporal patch size of the vision encoder. | |
| merge_size (`int`, *optional*, defaults to 2): | |
| The merge size of the vision encoder to llm encoder. | |
| """ | |
| model_input_names = ["pixel_values", "image_grid_thw"] | |
| def __init__( | |
| self, | |
| do_resize: bool = True, | |
| do_rescale: bool = True, | |
| rescale_factor: Union[int, float] = 1 / 255, | |
| do_normalize: bool = True, | |
| image_mean: Optional[Union[float, List[float]]] = None, | |
| image_std: Optional[Union[float, List[float]]] = None, | |
| do_convert_rgb: bool = True, | |
| min_pixels: int = 56 * 56, | |
| max_pixels: int = 28 * 28 * 1280, | |
| patch_size: int = 14, | |
| temporal_patch_size: int = 2, | |
| merge_size: int = 2, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.do_resize = do_resize | |
| self.do_rescale = do_rescale | |
| self.rescale_factor = rescale_factor | |
| self.do_normalize = do_normalize | |
| self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN | |
| self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD | |
| self.min_pixels = min_pixels | |
| self.max_pixels = max_pixels | |
| self.patch_size = patch_size | |
| self.temporal_patch_size = temporal_patch_size | |
| self.merge_size = merge_size | |
| self.size = {"shortest_edge": min_pixels, "longest_edge": max_pixels} | |
| self.do_convert_rgb = do_convert_rgb | |
| def _preprocess( | |
| self, | |
| images: ImageInput, | |
| do_resize: bool = None, | |
| do_rescale: bool = None, | |
| rescale_factor: float = None, | |
| do_normalize: bool = None, | |
| image_mean: Optional[Union[float, List[float]]] = None, | |
| image_std: Optional[Union[float, List[float]]] = None, | |
| do_convert_rgb: bool = None, | |
| data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, | |
| input_data_format: Optional[Union[str, ChannelDimension]] = None, | |
| ): | |
| """ | |
| Preprocess an image or batch of images. Copy of the `preprocess` method from `Sarashina2Vision`. | |
| Args: | |
| images (`ImageInput`): | |
| Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. | |
| vision_info (`List[Dict]`, *optional*): | |
| Optional list of dictionaries containing additional information about vision inputs. | |
| do_resize (`bool`, *optional*, defaults to `self.do_resize`): | |
| Whether to resize the image. | |
| do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): | |
| Whether to rescale the image. | |
| rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): | |
| Scale factor to use if rescaling the image. | |
| do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): | |
| Whether to normalize the image. | |
| image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): | |
| Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. | |
| image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): | |
| Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. | |
| do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): | |
| Whether to convert the image to RGB. | |
| data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): | |
| The channel dimension format for the output image. Can be one of: | |
| - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. | |
| - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. | |
| - Unset: Use the channel dimension format of the input image. | |
| input_data_format (`ChannelDimension` or `str`, *optional*): | |
| The channel dimension format for the input image. Can be one of: | |
| - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. | |
| - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. | |
| - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. | |
| """ | |
| images = make_list_of_images(images) | |
| if do_convert_rgb: | |
| images = [convert_to_rgb(image) for image in images] | |
| # All transformations expect numpy arrays. | |
| images = [to_numpy_array(image) for image in images] | |
| if do_rescale and is_scaled_image(images[0]): | |
| logger.warning_once( | |
| "It looks like you are trying to rescale already rescaled images. If the input" | |
| " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." | |
| ) | |
| if input_data_format is None: | |
| # We assume that all images have the same channel dimension format. | |
| input_data_format = infer_channel_dimension_format(images[0]) | |
| height, width = get_image_size(images[0], channel_dim=input_data_format) | |
| resized_height, resized_width = height, width | |
| processed_images = [] | |
| for image in images: | |
| if do_rescale: | |
| image = self.rescale( | |
| image, scale=rescale_factor, input_data_format=input_data_format | |
| ) | |
| if do_normalize: | |
| image = self.normalize( | |
| image=image, | |
| mean=image_mean, | |
| std=image_std, | |
| input_data_format=input_data_format, | |
| ) | |
| image = to_channel_dimension_format( | |
| image, data_format, input_channel_dim=input_data_format | |
| ) | |
| if do_resize: | |
| resized_height, resized_width = smart_resize( | |
| height, | |
| width, | |
| factor=self.patch_size * self.merge_size, | |
| min_pixels=self.min_pixels, | |
| max_pixels=self.max_pixels, | |
| ) | |
| image = ( | |
| F.interpolate( | |
| torch.from_numpy(image).unsqueeze(0), | |
| size=(resized_height, resized_width), | |
| mode="bicubic", | |
| ) | |
| .squeeze(0) | |
| .numpy() | |
| ) | |
| processed_images.append(image) | |
| patches = np.array(processed_images) | |
| if data_format == ChannelDimension.LAST: | |
| patches = patches.transpose(0, 3, 1, 2) | |
| if patches.shape[0] % self.temporal_patch_size != 0: | |
| repeats = np.repeat(patches[-1][np.newaxis], self.temporal_patch_size - 1, axis=0) | |
| patches = np.concatenate([patches, repeats], axis=0) | |
| channel = patches.shape[1] | |
| grid_t = patches.shape[0] // self.temporal_patch_size | |
| grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size | |
| patches = patches.reshape( | |
| grid_t, | |
| self.temporal_patch_size, | |
| channel, | |
| grid_h // self.merge_size, | |
| self.merge_size, | |
| self.patch_size, | |
| grid_w // self.merge_size, | |
| self.merge_size, | |
| self.patch_size, | |
| ) | |
| patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) | |
| flatten_patches = patches.reshape( | |
| grid_t * grid_h * grid_w, | |
| channel * self.temporal_patch_size * self.patch_size * self.patch_size, | |
| ) | |
| return flatten_patches, (grid_t, grid_h, grid_w) | |
| def preprocess( | |
| self, | |
| images: ImageInput, | |
| do_resize: bool = None, | |
| size: Dict[str, int] = None, | |
| do_rescale: bool = None, | |
| rescale_factor: float = None, | |
| do_normalize: bool = None, | |
| image_mean: Optional[Union[float, List[float]]] = None, | |
| image_std: Optional[Union[float, List[float]]] = None, | |
| do_convert_rgb: bool = None, | |
| return_tensors: Optional[Union[str, TensorType]] = None, | |
| data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, | |
| input_data_format: Optional[Union[str, ChannelDimension]] = None, | |
| ): | |
| """ | |
| Args: | |
| images (`ImageInput`): | |
| Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If | |
| passing in images with pixel values between 0 and 1, set `do_rescale=False`. | |
| do_resize (`bool`, *optional*, defaults to `self.do_resize`): | |
| Whether to resize the image. | |
| size (`Dict[str, int]`, *optional*, defaults to `self.size`): | |
| Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with | |
| the longest edge resized to keep the input aspect ratio. | |
| do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): | |
| Whether to rescale the image. | |
| rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): | |
| Rescale factor to rescale the image by if `do_rescale` is set to `True`. | |
| do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): | |
| Whether to normalize the image. | |
| image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): | |
| Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. | |
| image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): | |
| Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to | |
| `True`. | |
| do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): | |
| Whether to convert the image to RGB. | |
| return_tensors (`str` or `TensorType`, *optional*): | |
| The type of tensors to return. Can be one of: | |
| - Unset: Return a list of `np.ndarray`. | |
| - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. | |
| - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. | |
| - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. | |
| - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. | |
| data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): | |
| The channel dimension format for the output image. Can be one of: | |
| - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. | |
| - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. | |
| - Unset: Use the channel dimension format of the input image. | |
| input_data_format (`ChannelDimension` or `str`, *optional*): | |
| The channel dimension format for the input image. If unset, the channel dimension format is inferred | |
| from the input image. Can be one of: | |
| - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. | |
| - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. | |
| - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. | |
| """ | |
| do_resize = do_resize if do_resize is not None else self.do_resize | |
| size = size if size is not None else self.size | |
| do_rescale = do_rescale if do_rescale is not None else self.do_rescale | |
| rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor | |
| do_normalize = do_normalize if do_normalize is not None else self.do_normalize | |
| image_mean = image_mean if image_mean is not None else self.image_mean | |
| image_std = image_std if image_std is not None else self.image_std | |
| do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb | |
| if images is not None: | |
| images = make_flat_list_of_images(images) | |
| if images is not None and not valid_images(images): | |
| raise ValueError( | |
| "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " | |
| "torch.Tensor, tf.Tensor or jax.ndarray." | |
| ) | |
| if images is not None: | |
| pixel_values, vision_grid_thws = [], [] | |
| for image in images: | |
| patches, image_grid_thw = self._preprocess( | |
| image, | |
| do_resize=do_resize, | |
| do_rescale=do_rescale, | |
| rescale_factor=rescale_factor, | |
| do_normalize=do_normalize, | |
| image_mean=image_mean, | |
| image_std=image_std, | |
| data_format=data_format, | |
| do_convert_rgb=do_convert_rgb, | |
| input_data_format=input_data_format, | |
| ) | |
| pixel_values.extend(patches) | |
| vision_grid_thws.append(image_grid_thw) | |
| pixel_values = np.array(pixel_values) | |
| vision_grid_thws = np.array(vision_grid_thws) | |
| data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws} | |
| return BatchFeature(data=data, tensor_type=return_tensors) | |
| class Sarashina2VisionVideoProcessor(BaseVideoProcessor): | |
| def __init__( | |
| self, | |
| do_rescale: bool = True, | |
| rescale_factor: Union[int, float] = 1 / 255, | |
| do_normalize: bool = True, | |
| image_mean: Optional[Union[float, List[float]]] = None, | |
| image_std: Optional[Union[float, List[float]]] = None, | |
| max_pixels: int = 28 * 28 * 1280, | |
| patch_size: int = 14, | |
| temporal_patch_size: int = 2, | |
| merge_size: int = 2, | |
| fps: int = 2, | |
| fps_min_frames: int = 2, | |
| fps_max_frames: int = 64, | |
| video_min_token_num: int = 128, | |
| video_max_token_num: int = 768, | |
| total_pixels: int = 3072 * 28 * 28, | |
| do_sample_frames: bool = True, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.do_rescale = do_rescale | |
| self.rescale_factor = rescale_factor | |
| self.do_normalize = do_normalize | |
| self.image_mean = image_mean | |
| self.image_std = image_std | |
| self.max_pixels = max_pixels | |
| self.patch_size = patch_size | |
| self.merge_size = merge_size | |
| self.image_factor = self.patch_size * self.merge_size | |
| self.fps = fps | |
| self.fps_min_frames = fps_min_frames | |
| self.fps_max_frames = fps_max_frames | |
| self.do_sample_frames = do_sample_frames | |
| self.video_min_token_num = video_min_token_num | |
| self.video_max_token_num = video_max_token_num | |
| self.temporal_patch_size = temporal_patch_size | |
| self.total_pixels = max(total_pixels, max_pixels) | |
| def sample_frames( | |
| self, | |
| metadata: VideoMetadata, | |
| **kwargs, | |
| ): | |
| total_num_frames = metadata.total_num_frames | |
| min_frames = ( | |
| math.ceil(self.fps_min_frames / self.temporal_patch_size) * self.temporal_patch_size | |
| ) | |
| max_frames = min(self.fps_max_frames, total_num_frames) | |
| nframes = total_num_frames / metadata.fps * self.fps | |
| if nframes > total_num_frames: | |
| logger.warning( | |
| f"smart_nframes: nframes[{nframes}] > total_num_frames[{total_num_frames}]" | |
| ) | |
| nframes = min(min(max(nframes, min_frames), max_frames), total_num_frames) | |
| nframes = math.floor(nframes / self.temporal_patch_size) * self.temporal_patch_size | |
| if not (self.temporal_patch_size <= nframes and nframes <= total_num_frames): | |
| raise ValueError( | |
| f"nframes should in interval [{self.temporal_patch_size}, {total_num_frames}], but got {nframes}." | |
| ) | |
| indices = torch.linspace(0, total_num_frames - 1, nframes).round().long().tolist() | |
| return indices | |
| def _preprocess( | |
| self, | |
| videos: list["torch.Tensor"], | |
| do_rescale: bool, | |
| rescale_factor: float, | |
| do_normalize: bool, | |
| image_mean: Optional[Union[float, list[float]]], | |
| image_std: Optional[Union[float, list[float]]], | |
| return_tensors: Optional[Union[str, TensorType]] = None, | |
| **kwargs, | |
| ) -> BatchFeature: | |
| pixel_values = [] | |
| vision_grid_thws = [] | |
| for video in videos: | |
| video = self.convert_to_rgb(video) | |
| video = self.rescale_and_normalize( | |
| video, | |
| do_rescale, | |
| rescale_factor, | |
| do_normalize, | |
| image_mean, | |
| image_std, | |
| ) | |
| nframes, _, height, width = video.shape | |
| min_pixels = self.video_min_token_num * (self.image_factor**2) | |
| total_pixels = self.total_pixels | |
| max_pixels = min( | |
| self.max_pixels, | |
| max( | |
| total_pixels / nframes * self.temporal_patch_size, | |
| int(min_pixels * 1.05), | |
| ), | |
| ) | |
| resized_height, resized_width = smart_resize( | |
| height, | |
| width, | |
| factor=self.image_factor, | |
| min_pixels=min_pixels, | |
| max_pixels=max_pixels, | |
| ) | |
| video = F.interpolate( | |
| video, | |
| size=(resized_height, resized_width), | |
| mode="bicubic", | |
| ) | |
| if video.shape[0] % self.temporal_patch_size != 0: | |
| repeats = video[-1].unsqueeze(0).repeat(self.temporal_patch_size - 1, 1, 1, 1) | |
| patch = torch.cat([video, repeats], dim=0) | |
| else: | |
| patch = video | |
| grid_t = patch.shape[0] // self.temporal_patch_size | |
| channel = patch.shape[1] | |
| grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size | |
| patch = patch.reshape( | |
| grid_t, | |
| self.temporal_patch_size, | |
| channel, | |
| grid_h // self.merge_size, | |
| self.merge_size, | |
| self.patch_size, | |
| grid_w // self.merge_size, | |
| self.merge_size, | |
| self.patch_size, | |
| ) | |
| patch = patch.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) | |
| flatten_patch = patch.reshape( | |
| grid_t * grid_h * grid_w, | |
| channel * self.temporal_patch_size * self.patch_size * self.patch_size, | |
| ) | |
| pixel_values.extend(np.array(flatten_patch)) | |
| vision_grid_thws.append((grid_t, grid_h, grid_w)) | |
| data = { | |
| "pixel_values_video": np.array(pixel_values), | |
| "video_grid_thw": np.array(vision_grid_thws), | |
| } | |
| return BatchFeature(data=data, tensor_type=return_tensors) | |
| def fetch_videos( | |
| self, | |
| video_url_or_urls: Union[str, list[str], list[list[str]]], | |
| sample_indices_fn=None, | |
| backend="torchvision", | |
| ): | |
| """ | |
| Convert a single or a list of urls into the corresponding `np.array` objects. | |
| If a single url is passed, the return value will be a single object. If a list is passed a list of objects is | |
| returned. | |
| """ | |
| if isinstance(video_url_or_urls, list): | |
| return list( | |
| zip( | |
| *[ | |
| self.fetch_videos(x, sample_indices_fn=sample_indices_fn, backend=backend) | |
| for x in video_url_or_urls | |
| ] | |
| ) | |
| ) | |
| else: | |
| device = self.device if hasattr(self, "device") and self.device is not None else "cpu" | |
| return load_video( | |
| video_url_or_urls, | |
| backend=backend, | |
| sample_indices_fn=sample_indices_fn, | |
| device=device, | |
| ) | |
| class Sarashina2VisionProcessorKwargs(ProcessingKwargs, total=False): | |
| _defaults = { | |
| "text_kwargs": { | |
| "padding": False, | |
| }, | |
| } | |
| class Sarashina2VisionProcessor(ProcessorMixin): | |
| r""" | |
| Constructs Sarashina2Vision processor which wraps a Sarashina2Vision image processor and a LLama tokenizer into a single processor. | |
| [`Sarashina2VisionProcessor`] offers all the functionalities of [`Sarashina2VisionImageProcessor`] and [`LlamaTokenizerFast`]. See the | |
| [`~Sarashina2VisionProcessor.__call__`] and [`~Sarashina2VisionProcessor.decode`] for more information. | |
| Args: | |
| image_processor ([`Sarashina2VisionImageProcessor`], *optional*): | |
| The image processor is a required input. | |
| tokenizer ([`LlamaTokenizerFast`], *optional*): | |
| The tokenizer is a required input. | |
| chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages | |
| in a chat into a tokenizable string. | |
| """ | |
| attributes = ["image_processor", "video_processor", "tokenizer"] | |
| valid_kwargs = ["chat_template"] | |
| image_processor_class = "AutoImageProcessor" | |
| video_processor_class = "AutoVideoProcessor" | |
| tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") | |
| def __init__( | |
| self, | |
| image_processor=None, | |
| video_processor=None, | |
| tokenizer=None, | |
| chat_template=None, | |
| **kwargs, | |
| ): | |
| self.image_token = ( | |
| "<|file|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token | |
| ) | |
| self.video_token = ( | |
| "<|middle|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token | |
| ) | |
| super().__init__(image_processor, video_processor, tokenizer, chat_template=chat_template) | |
| def __call__( | |
| self, | |
| images: ImageInput = None, | |
| videos: VideoInput = None, | |
| text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, | |
| **kwargs: Unpack[Sarashina2VisionProcessorKwargs], | |
| ) -> BatchFeature: | |
| """ | |
| Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` | |
| and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode | |
| the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to | |
| Sarashina2VisionImageProcessor's [`~Sarashina2VisionImageProcessor.__call__`] if `vision_infos` is not `None`. | |
| Args: | |
| images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): | |
| The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch | |
| tensor. Both channels-first and channels-last formats are supported. | |
| text (`str`, `List[str]`, `List[List[str]]`): | |
| The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings | |
| (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set | |
| `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). | |
| return_tensors (`str` or [`~utils.TensorType`], *optional*): | |
| If set, will return tensors of a particular framework. Acceptable values are: | |
| - `'tf'`: Return TensorFlow `tf.constant` objects. | |
| - `'pt'`: Return PyTorch `torch.Tensor` objects. | |
| - `'np'`: Return NumPy `np.ndarray` objects. | |
| - `'jax'`: Return JAX `jnp.ndarray` objects. | |
| Returns: | |
| [`BatchFeature`]: A [`BatchFeature`] with the following fields: | |
| - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. | |
| - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when | |
| `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not | |
| `None`). | |
| - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. | |
| - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. | |
| """ | |
| output_kwargs = self._merge_kwargs( | |
| Sarashina2VisionProcessorKwargs, | |
| tokenizer_init_kwargs=self.tokenizer.init_kwargs, | |
| **kwargs, | |
| ) | |
| if images is not None: | |
| image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) | |
| image_grid_thw = image_inputs["image_grid_thw"] | |
| else: | |
| image_inputs = {} | |
| image_grid_thw = None | |
| if videos is not None: | |
| video_inputs = self.video_processor(videos=videos, **output_kwargs["images_kwargs"]) | |
| video_grid_thw = video_inputs["video_grid_thw"] | |
| else: | |
| video_inputs = {} | |
| video_grid_thw = None | |
| if not isinstance(text, list): | |
| text = [text] | |
| if image_grid_thw is not None or video_grid_thw is not None: | |
| merge_length = self.image_processor.merge_size**2 | |
| image_index = 0 | |
| video_index = 0 | |
| for i in range(len(text)): | |
| if images is not None: | |
| while self.image_token in text[i]: | |
| text[i] = text[i].replace( | |
| self.image_token, | |
| "<|placeholder|>" | |
| * (image_grid_thw[image_index].prod() // merge_length), | |
| 1, | |
| ) | |
| image_index += 1 | |
| text[i] = text[i].replace("<|placeholder|>", self.image_token) | |
| if videos is not None: | |
| while self.video_token in text[i]: | |
| text[i] = text[i].replace( | |
| self.video_token, | |
| "<|placeholder|>" | |
| * (video_grid_thw[video_index].prod() // merge_length), | |
| 1, | |
| ) | |
| video_index += 1 | |
| text[i] = text[i].replace("<|placeholder|>", self.video_token) | |
| text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) | |
| return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs}) | |
| def batch_decode(self, *args, **kwargs): | |
| """ | |
| This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. | |
| """ | |
| return self.tokenizer.batch_decode(*args, **kwargs) | |
| def decode(self, *args, **kwargs): | |
| """ | |
| This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. | |
| """ | |
| return self.tokenizer.decode(*args, **kwargs) | |
| def post_process_image_text_to_text(self, generated_outputs): | |
| """ | |
| Post-process the output of the model to decode the text. | |
| Args: | |
| generated_outputs (`torch.Tensor` or `np.ndarray`): | |
| The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` | |
| or `(sequence_length,)`. | |
| Returns: | |
| `List[str]`: The decoded text. | |
| """ | |
| return self.tokenizer.batch_decode( | |
| generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| def model_input_names(self): | |
| tokenizer_input_names = self.tokenizer.model_input_names | |
| image_processor_input_names = self.image_processor.model_input_names | |
| return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) | |
| Sarashina2VisionProcessor.register_for_auto_class("AutoProcessor") | |
| AutoImageProcessor.register("Sarashina2VisionImageProcessor", Sarashina2VisionImageProcessor) | |
| AutoVideoProcessor.register("Sarashina2VisionVideoProcessor", Sarashina2VisionVideoProcessor) | |