| from __future__ import annotations |
|
|
| import json |
| import os |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
|
|
|
|
| class HTRProcessor: |
| model_input_names = ["pixel_values"] |
|
|
| @classmethod |
| def register_for_auto_class(cls, auto_class="AutoProcessor"): |
| """Compatibility with transformers AutoProcessor (no-op for custom processor).""" |
| pass |
|
|
| def __init__( |
| self, |
| characters: list[str], |
| image_height: int = 64, |
| image_max_width: int = 3072, |
| width_stride: int = 32, |
| resample: str = "bilinear", |
| ) -> None: |
| self.characters = characters |
| self.image_height = int(image_height) |
| self.image_max_width = int(image_max_width) |
| self.width_stride = int(width_stride) |
| self.resample = resample |
| self.id_to_char = {idx + 1: char for idx, char in enumerate(self.characters)} |
|
|
| @staticmethod |
| def _resolve_file( |
| path_or_repo_id: str, filename: str, local_files_only: bool |
| ) -> str: |
| candidate = Path(path_or_repo_id) / filename |
| if candidate.exists(): |
| return str(candidate) |
| return hf_hub_download( |
| repo_id=path_or_repo_id, |
| filename=filename, |
| local_files_only=local_files_only, |
| ) |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: str, |
| local_files_only: bool = False, |
| **_: dict[str, Any], |
| ) -> "HTRProcessor": |
| cfg_path = cls._resolve_file( |
| pretrained_model_name_or_path, "preprocessor_config.json", local_files_only |
| ) |
| with open(cfg_path, "r", encoding="utf-8") as f: |
| cfg = json.load(f) |
|
|
| vocab_filename = cfg.get("vocab_file", "alphabet.json") |
| vocab_path = cls._resolve_file( |
| pretrained_model_name_or_path, vocab_filename, local_files_only |
| ) |
| with open(vocab_path, "r", encoding="utf-8") as f: |
| vocab_data = json.load(f) |
|
|
| if isinstance(vocab_data, dict) and "characters" in vocab_data: |
| characters = vocab_data["characters"] |
| elif isinstance(vocab_data, list): |
| characters = vocab_data |
| else: |
| raise ValueError( |
| "Unsupported vocab file format. Expected list or {'characters': [...]} ." |
| ) |
|
|
| return cls( |
| characters=characters, |
| image_height=cfg.get("image_height", 64), |
| image_max_width=cfg.get("image_max_width", 3072), |
| width_stride=cfg.get("width_stride", 32), |
| resample=cfg.get("resample", "bilinear"), |
| ) |
|
|
| def save_pretrained(self, save_directory: str) -> None: |
| os.makedirs(save_directory, exist_ok=True) |
| vocab_path = os.path.join(save_directory, "alphabet.json") |
| with open(vocab_path, "w", encoding="utf-8") as f: |
| json.dump({"characters": self.characters}, f, ensure_ascii=False, indent=2) |
|
|
| preprocessor_cfg = { |
| "processor_class": self.__class__.__name__, |
| "vocab_file": "alphabet.json", |
| "image_height": self.image_height, |
| "image_max_width": self.image_max_width, |
| "width_stride": self.width_stride, |
| "resample": self.resample, |
| } |
| with open( |
| os.path.join(save_directory, "preprocessor_config.json"), |
| "w", |
| encoding="utf-8", |
| ) as f: |
| json.dump(preprocessor_cfg, f, ensure_ascii=False, indent=2) |
|
|
| def _load_pil(self, image: str | Image.Image | np.ndarray) -> Image.Image: |
| if isinstance(image, Image.Image): |
| return image.convert("L") |
| if isinstance(image, np.ndarray): |
| if image.ndim == 2: |
| return Image.fromarray(image).convert("L") |
| if image.ndim == 3: |
| return Image.fromarray(image).convert("L") |
| raise ValueError(f"Unsupported ndarray shape: {image.shape}") |
| if isinstance(image, str): |
| return Image.open(image).convert("L") |
| raise TypeError(f"Unsupported image input type: {type(image)}") |
|
|
| def _preprocess_image(self, image: str | Image.Image | np.ndarray) -> np.ndarray: |
| img = self._load_pil(image) |
| w, h = img.size |
| if h <= 0: |
| raise ValueError("Input image has invalid height.") |
| scale = self.image_height / float(h) |
| new_w = max(1, int(w * scale)) |
|
|
| resample_map = { |
| "nearest": Image.Resampling.NEAREST, |
| "bilinear": Image.Resampling.BILINEAR, |
| "bicubic": Image.Resampling.BICUBIC, |
| "lanczos": Image.Resampling.LANCZOS, |
| } |
| pil_resample = resample_map.get( |
| self.resample.lower(), Image.Resampling.BILINEAR |
| ) |
| img = img.resize((new_w, self.image_height), resample=pil_resample) |
| arr = np.array(img) |
|
|
| if new_w > self.image_max_width: |
| arr = arr[:, : self.image_max_width] |
| new_w = self.image_max_width |
|
|
| if new_w % self.width_stride != 0: |
| aligned_w = ((new_w // self.width_stride) + 1) * self.width_stride |
| pad_width = aligned_w - new_w |
| arr = np.pad( |
| arr, |
| ((0, 0), (0, pad_width)), |
| mode="constant", |
| constant_values=0, |
| ) |
| new_w = aligned_w |
|
|
| arr = arr.astype(np.float32) / 255.0 |
| if arr.ndim == 2: |
| arr = np.expand_dims(arr, axis=-1) |
| return arr.transpose(2, 0, 1).astype(np.float32) |
|
|
| def __call__( |
| self, |
| images: str | Image.Image | np.ndarray | list[str | Image.Image | np.ndarray], |
| return_tensors: str = "pt", |
| **_: dict[str, Any], |
| ) -> dict[str, Any]: |
| batch_images = images if isinstance(images, list) else [images] |
| pixel_values = np.stack( |
| [self._preprocess_image(img) for img in batch_images], axis=0 |
| ) |
| if return_tensors == "pt": |
| return {"pixel_values": torch.from_numpy(pixel_values)} |
| if return_tensors == "np": |
| return {"pixel_values": pixel_values} |
| raise ValueError("Supported return_tensors values are 'pt' and 'np'.") |
|
|
| @staticmethod |
| def _ctc_greedy_decode( |
| logits_tnc: np.ndarray, blank_idx: int = 0 |
| ) -> list[list[int]]: |
| preds = np.argmax(logits_tnc, axis=2) |
| _, batch_size, _ = logits_tnc.shape |
| decoded: list[list[int]] = [] |
| for n in range(batch_size): |
| seq = preds[:, n] |
| chars: list[int] = [] |
| prev = blank_idx |
| for idx in seq: |
| token = int(idx) |
| if token != blank_idx and token != prev: |
| chars.append(token) |
| prev = token |
| decoded.append(chars) |
| return decoded |
|
|
| def batch_decode( |
| self, |
| logits: torch.Tensor | np.ndarray, |
| blank_idx: int = 0, |
| logit_layout: str = "ntc", |
| ) -> list[str]: |
| logits_np = ( |
| logits.detach().cpu().numpy() |
| if isinstance(logits, torch.Tensor) |
| else logits |
| ) |
| if logits_np.ndim != 3: |
| raise ValueError(f"Expected logits rank 3, got shape {logits_np.shape}.") |
|
|
| if logit_layout == "ntc": |
| logits_tnc = np.transpose(logits_np, (1, 0, 2)) |
| elif logit_layout == "tnc": |
| logits_tnc = logits_np |
| else: |
| raise ValueError("logit_layout must be 'ntc' or 'tnc'.") |
|
|
| decoded_ids = self._ctc_greedy_decode(logits_tnc, blank_idx=blank_idx) |
| return [ |
| "".join( |
| self.id_to_char.get(token, "") |
| for token in seq |
| if token in self.id_to_char |
| ) |
| for seq in decoded_ids |
| ] |
|
|