| import torch | |
| import numpy as np | |
| from PIL import Image, ImageOps, ImageEnhance | |
| from transformers.image_processing_utils import BaseImageProcessor, BatchFeature | |
| from transformers.utils import logging | |
| logger = logging.get_logger(__name__) | |
| def _prepare_for_inference(img: Image.Image) -> Image.Image: | |
| """ | |
| Normalize real-world inputs (screenshots, camera, PDF crops) to the | |
| clean white-background style the model was trained on. | |
| Steps applied in order: | |
| 1. Convert to grayscale luminance to check background tone | |
| 2. If dark background (mean < 0.45), invert — handles dark mode / night mode | |
| 3. Auto-contrast to stretch histogram — fixes low-contrast scans/photos | |
| 4. Mild sharpening to counter screenshot JPEG blur | |
| """ | |
| arr = np.array(img.convert("L"), dtype=np.float32) / 255.0 | |
| if arr.mean() < 0.45: | |
| img = ImageOps.invert(img.convert("RGB")) | |
| img = ImageOps.autocontrast(img, cutoff=1) | |
| img = ImageEnhance.Sharpness(img).enhance(1.4) | |
| return img.convert("RGB") | |
| class LaTeXOCRImageProcessor(BaseImageProcessor): | |
| model_type = "latex_ocr" | |
| def __init__( | |
| self, | |
| image_height=64, | |
| max_image_width=1024, | |
| patch_size=16, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.image_height = image_height | |
| self.max_image_width = max_image_width | |
| self.patch_size = patch_size | |
| def preprocess(self, images, do_prepare=True, **kwargs) -> BatchFeature: | |
| if not isinstance(images, list): | |
| images = [images] | |
| processed_images = [] | |
| for img in images: | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| if do_prepare: | |
| img = _prepare_for_inference(img) | |
| w, h = img.size | |
| new_w = int(round(w * self.image_height / max(h, 1))) | |
| new_w = min(new_w, self.max_image_width) | |
| new_w = max((new_w // self.patch_size) * self.patch_size, self.patch_size) | |
| if (w, h) != (new_w, self.image_height): | |
| img = img.resize((new_w, self.image_height), Image.BILINEAR) | |
| img_array = np.array(img).astype(np.float32) / 255.0 | |
| img_array = (img_array - 0.5) / 0.5 | |
| img_array = np.transpose(img_array, (2, 0, 1)) | |
| processed_images.append(img_array) | |
| return BatchFeature(data={"pixel_values": processed_images}, tensor_type="pt") |