from __future__ import annotations import json import os from pathlib import Path from typing import Any import numpy as np import torch from huggingface_hub import hf_hub_download from PIL import Image class HTRProcessor: model_input_names = ["pixel_values"] @classmethod def register_for_auto_class(cls, auto_class="AutoProcessor"): """Compatibility with transformers AutoProcessor (no-op for custom processor).""" pass def __init__( self, characters: list[str], image_height: int = 64, image_max_width: int = 3072, width_stride: int = 32, resample: str = "bilinear", ) -> None: self.characters = characters self.image_height = int(image_height) self.image_max_width = int(image_max_width) self.width_stride = int(width_stride) self.resample = resample self.id_to_char = {idx + 1: char for idx, char in enumerate(self.characters)} @staticmethod def _resolve_file( path_or_repo_id: str, filename: str, local_files_only: bool ) -> str: candidate = Path(path_or_repo_id) / filename if candidate.exists(): return str(candidate) return hf_hub_download( repo_id=path_or_repo_id, filename=filename, local_files_only=local_files_only, ) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, local_files_only: bool = False, **_: dict[str, Any], ) -> "HTRProcessor": cfg_path = cls._resolve_file( pretrained_model_name_or_path, "preprocessor_config.json", local_files_only ) with open(cfg_path, "r", encoding="utf-8") as f: cfg = json.load(f) vocab_filename = cfg.get("vocab_file", "alphabet.json") vocab_path = cls._resolve_file( pretrained_model_name_or_path, vocab_filename, local_files_only ) with open(vocab_path, "r", encoding="utf-8") as f: vocab_data = json.load(f) if isinstance(vocab_data, dict) and "characters" in vocab_data: characters = vocab_data["characters"] elif isinstance(vocab_data, list): characters = vocab_data else: raise ValueError( "Unsupported vocab file format. Expected list or {'characters': [...]} ." ) return cls( characters=characters, image_height=cfg.get("image_height", 64), image_max_width=cfg.get("image_max_width", 3072), width_stride=cfg.get("width_stride", 32), resample=cfg.get("resample", "bilinear"), ) def save_pretrained(self, save_directory: str) -> None: os.makedirs(save_directory, exist_ok=True) vocab_path = os.path.join(save_directory, "alphabet.json") with open(vocab_path, "w", encoding="utf-8") as f: json.dump({"characters": self.characters}, f, ensure_ascii=False, indent=2) preprocessor_cfg = { "processor_class": self.__class__.__name__, "vocab_file": "alphabet.json", "image_height": self.image_height, "image_max_width": self.image_max_width, "width_stride": self.width_stride, "resample": self.resample, } with open( os.path.join(save_directory, "preprocessor_config.json"), "w", encoding="utf-8", ) as f: json.dump(preprocessor_cfg, f, ensure_ascii=False, indent=2) def _load_pil(self, image: str | Image.Image | np.ndarray) -> Image.Image: if isinstance(image, Image.Image): return image.convert("L") if isinstance(image, np.ndarray): if image.ndim == 2: return Image.fromarray(image).convert("L") if image.ndim == 3: return Image.fromarray(image).convert("L") raise ValueError(f"Unsupported ndarray shape: {image.shape}") if isinstance(image, str): return Image.open(image).convert("L") raise TypeError(f"Unsupported image input type: {type(image)}") def _preprocess_image(self, image: str | Image.Image | np.ndarray) -> np.ndarray: img = self._load_pil(image) w, h = img.size if h <= 0: raise ValueError("Input image has invalid height.") scale = self.image_height / float(h) new_w = max(1, int(w * scale)) resample_map = { "nearest": Image.Resampling.NEAREST, "bilinear": Image.Resampling.BILINEAR, "bicubic": Image.Resampling.BICUBIC, "lanczos": Image.Resampling.LANCZOS, } pil_resample = resample_map.get( self.resample.lower(), Image.Resampling.BILINEAR ) img = img.resize((new_w, self.image_height), resample=pil_resample) arr = np.array(img) if new_w > self.image_max_width: arr = arr[:, : self.image_max_width] new_w = self.image_max_width if new_w % self.width_stride != 0: aligned_w = ((new_w // self.width_stride) + 1) * self.width_stride pad_width = aligned_w - new_w arr = np.pad( arr, ((0, 0), (0, pad_width)), mode="constant", constant_values=0, ) new_w = aligned_w arr = arr.astype(np.float32) / 255.0 if arr.ndim == 2: arr = np.expand_dims(arr, axis=-1) return arr.transpose(2, 0, 1).astype(np.float32) def __call__( self, images: str | Image.Image | np.ndarray | list[str | Image.Image | np.ndarray], return_tensors: str = "pt", **_: dict[str, Any], ) -> dict[str, Any]: batch_images = images if isinstance(images, list) else [images] pixel_values = np.stack( [self._preprocess_image(img) for img in batch_images], axis=0 ) if return_tensors == "pt": return {"pixel_values": torch.from_numpy(pixel_values)} if return_tensors == "np": return {"pixel_values": pixel_values} raise ValueError("Supported return_tensors values are 'pt' and 'np'.") @staticmethod def _ctc_greedy_decode( logits_tnc: np.ndarray, blank_idx: int = 0 ) -> list[list[int]]: preds = np.argmax(logits_tnc, axis=2) _, batch_size, _ = logits_tnc.shape decoded: list[list[int]] = [] for n in range(batch_size): seq = preds[:, n] chars: list[int] = [] prev = blank_idx for idx in seq: token = int(idx) if token != blank_idx and token != prev: chars.append(token) prev = token decoded.append(chars) return decoded def batch_decode( self, logits: torch.Tensor | np.ndarray, blank_idx: int = 0, logit_layout: str = "ntc", ) -> list[str]: logits_np = ( logits.detach().cpu().numpy() if isinstance(logits, torch.Tensor) else logits ) if logits_np.ndim != 3: raise ValueError(f"Expected logits rank 3, got shape {logits_np.shape}.") if logit_layout == "ntc": logits_tnc = np.transpose(logits_np, (1, 0, 2)) elif logit_layout == "tnc": logits_tnc = logits_np else: raise ValueError("logit_layout must be 'ntc' or 'tnc'.") decoded_ids = self._ctc_greedy_decode(logits_tnc, blank_idx=blank_idx) return [ "".join( self.id_to_char.get(token, "") for token in seq if token in self.id_to_char ) for seq in decoded_ids ]