|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import AutoImageProcessor, AutoConfig |
|
|
from transformers.image_processing_base import ImageProcessingMixin |
|
|
from transformers.utils.generic import TensorType |
|
|
|
|
|
try: |
|
|
|
|
|
from .ds_cfg import BackboneID, BACKBONE_META |
|
|
except ImportError: |
|
|
|
|
|
from ds_cfg import BackboneID, BACKBONE_META |
|
|
|
|
|
|
|
|
class BackboneMLPHead224ImageProcessor(ImageProcessingMixin): |
|
|
""" |
|
|
This processor performs image preprocessing and outputs {"pixel_values": ...}. |
|
|
์ด processor๋ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ๋ฅผ ์ํํ๊ณ {"pixel_values": ...}๋ฅผ ๋ฐํํจ. |
|
|
|
|
|
Key requirements: |
|
|
ํต์ฌ ์๊ตฌ์ฌํญ: |
|
|
|
|
|
1) save_pretrained() must produce a JSON-serializable preprocessor_config.json. |
|
|
save_pretrained()๋ JSON ์ง๋ ฌํ ๊ฐ๋ฅํ preprocessor_config.json์ ์์ฑํด์ผ ํจ. |
|
|
2) Runtime-only objects (delegate processor, timm/torchvision transforms) must NOT be serialized. |
|
|
๋ฐํ์ ๊ฐ์ฒด(delegate processor, timm/torchvision transform)๋ ์ ๋ ์ง๋ ฌํํ๋ฉด ์ ๋จ. |
|
|
3) Runtime objects are rebuilt at init/load time based on backbone meta. |
|
|
๋ฐํ์ ๊ฐ์ฒด๋ backbone meta์ ๋ฐ๋ผ init/load ์์ ์ ์ฌ๊ตฌ์ฑ. |
|
|
4) For reproducibility, use_fast must be explicitly persisted and honored on load. |
|
|
์ฌํ์ฑ์ ์ํด use_fast๋ ๋ช
์์ ์ผ๋ก ์ ์ฅ๋๊ณ , ๋ก๋์ ๋ฐ๋์ ๋ฐ์๋์ด์ผ ํจ. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
model_input_names = ["pixel_values"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
backbone_name_or_path: BackboneID, |
|
|
is_training: bool = False, |
|
|
use_fast: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
if backbone_name_or_path not in BACKBONE_META: |
|
|
raise ValueError( |
|
|
f"Unsupported backbone_name_or_path={backbone_name_or_path}. " |
|
|
f"Allowed: {sorted(BACKBONE_META.keys())}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.backbone_name_or_path = backbone_name_or_path |
|
|
self.is_training = bool(is_training) |
|
|
|
|
|
|
|
|
|
|
|
self.use_fast = bool(use_fast) |
|
|
|
|
|
|
|
|
|
|
|
self._meta = None |
|
|
self._delegate = None |
|
|
self._timm_transform = None |
|
|
self._torchvision_transform = None |
|
|
|
|
|
|
|
|
|
|
|
self._build_runtime() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_runtime(self): |
|
|
""" |
|
|
Build runtime delegate/transform based on BACKBONE_META["type"]. |
|
|
BACKBONE_META["type"]์ ๋ฐ๋ผ ๋ฐํ์ delegate/transform์ ๊ตฌ์ฑ. |
|
|
""" |
|
|
meta = BACKBONE_META[self.backbone_name_or_path] |
|
|
self._meta = meta |
|
|
|
|
|
|
|
|
|
|
|
self._delegate = None |
|
|
self._timm_transform = None |
|
|
self._torchvision_transform = None |
|
|
|
|
|
t = meta["type"] |
|
|
|
|
|
if t == "timm_densenet": |
|
|
|
|
|
|
|
|
self._timm_transform = self._build_timm_transform( |
|
|
backbone_id=self.backbone_name_or_path, |
|
|
is_training=self.is_training, |
|
|
) |
|
|
return |
|
|
|
|
|
if t == "torchvision_densenet": |
|
|
|
|
|
|
|
|
self._torchvision_transform = self._build_torchvision_densenet_transform( |
|
|
is_training=self.is_training |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._delegate = AutoImageProcessor.from_pretrained( |
|
|
self.backbone_name_or_path, |
|
|
use_fast=self.use_fast, |
|
|
|
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _build_timm_transform(*, backbone_id: str, is_training: bool): |
|
|
""" |
|
|
Create timm transform without storing non-serializable objects in config. |
|
|
๋น์ง๋ ฌํ ๊ฐ์ฒด๋ฅผ config์ ์ ์ฅํ์ง ์๊ณ timm transform์ ์์ฑ. |
|
|
""" |
|
|
try: |
|
|
import timm |
|
|
from timm.data import resolve_model_data_config, create_transform |
|
|
except Exception as e: |
|
|
raise ImportError( |
|
|
"timm backbone processor requires `timm`. Install: pip install timm" |
|
|
) from e |
|
|
|
|
|
|
|
|
|
|
|
m = timm.create_model(f"hf_hub:{backbone_id}", pretrained=False, num_classes=0) |
|
|
dc = resolve_model_data_config(m) |
|
|
|
|
|
|
|
|
|
|
|
tfm = create_transform(**dc, is_training=is_training) |
|
|
return tfm |
|
|
|
|
|
@staticmethod |
|
|
def _build_torchvision_densenet_transform(*, is_training: bool): |
|
|
""" |
|
|
Build torchvision preprocessing for DenseNet-121 (224 pipeline). |
|
|
DenseNet-121์ฉ torchvision ์ ์ฒ๋ฆฌ(224 ํ์ดํ๋ผ์ธ)๋ฅผ ๊ตฌ์ฑ. |
|
|
""" |
|
|
try: |
|
|
from torchvision import transforms |
|
|
except Exception as e: |
|
|
raise ImportError( |
|
|
"torchvision DenseNet processor requires `torchvision`. Install: pip install torchvision" |
|
|
) from e |
|
|
|
|
|
|
|
|
|
|
|
mean = (0.485, 0.456, 0.406) |
|
|
std = (0.229, 0.224, 0.225) |
|
|
|
|
|
|
|
|
|
|
|
if is_training: |
|
|
return transforms.Compose( |
|
|
[ |
|
|
|
|
|
|
|
|
transforms.Resize(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=mean, std=std), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
return transforms.Compose( |
|
|
[ |
|
|
transforms.Resize(256), |
|
|
|
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=mean, std=std), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_dict(self) -> dict[str, Any]: |
|
|
""" |
|
|
Return a JSON-serializable dict for preprocessor_config.json. |
|
|
preprocessor_config.json์ ๋ค์ด๊ฐ JSON ์ง๋ ฌํ dict๋ฅผ ๋ฐํ. |
|
|
|
|
|
Important: do not leak runtime objects into the serialized dict. |
|
|
์ค์: ๋ฐํ์ ๊ฐ์ฒด๊ฐ ์ง๋ ฌํ dict์ ์์ด๋ฉด ์ ๋จ. |
|
|
""" |
|
|
|
|
|
|
|
|
d = super().to_dict() |
|
|
|
|
|
|
|
|
|
|
|
d["image_processor_type"] = self.__class__.__name__ |
|
|
d["backbone_name_or_path"] = self.backbone_name_or_path |
|
|
d["is_training"] = self.is_training |
|
|
d["use_fast"] = self.use_fast |
|
|
|
|
|
|
|
|
|
|
|
for key in ["_meta", "_delegate", "_timm_transform", "_torchvision_transform"]: |
|
|
d.pop(key, None) |
|
|
|
|
|
return d |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): |
|
|
""" |
|
|
Standard load path used by BaseImageProcessor / AutoImageProcessor. |
|
|
BaseImageProcessor / AutoImageProcessor๊ฐ ์ฌ์ฉํ๋ ํ์ค ๋ก๋ ๊ฒฝ๋ก์. |
|
|
""" |
|
|
backbone = image_processor_dict.get("backbone_name_or_path", None) |
|
|
if backbone is None: |
|
|
raise ValueError("preprocessor_config.json missing key: backbone_name_or_path") |
|
|
|
|
|
is_training = bool(image_processor_dict.get("is_training", False)) |
|
|
use_fast = bool(image_processor_dict.get("use_fast", False)) |
|
|
|
|
|
return cls( |
|
|
backbone_name_or_path=backbone, |
|
|
is_training=is_training, |
|
|
use_fast=use_fast, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
|
|
""" |
|
|
Fallback path if AutoImageProcessor calls class.from_pretrained directly. |
|
|
AutoImageProcessor๊ฐ class.from_pretrained๋ฅผ ์ง์ ํธ์ถํ๋ ๊ฒฝ์ฐ๋ฅผ ๋๋นํ ๋ฉ์๋. |
|
|
|
|
|
Strategy: |
|
|
์ ๋ต: |
|
|
|
|
|
- Read config.json via AutoConfig and recover backbone_name_or_path. |
|
|
AutoConfig๋ก config.json์ ์ฝ๊ณ backbone_name_or_path๋ฅผ ๋ณต๊ตฌ. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_fast = bool(kwargs.pop("use_fast", False)) |
|
|
|
|
|
kwargs.pop("trust_remote_code", None) |
|
|
cfg = AutoConfig.from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
trust_remote_code =True, |
|
|
**kwargs) |
|
|
backbone = getattr(cfg, "backbone_name_or_path", None) |
|
|
if backbone is None: |
|
|
raise ValueError("Cannot build processor: backbone_name_or_path not found in config.json") |
|
|
|
|
|
return cls(backbone_name_or_path=backbone, is_training=False, use_fast=use_fast) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _ensure_list(images: Any) -> list[Any]: |
|
|
|
|
|
|
|
|
if isinstance(images, (list, tuple)): |
|
|
return list(images) |
|
|
return [images] |
|
|
|
|
|
@staticmethod |
|
|
def _to_pil_rgb(x: Any): |
|
|
|
|
|
|
|
|
from PIL import Image as PILImage |
|
|
|
|
|
if isinstance(x, PILImage.Image): |
|
|
return x.convert("RGB") |
|
|
if isinstance(x, np.ndarray) and x.ndim == 3: |
|
|
return PILImage.fromarray(x).convert("RGB") |
|
|
raise TypeError(f"Unsupported image type: {type(x)}") |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: Any | list[Any], |
|
|
return_tensors: str | TensorType | None = "pt", |
|
|
**kwargs, |
|
|
) -> dict[str, Any]: |
|
|
""" |
|
|
Convert images into {"pixel_values": Tensor/ndarray}. |
|
|
์ด๋ฏธ์ง๋ฅผ {"pixel_values": Tensor/ndarray}๋ก ๋ณํ. |
|
|
""" |
|
|
images = self._ensure_list(images) |
|
|
|
|
|
|
|
|
|
|
|
if (self._delegate is None) and (self._timm_transform is None) and (self._torchvision_transform is None): |
|
|
self._build_runtime() |
|
|
|
|
|
|
|
|
|
|
|
if self._timm_transform is not None: |
|
|
pv: list[torch.Tensor] = [] |
|
|
for im in images: |
|
|
pil = self._to_pil_rgb(im) |
|
|
t = self._timm_transform(pil) |
|
|
if not isinstance(t, torch.Tensor): |
|
|
raise RuntimeError("Unexpected timm transform output (expected torch.Tensor).") |
|
|
pv.append(t) |
|
|
pixel_values = torch.stack(pv, dim=0) |
|
|
return self._format_return(pixel_values, return_tensors) |
|
|
|
|
|
|
|
|
|
|
|
if self._torchvision_transform is not None: |
|
|
pv: list[torch.Tensor] = [] |
|
|
for im in images: |
|
|
pil = self._to_pil_rgb(im) |
|
|
t = self._torchvision_transform(pil) |
|
|
if not isinstance(t, torch.Tensor): |
|
|
raise RuntimeError("Unexpected torchvision transform output (expected torch.Tensor).") |
|
|
pv.append(t) |
|
|
pixel_values = torch.stack(pv, dim=0) |
|
|
return self._format_return(pixel_values, return_tensors) |
|
|
|
|
|
|
|
|
|
|
|
if self._delegate is None: |
|
|
raise RuntimeError("Processor runtime not built: delegate is None and no transforms are available.") |
|
|
|
|
|
return self._delegate(images, return_tensors=return_tensors, **kwargs) |
|
|
|
|
|
@staticmethod |
|
|
def _format_return(pixel_values: torch.Tensor, return_tensors: str | TensorType | None) -> dict[str, Any]: |
|
|
""" |
|
|
Format pixel_values according to return_tensors. |
|
|
return_tensors์ ๋ง์ถฐ pixel_values ๋ฐํ ํฌ๋งท์ ๋ณํ. |
|
|
""" |
|
|
if return_tensors is None or return_tensors in ("pt", TensorType.PYTORCH): |
|
|
return {"pixel_values": pixel_values} |
|
|
if return_tensors in ("np", TensorType.NUMPY): |
|
|
return {"pixel_values": pixel_values.detach().cpu().numpy()} |
|
|
raise ValueError(f"Unsupported return_tensors={return_tensors}. Use 'pt' or 'np'.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ != "__main__": |
|
|
BackboneMLPHead224ImageProcessor.register_for_auto_class("AutoImageProcessor") |
|
|
|