File size: 1,470 Bytes
b0377a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | from PIL import Image, ImageOps
import numpy as np
import torch
from transformers import ImageProcessingMixin
import os
import json
class Im2LatexProcessor(ImageProcessingMixin):
def __init__(self, image_size=(256, 256), **kwargs):
super().__init__(**kwargs)
self.image_size = image_size
def preprocess(self, image: Image.Image) -> torch.Tensor:
"""
Process a PIL image and return a tensor.
"""
img = image.convert("L")
img = ImageOps.pad(img, self.image_size, color=255)
arr = np.asarray(img, dtype=np.float32) / 255.0
arr = np.expand_dims(arr, 0) # (1, H, W)
return torch.tensor(arr, dtype=torch.float32)
def __call__(self, image_path: str) -> torch.Tensor:
"""
Process an image file path.
"""
image = Image.open(image_path)
return self.preprocess(image)
def save_pretrained(self, save_directory):
"""
Save processor config
"""
self.image_processor_config = {
"image_size": self.image_size,
}
with open(os.path.join(save_directory, "preprocessor_config.json"), "w") as f:
json.dump(self.image_processor_config, f)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Load processor config
"""
with open(os.path.join(pretrained_model_name_or_path, "preprocessor_config.json"), "r") as f:
config = json.load(f)
return cls(**config)
|