| import json |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.utils import PushToHubMixin |
|
|
|
|
| class VisionCondenserProcessor(PushToHubMixin): |
| """Resize images to a fixed square tensor batch.""" |
|
|
| _auto_class = "AutoProcessor" |
| model_input_names = ["pixel_values"] |
|
|
| def __init__( |
| self, |
| image_size: int = 224, |
| **kwargs, |
| ) -> None: |
| kwargs.pop("patch_size", None) |
| kwargs.pop("pool_size", None) |
| kwargs.pop("max_patches", None) |
| self.image_size = int(image_size) |
| self.extra_config = dict(kwargs) |
| if self.image_size <= 0: |
| raise ValueError(f"image_size must be positive, got {image_size}.") |
|
|
| @classmethod |
| def register_for_auto_class(cls, auto_class: str = "AutoProcessor"): |
| cls._auto_class = auto_class |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| config_path = Path(pretrained_model_name_or_path) / "preprocessor_config.json" |
| with config_path.open("r", encoding="utf-8") as f: |
| config = json.load(f) |
|
|
| config.update(kwargs) |
| config.pop("processor_class", None) |
| config.pop("auto_map", None) |
| return cls(**config) |
|
|
| def save_pretrained(self, save_directory, **kwargs): |
| del kwargs |
| save_path = Path(save_directory) |
| save_path.mkdir(parents=True, exist_ok=True) |
|
|
| config = { |
| "processor_class": self.__class__.__name__, |
| "image_size": self.image_size, |
| **self.extra_config, |
| "auto_map": { |
| "AutoProcessor": "processing_vision_condenser.VisionCondenserProcessor" |
| }, |
| } |
|
|
| config_path = save_path / "preprocessor_config.json" |
| with config_path.open("w", encoding="utf-8") as f: |
| json.dump(config, f, indent=2, sort_keys=True) |
| f.write("\n") |
|
|
| source_file = Path(__file__) |
| target_file = save_path / source_file.name |
| if not target_file.exists(): |
| target_file.write_bytes(source_file.read_bytes()) |
|
|
| return (str(config_path),) |
|
|
| def _ensure_pil_image(self, image: Any) -> Image.Image: |
| if isinstance(image, Image.Image): |
| return image |
| return Image.fromarray(np.asarray(image)) |
|
|
| def _normalize_images(self, images: Any) -> list[Image.Image]: |
| if isinstance(images, Image.Image): |
| return [images] |
| if not isinstance(images, (list, tuple)) or len(images) == 0: |
| raise ValueError( |
| "images must be a PIL image or a non-empty list of PIL images." |
| ) |
| return [self._ensure_pil_image(image) for image in images] |
|
|
| def _image_to_tensor(self, image: Image.Image) -> torch.Tensor: |
| image = image.convert("RGB") |
| image = image.resize( |
| (self.image_size, self.image_size), |
| Image.Resampling.BICUBIC, |
| ) |
| pixels = torch.from_numpy(np.array(image)).permute(2, 0, 1).contiguous() |
| return pixels.float().div(255.0) |
|
|
| def __call__( |
| self, |
| images, |
| return_tensors: str | None = None, |
| **kwargs, |
| ) -> BatchFeature: |
| del kwargs |
| if return_tensors not in (None, "pt"): |
| raise ValueError("Only return_tensors='pt' is supported.") |
|
|
| pil_images = self._normalize_images(images) |
| pixel_values = torch.stack( |
| [self._image_to_tensor(image) for image in pil_images], |
| dim=0, |
| ).contiguous() |
|
|
| data = {"pixel_values": pixel_values} |
| tensor_type = "pt" if return_tensors == "pt" else None |
| return BatchFeature(data=data, tensor_type=tensor_type) |
|
|