"""AshishOCR processor for handling image and text inputs.""" from typing import List, Optional, Union from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_utils import ImageInput from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils_base import PreTokenizedInput, TextInput class AshishOcrImageProcessor(BaseImageProcessor): """Image processor for AshishOCR model.""" model_input_names = ["pixel_values", "image_grid_thw"] def __init__( self, do_resize: bool = True, size: dict = None, do_rescale: bool = True, rescale_factor: float = 1/255, do_normalize: bool = True, image_mean: list = None, image_std: list = None, min_pixels: int = 56 * 56, max_pixels: int = 28 * 28 * 1280, patch_size: int = 14, temporal_patch_size: int = 2, merge_size: int = 2, **kwargs, ): super().__init__(**kwargs) self.do_resize = do_resize self.size = size if size is not None else {"shortest_edge": 336} self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073] self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711] self.min_pixels = min_pixels self.max_pixels = max_pixels self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.merge_size = merge_size def preprocess( self, images: ImageInput, **kwargs, ) -> BatchFeature: import numpy as np import torch from PIL import Image if not isinstance(images, list): images = [images] processed_images = [] grid_thw = [] for image in images: if isinstance(image, str): image = Image.open(image).convert("RGB") elif not isinstance(image, Image.Image): image = Image.fromarray(np.array(image)) # Resize width, height = image.size target_size = self.size.get("shortest_edge", 336) # Calculate resize dimensions if width < height: new_width = target_size new_height = int(height * target_size / width) else: new_height = target_size new_width = int(width * target_size / height) # Ensure dimensions are divisible by patch_size new_width = (new_width // self.patch_size) * self.patch_size new_height = (new_height // self.patch_size) * self.patch_size image = image.resize((new_width, new_height), Image.BILINEAR) # Convert to tensor image_array = np.array(image).astype(np.float32) if self.do_rescale: image_array = image_array * self.rescale_factor if self.do_normalize: image_array = (image_array - np.array(self.image_mean)) / np.array(self.image_std) # HWC to CHW image_tensor = torch.tensor(image_array).permute(2, 0, 1) # Add temporal dimension for 3D conv: (C, H, W) -> (C, T, H, W) image_tensor = image_tensor.unsqueeze(1).repeat(1, self.temporal_patch_size, 1, 1) processed_images.append(image_tensor) # Calculate grid size (T, H, W in patches) t = 1 h = new_height // self.patch_size w = new_width // self.patch_size grid_thw.append([t, h, w]) pixel_values = torch.stack(processed_images, dim=0) image_grid_thw = torch.tensor(grid_thw, dtype=torch.long) return BatchFeature(data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}) class AshishOcrProcessor(ProcessorMixin): """Processor for AshishOCR that combines image processor and tokenizer.""" attributes = ["image_processor", "tokenizer"] image_processor_class = "AshishOcrImageProcessor" tokenizer_class = "AutoTokenizer" def __init__(self, image_processor=None, tokenizer=None, **kwargs): super().__init__(image_processor, tokenizer) def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, images: ImageInput = None, videos: ImageInput = None, padding: bool = False, truncation: bool = None, max_length: int = None, return_tensors: str = None, **kwargs, ) -> BatchFeature: encoding = BatchFeature() if images is not None: image_features = self.image_processor(images, **kwargs) encoding.update(image_features) if text is not None: text_encoding = self.tokenizer( text, padding=padding, truncation=truncation, max_length=max_length, return_tensors=return_tensors, **kwargs, ) encoding.update(text_encoding) return encoding def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))