#!/usr/bin/env python3 # -*- coding: utf-8 -*- # src/ds_proc.py # ============================================================ # ImageProcessor (AutoImageProcessor integration) # ImageProcessor (AutoImageProcessor 연동) # ============================================================ from typing import Any import numpy as np import torch from transformers import AutoImageProcessor, AutoConfig from transformers.image_processing_base import ImageProcessingMixin from transformers.utils.generic import TensorType try: # Hub/Colab: dynamic module 로딩에서는 상대 import가 정상 from .ds_cfg import BackboneID, BACKBONE_META except ImportError: # 로컬: python script.py 또는 top-level import에서는 절대 import로 fallback from ds_cfg import BackboneID, BACKBONE_META class BackboneMLPHead224ImageProcessor(ImageProcessingMixin): """ This processor performs image preprocessing and outputs {"pixel_values": ...}. 이 processor는 이미지 전처리를 수행하고 {"pixel_values": ...}를 반환함. Key requirements: 핵심 요구사항: 1) save_pretrained() must produce a JSON-serializable preprocessor_config.json. save_pretrained()는 JSON 직렬화 가능한 preprocessor_config.json을 생성해야 함. 2) Runtime-only objects (delegate processor, timm/torchvision transforms) must NOT be serialized. 런타임 객체(delegate processor, timm/torchvision transform)는 절대 직렬화하면 안 됨. 3) Runtime objects are rebuilt at init/load time based on backbone meta. 런타임 객체는 backbone meta에 따라 init/load 시점에 재구성. 4) For reproducibility, use_fast must be explicitly persisted and honored on load. 재현성을 위해 use_fast는 명시적으로 저장되고, 로드시 반드시 반영되어야 함. """ # HF vision models conventionally expect "pixel_values" as the primary input key. # HF vision 모델은 관례적으로 입력 키로 "pixel_values"를 기대. model_input_names = ["pixel_values"] def __init__( self, backbone_name_or_path: BackboneID, is_training: bool = False, # timm 에서 data augmentation 용. use_fast: bool = False, **kwargs, ): # ImageProcessingMixin stores extra kwargs and manages auto_map metadata. # ImageProcessingMixin은 추가 kwargs를 저장하고 auto_map 메타를 관리. super().__init__(**kwargs) # Enforce whitelist via BACKBONE_META to keep behavior stable. # 동작 안정성을 위해 BACKBONE_META 기반 화이트리스트를 강제. - fast fail if backbone_name_or_path not in BACKBONE_META: raise ValueError( f"Unsupported backbone_name_or_path={backbone_name_or_path}. " f"Allowed: {sorted(BACKBONE_META.keys())}" ) # Serializable fields only: these should appear in preprocessor_config.json. # 직렬화 가능한 필드만: 이 값들만 preprocessor_config.json에 들어가야 함 self.backbone_name_or_path = backbone_name_or_path self.is_training = bool(is_training) # Reproducibility switch for transformers processors. # transformers processor의 fast/slow 선택을 재현 가능하게 고정. self.use_fast = bool(use_fast) # Runtime-only fields: must never be serialized. # 런타임 전용 필드: 절대 직렬화되면 안 됨. self._meta = None self._delegate = None self._timm_transform = None self._torchvision_transform = None # Build runtime objects according to backbone type. # backbone type에 따라 런타임 객체를 구성. self._build_runtime() # ============================================================ # Runtime builders # 런타임 빌더 # ============================================================ def _build_runtime(self): """ Build runtime delegate/transform based on BACKBONE_META["type"]. BACKBONE_META["type"]에 따라 런타임 delegate/transform을 구성. """ meta = BACKBONE_META[self.backbone_name_or_path] self._meta = meta # Always reset runtime fields before rebuilding. # 재구성 전 런타임 필드는 항상 초기화. self._delegate = None self._timm_transform = None self._torchvision_transform = None t = meta["type"] if t == "timm_densenet": # timm DenseNet uses timm.data transforms for ImageNet-style preprocessing. # timm DenseNet은 ImageNet 전처리를 위해 timm.data transform을 사용. self._timm_transform = self._build_timm_transform( backbone_id=self.backbone_name_or_path, is_training=self.is_training, ) return if t == "torchvision_densenet": # torchvision DenseNet requires torchvision-style preprocessing (resize/crop/tensor/normalize). # torchvision DenseNet은 torchvision 스타일 전처리(resize/crop/tensor/normalize)가 필요. self._torchvision_transform = self._build_torchvision_densenet_transform( is_training=self.is_training ) return # Default: transformers backbone delegates to its official AutoImageProcessor. # 기본: transformers 백본은 공식 AutoImageProcessor에 위임. # # IMPORTANT: # - use_fast는 transformers 기본값 변경에 흔들리지 않도록 반드시 명시적으로 전달. self._delegate = AutoImageProcessor.from_pretrained( self.backbone_name_or_path, use_fast=self.use_fast, # trust_remote_code = True, ) @staticmethod def _build_timm_transform(*, backbone_id: str, is_training: bool): """ Create timm transform without storing non-serializable objects in config. 비직렬화 객체를 config에 저장하지 않고 timm transform을 생성. """ try: import timm from timm.data import resolve_model_data_config, create_transform except Exception as e: raise ImportError( "timm backbone processor requires `timm`. Install: pip install timm" ) from e # We only need model metadata to resolve data config, so pretrained=False is preferred. # data config 추출만 필요하므로 pretrained=False를 우선 사용. m = timm.create_model(f"hf_hub:{backbone_id}", pretrained=False, num_classes=0) dc = resolve_model_data_config(m) # create_transform returns a torchvision-like callable that maps PIL -> torch.Tensor(C,H,W). # create_transform은 PIL -> torch.Tensor(C,H,W)로 매핑하는 callable을 반환. tfm = create_transform(**dc, is_training=is_training) # is_training :Data Aug. return tfm @staticmethod def _build_torchvision_densenet_transform(*, is_training: bool): """ Build torchvision preprocessing for DenseNet-121 (224 pipeline). DenseNet-121용 torchvision 전처리(224 파이프라인)를 구성. """ try: from torchvision import transforms except Exception as e: raise ImportError( "torchvision DenseNet processor requires `torchvision`. Install: pip install torchvision" ) from e # These are the standard ImageNet normalization stats used by torchvision weights. # 이 값들은 torchvision weights가 사용하는 표준 ImageNet 정규화 통계. mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) # Training pipeline typically uses RandomResizedCrop and horizontal flip. # 학습 파이프라인은 보통 RandomResizedCrop과 좌우반전을 사용. if is_training: return transforms.Compose( [ # transforms.RandomResizedCrop(224), # transforms.RandomHorizontalFlip(p=0.5), transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ] ) # Inference pipeline typically uses Resize(256) + CenterCrop(224). # 추론 파이프라인은 보통 Resize(256) + CenterCrop(224)를 사용. return transforms.Compose( [ transforms.Resize(256), # transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ] ) # ============================================================ # Serialization # 직렬화 # ============================================================ def to_dict(self) -> dict[str, Any]: """ Return a JSON-serializable dict for preprocessor_config.json. preprocessor_config.json에 들어갈 JSON 직렬화 dict를 반환. Important: do not leak runtime objects into the serialized dict. 중요: 런타임 객체가 직렬화 dict에 섞이면 안 됨. """ # ImageProcessingMixin.to_dict() adds metadata such as image_processor_type/auto_map. # ImageProcessingMixin.to_dict()는 image_processor_type/auto_map 같은 메타를 추가합니다. d = super().to_dict() # Force minimal stable fields for long-term compatibility. # 장기 호환을 위해 최소 안정 필드를 강제로 지정. d["image_processor_type"] = self.__class__.__name__ d["backbone_name_or_path"] = self.backbone_name_or_path d["is_training"] = self.is_training d["use_fast"] = self.use_fast # Remove any runtime-only fields defensively. # 런타임 전용 필드는 보수적으로 제거. for key in ["_meta", "_delegate", "_timm_transform", "_torchvision_transform"]: d.pop(key, None) return d @classmethod def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): """ Standard load path used by BaseImageProcessor / AutoImageProcessor. BaseImageProcessor / AutoImageProcessor가 사용하는 표준 로드 경로임. """ backbone = image_processor_dict.get("backbone_name_or_path", None) if backbone is None: raise ValueError("preprocessor_config.json missing key: backbone_name_or_path") is_training = bool(image_processor_dict.get("is_training", False)) use_fast = bool(image_processor_dict.get("use_fast", False)) return cls( backbone_name_or_path=backbone, is_training=is_training, use_fast=use_fast, **kwargs, ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): """ Fallback path if AutoImageProcessor calls class.from_pretrained directly. AutoImageProcessor가 class.from_pretrained를 직접 호출하는 경우를 대비한 메서드. Strategy: 전략: - Read config.json via AutoConfig and recover backbone_name_or_path. AutoConfig로 config.json을 읽고 backbone_name_or_path를 복구. """ # is_training is runtime-only and should default to False for inference/serving. # is_training은 런타임 전용이며 추론/서빙 기본값은 False 임. # # IMPORTANT: # - use_fast는 kwargs로 전달될 수 있으므로, 있으면 반영. use_fast = bool(kwargs.pop("use_fast", False)) kwargs.pop("trust_remote_code", None) cfg = AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code =True, **kwargs) backbone = getattr(cfg, "backbone_name_or_path", None) if backbone is None: raise ValueError("Cannot build processor: backbone_name_or_path not found in config.json") return cls(backbone_name_or_path=backbone, is_training=False, use_fast=use_fast) # ============================================================ # Call interface # 호출 인터페이스 # ============================================================ @staticmethod def _ensure_list(images: Any) -> list[Any]: # Normalize scalar image input to a list for uniform processing. # 단일 입력을 리스트로 정규화하여 동일한 처리 경로를 사용. if isinstance(images, (list, tuple)): return list(images) return [images] @staticmethod def _to_pil_rgb(x: Any): # Convert common image inputs into PIL RGB images. # 일반적인 입력을 PIL RGB 이미지로 변환. from PIL import Image as PILImage if isinstance(x, PILImage.Image): return x.convert("RGB") if isinstance(x, np.ndarray) and x.ndim == 3: return PILImage.fromarray(x).convert("RGB") raise TypeError(f"Unsupported image type: {type(x)}") def __call__( self, images: Any | list[Any], return_tensors: str | TensorType | None = "pt", **kwargs, ) -> dict[str, Any]: """ Convert images into {"pixel_values": Tensor/ndarray}. 이미지를 {"pixel_values": Tensor/ndarray}로 변환. """ images = self._ensure_list(images) # Rebuild runtime if needed (e.g., right after deserialization). # 직렬화 복원 직후 등 런타임이 비어있을 수 있으므로 재구성. if (self._delegate is None) and (self._timm_transform is None) and (self._torchvision_transform is None): self._build_runtime() # timm path: PIL -> torch.Tensor(C,H,W) normalized float32. # timm 경로: PIL -> torch.Tensor(C,H,W) 정규화 float32. if self._timm_transform is not None: pv: list[torch.Tensor] = [] for im in images: pil = self._to_pil_rgb(im) t = self._timm_transform(pil) if not isinstance(t, torch.Tensor): raise RuntimeError("Unexpected timm transform output (expected torch.Tensor).") pv.append(t) pixel_values = torch.stack(pv, dim=0) # (B,C,H,W) return self._format_return(pixel_values, return_tensors) # torchvision path: PIL -> torch.Tensor(C,H,W) normalized float32. # torchvision 경로: PIL -> torch.Tensor(C,H,W) 정규화 float32. if self._torchvision_transform is not None: pv: list[torch.Tensor] = [] for im in images: pil = self._to_pil_rgb(im) t = self._torchvision_transform(pil) if not isinstance(t, torch.Tensor): raise RuntimeError("Unexpected torchvision transform output (expected torch.Tensor).") pv.append(t) pixel_values = torch.stack(pv, dim=0) # (B,C,H,W) return self._format_return(pixel_values, return_tensors) # transformers delegate path: rely on official processor behavior. # transformers 위임 경로: 공식 processor 동작을 그대로 사용. if self._delegate is None: raise RuntimeError("Processor runtime not built: delegate is None and no transforms are available.") return self._delegate(images, return_tensors=return_tensors, **kwargs) @staticmethod def _format_return(pixel_values: torch.Tensor, return_tensors: str | TensorType | None) -> dict[str, Any]: """ Format pixel_values according to return_tensors. return_tensors에 맞춰 pixel_values 반환 포맷을 변환. """ if return_tensors is None or return_tensors in ("pt", TensorType.PYTORCH): return {"pixel_values": pixel_values} if return_tensors in ("np", TensorType.NUMPY): return {"pixel_values": pixel_values.detach().cpu().numpy()} raise ValueError(f"Unsupported return_tensors={return_tensors}. Use 'pt' or 'np'.") # Register this processor for AutoImageProcessor resolution. # AutoImageProcessor 해석을 위해 이 processor를 등록. if __name__ != "__main__": BackboneMLPHead224ImageProcessor.register_for_auto_class("AutoImageProcessor")