import json from pathlib import Path from typing import Any import numpy as np import torch from PIL import Image from transformers.feature_extraction_utils import BatchFeature from transformers.utils import PushToHubMixin class VisionCondenserProcessor(PushToHubMixin): """Resize images to a fixed square tensor batch.""" _auto_class = "AutoProcessor" model_input_names = ["pixel_values"] def __init__( self, image_size: int = 224, **kwargs, ) -> None: kwargs.pop("patch_size", None) kwargs.pop("pool_size", None) kwargs.pop("max_patches", None) self.image_size = int(image_size) self.extra_config = dict(kwargs) if self.image_size <= 0: raise ValueError(f"image_size must be positive, got {image_size}.") @classmethod def register_for_auto_class(cls, auto_class: str = "AutoProcessor"): cls._auto_class = auto_class @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config_path = Path(pretrained_model_name_or_path) / "preprocessor_config.json" with config_path.open("r", encoding="utf-8") as f: config = json.load(f) config.update(kwargs) config.pop("processor_class", None) config.pop("auto_map", None) return cls(**config) def save_pretrained(self, save_directory, **kwargs): del kwargs save_path = Path(save_directory) save_path.mkdir(parents=True, exist_ok=True) config = { "processor_class": self.__class__.__name__, "image_size": self.image_size, **self.extra_config, "auto_map": { "AutoProcessor": "processing_vision_condenser.VisionCondenserProcessor" }, } config_path = save_path / "preprocessor_config.json" with config_path.open("w", encoding="utf-8") as f: json.dump(config, f, indent=2, sort_keys=True) f.write("\n") source_file = Path(__file__) target_file = save_path / source_file.name if not target_file.exists(): target_file.write_bytes(source_file.read_bytes()) return (str(config_path),) def _ensure_pil_image(self, image: Any) -> Image.Image: if isinstance(image, Image.Image): return image return Image.fromarray(np.asarray(image)) def _normalize_images(self, images: Any) -> list[Image.Image]: if isinstance(images, Image.Image): return [images] if not isinstance(images, (list, tuple)) or len(images) == 0: raise ValueError( "images must be a PIL image or a non-empty list of PIL images." ) return [self._ensure_pil_image(image) for image in images] def _image_to_tensor(self, image: Image.Image) -> torch.Tensor: image = image.convert("RGB") image = image.resize( (self.image_size, self.image_size), Image.Resampling.BICUBIC, ) pixels = torch.from_numpy(np.array(image)).permute(2, 0, 1).contiguous() return pixels.float().div(255.0) def __call__( self, images, return_tensors: str | None = None, **kwargs, ) -> BatchFeature: del kwargs if return_tensors not in (None, "pt"): raise ValueError("Only return_tensors='pt' is supported.") pil_images = self._normalize_images(images) pixel_values = torch.stack( [self._image_to_tensor(image) for image in pil_images], dim=0, ).contiguous() data = {"pixel_values": pixel_values} tensor_type = "pt" if return_tensors == "pt" else None return BatchFeature(data=data, tensor_type=tensor_type)