fae_siglip2 / processing_vision_condenser.py
toilaluan's picture
Upload folder using huggingface_hub
76bd3f0 verified
import json
from pathlib import Path
from typing import Any
import numpy as np
import torch
from PIL import Image
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils import PushToHubMixin
class VisionCondenserProcessor(PushToHubMixin):
"""Resize images to a fixed square tensor batch."""
_auto_class = "AutoProcessor"
model_input_names = ["pixel_values"]
def __init__(
self,
image_size: int = 224,
**kwargs,
) -> None:
kwargs.pop("patch_size", None)
kwargs.pop("pool_size", None)
kwargs.pop("max_patches", None)
self.image_size = int(image_size)
self.extra_config = dict(kwargs)
if self.image_size <= 0:
raise ValueError(f"image_size must be positive, got {image_size}.")
@classmethod
def register_for_auto_class(cls, auto_class: str = "AutoProcessor"):
cls._auto_class = auto_class
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config_path = Path(pretrained_model_name_or_path) / "preprocessor_config.json"
with config_path.open("r", encoding="utf-8") as f:
config = json.load(f)
config.update(kwargs)
config.pop("processor_class", None)
config.pop("auto_map", None)
return cls(**config)
def save_pretrained(self, save_directory, **kwargs):
del kwargs
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
config = {
"processor_class": self.__class__.__name__,
"image_size": self.image_size,
**self.extra_config,
"auto_map": {
"AutoProcessor": "processing_vision_condenser.VisionCondenserProcessor"
},
}
config_path = save_path / "preprocessor_config.json"
with config_path.open("w", encoding="utf-8") as f:
json.dump(config, f, indent=2, sort_keys=True)
f.write("\n")
source_file = Path(__file__)
target_file = save_path / source_file.name
if not target_file.exists():
target_file.write_bytes(source_file.read_bytes())
return (str(config_path),)
def _ensure_pil_image(self, image: Any) -> Image.Image:
if isinstance(image, Image.Image):
return image
return Image.fromarray(np.asarray(image))
def _normalize_images(self, images: Any) -> list[Image.Image]:
if isinstance(images, Image.Image):
return [images]
if not isinstance(images, (list, tuple)) or len(images) == 0:
raise ValueError(
"images must be a PIL image or a non-empty list of PIL images."
)
return [self._ensure_pil_image(image) for image in images]
def _image_to_tensor(self, image: Image.Image) -> torch.Tensor:
image = image.convert("RGB")
image = image.resize(
(self.image_size, self.image_size),
Image.Resampling.BICUBIC,
)
pixels = torch.from_numpy(np.array(image)).permute(2, 0, 1).contiguous()
return pixels.float().div(255.0)
def __call__(
self,
images,
return_tensors: str | None = None,
**kwargs,
) -> BatchFeature:
del kwargs
if return_tensors not in (None, "pt"):
raise ValueError("Only return_tensors='pt' is supported.")
pil_images = self._normalize_images(images)
pixel_values = torch.stack(
[self._image_to_tensor(image) for image in pil_images],
dim=0,
).contiguous()
data = {"pixel_values": pixel_values}
tensor_type = "pt" if return_tensors == "pt" else None
return BatchFeature(data=data, tensor_type=tensor_type)