|
|
""" |
|
|
Standalone FontClassifierImageProcessor for HuggingFace Hub deployment. |
|
|
""" |
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
from PIL import Image |
|
|
from transformers import AutoImageProcessor |
|
|
|
|
|
|
|
|
def pad_to_square(image): |
|
|
""" |
|
|
Shared utility function to pad image to square while preserving aspect ratio. |
|
|
Works with both PIL Images and numpy arrays. |
|
|
""" |
|
|
if isinstance(image, Image.Image): |
|
|
w, h = image.size |
|
|
max_size = max(w, h) |
|
|
pad_w = (max_size - w) // 2 |
|
|
pad_h = (max_size - h) // 2 |
|
|
padding = (pad_w, pad_h, max_size - w - pad_w, max_size - h - pad_h) |
|
|
return T.Pad(padding, fill=0)(image) |
|
|
elif isinstance(image, np.ndarray): |
|
|
|
|
|
if image.ndim == 3 and image.shape[2] == 3: |
|
|
pil_img = Image.fromarray(image.astype(np.uint8)) |
|
|
padded_pil = pad_to_square(pil_img) |
|
|
return np.array(padded_pil) |
|
|
return image |
|
|
|
|
|
class FontClassifierImageProcessor(AutoImageProcessor): |
|
|
""" |
|
|
Custom image processor that includes pad_to_square transformation. |
|
|
This ensures that Inference Endpoints will apply the same preprocessing as training. |
|
|
""" |
|
|
|
|
|
model_input_names = ["pixel_values"] |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
self._original_preprocess = super().preprocess |
|
|
|
|
|
def preprocess(self, images, **kwargs): |
|
|
"""Override preprocess to include pad_to_square""" |
|
|
|
|
|
if isinstance(images, (Image.Image, np.ndarray)): |
|
|
images = [images] |
|
|
single_image = True |
|
|
else: |
|
|
single_image = False |
|
|
|
|
|
|
|
|
padded_images = [pad_to_square(img) for img in images] |
|
|
|
|
|
|
|
|
result = self._original_preprocess(padded_images, **kwargs) |
|
|
|
|
|
|
|
|
if single_image and isinstance(result, dict) and 'pixel_values' in result: |
|
|
|
|
|
pass |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
AutoImageProcessor.register("FontClassifierImageProcessor", FontClassifierImageProcessor) |