ukr-htr-convtext / processing_htr.py
Valerii Sielikhov
Add compatibility method for AutoProcessor in HTRProcessor class
0f8229d
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
class HTRProcessor:
model_input_names = ["pixel_values"]
@classmethod
def register_for_auto_class(cls, auto_class="AutoProcessor"):
"""Compatibility with transformers AutoProcessor (no-op for custom processor)."""
pass
def __init__(
self,
characters: list[str],
image_height: int = 64,
image_max_width: int = 3072,
width_stride: int = 32,
resample: str = "bilinear",
) -> None:
self.characters = characters
self.image_height = int(image_height)
self.image_max_width = int(image_max_width)
self.width_stride = int(width_stride)
self.resample = resample
self.id_to_char = {idx + 1: char for idx, char in enumerate(self.characters)}
@staticmethod
def _resolve_file(
path_or_repo_id: str, filename: str, local_files_only: bool
) -> str:
candidate = Path(path_or_repo_id) / filename
if candidate.exists():
return str(candidate)
return hf_hub_download(
repo_id=path_or_repo_id,
filename=filename,
local_files_only=local_files_only,
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
local_files_only: bool = False,
**_: dict[str, Any],
) -> "HTRProcessor":
cfg_path = cls._resolve_file(
pretrained_model_name_or_path, "preprocessor_config.json", local_files_only
)
with open(cfg_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
vocab_filename = cfg.get("vocab_file", "alphabet.json")
vocab_path = cls._resolve_file(
pretrained_model_name_or_path, vocab_filename, local_files_only
)
with open(vocab_path, "r", encoding="utf-8") as f:
vocab_data = json.load(f)
if isinstance(vocab_data, dict) and "characters" in vocab_data:
characters = vocab_data["characters"]
elif isinstance(vocab_data, list):
characters = vocab_data
else:
raise ValueError(
"Unsupported vocab file format. Expected list or {'characters': [...]} ."
)
return cls(
characters=characters,
image_height=cfg.get("image_height", 64),
image_max_width=cfg.get("image_max_width", 3072),
width_stride=cfg.get("width_stride", 32),
resample=cfg.get("resample", "bilinear"),
)
def save_pretrained(self, save_directory: str) -> None:
os.makedirs(save_directory, exist_ok=True)
vocab_path = os.path.join(save_directory, "alphabet.json")
with open(vocab_path, "w", encoding="utf-8") as f:
json.dump({"characters": self.characters}, f, ensure_ascii=False, indent=2)
preprocessor_cfg = {
"processor_class": self.__class__.__name__,
"vocab_file": "alphabet.json",
"image_height": self.image_height,
"image_max_width": self.image_max_width,
"width_stride": self.width_stride,
"resample": self.resample,
}
with open(
os.path.join(save_directory, "preprocessor_config.json"),
"w",
encoding="utf-8",
) as f:
json.dump(preprocessor_cfg, f, ensure_ascii=False, indent=2)
def _load_pil(self, image: str | Image.Image | np.ndarray) -> Image.Image:
if isinstance(image, Image.Image):
return image.convert("L")
if isinstance(image, np.ndarray):
if image.ndim == 2:
return Image.fromarray(image).convert("L")
if image.ndim == 3:
return Image.fromarray(image).convert("L")
raise ValueError(f"Unsupported ndarray shape: {image.shape}")
if isinstance(image, str):
return Image.open(image).convert("L")
raise TypeError(f"Unsupported image input type: {type(image)}")
def _preprocess_image(self, image: str | Image.Image | np.ndarray) -> np.ndarray:
img = self._load_pil(image)
w, h = img.size
if h <= 0:
raise ValueError("Input image has invalid height.")
scale = self.image_height / float(h)
new_w = max(1, int(w * scale))
resample_map = {
"nearest": Image.Resampling.NEAREST,
"bilinear": Image.Resampling.BILINEAR,
"bicubic": Image.Resampling.BICUBIC,
"lanczos": Image.Resampling.LANCZOS,
}
pil_resample = resample_map.get(
self.resample.lower(), Image.Resampling.BILINEAR
)
img = img.resize((new_w, self.image_height), resample=pil_resample)
arr = np.array(img)
if new_w > self.image_max_width:
arr = arr[:, : self.image_max_width]
new_w = self.image_max_width
if new_w % self.width_stride != 0:
aligned_w = ((new_w // self.width_stride) + 1) * self.width_stride
pad_width = aligned_w - new_w
arr = np.pad(
arr,
((0, 0), (0, pad_width)),
mode="constant",
constant_values=0,
)
new_w = aligned_w
arr = arr.astype(np.float32) / 255.0
if arr.ndim == 2:
arr = np.expand_dims(arr, axis=-1)
return arr.transpose(2, 0, 1).astype(np.float32)
def __call__(
self,
images: str | Image.Image | np.ndarray | list[str | Image.Image | np.ndarray],
return_tensors: str = "pt",
**_: dict[str, Any],
) -> dict[str, Any]:
batch_images = images if isinstance(images, list) else [images]
pixel_values = np.stack(
[self._preprocess_image(img) for img in batch_images], axis=0
)
if return_tensors == "pt":
return {"pixel_values": torch.from_numpy(pixel_values)}
if return_tensors == "np":
return {"pixel_values": pixel_values}
raise ValueError("Supported return_tensors values are 'pt' and 'np'.")
@staticmethod
def _ctc_greedy_decode(
logits_tnc: np.ndarray, blank_idx: int = 0
) -> list[list[int]]:
preds = np.argmax(logits_tnc, axis=2)
_, batch_size, _ = logits_tnc.shape
decoded: list[list[int]] = []
for n in range(batch_size):
seq = preds[:, n]
chars: list[int] = []
prev = blank_idx
for idx in seq:
token = int(idx)
if token != blank_idx and token != prev:
chars.append(token)
prev = token
decoded.append(chars)
return decoded
def batch_decode(
self,
logits: torch.Tensor | np.ndarray,
blank_idx: int = 0,
logit_layout: str = "ntc",
) -> list[str]:
logits_np = (
logits.detach().cpu().numpy()
if isinstance(logits, torch.Tensor)
else logits
)
if logits_np.ndim != 3:
raise ValueError(f"Expected logits rank 3, got shape {logits_np.shape}.")
if logit_layout == "ntc":
logits_tnc = np.transpose(logits_np, (1, 0, 2))
elif logit_layout == "tnc":
logits_tnc = logits_np
else:
raise ValueError("logit_layout must be 'ntc' or 'tnc'.")
decoded_ids = self._ctc_greedy_decode(logits_tnc, blank_idx=blank_idx)
return [
"".join(
self.id_to_char.get(token, "")
for token in seq
if token in self.id_to_char
)
for seq in decoded_ids
]