| | from __future__ import annotations |
| |
|
| | import json |
| | import math |
| | import os |
| | from pathlib import Path |
| | from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union |
| |
|
| | import numpy as np |
| | from PIL import Image |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from transformers import AutoTokenizer, BatchFeature, ProcessorMixin |
| | from transformers.image_processing_utils import BaseImageProcessor |
| | from transformers.utils import TensorType, cached_file |
| |
|
| |
|
| | CONFIG_NAME = "config.json" |
| | PREPROCESSOR_CONFIG_NAME = "preprocessor_config.json" |
| | PROCESSOR_CONFIG_NAME = "processor_config.json" |
| |
|
| |
|
| | ImageLike = Union[Image.Image, np.ndarray, torch.Tensor] |
| |
|
| |
|
| | def _select_cached_file_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: |
| | allowed = { |
| | "cache_dir", |
| | "force_download", |
| | "proxies", |
| | "token", |
| | "local_files_only", |
| | "revision", |
| | "subfolder", |
| | } |
| | out = {k: v for k, v in kwargs.items() if k in allowed} |
| | out.setdefault("_raise_exceptions_for_missing_entries", False) |
| | out.setdefault("_raise_exceptions_for_gated_repo", False) |
| | out.setdefault("_raise_exceptions_for_connection_errors", False) |
| | return out |
| |
|
| |
|
| | def _resolve_repo_file(pretrained_model_name_or_path: Union[str, os.PathLike], filename: str, **kwargs) -> Optional[str]: |
| | path = str(pretrained_model_name_or_path) |
| |
|
| | if os.path.isdir(path): |
| | candidate = os.path.join(path, filename) |
| | return candidate if os.path.exists(candidate) else None |
| |
|
| | if os.path.isfile(path): |
| | return path if os.path.basename(path) == filename else None |
| |
|
| | try: |
| | return cached_file(path, filename, **_select_cached_file_kwargs(kwargs)) |
| | except Exception: |
| | return None |
| |
|
| |
|
| | def _load_json_file(path: str) -> Dict[str, Any]: |
| | with open(path, "r", encoding="utf-8") as f: |
| | return json.load(f) |
| |
|
| |
|
| | def _load_image_processor_dict(pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> Dict[str, Any]: |
| | processor_path = _resolve_repo_file(pretrained_model_name_or_path, PROCESSOR_CONFIG_NAME, **kwargs) |
| | if processor_path is not None: |
| | processor_dict = _load_json_file(processor_path) |
| | nested = processor_dict.get("image_processor") |
| | if isinstance(nested, dict): |
| | return nested |
| |
|
| | preprocessor_path = _resolve_repo_file(pretrained_model_name_or_path, PREPROCESSOR_CONFIG_NAME, **kwargs) |
| | if preprocessor_path is not None: |
| | return _load_json_file(preprocessor_path) |
| |
|
| | config_path = _resolve_repo_file(pretrained_model_name_or_path, CONFIG_NAME, **kwargs) |
| | if config_path is not None: |
| | return _load_json_file(config_path) |
| |
|
| | raise FileNotFoundError( |
| | f"Could not find {PREPROCESSOR_CONFIG_NAME}, {PROCESSOR_CONFIG_NAME}, or {CONFIG_NAME} in {pretrained_model_name_or_path!r}." |
| | ) |
| |
|
| |
|
| | class AnandaImageProcessor(BaseImageProcessor): |
| | """Image processor for Ananda OCR-style visual prefix inputs. |
| | |
| | Behavior mirrored from the development inference path: |
| | 1. Convert to RGB / 3 channels. |
| | 2. Convert to CHW float32 in [0, 1]. |
| | 3. Normalize with config mean/std. |
| | 4. Pad H/W up to a multiple of patch_size. |
| | 5. Pad again up to a multiple of patch_size * merge_factor. |
| | 6. Emit `pixel_values` and `patch_attention_mask`. |
| | """ |
| |
|
| | model_input_names = ["pixel_values", "patch_attention_mask"] |
| |
|
| | def __init__( |
| | self, |
| | patch_size: int = 16, |
| | merge_factor: int = 1, |
| | do_convert_rgb: bool = True, |
| | do_rescale: bool = True, |
| | rescale_factor: float = 1.0 / 255.0, |
| | do_normalize: bool = True, |
| | image_mean: Optional[Sequence[float]] = None, |
| | image_std: Optional[Sequence[float]] = None, |
| | pad_value: float = 0.0, |
| | processor_class: Optional[str] = "AnandaProcessor", |
| | **kwargs: Any, |
| | ) -> None: |
| | super().__init__(**kwargs) |
| | self.patch_size = int(patch_size) |
| | self.merge_factor = max(int(merge_factor), 1) |
| | self.do_convert_rgb = bool(do_convert_rgb) |
| | self.do_rescale = bool(do_rescale) |
| | self.rescale_factor = float(rescale_factor) |
| | self.do_normalize = bool(do_normalize) |
| | self.image_mean = list(image_mean) if image_mean is not None else [0.5, 0.5, 0.5] |
| | self.image_std = list(image_std) if image_std is not None else [0.5, 0.5, 0.5] |
| | self.pad_value = float(pad_value) |
| | self.processor_class = processor_class |
| |
|
| | @classmethod |
| | def from_model_config(cls, model_config: Union[Dict[str, Any], Any]) -> "AnandaImageProcessor": |
| | if isinstance(model_config, dict): |
| | cfg = model_config |
| | else: |
| | cfg = vars(model_config) |
| |
|
| | return cls( |
| | patch_size=int(cfg.get("patch_size", 16)), |
| | merge_factor=int(cfg.get("encoder_2d_merge_factor", 1)), |
| | image_mean=cfg.get("image_normalization_mean", [0.5, 0.5, 0.5]), |
| | image_std=cfg.get("image_normalization_std", [0.5, 0.5, 0.5]), |
| | ) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs: Any) -> "AnandaImageProcessor": |
| | config_dict = _load_image_processor_dict(pretrained_model_name_or_path, **kwargs) |
| | nested = config_dict.get("image_processor") |
| | if isinstance(nested, dict): |
| | config_dict = nested |
| |
|
| | return cls( |
| | patch_size=int(config_dict.get("patch_size", 16)), |
| | merge_factor=int(config_dict.get("merge_factor", config_dict.get("encoder_2d_merge_factor", 1))), |
| | do_convert_rgb=bool(config_dict.get("do_convert_rgb", True)), |
| | do_rescale=bool(config_dict.get("do_rescale", True)), |
| | rescale_factor=float(config_dict.get("rescale_factor", 1.0 / 255.0)), |
| | do_normalize=bool(config_dict.get("do_normalize", True)), |
| | image_mean=config_dict.get("image_mean", config_dict.get("image_normalization_mean", [0.5, 0.5, 0.5])), |
| | image_std=config_dict.get("image_std", config_dict.get("image_normalization_std", [0.5, 0.5, 0.5])), |
| | pad_value=float(config_dict.get("pad_value", 0.0)), |
| | processor_class=config_dict.get("processor_class", "AnandaProcessor"), |
| | ) |
| |
|
| | def to_dict(self) -> Dict[str, Any]: |
| | return { |
| | "image_processor_type": self.__class__.__name__, |
| | "processor_class": self.processor_class, |
| | "auto_map": { |
| | "AutoImageProcessor": "inference_processor.AnandaImageProcessor", |
| | "AutoProcessor": "inference_processor.AnandaProcessor", |
| | }, |
| | "patch_size": self.patch_size, |
| | "merge_factor": self.merge_factor, |
| | "do_convert_rgb": self.do_convert_rgb, |
| | "do_rescale": self.do_rescale, |
| | "rescale_factor": self.rescale_factor, |
| | "do_normalize": self.do_normalize, |
| | "image_mean": list(self.image_mean), |
| | "image_std": list(self.image_std), |
| | "pad_value": self.pad_value, |
| | } |
| |
|
| | def save_pretrained(self, save_directory: Union[str, os.PathLike], **_: Any) -> List[str]: |
| | os.makedirs(save_directory, exist_ok=True) |
| | output_path = os.path.join(save_directory, PREPROCESSOR_CONFIG_NAME) |
| | with open(output_path, "w", encoding="utf-8") as f: |
| | json.dump(self.to_dict(), f, ensure_ascii=False, indent=2) |
| | return [output_path] |
| |
|
| | @staticmethod |
| | def _ensure_list(images: Union[ImageLike, Sequence[ImageLike]]) -> List[ImageLike]: |
| | if isinstance(images, (list, tuple)): |
| | return list(images) |
| | return [images] |
| |
|
| | def _to_chw_uint8(self, image: ImageLike) -> torch.Tensor: |
| | if isinstance(image, Image.Image): |
| | img = image.convert("RGB") if self.do_convert_rgb else image |
| | arr = np.array(img, dtype=np.uint8) |
| | tensor = torch.from_numpy(arr) |
| | if tensor.ndim == 2: |
| | tensor = tensor.unsqueeze(-1) |
| | tensor = tensor.permute(2, 0, 1).contiguous() |
| | elif isinstance(image, np.ndarray): |
| | arr = image |
| | if arr.ndim == 2: |
| | arr = arr[..., None] |
| | if arr.ndim != 3: |
| | raise ValueError(f"Expected 2D or 3D ndarray image, got shape={arr.shape}") |
| | tensor = torch.from_numpy(arr) |
| | if tensor.shape[0] in (1, 3, 4): |
| | pass |
| | elif tensor.shape[-1] in (1, 3, 4): |
| | tensor = tensor.permute(2, 0, 1) |
| | else: |
| | raise ValueError(f"Could not infer channel dimension from ndarray shape={arr.shape}") |
| | tensor = tensor.contiguous() |
| | elif torch.is_tensor(image): |
| | tensor = image.detach().cpu() |
| | if tensor.ndim == 2: |
| | tensor = tensor.unsqueeze(0) |
| | if tensor.ndim != 3: |
| | raise ValueError(f"Expected 2D or 3D tensor image, got shape={tuple(tensor.shape)}") |
| | if tensor.shape[0] in (1, 3, 4): |
| | pass |
| | elif tensor.shape[-1] in (1, 3, 4): |
| | tensor = tensor.permute(2, 0, 1) |
| | else: |
| | raise ValueError(f"Could not infer channel dimension from tensor shape={tuple(tensor.shape)}") |
| | tensor = tensor.contiguous() |
| | else: |
| | raise TypeError(f"Unsupported image type: {type(image)!r}") |
| |
|
| | if tensor.shape[0] == 1: |
| | tensor = tensor.expand(3, -1, -1) |
| | elif tensor.shape[0] == 4: |
| | tensor = tensor[:3] |
| | elif tensor.shape[0] != 3: |
| | raise ValueError(f"Expected 1, 3, or 4 channels, got {tensor.shape[0]}") |
| |
|
| | if tensor.dtype.is_floating_point: |
| | max_val = float(tensor.max().item()) if tensor.numel() else 0.0 |
| | if max_val <= 1.0 + 1e-6: |
| | tensor = tensor * 255.0 |
| | tensor = tensor.round().clamp_(0.0, 255.0).to(torch.uint8) |
| | else: |
| | tensor = tensor.clamp_(0, 255).to(torch.uint8) |
| |
|
| | return tensor.contiguous() |
| |
|
| | def _normalize(self, chw_u8: torch.Tensor) -> torch.Tensor: |
| | x = chw_u8.to(torch.float32) |
| | if self.do_rescale: |
| | x = x * self.rescale_factor |
| | mean = torch.tensor(self.image_mean, dtype=torch.float32).view(3, 1, 1) |
| | std = torch.tensor(self.image_std, dtype=torch.float32).view(3, 1, 1) |
| | if self.do_normalize: |
| | x = (x - mean) / std |
| | return x |
| |
|
| | def _pad_to_patch_multiple(self, img: torch.Tensor) -> torch.Tensor: |
| | _, h, w = img.shape |
| | p = self.patch_size |
| | target_h = int(math.ceil(h / p) * p) |
| | target_w = int(math.ceil(w / p) * p) |
| | if target_h != h or target_w != w: |
| | img = F.pad(img, (0, target_w - w, 0, target_h - h), value=self.pad_value) |
| | return img |
| |
|
| | def _pad_for_merge_factor(self, img_norm: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if img_norm.ndim != 3: |
| | raise ValueError(f"Expected image tensor with shape (3,H,W), got {tuple(img_norm.shape)}") |
| |
|
| | p = self.patch_size |
| | m = self.merge_factor |
| | base = p * m |
| | _, h, w = img_norm.shape |
| |
|
| | if h % p != 0 or w % p != 0: |
| | raise ValueError(f"Image must be patch-multiple before merge padding, got H={h}, W={w}, patch_size={p}") |
| |
|
| | target_h = int(math.ceil(h / base) * base) |
| | target_w = int(math.ceil(w / base) * base) |
| |
|
| | ph, pw = h // p, w // p |
| | target_ph, target_pw = target_h // p, target_w // p |
| |
|
| | mask_2d = torch.ones((ph, pw), dtype=torch.bool) |
| | if target_ph != ph or target_pw != pw: |
| | mask_2d = F.pad(mask_2d, (0, target_pw - pw, 0, target_ph - ph), value=False) |
| |
|
| | if target_h != h or target_w != w: |
| | img_norm = F.pad(img_norm, (0, target_w - w, 0, target_h - h), value=self.pad_value) |
| |
|
| | return img_norm, mask_2d.reshape(-1).to(torch.long) |
| |
|
| | def _preprocess_single(self, image: ImageLike) -> Tuple[torch.Tensor, torch.Tensor]: |
| | chw_u8 = self._to_chw_uint8(image) |
| | img = self._normalize(chw_u8) |
| | img = self._pad_to_patch_multiple(img) |
| | return self._pad_for_merge_factor(img) |
| |
|
| | def preprocess( |
| | self, |
| | images: Union[ImageLike, Sequence[ImageLike]], |
| | return_tensors: Optional[Union[str, TensorType]] = None, |
| | **_: Any, |
| | ) -> BatchFeature: |
| | image_list = self._ensure_list(images) |
| | if len(image_list) == 0: |
| | raise ValueError("`images` must contain at least one image") |
| |
|
| | processed: List[torch.Tensor] = [] |
| | patch_masks: List[torch.Tensor] = [] |
| | for image in image_list: |
| | px, pm = self._preprocess_single(image) |
| | processed.append(px) |
| | patch_masks.append(pm) |
| |
|
| | max_h = max(t.shape[1] for t in processed) |
| | max_w = max(t.shape[2] for t in processed) |
| | p = self.patch_size |
| | batch_patch_h = max_h // p |
| | batch_patch_w = max_w // p |
| |
|
| | batch_pixels: List[torch.Tensor] = [] |
| | batch_masks: List[torch.Tensor] = [] |
| | for px, pm in zip(processed, patch_masks): |
| | _, h, w = px.shape |
| | ph, pw = h // p, w // p |
| |
|
| | if h != max_h or w != max_w: |
| | px = F.pad(px, (0, max_w - w, 0, max_h - h), value=self.pad_value) |
| |
|
| | pm_2d = pm.view(ph, pw).to(torch.bool) |
| | if ph != batch_patch_h or pw != batch_patch_w: |
| | pm_2d = F.pad(pm_2d, (0, batch_patch_w - pw, 0, batch_patch_h - ph), value=False) |
| |
|
| | batch_pixels.append(px) |
| | batch_masks.append(pm_2d.reshape(-1).to(torch.long)) |
| |
|
| | data = { |
| | "pixel_values": torch.stack(batch_pixels, dim=0), |
| | "patch_attention_mask": torch.stack(batch_masks, dim=0), |
| | } |
| | return BatchFeature(data=data, tensor_type=return_tensors) |
| |
|
| | __call__ = preprocess |
| |
|
| |
|
| | class AnandaProcessor(ProcessorMixin): |
| | attributes = ["image_processor", "tokenizer"] |
| | image_processor_class = "AutoImageProcessor" |
| | tokenizer_class = "AutoTokenizer" |
| | model_input_names = ["input_ids", "attention_mask", "pixel_values", "patch_attention_mask"] |
| |
|
| | def __init__(self, image_processor: AnandaImageProcessor, tokenizer, **kwargs: Any) -> None: |
| | self.image_processor = image_processor |
| | self.tokenizer = tokenizer |
| | self.current_processor = self.image_processor |
| | self._in_target_context_manager = False |
| | super().__init__(image_processor, tokenizer, **kwargs) |
| |
|
| | @classmethod |
| | def from_model_config(cls, tokenizer, model_config: Union[Dict[str, Any], Any]) -> "AnandaProcessor": |
| | image_processor = AnandaImageProcessor.from_model_config(model_config) |
| | return cls(image_processor=image_processor, tokenizer=tokenizer) |
| |
|
| | @classmethod |
| | def from_pretrained( |
| | cls, |
| | pretrained_model_name_or_path: Union[str, os.PathLike], |
| | trust_remote_code: bool = True, |
| | **kwargs: Any, |
| | ) -> "AnandaProcessor": |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | pretrained_model_name_or_path, |
| | trust_remote_code=trust_remote_code, |
| | **kwargs, |
| | ) |
| | image_processor = AnandaImageProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs) |
| | return cls(image_processor=image_processor, tokenizer=tokenizer) |
| |
|
| | def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs: Any) -> List[str]: |
| | os.makedirs(save_directory, exist_ok=True) |
| |
|
| | saved_files: List[str] = [] |
| | saved_files.extend(self.image_processor.save_pretrained(save_directory)) |
| | saved_files.extend(self.tokenizer.save_pretrained(save_directory)) |
| |
|
| | processor_dict = { |
| | "processor_class": self.__class__.__name__, |
| | "auto_map": {"AutoProcessor": "inference_processor.AnandaProcessor"}, |
| | "image_processor": self.image_processor.to_dict(), |
| | } |
| | output_path = os.path.join(save_directory, PROCESSOR_CONFIG_NAME) |
| | with open(output_path, "w", encoding="utf-8") as f: |
| | json.dump(processor_dict, f, ensure_ascii=False, indent=2) |
| | saved_files.append(output_path) |
| | return saved_files |
| |
|
| | def __call__( |
| | self, |
| | text: Optional[Union[str, Sequence[str]]] = None, |
| | images: Optional[Union[ImageLike, Sequence[ImageLike]]] = None, |
| | return_tensors: Optional[Union[str, TensorType]] = None, |
| | add_special_tokens: bool = True, |
| | **kwargs: Any, |
| | ) -> BatchFeature: |
| | if text is None and images is None: |
| | raise ValueError("At least one of `text` or `images` must be provided") |
| |
|
| | encoding: Dict[str, Any] = {} |
| |
|
| | if images is not None: |
| | image_features = self.image_processor(images=images, return_tensors=return_tensors) |
| | encoding.update(image_features) |
| | batch_size = int(image_features["pixel_values"].shape[0]) |
| | else: |
| | batch_size = None |
| |
|
| | if text is None: |
| | bos_id = self.tokenizer.bos_token_id |
| | eos_id = self.tokenizer.eos_token_id |
| | prompt_id = bos_id if bos_id is not None else eos_id |
| | if prompt_id is None: |
| | raise ValueError("Tokenizer must define bos_token_id or eos_token_id.") |
| | if batch_size is None: |
| | batch_size = 1 |
| |
|
| | input_ids = [[int(prompt_id)] for _ in range(batch_size)] |
| | attention_mask = [[1] for _ in range(batch_size)] |
| | if return_tensors == "pt" or return_tensors == TensorType.PYTORCH: |
| | encoding["input_ids"] = torch.tensor(input_ids, dtype=torch.long) |
| | encoding["attention_mask"] = torch.tensor(attention_mask, dtype=torch.long) |
| | else: |
| | encoding["input_ids"] = input_ids |
| | encoding["attention_mask"] = attention_mask |
| | else: |
| | text_encoding = self.tokenizer( |
| | text, |
| | add_special_tokens=add_special_tokens, |
| | return_tensors=return_tensors, |
| | **kwargs, |
| | ) |
| | encoding.update(text_encoding) |
| |
|
| | return BatchFeature(data=encoding, tensor_type=return_tensors) |
| |
|
| | def batch_decode(self, *args: Any, **kwargs: Any): |
| | return self.tokenizer.batch_decode(*args, **kwargs) |
| |
|
| | def decode(self, *args: Any, **kwargs: Any): |
| | return self.tokenizer.decode(*args, **kwargs) |
| |
|
| | def apply_chat_template(self, *args: Any, **kwargs: Any): |
| | return self.tokenizer.apply_chat_template(*args, **kwargs) |
| |
|