from PIL import Image, ImageOps import numpy as np import torch from transformers import ImageProcessingMixin import os import json class Im2LatexProcessor(ImageProcessingMixin): def __init__(self, image_size=(256, 256), **kwargs): super().__init__(**kwargs) self.image_size = image_size def preprocess(self, image: Image.Image) -> torch.Tensor: """ Process a PIL image and return a tensor. """ img = image.convert("L") img = ImageOps.pad(img, self.image_size, color=255) arr = np.asarray(img, dtype=np.float32) / 255.0 arr = np.expand_dims(arr, 0) # (1, H, W) return torch.tensor(arr, dtype=torch.float32) def __call__(self, image_path: str) -> torch.Tensor: """ Process an image file path. """ image = Image.open(image_path) return self.preprocess(image) def save_pretrained(self, save_directory): """ Save processor config """ self.image_processor_config = { "image_size": self.image_size, } with open(os.path.join(save_directory, "preprocessor_config.json"), "w") as f: json.dump(self.image_processor_config, f) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """ Load processor config """ with open(os.path.join(pretrained_model_name_or_path, "preprocessor_config.json"), "r") as f: config = json.load(f) return cls(**config)