File size: 1,470 Bytes
b0377a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from PIL import Image, ImageOps
import numpy as np
import torch
from transformers import ImageProcessingMixin
import os
import json

class Im2LatexProcessor(ImageProcessingMixin):
   def __init__(self, image_size=(256, 256), **kwargs):
      super().__init__(**kwargs)
      self.image_size = image_size

   def preprocess(self, image: Image.Image) -> torch.Tensor:
      """
      Process a PIL image and return a tensor.
      """
      img = image.convert("L")
      img = ImageOps.pad(img, self.image_size, color=255)
      arr = np.asarray(img, dtype=np.float32) / 255.0
      arr = np.expand_dims(arr, 0)  # (1, H, W)
      return torch.tensor(arr, dtype=torch.float32)

   def __call__(self, image_path: str) -> torch.Tensor:
      """
      Process an image file path.
      """
      image = Image.open(image_path)
      return self.preprocess(image)

   def save_pretrained(self, save_directory):
      """
      Save processor config
      """
      self.image_processor_config = {
         "image_size": self.image_size,
      }
      with open(os.path.join(save_directory, "preprocessor_config.json"), "w") as f:
         json.dump(self.image_processor_config, f)

   @classmethod
   def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
      """
      Load processor config
      """
      with open(os.path.join(pretrained_model_name_or_path, "preprocessor_config.json"), "r") as f:
         config = json.load(f)
      return cls(**config)