bb_mlp_224 / ds_proc.py
dsaint31's picture
Add/Update backbone checkpoints (count=6)
200cb5d verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# src/ds_proc.py
# ============================================================
# ImageProcessor (AutoImageProcessor integration)
# ImageProcessor (AutoImageProcessor ์—ฐ๋™)
# ============================================================
from typing import Any
import numpy as np
import torch
from transformers import AutoImageProcessor, AutoConfig
from transformers.image_processing_base import ImageProcessingMixin
from transformers.utils.generic import TensorType
try:
# Hub/Colab: dynamic module ๋กœ๋”ฉ์—์„œ๋Š” ์ƒ๋Œ€ import๊ฐ€ ์ •์ƒ
from .ds_cfg import BackboneID, BACKBONE_META
except ImportError:
# ๋กœ์ปฌ: python script.py ๋˜๋Š” top-level import์—์„œ๋Š” ์ ˆ๋Œ€ import๋กœ fallback
from ds_cfg import BackboneID, BACKBONE_META
class BackboneMLPHead224ImageProcessor(ImageProcessingMixin):
"""
This processor performs image preprocessing and outputs {"pixel_values": ...}.
์ด processor๋Š” ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ณ  {"pixel_values": ...}๋ฅผ ๋ฐ˜ํ™˜ํ•จ.
Key requirements:
ํ•ต์‹ฌ ์š”๊ตฌ์‚ฌํ•ญ:
1) save_pretrained() must produce a JSON-serializable preprocessor_config.json.
save_pretrained()๋Š” JSON ์ง๋ ฌํ™” ๊ฐ€๋Šฅํ•œ preprocessor_config.json์„ ์ƒ์„ฑํ•ด์•ผ ํ•จ.
2) Runtime-only objects (delegate processor, timm/torchvision transforms) must NOT be serialized.
๋Ÿฐํƒ€์ž„ ๊ฐ์ฒด(delegate processor, timm/torchvision transform)๋Š” ์ ˆ๋Œ€ ์ง๋ ฌํ™”ํ•˜๋ฉด ์•ˆ ๋จ.
3) Runtime objects are rebuilt at init/load time based on backbone meta.
๋Ÿฐํƒ€์ž„ ๊ฐ์ฒด๋Š” backbone meta์— ๋”ฐ๋ผ init/load ์‹œ์ ์— ์žฌ๊ตฌ์„ฑ.
4) For reproducibility, use_fast must be explicitly persisted and honored on load.
์žฌํ˜„์„ฑ์„ ์œ„ํ•ด use_fast๋Š” ๋ช…์‹œ์ ์œผ๋กœ ์ €์žฅ๋˜๊ณ , ๋กœ๋“œ์‹œ ๋ฐ˜๋“œ์‹œ ๋ฐ˜์˜๋˜์–ด์•ผ ํ•จ.
"""
# HF vision models conventionally expect "pixel_values" as the primary input key.
# HF vision ๋ชจ๋ธ์€ ๊ด€๋ก€์ ์œผ๋กœ ์ž…๋ ฅ ํ‚ค๋กœ "pixel_values"๋ฅผ ๊ธฐ๋Œ€.
model_input_names = ["pixel_values"]
def __init__(
self,
backbone_name_or_path: BackboneID,
is_training: bool = False, # timm ์—์„œ data augmentation ์šฉ.
use_fast: bool = False,
**kwargs,
):
# ImageProcessingMixin stores extra kwargs and manages auto_map metadata.
# ImageProcessingMixin์€ ์ถ”๊ฐ€ kwargs๋ฅผ ์ €์žฅํ•˜๊ณ  auto_map ๋ฉ”ํƒ€๋ฅผ ๊ด€๋ฆฌ.
super().__init__(**kwargs)
# Enforce whitelist via BACKBONE_META to keep behavior stable.
# ๋™์ž‘ ์•ˆ์ •์„ฑ์„ ์œ„ํ•ด BACKBONE_META ๊ธฐ๋ฐ˜ ํ™”์ดํŠธ๋ฆฌ์ŠคํŠธ๋ฅผ ๊ฐ•์ œ. - fast fail
if backbone_name_or_path not in BACKBONE_META:
raise ValueError(
f"Unsupported backbone_name_or_path={backbone_name_or_path}. "
f"Allowed: {sorted(BACKBONE_META.keys())}"
)
# Serializable fields only: these should appear in preprocessor_config.json.
# ์ง๋ ฌํ™” ๊ฐ€๋Šฅํ•œ ํ•„๋“œ๋งŒ: ์ด ๊ฐ’๋“ค๋งŒ preprocessor_config.json์— ๋“ค์–ด๊ฐ€์•ผ ํ•จ
self.backbone_name_or_path = backbone_name_or_path
self.is_training = bool(is_training)
# Reproducibility switch for transformers processors.
# transformers processor์˜ fast/slow ์„ ํƒ์„ ์žฌํ˜„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ๊ณ ์ •.
self.use_fast = bool(use_fast)
# Runtime-only fields: must never be serialized.
# ๋Ÿฐํƒ€์ž„ ์ „์šฉ ํ•„๋“œ: ์ ˆ๋Œ€ ์ง๋ ฌํ™”๋˜๋ฉด ์•ˆ ๋จ.
self._meta = None
self._delegate = None
self._timm_transform = None
self._torchvision_transform = None
# Build runtime objects according to backbone type.
# backbone type์— ๋”ฐ๋ผ ๋Ÿฐํƒ€์ž„ ๊ฐ์ฒด๋ฅผ ๊ตฌ์„ฑ.
self._build_runtime()
# ============================================================
# Runtime builders
# ๋Ÿฐํƒ€์ž„ ๋นŒ๋”
# ============================================================
def _build_runtime(self):
"""
Build runtime delegate/transform based on BACKBONE_META["type"].
BACKBONE_META["type"]์— ๋”ฐ๋ผ ๋Ÿฐํƒ€์ž„ delegate/transform์„ ๊ตฌ์„ฑ.
"""
meta = BACKBONE_META[self.backbone_name_or_path]
self._meta = meta
# Always reset runtime fields before rebuilding.
# ์žฌ๊ตฌ์„ฑ ์ „ ๋Ÿฐํƒ€์ž„ ํ•„๋“œ๋Š” ํ•ญ์ƒ ์ดˆ๊ธฐํ™”.
self._delegate = None
self._timm_transform = None
self._torchvision_transform = None
t = meta["type"]
if t == "timm_densenet":
# timm DenseNet uses timm.data transforms for ImageNet-style preprocessing.
# timm DenseNet์€ ImageNet ์ „์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด timm.data transform์„ ์‚ฌ์šฉ.
self._timm_transform = self._build_timm_transform(
backbone_id=self.backbone_name_or_path,
is_training=self.is_training,
)
return
if t == "torchvision_densenet":
# torchvision DenseNet requires torchvision-style preprocessing (resize/crop/tensor/normalize).
# torchvision DenseNet์€ torchvision ์Šคํƒ€์ผ ์ „์ฒ˜๋ฆฌ(resize/crop/tensor/normalize)๊ฐ€ ํ•„์š”.
self._torchvision_transform = self._build_torchvision_densenet_transform(
is_training=self.is_training
)
return
# Default: transformers backbone delegates to its official AutoImageProcessor.
# ๊ธฐ๋ณธ: transformers ๋ฐฑ๋ณธ์€ ๊ณต์‹ AutoImageProcessor์— ์œ„์ž„.
#
# IMPORTANT:
# - use_fast๋Š” transformers ๊ธฐ๋ณธ๊ฐ’ ๋ณ€๊ฒฝ์— ํ”๋“ค๋ฆฌ์ง€ ์•Š๋„๋ก ๋ฐ˜๋“œ์‹œ ๋ช…์‹œ์ ์œผ๋กœ ์ „๋‹ฌ.
self._delegate = AutoImageProcessor.from_pretrained(
self.backbone_name_or_path,
use_fast=self.use_fast,
# trust_remote_code = True,
)
@staticmethod
def _build_timm_transform(*, backbone_id: str, is_training: bool):
"""
Create timm transform without storing non-serializable objects in config.
๋น„์ง๋ ฌํ™” ๊ฐ์ฒด๋ฅผ config์— ์ €์žฅํ•˜์ง€ ์•Š๊ณ  timm transform์„ ์ƒ์„ฑ.
"""
try:
import timm
from timm.data import resolve_model_data_config, create_transform
except Exception as e:
raise ImportError(
"timm backbone processor requires `timm`. Install: pip install timm"
) from e
# We only need model metadata to resolve data config, so pretrained=False is preferred.
# data config ์ถ”์ถœ๋งŒ ํ•„์š”ํ•˜๋ฏ€๋กœ pretrained=False๋ฅผ ์šฐ์„  ์‚ฌ์šฉ.
m = timm.create_model(f"hf_hub:{backbone_id}", pretrained=False, num_classes=0)
dc = resolve_model_data_config(m)
# create_transform returns a torchvision-like callable that maps PIL -> torch.Tensor(C,H,W).
# create_transform์€ PIL -> torch.Tensor(C,H,W)๋กœ ๋งคํ•‘ํ•˜๋Š” callable์„ ๋ฐ˜ํ™˜.
tfm = create_transform(**dc, is_training=is_training) # is_training :Data Aug.
return tfm
@staticmethod
def _build_torchvision_densenet_transform(*, is_training: bool):
"""
Build torchvision preprocessing for DenseNet-121 (224 pipeline).
DenseNet-121์šฉ torchvision ์ „์ฒ˜๋ฆฌ(224 ํŒŒ์ดํ”„๋ผ์ธ)๋ฅผ ๊ตฌ์„ฑ.
"""
try:
from torchvision import transforms
except Exception as e:
raise ImportError(
"torchvision DenseNet processor requires `torchvision`. Install: pip install torchvision"
) from e
# These are the standard ImageNet normalization stats used by torchvision weights.
# ์ด ๊ฐ’๋“ค์€ torchvision weights๊ฐ€ ์‚ฌ์šฉํ•˜๋Š” ํ‘œ์ค€ ImageNet ์ •๊ทœํ™” ํ†ต๊ณ„.
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
# Training pipeline typically uses RandomResizedCrop and horizontal flip.
# ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ์€ ๋ณดํ†ต RandomResizedCrop๊ณผ ์ขŒ์šฐ๋ฐ˜์ „์„ ์‚ฌ์šฉ.
if is_training:
return transforms.Compose(
[
# transforms.RandomResizedCrop(224),
# transforms.RandomHorizontalFlip(p=0.5),
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)
# Inference pipeline typically uses Resize(256) + CenterCrop(224).
# ์ถ”๋ก  ํŒŒ์ดํ”„๋ผ์ธ์€ ๋ณดํ†ต Resize(256) + CenterCrop(224)๋ฅผ ์‚ฌ์šฉ.
return transforms.Compose(
[
transforms.Resize(256),
# transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)
# ============================================================
# Serialization
# ์ง๋ ฌํ™”
# ============================================================
def to_dict(self) -> dict[str, Any]:
"""
Return a JSON-serializable dict for preprocessor_config.json.
preprocessor_config.json์— ๋“ค์–ด๊ฐˆ JSON ์ง๋ ฌํ™” dict๋ฅผ ๋ฐ˜ํ™˜.
Important: do not leak runtime objects into the serialized dict.
์ค‘์š”: ๋Ÿฐํƒ€์ž„ ๊ฐ์ฒด๊ฐ€ ์ง๋ ฌํ™” dict์— ์„ž์ด๋ฉด ์•ˆ ๋จ.
"""
# ImageProcessingMixin.to_dict() adds metadata such as image_processor_type/auto_map.
# ImageProcessingMixin.to_dict()๋Š” image_processor_type/auto_map ๊ฐ™์€ ๋ฉ”ํƒ€๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
d = super().to_dict()
# Force minimal stable fields for long-term compatibility.
# ์žฅ๊ธฐ ํ˜ธํ™˜์„ ์œ„ํ•ด ์ตœ์†Œ ์•ˆ์ • ํ•„๋“œ๋ฅผ ๊ฐ•์ œ๋กœ ์ง€์ •.
d["image_processor_type"] = self.__class__.__name__
d["backbone_name_or_path"] = self.backbone_name_or_path
d["is_training"] = self.is_training
d["use_fast"] = self.use_fast
# Remove any runtime-only fields defensively.
# ๋Ÿฐํƒ€์ž„ ์ „์šฉ ํ•„๋“œ๋Š” ๋ณด์ˆ˜์ ์œผ๋กœ ์ œ๊ฑฐ.
for key in ["_meta", "_delegate", "_timm_transform", "_torchvision_transform"]:
d.pop(key, None)
return d
@classmethod
def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs):
"""
Standard load path used by BaseImageProcessor / AutoImageProcessor.
BaseImageProcessor / AutoImageProcessor๊ฐ€ ์‚ฌ์šฉํ•˜๋Š” ํ‘œ์ค€ ๋กœ๋“œ ๊ฒฝ๋กœ์ž„.
"""
backbone = image_processor_dict.get("backbone_name_or_path", None)
if backbone is None:
raise ValueError("preprocessor_config.json missing key: backbone_name_or_path")
is_training = bool(image_processor_dict.get("is_training", False))
use_fast = bool(image_processor_dict.get("use_fast", False))
return cls(
backbone_name_or_path=backbone,
is_training=is_training,
use_fast=use_fast,
**kwargs,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
"""
Fallback path if AutoImageProcessor calls class.from_pretrained directly.
AutoImageProcessor๊ฐ€ class.from_pretrained๋ฅผ ์ง์ ‘ ํ˜ธ์ถœํ•˜๋Š” ๊ฒฝ์šฐ๋ฅผ ๋Œ€๋น„ํ•œ ๋ฉ”์„œ๋“œ.
Strategy:
์ „๋žต:
- Read config.json via AutoConfig and recover backbone_name_or_path.
AutoConfig๋กœ config.json์„ ์ฝ๊ณ  backbone_name_or_path๋ฅผ ๋ณต๊ตฌ.
"""
# is_training is runtime-only and should default to False for inference/serving.
# is_training์€ ๋Ÿฐํƒ€์ž„ ์ „์šฉ์ด๋ฉฐ ์ถ”๋ก /์„œ๋น™ ๊ธฐ๋ณธ๊ฐ’์€ False ์ž„.
#
# IMPORTANT:
# - use_fast๋Š” kwargs๋กœ ์ „๋‹ฌ๋  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ์žˆ์œผ๋ฉด ๋ฐ˜์˜.
use_fast = bool(kwargs.pop("use_fast", False))
kwargs.pop("trust_remote_code", None)
cfg = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code =True,
**kwargs)
backbone = getattr(cfg, "backbone_name_or_path", None)
if backbone is None:
raise ValueError("Cannot build processor: backbone_name_or_path not found in config.json")
return cls(backbone_name_or_path=backbone, is_training=False, use_fast=use_fast)
# ============================================================
# Call interface
# ํ˜ธ์ถœ ์ธํ„ฐํŽ˜์ด์Šค
# ============================================================
@staticmethod
def _ensure_list(images: Any) -> list[Any]:
# Normalize scalar image input to a list for uniform processing.
# ๋‹จ์ผ ์ž…๋ ฅ์„ ๋ฆฌ์ŠคํŠธ๋กœ ์ •๊ทœํ™”ํ•˜์—ฌ ๋™์ผํ•œ ์ฒ˜๋ฆฌ ๊ฒฝ๋กœ๋ฅผ ์‚ฌ์šฉ.
if isinstance(images, (list, tuple)):
return list(images)
return [images]
@staticmethod
def _to_pil_rgb(x: Any):
# Convert common image inputs into PIL RGB images.
# ์ผ๋ฐ˜์ ์ธ ์ž…๋ ฅ์„ PIL RGB ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜.
from PIL import Image as PILImage
if isinstance(x, PILImage.Image):
return x.convert("RGB")
if isinstance(x, np.ndarray) and x.ndim == 3:
return PILImage.fromarray(x).convert("RGB")
raise TypeError(f"Unsupported image type: {type(x)}")
def __call__(
self,
images: Any | list[Any],
return_tensors: str | TensorType | None = "pt",
**kwargs,
) -> dict[str, Any]:
"""
Convert images into {"pixel_values": Tensor/ndarray}.
์ด๋ฏธ์ง€๋ฅผ {"pixel_values": Tensor/ndarray}๋กœ ๋ณ€ํ™˜.
"""
images = self._ensure_list(images)
# Rebuild runtime if needed (e.g., right after deserialization).
# ์ง๋ ฌํ™” ๋ณต์› ์งํ›„ ๋“ฑ ๋Ÿฐํƒ€์ž„์ด ๋น„์–ด์žˆ์„ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ์žฌ๊ตฌ์„ฑ.
if (self._delegate is None) and (self._timm_transform is None) and (self._torchvision_transform is None):
self._build_runtime()
# timm path: PIL -> torch.Tensor(C,H,W) normalized float32.
# timm ๊ฒฝ๋กœ: PIL -> torch.Tensor(C,H,W) ์ •๊ทœํ™” float32.
if self._timm_transform is not None:
pv: list[torch.Tensor] = []
for im in images:
pil = self._to_pil_rgb(im)
t = self._timm_transform(pil)
if not isinstance(t, torch.Tensor):
raise RuntimeError("Unexpected timm transform output (expected torch.Tensor).")
pv.append(t)
pixel_values = torch.stack(pv, dim=0) # (B,C,H,W)
return self._format_return(pixel_values, return_tensors)
# torchvision path: PIL -> torch.Tensor(C,H,W) normalized float32.
# torchvision ๊ฒฝ๋กœ: PIL -> torch.Tensor(C,H,W) ์ •๊ทœํ™” float32.
if self._torchvision_transform is not None:
pv: list[torch.Tensor] = []
for im in images:
pil = self._to_pil_rgb(im)
t = self._torchvision_transform(pil)
if not isinstance(t, torch.Tensor):
raise RuntimeError("Unexpected torchvision transform output (expected torch.Tensor).")
pv.append(t)
pixel_values = torch.stack(pv, dim=0) # (B,C,H,W)
return self._format_return(pixel_values, return_tensors)
# transformers delegate path: rely on official processor behavior.
# transformers ์œ„์ž„ ๊ฒฝ๋กœ: ๊ณต์‹ processor ๋™์ž‘์„ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ.
if self._delegate is None:
raise RuntimeError("Processor runtime not built: delegate is None and no transforms are available.")
return self._delegate(images, return_tensors=return_tensors, **kwargs)
@staticmethod
def _format_return(pixel_values: torch.Tensor, return_tensors: str | TensorType | None) -> dict[str, Any]:
"""
Format pixel_values according to return_tensors.
return_tensors์— ๋งž์ถฐ pixel_values ๋ฐ˜ํ™˜ ํฌ๋งท์„ ๋ณ€ํ™˜.
"""
if return_tensors is None or return_tensors in ("pt", TensorType.PYTORCH):
return {"pixel_values": pixel_values}
if return_tensors in ("np", TensorType.NUMPY):
return {"pixel_values": pixel_values.detach().cpu().numpy()}
raise ValueError(f"Unsupported return_tensors={return_tensors}. Use 'pt' or 'np'.")
# Register this processor for AutoImageProcessor resolution.
# AutoImageProcessor ํ•ด์„์„ ์œ„ํ•ด ์ด processor๋ฅผ ๋“ฑ๋ก.
if __name__ != "__main__":
BackboneMLPHead224ImageProcessor.register_for_auto_class("AutoImageProcessor")