|
|
"""AshishOCR processor for handling image and text inputs."""
|
|
|
|
|
|
from typing import List, Optional, Union
|
|
|
|
|
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
|
|
from transformers.image_utils import ImageInput
|
|
|
from transformers.processing_utils import ProcessorMixin
|
|
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
|
|
|
|
|
|
|
|
class AshishOcrImageProcessor(BaseImageProcessor):
|
|
|
"""Image processor for AshishOCR model."""
|
|
|
|
|
|
model_input_names = ["pixel_values", "image_grid_thw"]
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
do_resize: bool = True,
|
|
|
size: dict = None,
|
|
|
do_rescale: bool = True,
|
|
|
rescale_factor: float = 1/255,
|
|
|
do_normalize: bool = True,
|
|
|
image_mean: list = None,
|
|
|
image_std: list = None,
|
|
|
min_pixels: int = 56 * 56,
|
|
|
max_pixels: int = 28 * 28 * 1280,
|
|
|
patch_size: int = 14,
|
|
|
temporal_patch_size: int = 2,
|
|
|
merge_size: int = 2,
|
|
|
**kwargs,
|
|
|
):
|
|
|
super().__init__(**kwargs)
|
|
|
self.do_resize = do_resize
|
|
|
self.size = size if size is not None else {"shortest_edge": 336}
|
|
|
self.do_rescale = do_rescale
|
|
|
self.rescale_factor = rescale_factor
|
|
|
self.do_normalize = do_normalize
|
|
|
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
|
|
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
|
|
self.min_pixels = min_pixels
|
|
|
self.max_pixels = max_pixels
|
|
|
self.patch_size = patch_size
|
|
|
self.temporal_patch_size = temporal_patch_size
|
|
|
self.merge_size = merge_size
|
|
|
|
|
|
def preprocess(
|
|
|
self,
|
|
|
images: ImageInput,
|
|
|
**kwargs,
|
|
|
) -> BatchFeature:
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from PIL import Image
|
|
|
|
|
|
if not isinstance(images, list):
|
|
|
images = [images]
|
|
|
|
|
|
processed_images = []
|
|
|
grid_thw = []
|
|
|
|
|
|
for image in images:
|
|
|
if isinstance(image, str):
|
|
|
image = Image.open(image).convert("RGB")
|
|
|
elif not isinstance(image, Image.Image):
|
|
|
image = Image.fromarray(np.array(image))
|
|
|
|
|
|
|
|
|
width, height = image.size
|
|
|
target_size = self.size.get("shortest_edge", 336)
|
|
|
|
|
|
|
|
|
if width < height:
|
|
|
new_width = target_size
|
|
|
new_height = int(height * target_size / width)
|
|
|
else:
|
|
|
new_height = target_size
|
|
|
new_width = int(width * target_size / height)
|
|
|
|
|
|
|
|
|
new_width = (new_width // self.patch_size) * self.patch_size
|
|
|
new_height = (new_height // self.patch_size) * self.patch_size
|
|
|
|
|
|
image = image.resize((new_width, new_height), Image.BILINEAR)
|
|
|
|
|
|
|
|
|
image_array = np.array(image).astype(np.float32)
|
|
|
|
|
|
if self.do_rescale:
|
|
|
image_array = image_array * self.rescale_factor
|
|
|
|
|
|
if self.do_normalize:
|
|
|
image_array = (image_array - np.array(self.image_mean)) / np.array(self.image_std)
|
|
|
|
|
|
|
|
|
image_tensor = torch.tensor(image_array).permute(2, 0, 1)
|
|
|
|
|
|
|
|
|
image_tensor = image_tensor.unsqueeze(1).repeat(1, self.temporal_patch_size, 1, 1)
|
|
|
|
|
|
processed_images.append(image_tensor)
|
|
|
|
|
|
|
|
|
t = 1
|
|
|
h = new_height // self.patch_size
|
|
|
w = new_width // self.patch_size
|
|
|
grid_thw.append([t, h, w])
|
|
|
|
|
|
pixel_values = torch.stack(processed_images, dim=0)
|
|
|
image_grid_thw = torch.tensor(grid_thw, dtype=torch.long)
|
|
|
|
|
|
return BatchFeature(data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw})
|
|
|
|
|
|
|
|
|
class AshishOcrProcessor(ProcessorMixin):
|
|
|
"""Processor for AshishOCR that combines image processor and tokenizer."""
|
|
|
|
|
|
attributes = ["image_processor", "tokenizer"]
|
|
|
image_processor_class = "AshishOcrImageProcessor"
|
|
|
tokenizer_class = "AutoTokenizer"
|
|
|
|
|
|
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
|
|
super().__init__(image_processor, tokenizer)
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
|
|
images: ImageInput = None,
|
|
|
videos: ImageInput = None,
|
|
|
padding: bool = False,
|
|
|
truncation: bool = None,
|
|
|
max_length: int = None,
|
|
|
return_tensors: str = None,
|
|
|
**kwargs,
|
|
|
) -> BatchFeature:
|
|
|
encoding = BatchFeature()
|
|
|
|
|
|
if images is not None:
|
|
|
image_features = self.image_processor(images, **kwargs)
|
|
|
encoding.update(image_features)
|
|
|
|
|
|
if text is not None:
|
|
|
text_encoding = self.tokenizer(
|
|
|
text,
|
|
|
padding=padding,
|
|
|
truncation=truncation,
|
|
|
max_length=max_length,
|
|
|
return_tensors=return_tensors,
|
|
|
**kwargs,
|
|
|
)
|
|
|
encoding.update(text_encoding)
|
|
|
|
|
|
return encoding
|
|
|
|
|
|
def batch_decode(self, *args, **kwargs):
|
|
|
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
|
|
|
|
def decode(self, *args, **kwargs):
|
|
|
return self.tokenizer.decode(*args, **kwargs)
|
|
|
|
|
|
@property
|
|
|
def model_input_names(self):
|
|
|
tokenizer_input_names = self.tokenizer.model_input_names
|
|
|
image_processor_input_names = self.image_processor.model_input_names
|
|
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
|
|
|