dsaint31's picture
Add/Update backbone checkpoints (count=6)
e115a15 verified
from transformers import PretrainedConfig
from typing import Literal, Any
__all__ = [
"BackboneID",
"BACKBONE_META",
"BackboneMLPHeadConfig",
]
# ============================================================
# Backbone whitelist + meta registry
# ============================================================
BackboneID = Literal[
"google/vit-base-patch16-224",
"microsoft/swin-tiny-patch4-window7-224",
"microsoft/resnet-50",
"google/efficientnet-b0",
"timm/densenet121.tv_in1k",
"torchvision/densenet121",
]
# ============================================================
# 2) Backbone metadata registry (Feature dim/rule/unfreeze rule)
# 2) ๋ฐฑ๋ณธ ๋ฉ”ํƒ€ ๋ ˆ์ง€์ŠคํŠธ๋ฆฌ (feature dim/rule/unfreeze rule ๊ณ ์ •)
# ============================================================
# This table is the single source of truth for feature extraction and fine-tuning rules per backbone.
# ์ด ํ…Œ์ด๋ธ”์€ backbone๋ณ„ feature ์ถ”์ถœ ๋ฐ ๋ฏธ์„ธ์กฐ์ • ๊ทœ์น™์˜ ๋‹จ์ผ ๊ธฐ์ค€(source of truth)์ž…๋‹ˆ๋‹ค.
#
# The key type is BackboneID to ensure meta keys never drift from the whitelist.
# ํ‚ค ํƒ€์ž…์„ BackboneID๋กœ ๊ณ ์ •ํ•˜์—ฌ ๋ฉ”ํƒ€ ํ‚ค๊ฐ€ ํ™”์ดํŠธ๋ฆฌ์ŠคํŠธ์™€ ์–ด๊ธ‹๋‚˜์ง€ ์•Š๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.
BACKBONE_META: dict[BackboneID, dict[str, Any]] = {
# -------------------------
# Transformers (ViT/Swin)
# -------------------------
# These backbones come from transformers and typically output hidden states and/or pooler outputs.
# ์ด ๋ฐฑ๋ณธ๋“ค์€ transformers ๊ณ„์—ด์ด๋ฉฐ hidden states์™€ pooler ์ถœ๋ ฅ ๋“ฑ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
"google/vit-base-patch16-224": {
# type indicates which loading/forward/extraction pathway the model code should use.
# type์€ ๋ชจ๋ธ ์ฝ”๋“œ๊ฐ€ ์–ด๋–ค ๋กœ๋”ฉ/forward/feature ์ถ”์ถœ ๊ฒฝ๋กœ๋ฅผ ์‚ฌ์šฉํ• ์ง€ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
"type": "vit",
# feat_dim is the feature vector dimension consumed by the MLP head.
# feat_dim์€ MLP head๊ฐ€ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›๋Š” feature ๋ฒกํ„ฐ ์ฐจ์›์ž…๋‹ˆ๋‹ค.
"feat_dim": 768,
# feat_rule defines how to get a (B, feat_dim) tensor from backbone outputs.
# feat_rule์€ backbone ์ถœ๋ ฅ์—์„œ (B, feat_dim) ํ…์„œ๋ฅผ ์–ป๋Š” ๊ทœ์น™์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
"feat_rule": "cls", # Use last_hidden_state[:, 0, :] as CLS token embedding.
# last_hidden_state[:, 0, :]๋ฅผ CLS ํ† ํฐ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
# unfreeze defines the policy to unfreeze layers during stage2 fine-tuning.
# unfreeze๋Š” stage2 ๋ฏธ์„ธ์กฐ์ •์—์„œ ์–ด๋–ค ๋ ˆ์ด์–ด๋ฅผ ํ’€์ง€ ์ •์ฑ…์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
"unfreeze": "last_n", # Unfreeze the last n encoder blocks.
# encoder ๋ธ”๋ก์˜ ๋งˆ์ง€๋ง‰ n๊ฐœ๋ฅผ unfreeze ํ•ฉ๋‹ˆ๋‹ค.
# has_bn indicates whether BatchNorm exists and should be handled carefully when freezing.
# has_bn์€ BatchNorm ์กด์žฌ ์—ฌ๋ถ€์ด๋ฉฐ freeze ์‹œ ํŠน๋ณ„ ์ทจ๊ธ‰์ด ํ•„์š”ํ•œ์ง€ ํŒ๋‹จ์— ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
"has_bn": False,
},
"microsoft/swin-tiny-patch4-window7-224": {
# This backbone is a Swin Transformer, which may or may not provide a pooler output depending on implementation.
# ์ด ๋ฐฑ๋ณธ์€ Swin Transformer์ด๋ฉฐ ๊ตฌํ˜„์— ๋”ฐ๋ผ pooler output ์ œ๊ณต ์—ฌ๋ถ€๊ฐ€ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
"type": "swin",
"feat_dim": 768,
# Prefer pooler output if available, otherwise fall back to mean pooling.
# pooler๊ฐ€ ์žˆ์œผ๋ฉด ์šฐ์„  ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด mean pooling์œผ๋กœ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค.
"feat_rule": "pool_or_mean",
# Unfreeze strategy is aligned with transformer-style encoder blocks.
# unfreeze ์ „๋žต์€ transformer ๊ณ„์—ด encoder ๋ธ”๋ก ๊ธฐ์ค€์œผ๋กœ ๋งž์ถฅ๋‹ˆ๋‹ค.
"unfreeze": "last_n",
"has_bn": False,
},
# -------------------------
# Transformers (CNNs)
# -------------------------
# These backbones are CNNs exposed via transformers, usually producing pooled feature vectors or feature maps.
# ์ด ๋ฐฑ๋ณธ๋“ค์€ transformers๋กœ ๋…ธ์ถœ๋œ CNN์ด๋ฉฐ pooled feature ๋˜๋Š” feature map์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
"microsoft/resnet-50": {
# This entry assumes a transformers-compatible ResNet that can expose pooler or a final feature map.
# ์ด ํ•ญ๋ชฉ์€ transformers ํ˜ธํ™˜ ResNet์ด pooler ๋˜๋Š” ์ตœ์ข… feature map์„ ์ œ๊ณตํ•  ์ˆ˜ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
"type": "resnet",
"feat_dim": 2048,
# Use pooler output if the model provides it, otherwise apply global average pooling (GAP).
# pooler๊ฐ€ ์žˆ์œผ๋ฉด ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด global average pooling(GAP)์„ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
"feat_rule": "pool_or_gap",
# CNN unfreeze policy can still be expressed as "last_n" at a block/stage granularity in your model code.
# CNN๋„ ๋ชจ๋ธ ์ฝ”๋“œ์—์„œ block/stage ๋‹จ์œ„๋กœ last_n ์ •์ฑ…์„ ์ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
"unfreeze": "last_n",
"has_bn": True,
},
"google/efficientnet-b0": {
# This entry assumes a transformers-compatible EfficientNet that exposes pooled features or a final feature map.
# ์ด ํ•ญ๋ชฉ์€ transformers ํ˜ธํ™˜ EfficientNet์ด pooled feature ๋˜๋Š” ์ตœ์ข… feature map์„ ์ œ๊ณตํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
"type": "efficientnet",
"feat_dim": 1280,
"feat_rule": "pool_or_gap",
"unfreeze": "last_n",
"has_bn": True,
},
# -------------------------
# timm (DenseNet via HF Hub)
# -------------------------
# This backbone is loaded via timm using the "hf_hub:" prefix in your model loader.
# ์ด ๋ฐฑ๋ณธ์€ ๋ชจ๋ธ ๋กœ๋”์—์„œ timm์˜ "hf_hub:" ํ”„๋ฆฌํ”ฝ์Šค๋ฅผ ์‚ฌ์šฉํ•ด ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
"timm/densenet121.tv_in1k": {
"type": "timm_densenet",
# DenseNet-121 final channel dimension is 1024 for the canonical architecture.
# DenseNet-121์˜ ํ‘œ์ค€ ์•„ํ‚คํ…์ฒ˜์—์„œ ์ตœ์ข… ์ฑ„๋„ ์ฐจ์›์€ 1024์ž…๋‹ˆ๋‹ค.
"feat_dim": 1024,
# timm forward_features typically returns a feature map that you then GAP to (B, C).
# timm์˜ forward_features๋Š” ๋ณดํ†ต feature map์„ ๋ฐ˜ํ™˜ํ•˜๊ณ  ์ดํ›„ GAP์œผ๋กœ (B, C)๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
"feat_rule": "timm_gap",
# DenseNet uses BatchNorm heavily, so freeze_bn behavior matters for stage1/stage2.
# DenseNet์€ BatchNorm ์‚ฌ์šฉ์ด ๋งŽ์•„ stage1/stage2์—์„œ freeze_bn ์ฒ˜๋ฆฌ๊ฐ€ ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.
"unfreeze": "last_n",
"has_bn": True,
},
# -------------------------
# torchvision (DenseNet direct)
# -------------------------
# This backbone is intended for torchvision-style loading and feature extraction, not transformers/timm.
# ์ด ๋ฐฑ๋ณธ์€ transformers/timm์ด ์•„๋‹ˆ๋ผ torchvision ์Šคํƒ€์ผ ๋กœ๋”ฉ ๋ฐ feature ์ถ”์ถœ์„ ๋Œ€์ƒ์œผ๋กœ ํ•ฉ๋‹ˆ๋‹ค.
"torchvision/densenet121": {
"type": "torchvision_densenet",
"feat_dim": 1024,
# torchvision DenseNet usually exposes .features and you apply GAP to obtain (B, C).
# torchvision DenseNet์€ ๋ณดํ†ต .features๋ฅผ ๋…ธ์ถœํ•˜๋ฉฐ GAP์œผ๋กœ (B, C)๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค.
"feat_rule": "torchvision_densenet_gap",
# Unfreeze policy remains last_n, but the interpretation must match torchvision module naming.
# unfreeze ์ •์ฑ…์€ last_n์„ ์œ ์ง€ํ•˜๋˜, ํ•ด์„์€ torchvision ๋ชจ๋“ˆ ๋„ค์ด๋ฐ์— ๋งž์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.
"unfreeze": "last_n",
"has_bn": True,
},
}
class BackboneMLPHeadConfig(PretrainedConfig):
"""
Configuration for Backbone + MLP Head models.
Backbone + MLP Head ๋ชจ๋ธ์„ ์œ„ํ•œ ์„ค์ • ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค.
"""
# This string is used by Hugging Face AutoConfig to identify the config class.
# Hugging Face AutoConfig๊ฐ€ ์ด config๋ฅผ ์‹๋ณ„ํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉํ•˜๋Š” ๊ณ ์œ  ID์ž…๋‹ˆ๋‹ค.
model_type = "backbone-mlphead-224-fixed"
def __init__(
self,
backbone_name_or_path: BackboneID | None = None,
mlp_head_bottleneck: int = 256,
mlp_head_dropout: float = 0.2,
label2id: dict[str, int] | None = None,
id2label: dict[int, str] | None = None,
**kwargs,
):
# ============================================================
# 0) Guard for argument-less construction
# 0) ๋ฌด์ธ์ž ์ƒ์„ฑ ๊ฒฝ๋กœ ๋ฐฉ์–ด
# ============================================================
# Transformers may internally construct this config without arguments
# (e.g., during AutoConfig resolution or Hub loading).
# Transformers ๋‚ด๋ถ€์—์„œ AutoConfig ๋˜๋Š” Hub ๋กœ๋“œ ๊ณผ์ • ์ค‘
# ์ธ์ž ์—†์ด config๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒฝ๋กœ๊ฐ€ ์‹ค์ œ๋กœ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค.
#
# In this case, we must NOT validate or raise errors.
# ์ด ๊ฒฝ์šฐ ๊ฒ€์ฆ์ด๋‚˜ ์˜ˆ์™ธ๋ฅผ ๋ฐœ์ƒ์‹œํ‚ค๋ฉด ์•ˆ ๋ฉ๋‹ˆ๋‹ค.
#
# The goal here is to provide a minimal, serialization-safe config.
# ๋ชฉํ‘œ๋Š” ์ตœ์†Œํ•œ์˜ ๊ฐ’๋งŒ ์ฑ„์›Œ ์ง๋ ฌํ™”/์—ญ์ง๋ ฌํ™”๊ฐ€ ๊นจ์ง€์ง€ ์•Š๊ฒŒ ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
if backbone_name_or_path is None:
# num_labels may be implicitly assumed by downstream code,
# so we explicitly set a safe default.
# num_labels๋Š” downstream ์ฝ”๋“œ์—์„œ ์•”๋ฌต์ ์œผ๋กœ ๊ฐ€์ •๋˜๋ฏ€๋กœ
# ์•ˆ์ „ํ•œ ๊ธฐ๋ณธ๊ฐ’์„ ๋ช…์‹œ์ ์œผ๋กœ ๋„ฃ์–ด์ค๋‹ˆ๋‹ค.
if "num_labels" not in kwargs:
kwargs["num_labels"] = 0
super().__init__(**kwargs)
# Backbone is intentionally unset in this path.
# ์ด ๊ฒฝ๋กœ์—์„œ๋Š” backbone์„ ์˜๋„์ ์œผ๋กœ ์„ค์ •ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
self.backbone_name_or_path = None
# Store MLP head hyperparameters for completeness.
# MLP head ๊ด€๋ จ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋Š” ํ˜•ํƒœ ์œ ์ง€๋ฅผ ์œ„ํ•ด ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
self.mlp_head_bottleneck = int(mlp_head_bottleneck)
self.mlp_head_dropout = float(mlp_head_dropout)
# Empty label mappings ensure stable save/load behavior.
# ๋นˆ label ๋งคํ•‘์€ ์ €์žฅ/๋กœ๋“œ ์•ˆ์ •์„ฑ์„ ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค.
self.label2id = {}
self.id2label = {}
self.num_labels = int(kwargs.get("num_labels", 0))
return
# ============================================================
# 1) Backbone whitelist validation
# 1) Backbone ํ™”์ดํŠธ๋ฆฌ์ŠคํŠธ ๊ฒ€์ฆ
# ============================================================
# Only backbones explicitly registered in BACKBONE_META are allowed.
# BACKBONE_META์— ๋“ฑ๋ก๋œ backbone๋งŒ ํ—ˆ์šฉํ•ฉ๋‹ˆ๋‹ค.
#
# This prevents accidental usage of unsupported or inconsistent backbones.
# ์ง€์›๋˜์ง€ ์•Š๊ฑฐ๋‚˜ ๋ถˆ์ผ์น˜ํ•œ backbone ์‚ฌ์šฉ์„ ์›์ฒœ ์ฐจ๋‹จํ•ฉ๋‹ˆ๋‹ค.
if backbone_name_or_path not in BACKBONE_META:
raise ValueError(
f"Unsupported backbone_name_or_path={backbone_name_or_path}. "
f"Allowed: {sorted(BACKBONE_META.keys())}"
)
# ============================================================
# 2) Label mapping normalization
# 2) ๋ผ๋ฒจ ๋งคํ•‘ ์ •๊ทœํ™”
# ============================================================
# Both label2id and id2label may be None during pure loading scenarios.
# ๋‹จ์ˆœ ๋กœ๋“œ(from_pretrained) ์‹œ label2id/id2label์ด None์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
#
# We allow this here to keep Hub loading and verification stable.
# Hub ๋กœ๋“œ ๋ฐ ๊ฒ€์ฆ ์•ˆ์ •์„ฑ์„ ์œ„ํ•ด ์—ฌ๊ธฐ์„œ๋Š” ์ด๋ฅผ ํ—ˆ์šฉํ•ฉ๋‹ˆ๋‹ค.
#
# Fail-fast validation should happen at model or training level instead.
# ์‹ค์ œ ์‚ฌ์šฉ ๊ฒ€์ฆ์€ ๋ชจ๋ธ ๋˜๋Š” ํ•™์Šต ๋‹จ๊ณ„์—์„œ fail-fast๋กœ ์ฒ˜๋ฆฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
if label2id is None and id2label is None:
# Respect num_labels if explicitly provided, otherwise default to 0.
# num_labels๊ฐ€ ์ฃผ์–ด์กŒ๋‹ค๋ฉด ์กด์ค‘ํ•˜๊ณ , ์•„๋‹ˆ๋ฉด 0์œผ๋กœ ๋‘ก๋‹ˆ๋‹ค.
num_labels = int(kwargs.get("num_labels", 0))
label2id_norm: dict[str, int] = {}
id2label_norm: dict[int, str] = {}
else:
# If only one mapping is provided, derive the other.
# ํ•˜๋‚˜๋งŒ ์ฃผ์–ด์ง„ ๊ฒฝ์šฐ ๋‚˜๋จธ์ง€๋ฅผ ์ž๋™์œผ๋กœ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
if id2label is None:
id2label = {v: k for k, v in label2id.items()}
if label2id is None:
label2id = {v: k for k, v in id2label.items()}
# Ensure both mappings are consistent in size.
# ๋‘ ๋งคํ•‘์˜ ํฌ๊ธฐ๊ฐ€ ์ผ์น˜ํ•˜๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
if len(label2id) != len(id2label):
raise ValueError(
f"label2id/id2label size mismatch: "
f"{len(label2id)} vs {len(id2label)}"
)
num_labels = len(id2label)
label2id_norm = dict(label2id)
id2label_norm = dict(id2label)
# ============================================================
# 3) num_labels consistency enforcement
# 3) num_labels ์ผ๊ด€์„ฑ ๊ฐ•์ œ
# ============================================================
# If num_labels is provided via kwargs, it must match inferred labels.
# kwargs๋กœ num_labels๊ฐ€ ๋“ค์–ด์˜จ ๊ฒฝ์šฐ ์ถ”๋ก ๋œ ๊ฐ’๊ณผ ๋ฐ˜๋“œ์‹œ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
if "num_labels" in kwargs:
if (label2id is not None or id2label is not None) and int(kwargs["num_labels"]) != num_labels:
raise ValueError(
f"kwargs['num_labels']={kwargs['num_labels']} "
f"!= inferred num_labels={num_labels}"
)
else:
kwargs["num_labels"] = num_labels
# ============================================================
# 4) Parent initialization
# 4) ๋ถ€๋ชจ ํด๋ž˜์Šค ์ดˆ๊ธฐํ™”
# ============================================================
# Initialize PretrainedConfig with normalized label mappings.
# ์ •๊ทœํ™”๋œ ๋ผ๋ฒจ ๋งคํ•‘์„ ์‚ฌ์šฉํ•˜์—ฌ PretrainedConfig๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
super().__init__(
label2id=label2id_norm,
id2label=id2label_norm,
**kwargs,
)
# ============================================================
# 5) Explicit attribute assignment for save/load stability
# 5) ์ €์žฅ/๋กœ๋“œ ์•ˆ์ •์„ฑ์„ ์œ„ํ•œ ๋ช…์‹œ์  ์†์„ฑ ๊ณ ์ •
# ============================================================
# Explicitly reassign critical fields to avoid subtle serialization issues.
# ๋ฏธ๋ฌ˜ํ•œ ์ง๋ ฌํ™” ๋ฌธ์ œ๋ฅผ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ํ•ต์‹ฌ ํ•„๋“œ๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ๋‹ค์‹œ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
self.backbone_name_or_path = backbone_name_or_path
self.mlp_head_bottleneck = int(mlp_head_bottleneck)
self.mlp_head_dropout = float(mlp_head_dropout)
self.label2id = label2id_norm
self.id2label = id2label_norm
self.num_labels = int(kwargs["num_labels"])
def to_dict(self):
# Call the parent implementation first.
# ๋จผ์ € ๋ถ€๋ชจ ๊ตฌํ˜„์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค.
output = super().to_dict()
# Force num_labels to be present and consistent.
# num_labels๊ฐ€ ๋ฐ˜๋“œ์‹œ ์กด์žฌํ•˜๊ณ  ์ผ๊ด€๋˜๋„๋ก ๊ฐ•์ œํ•ฉ๋‹ˆ๋‹ค.
output["num_labels"] = int(
getattr(self, "num_labels", output.get("num_labels", 0))
)
return output
# Register this config so it can be resolved via AutoConfig.
# AutoConfig๋ฅผ ํ†ตํ•ด ์ด config๊ฐ€ ํ•ด์„๋  ์ˆ˜ ์žˆ๋„๋ก ๋“ฑ๋กํ•ฉ๋‹ˆ๋‹ค.
BackboneMLPHeadConfig.register_for_auto_class("AutoConfig")