ocr-character / processing_ashish_ocr.py
imdigitalashish's picture
Upload folder using huggingface_hub
c2d866c verified
"""AshishOCR processor for handling image and text inputs."""
from typing import List, Optional, Union
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
class AshishOcrImageProcessor(BaseImageProcessor):
"""Image processor for AshishOCR model."""
model_input_names = ["pixel_values", "image_grid_thw"]
def __init__(
self,
do_resize: bool = True,
size: dict = None,
do_rescale: bool = True,
rescale_factor: float = 1/255,
do_normalize: bool = True,
image_mean: list = None,
image_std: list = None,
min_pixels: int = 56 * 56,
max_pixels: int = 28 * 28 * 1280,
patch_size: int = 14,
temporal_patch_size: int = 2,
merge_size: int = 2,
**kwargs,
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size if size is not None else {"shortest_edge": 336}
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.merge_size = merge_size
def preprocess(
self,
images: ImageInput,
**kwargs,
) -> BatchFeature:
import numpy as np
import torch
from PIL import Image
if not isinstance(images, list):
images = [images]
processed_images = []
grid_thw = []
for image in images:
if isinstance(image, str):
image = Image.open(image).convert("RGB")
elif not isinstance(image, Image.Image):
image = Image.fromarray(np.array(image))
# Resize
width, height = image.size
target_size = self.size.get("shortest_edge", 336)
# Calculate resize dimensions
if width < height:
new_width = target_size
new_height = int(height * target_size / width)
else:
new_height = target_size
new_width = int(width * target_size / height)
# Ensure dimensions are divisible by patch_size
new_width = (new_width // self.patch_size) * self.patch_size
new_height = (new_height // self.patch_size) * self.patch_size
image = image.resize((new_width, new_height), Image.BILINEAR)
# Convert to tensor
image_array = np.array(image).astype(np.float32)
if self.do_rescale:
image_array = image_array * self.rescale_factor
if self.do_normalize:
image_array = (image_array - np.array(self.image_mean)) / np.array(self.image_std)
# HWC to CHW
image_tensor = torch.tensor(image_array).permute(2, 0, 1)
# Add temporal dimension for 3D conv: (C, H, W) -> (C, T, H, W)
image_tensor = image_tensor.unsqueeze(1).repeat(1, self.temporal_patch_size, 1, 1)
processed_images.append(image_tensor)
# Calculate grid size (T, H, W in patches)
t = 1
h = new_height // self.patch_size
w = new_width // self.patch_size
grid_thw.append([t, h, w])
pixel_values = torch.stack(processed_images, dim=0)
image_grid_thw = torch.tensor(grid_thw, dtype=torch.long)
return BatchFeature(data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw})
class AshishOcrProcessor(ProcessorMixin):
"""Processor for AshishOCR that combines image processor and tokenizer."""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AshishOcrImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
super().__init__(image_processor, tokenizer)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
videos: ImageInput = None,
padding: bool = False,
truncation: bool = None,
max_length: int = None,
return_tensors: str = None,
**kwargs,
) -> BatchFeature:
encoding = BatchFeature()
if images is not None:
image_features = self.image_processor(images, **kwargs)
encoding.update(image_features)
if text is not None:
text_encoding = self.tokenizer(
text,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
**kwargs,
)
encoding.update(text_encoding)
return encoding
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))