dsaint31's picture
Add/Update backbone checkpoints (count=6)
200cb5d verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# src/ds_proc.py
# ============================================================
# ImageProcessor (AutoImageProcessor integration)
# ImageProcessor (AutoImageProcessor 연동)
# ============================================================
from typing import Any
import numpy as np
import torch
from transformers import AutoImageProcessor, AutoConfig
from transformers.image_processing_base import ImageProcessingMixin
from transformers.utils.generic import TensorType
try:
# Hub/Colab: dynamic module 로딩에서는 상대 import가 정상
from .ds_cfg import BackboneID, BACKBONE_META
except ImportError:
# 로컬: python script.py 또는 top-level import에서는 절대 import로 fallback
from ds_cfg import BackboneID, BACKBONE_META
class BackboneMLPHead224ImageProcessor(ImageProcessingMixin):
"""
This processor performs image preprocessing and outputs {"pixel_values": ...}.
이 processor는 이미지 전처리를 수행하고 {"pixel_values": ...}를 반환함.
Key requirements:
핵심 요구사항:
1) save_pretrained() must produce a JSON-serializable preprocessor_config.json.
save_pretrained()는 JSON 직렬화 가능한 preprocessor_config.json을 생성해야 함.
2) Runtime-only objects (delegate processor, timm/torchvision transforms) must NOT be serialized.
런타임 객체(delegate processor, timm/torchvision transform)는 절대 직렬화하면 안 됨.
3) Runtime objects are rebuilt at init/load time based on backbone meta.
런타임 객체는 backbone meta에 따라 init/load 시점에 재구성.
4) For reproducibility, use_fast must be explicitly persisted and honored on load.
재현성을 위해 use_fast는 명시적으로 저장되고, 로드시 반드시 반영되어야 함.
"""
# HF vision models conventionally expect "pixel_values" as the primary input key.
# HF vision 모델은 관례적으로 입력 키로 "pixel_values"를 기대.
model_input_names = ["pixel_values"]
def __init__(
self,
backbone_name_or_path: BackboneID,
is_training: bool = False, # timm 에서 data augmentation 용.
use_fast: bool = False,
**kwargs,
):
# ImageProcessingMixin stores extra kwargs and manages auto_map metadata.
# ImageProcessingMixin은 추가 kwargs를 저장하고 auto_map 메타를 관리.
super().__init__(**kwargs)
# Enforce whitelist via BACKBONE_META to keep behavior stable.
# 동작 안정성을 위해 BACKBONE_META 기반 화이트리스트를 강제. - fast fail
if backbone_name_or_path not in BACKBONE_META:
raise ValueError(
f"Unsupported backbone_name_or_path={backbone_name_or_path}. "
f"Allowed: {sorted(BACKBONE_META.keys())}"
)
# Serializable fields only: these should appear in preprocessor_config.json.
# 직렬화 가능한 필드만: 이 값들만 preprocessor_config.json에 들어가야 함
self.backbone_name_or_path = backbone_name_or_path
self.is_training = bool(is_training)
# Reproducibility switch for transformers processors.
# transformers processor의 fast/slow 선택을 재현 가능하게 고정.
self.use_fast = bool(use_fast)
# Runtime-only fields: must never be serialized.
# 런타임 전용 필드: 절대 직렬화되면 안 됨.
self._meta = None
self._delegate = None
self._timm_transform = None
self._torchvision_transform = None
# Build runtime objects according to backbone type.
# backbone type에 따라 런타임 객체를 구성.
self._build_runtime()
# ============================================================
# Runtime builders
# 런타임 빌더
# ============================================================
def _build_runtime(self):
"""
Build runtime delegate/transform based on BACKBONE_META["type"].
BACKBONE_META["type"]에 따라 런타임 delegate/transform을 구성.
"""
meta = BACKBONE_META[self.backbone_name_or_path]
self._meta = meta
# Always reset runtime fields before rebuilding.
# 재구성 전 런타임 필드는 항상 초기화.
self._delegate = None
self._timm_transform = None
self._torchvision_transform = None
t = meta["type"]
if t == "timm_densenet":
# timm DenseNet uses timm.data transforms for ImageNet-style preprocessing.
# timm DenseNet은 ImageNet 전처리를 위해 timm.data transform을 사용.
self._timm_transform = self._build_timm_transform(
backbone_id=self.backbone_name_or_path,
is_training=self.is_training,
)
return
if t == "torchvision_densenet":
# torchvision DenseNet requires torchvision-style preprocessing (resize/crop/tensor/normalize).
# torchvision DenseNet은 torchvision 스타일 전처리(resize/crop/tensor/normalize)가 필요.
self._torchvision_transform = self._build_torchvision_densenet_transform(
is_training=self.is_training
)
return
# Default: transformers backbone delegates to its official AutoImageProcessor.
# 기본: transformers 백본은 공식 AutoImageProcessor에 위임.
#
# IMPORTANT:
# - use_fast는 transformers 기본값 변경에 흔들리지 않도록 반드시 명시적으로 전달.
self._delegate = AutoImageProcessor.from_pretrained(
self.backbone_name_or_path,
use_fast=self.use_fast,
# trust_remote_code = True,
)
@staticmethod
def _build_timm_transform(*, backbone_id: str, is_training: bool):
"""
Create timm transform without storing non-serializable objects in config.
비직렬화 객체를 config에 저장하지 않고 timm transform을 생성.
"""
try:
import timm
from timm.data import resolve_model_data_config, create_transform
except Exception as e:
raise ImportError(
"timm backbone processor requires `timm`. Install: pip install timm"
) from e
# We only need model metadata to resolve data config, so pretrained=False is preferred.
# data config 추출만 필요하므로 pretrained=False를 우선 사용.
m = timm.create_model(f"hf_hub:{backbone_id}", pretrained=False, num_classes=0)
dc = resolve_model_data_config(m)
# create_transform returns a torchvision-like callable that maps PIL -> torch.Tensor(C,H,W).
# create_transform은 PIL -> torch.Tensor(C,H,W)로 매핑하는 callable을 반환.
tfm = create_transform(**dc, is_training=is_training) # is_training :Data Aug.
return tfm
@staticmethod
def _build_torchvision_densenet_transform(*, is_training: bool):
"""
Build torchvision preprocessing for DenseNet-121 (224 pipeline).
DenseNet-121용 torchvision 전처리(224 파이프라인)를 구성.
"""
try:
from torchvision import transforms
except Exception as e:
raise ImportError(
"torchvision DenseNet processor requires `torchvision`. Install: pip install torchvision"
) from e
# These are the standard ImageNet normalization stats used by torchvision weights.
# 이 값들은 torchvision weights가 사용하는 표준 ImageNet 정규화 통계.
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
# Training pipeline typically uses RandomResizedCrop and horizontal flip.
# 학습 파이프라인은 보통 RandomResizedCrop과 좌우반전을 사용.
if is_training:
return transforms.Compose(
[
# transforms.RandomResizedCrop(224),
# transforms.RandomHorizontalFlip(p=0.5),
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)
# Inference pipeline typically uses Resize(256) + CenterCrop(224).
# 추론 파이프라인은 보통 Resize(256) + CenterCrop(224)를 사용.
return transforms.Compose(
[
transforms.Resize(256),
# transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)
# ============================================================
# Serialization
# 직렬화
# ============================================================
def to_dict(self) -> dict[str, Any]:
"""
Return a JSON-serializable dict for preprocessor_config.json.
preprocessor_config.json에 들어갈 JSON 직렬화 dict를 반환.
Important: do not leak runtime objects into the serialized dict.
중요: 런타임 객체가 직렬화 dict에 섞이면 안 됨.
"""
# ImageProcessingMixin.to_dict() adds metadata such as image_processor_type/auto_map.
# ImageProcessingMixin.to_dict()는 image_processor_type/auto_map 같은 메타를 추가합니다.
d = super().to_dict()
# Force minimal stable fields for long-term compatibility.
# 장기 호환을 위해 최소 안정 필드를 강제로 지정.
d["image_processor_type"] = self.__class__.__name__
d["backbone_name_or_path"] = self.backbone_name_or_path
d["is_training"] = self.is_training
d["use_fast"] = self.use_fast
# Remove any runtime-only fields defensively.
# 런타임 전용 필드는 보수적으로 제거.
for key in ["_meta", "_delegate", "_timm_transform", "_torchvision_transform"]:
d.pop(key, None)
return d
@classmethod
def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs):
"""
Standard load path used by BaseImageProcessor / AutoImageProcessor.
BaseImageProcessor / AutoImageProcessor가 사용하는 표준 로드 경로임.
"""
backbone = image_processor_dict.get("backbone_name_or_path", None)
if backbone is None:
raise ValueError("preprocessor_config.json missing key: backbone_name_or_path")
is_training = bool(image_processor_dict.get("is_training", False))
use_fast = bool(image_processor_dict.get("use_fast", False))
return cls(
backbone_name_or_path=backbone,
is_training=is_training,
use_fast=use_fast,
**kwargs,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
"""
Fallback path if AutoImageProcessor calls class.from_pretrained directly.
AutoImageProcessor가 class.from_pretrained를 직접 호출하는 경우를 대비한 메서드.
Strategy:
전략:
- Read config.json via AutoConfig and recover backbone_name_or_path.
AutoConfig로 config.json을 읽고 backbone_name_or_path를 복구.
"""
# is_training is runtime-only and should default to False for inference/serving.
# is_training은 런타임 전용이며 추론/서빙 기본값은 False 임.
#
# IMPORTANT:
# - use_fast는 kwargs로 전달될 수 있으므로, 있으면 반영.
use_fast = bool(kwargs.pop("use_fast", False))
kwargs.pop("trust_remote_code", None)
cfg = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code =True,
**kwargs)
backbone = getattr(cfg, "backbone_name_or_path", None)
if backbone is None:
raise ValueError("Cannot build processor: backbone_name_or_path not found in config.json")
return cls(backbone_name_or_path=backbone, is_training=False, use_fast=use_fast)
# ============================================================
# Call interface
# 호출 인터페이스
# ============================================================
@staticmethod
def _ensure_list(images: Any) -> list[Any]:
# Normalize scalar image input to a list for uniform processing.
# 단일 입력을 리스트로 정규화하여 동일한 처리 경로를 사용.
if isinstance(images, (list, tuple)):
return list(images)
return [images]
@staticmethod
def _to_pil_rgb(x: Any):
# Convert common image inputs into PIL RGB images.
# 일반적인 입력을 PIL RGB 이미지로 변환.
from PIL import Image as PILImage
if isinstance(x, PILImage.Image):
return x.convert("RGB")
if isinstance(x, np.ndarray) and x.ndim == 3:
return PILImage.fromarray(x).convert("RGB")
raise TypeError(f"Unsupported image type: {type(x)}")
def __call__(
self,
images: Any | list[Any],
return_tensors: str | TensorType | None = "pt",
**kwargs,
) -> dict[str, Any]:
"""
Convert images into {"pixel_values": Tensor/ndarray}.
이미지를 {"pixel_values": Tensor/ndarray}로 변환.
"""
images = self._ensure_list(images)
# Rebuild runtime if needed (e.g., right after deserialization).
# 직렬화 복원 직후 등 런타임이 비어있을 수 있으므로 재구성.
if (self._delegate is None) and (self._timm_transform is None) and (self._torchvision_transform is None):
self._build_runtime()
# timm path: PIL -> torch.Tensor(C,H,W) normalized float32.
# timm 경로: PIL -> torch.Tensor(C,H,W) 정규화 float32.
if self._timm_transform is not None:
pv: list[torch.Tensor] = []
for im in images:
pil = self._to_pil_rgb(im)
t = self._timm_transform(pil)
if not isinstance(t, torch.Tensor):
raise RuntimeError("Unexpected timm transform output (expected torch.Tensor).")
pv.append(t)
pixel_values = torch.stack(pv, dim=0) # (B,C,H,W)
return self._format_return(pixel_values, return_tensors)
# torchvision path: PIL -> torch.Tensor(C,H,W) normalized float32.
# torchvision 경로: PIL -> torch.Tensor(C,H,W) 정규화 float32.
if self._torchvision_transform is not None:
pv: list[torch.Tensor] = []
for im in images:
pil = self._to_pil_rgb(im)
t = self._torchvision_transform(pil)
if not isinstance(t, torch.Tensor):
raise RuntimeError("Unexpected torchvision transform output (expected torch.Tensor).")
pv.append(t)
pixel_values = torch.stack(pv, dim=0) # (B,C,H,W)
return self._format_return(pixel_values, return_tensors)
# transformers delegate path: rely on official processor behavior.
# transformers 위임 경로: 공식 processor 동작을 그대로 사용.
if self._delegate is None:
raise RuntimeError("Processor runtime not built: delegate is None and no transforms are available.")
return self._delegate(images, return_tensors=return_tensors, **kwargs)
@staticmethod
def _format_return(pixel_values: torch.Tensor, return_tensors: str | TensorType | None) -> dict[str, Any]:
"""
Format pixel_values according to return_tensors.
return_tensors에 맞춰 pixel_values 반환 포맷을 변환.
"""
if return_tensors is None or return_tensors in ("pt", TensorType.PYTORCH):
return {"pixel_values": pixel_values}
if return_tensors in ("np", TensorType.NUMPY):
return {"pixel_values": pixel_values.detach().cpu().numpy()}
raise ValueError(f"Unsupported return_tensors={return_tensors}. Use 'pt' or 'np'.")
# Register this processor for AutoImageProcessor resolution.
# AutoImageProcessor 해석을 위해 이 processor를 등록.
if __name__ != "__main__":
BackboneMLPHead224ImageProcessor.register_for_auto_class("AutoImageProcessor")