| from transformers import PretrainedConfig | |
| from typing import Literal, Any | |
| __all__ = [ | |
| "BackboneID", | |
| "BACKBONE_META", | |
| "BackboneMLPHeadConfig", | |
| ] | |
| # ============================================================ | |
| # Backbone whitelist + meta registry | |
| # ============================================================ | |
| BackboneID = Literal[ | |
| "google/vit-base-patch16-224", | |
| "microsoft/swin-tiny-patch4-window7-224", | |
| "microsoft/resnet-50", | |
| "google/efficientnet-b0", | |
| "timm/densenet121.tv_in1k", | |
| "torchvision/densenet121", | |
| ] | |
| # ============================================================ | |
| # 2) Backbone metadata registry (Feature dim/rule/unfreeze rule) | |
| # 2) ๋ฐฑ๋ณธ ๋ฉํ ๋ ์ง์คํธ๋ฆฌ (feature dim/rule/unfreeze rule ๊ณ ์ ) | |
| # ============================================================ | |
| # This table is the single source of truth for feature extraction and fine-tuning rules per backbone. | |
| # ์ด ํ ์ด๋ธ์ backbone๋ณ feature ์ถ์ถ ๋ฐ ๋ฏธ์ธ์กฐ์ ๊ท์น์ ๋จ์ผ ๊ธฐ์ค(source of truth)์ ๋๋ค. | |
| # | |
| # The key type is BackboneID to ensure meta keys never drift from the whitelist. | |
| # ํค ํ์ ์ BackboneID๋ก ๊ณ ์ ํ์ฌ ๋ฉํ ํค๊ฐ ํ์ดํธ๋ฆฌ์คํธ์ ์ด๊ธ๋์ง ์๊ฒ ํฉ๋๋ค. | |
| BACKBONE_META: dict[BackboneID, dict[str, Any]] = { | |
| # ------------------------- | |
| # Transformers (ViT/Swin) | |
| # ------------------------- | |
| # These backbones come from transformers and typically output hidden states and/or pooler outputs. | |
| # ์ด ๋ฐฑ๋ณธ๋ค์ transformers ๊ณ์ด์ด๋ฉฐ hidden states์ pooler ์ถ๋ ฅ ๋ฑ์ ์ ๊ณตํฉ๋๋ค. | |
| "google/vit-base-patch16-224": { | |
| # type indicates which loading/forward/extraction pathway the model code should use. | |
| # type์ ๋ชจ๋ธ ์ฝ๋๊ฐ ์ด๋ค ๋ก๋ฉ/forward/feature ์ถ์ถ ๊ฒฝ๋ก๋ฅผ ์ฌ์ฉํ ์ง ๊ฒฐ์ ํฉ๋๋ค. | |
| "type": "vit", | |
| # feat_dim is the feature vector dimension consumed by the MLP head. | |
| # feat_dim์ MLP head๊ฐ ์ ๋ ฅ์ผ๋ก ๋ฐ๋ feature ๋ฒกํฐ ์ฐจ์์ ๋๋ค. | |
| "feat_dim": 768, | |
| # feat_rule defines how to get a (B, feat_dim) tensor from backbone outputs. | |
| # feat_rule์ backbone ์ถ๋ ฅ์์ (B, feat_dim) ํ ์๋ฅผ ์ป๋ ๊ท์น์ ์ ์ํฉ๋๋ค. | |
| "feat_rule": "cls", # Use last_hidden_state[:, 0, :] as CLS token embedding. | |
| # last_hidden_state[:, 0, :]๋ฅผ CLS ํ ํฐ ์๋ฒ ๋ฉ์ผ๋ก ์ฌ์ฉํฉ๋๋ค. | |
| # unfreeze defines the policy to unfreeze layers during stage2 fine-tuning. | |
| # unfreeze๋ stage2 ๋ฏธ์ธ์กฐ์ ์์ ์ด๋ค ๋ ์ด์ด๋ฅผ ํ์ง ์ ์ฑ ์ ์ ์ํฉ๋๋ค. | |
| "unfreeze": "last_n", # Unfreeze the last n encoder blocks. | |
| # encoder ๋ธ๋ก์ ๋ง์ง๋ง n๊ฐ๋ฅผ unfreeze ํฉ๋๋ค. | |
| # has_bn indicates whether BatchNorm exists and should be handled carefully when freezing. | |
| # has_bn์ BatchNorm ์กด์ฌ ์ฌ๋ถ์ด๋ฉฐ freeze ์ ํน๋ณ ์ทจ๊ธ์ด ํ์ํ์ง ํ๋จ์ ์ฌ์ฉํฉ๋๋ค. | |
| "has_bn": False, | |
| }, | |
| "microsoft/swin-tiny-patch4-window7-224": { | |
| # This backbone is a Swin Transformer, which may or may not provide a pooler output depending on implementation. | |
| # ์ด ๋ฐฑ๋ณธ์ Swin Transformer์ด๋ฉฐ ๊ตฌํ์ ๋ฐ๋ผ pooler output ์ ๊ณต ์ฌ๋ถ๊ฐ ๋ฌ๋ผ์ง ์ ์์ต๋๋ค. | |
| "type": "swin", | |
| "feat_dim": 768, | |
| # Prefer pooler output if available, otherwise fall back to mean pooling. | |
| # pooler๊ฐ ์์ผ๋ฉด ์ฐ์ ์ฌ์ฉํ๊ณ , ์์ผ๋ฉด mean pooling์ผ๋ก ๋์ฒดํฉ๋๋ค. | |
| "feat_rule": "pool_or_mean", | |
| # Unfreeze strategy is aligned with transformer-style encoder blocks. | |
| # unfreeze ์ ๋ต์ transformer ๊ณ์ด encoder ๋ธ๋ก ๊ธฐ์ค์ผ๋ก ๋ง์ถฅ๋๋ค. | |
| "unfreeze": "last_n", | |
| "has_bn": False, | |
| }, | |
| # ------------------------- | |
| # Transformers (CNNs) | |
| # ------------------------- | |
| # These backbones are CNNs exposed via transformers, usually producing pooled feature vectors or feature maps. | |
| # ์ด ๋ฐฑ๋ณธ๋ค์ transformers๋ก ๋ ธ์ถ๋ CNN์ด๋ฉฐ pooled feature ๋๋ feature map์ ์ ๊ณตํฉ๋๋ค. | |
| "microsoft/resnet-50": { | |
| # This entry assumes a transformers-compatible ResNet that can expose pooler or a final feature map. | |
| # ์ด ํญ๋ชฉ์ transformers ํธํ ResNet์ด pooler ๋๋ ์ต์ข feature map์ ์ ๊ณตํ ์ ์๋ค๊ณ ๊ฐ์ ํฉ๋๋ค. | |
| "type": "resnet", | |
| "feat_dim": 2048, | |
| # Use pooler output if the model provides it, otherwise apply global average pooling (GAP). | |
| # pooler๊ฐ ์์ผ๋ฉด ์ฌ์ฉํ๊ณ , ์์ผ๋ฉด global average pooling(GAP)์ ์ ์ฉํฉ๋๋ค. | |
| "feat_rule": "pool_or_gap", | |
| # CNN unfreeze policy can still be expressed as "last_n" at a block/stage granularity in your model code. | |
| # CNN๋ ๋ชจ๋ธ ์ฝ๋์์ block/stage ๋จ์๋ก last_n ์ ์ฑ ์ ์ ์ฉํ ์ ์์ต๋๋ค. | |
| "unfreeze": "last_n", | |
| "has_bn": True, | |
| }, | |
| "google/efficientnet-b0": { | |
| # This entry assumes a transformers-compatible EfficientNet that exposes pooled features or a final feature map. | |
| # ์ด ํญ๋ชฉ์ transformers ํธํ EfficientNet์ด pooled feature ๋๋ ์ต์ข feature map์ ์ ๊ณตํ๋ค๊ณ ๊ฐ์ ํฉ๋๋ค. | |
| "type": "efficientnet", | |
| "feat_dim": 1280, | |
| "feat_rule": "pool_or_gap", | |
| "unfreeze": "last_n", | |
| "has_bn": True, | |
| }, | |
| # ------------------------- | |
| # timm (DenseNet via HF Hub) | |
| # ------------------------- | |
| # This backbone is loaded via timm using the "hf_hub:" prefix in your model loader. | |
| # ์ด ๋ฐฑ๋ณธ์ ๋ชจ๋ธ ๋ก๋์์ timm์ "hf_hub:" ํ๋ฆฌํฝ์ค๋ฅผ ์ฌ์ฉํด ๋ก๋ํฉ๋๋ค. | |
| "timm/densenet121.tv_in1k": { | |
| "type": "timm_densenet", | |
| # DenseNet-121 final channel dimension is 1024 for the canonical architecture. | |
| # DenseNet-121์ ํ์ค ์ํคํ ์ฒ์์ ์ต์ข ์ฑ๋ ์ฐจ์์ 1024์ ๋๋ค. | |
| "feat_dim": 1024, | |
| # timm forward_features typically returns a feature map that you then GAP to (B, C). | |
| # timm์ forward_features๋ ๋ณดํต feature map์ ๋ฐํํ๊ณ ์ดํ GAP์ผ๋ก (B, C)๋ฅผ ๋ง๋ญ๋๋ค. | |
| "feat_rule": "timm_gap", | |
| # DenseNet uses BatchNorm heavily, so freeze_bn behavior matters for stage1/stage2. | |
| # DenseNet์ BatchNorm ์ฌ์ฉ์ด ๋ง์ stage1/stage2์์ freeze_bn ์ฒ๋ฆฌ๊ฐ ์ค์ํฉ๋๋ค. | |
| "unfreeze": "last_n", | |
| "has_bn": True, | |
| }, | |
| # ------------------------- | |
| # torchvision (DenseNet direct) | |
| # ------------------------- | |
| # This backbone is intended for torchvision-style loading and feature extraction, not transformers/timm. | |
| # ์ด ๋ฐฑ๋ณธ์ transformers/timm์ด ์๋๋ผ torchvision ์คํ์ผ ๋ก๋ฉ ๋ฐ feature ์ถ์ถ์ ๋์์ผ๋ก ํฉ๋๋ค. | |
| "torchvision/densenet121": { | |
| "type": "torchvision_densenet", | |
| "feat_dim": 1024, | |
| # torchvision DenseNet usually exposes .features and you apply GAP to obtain (B, C). | |
| # torchvision DenseNet์ ๋ณดํต .features๋ฅผ ๋ ธ์ถํ๋ฉฐ GAP์ผ๋ก (B, C)๋ฅผ ์ป์ต๋๋ค. | |
| "feat_rule": "torchvision_densenet_gap", | |
| # Unfreeze policy remains last_n, but the interpretation must match torchvision module naming. | |
| # unfreeze ์ ์ฑ ์ last_n์ ์ ์งํ๋, ํด์์ torchvision ๋ชจ๋ ๋ค์ด๋ฐ์ ๋ง์์ผ ํฉ๋๋ค. | |
| "unfreeze": "last_n", | |
| "has_bn": True, | |
| }, | |
| } | |
| class BackboneMLPHeadConfig(PretrainedConfig): | |
| """ | |
| Configuration for Backbone + MLP Head models. | |
| Backbone + MLP Head ๋ชจ๋ธ์ ์ํ ์ค์ ํด๋์ค์ ๋๋ค. | |
| """ | |
| # This string is used by Hugging Face AutoConfig to identify the config class. | |
| # Hugging Face AutoConfig๊ฐ ์ด config๋ฅผ ์๋ณํ๊ธฐ ์ํด ์ฌ์ฉํ๋ ๊ณ ์ ID์ ๋๋ค. | |
| model_type = "backbone-mlphead-224-fixed" | |
| def __init__( | |
| self, | |
| backbone_name_or_path: BackboneID | None = None, | |
| mlp_head_bottleneck: int = 256, | |
| mlp_head_dropout: float = 0.2, | |
| label2id: dict[str, int] | None = None, | |
| id2label: dict[int, str] | None = None, | |
| **kwargs, | |
| ): | |
| # ============================================================ | |
| # 0) Guard for argument-less construction | |
| # 0) ๋ฌด์ธ์ ์์ฑ ๊ฒฝ๋ก ๋ฐฉ์ด | |
| # ============================================================ | |
| # Transformers may internally construct this config without arguments | |
| # (e.g., during AutoConfig resolution or Hub loading). | |
| # Transformers ๋ด๋ถ์์ AutoConfig ๋๋ Hub ๋ก๋ ๊ณผ์ ์ค | |
| # ์ธ์ ์์ด config๋ฅผ ์์ฑํ๋ ๊ฒฝ๋ก๊ฐ ์ค์ ๋ก ์กด์ฌํฉ๋๋ค. | |
| # | |
| # In this case, we must NOT validate or raise errors. | |
| # ์ด ๊ฒฝ์ฐ ๊ฒ์ฆ์ด๋ ์์ธ๋ฅผ ๋ฐ์์ํค๋ฉด ์ ๋ฉ๋๋ค. | |
| # | |
| # The goal here is to provide a minimal, serialization-safe config. | |
| # ๋ชฉํ๋ ์ต์ํ์ ๊ฐ๋ง ์ฑ์ ์ง๋ ฌํ/์ญ์ง๋ ฌํ๊ฐ ๊นจ์ง์ง ์๊ฒ ํ๋ ๊ฒ์ ๋๋ค. | |
| if backbone_name_or_path is None: | |
| # num_labels may be implicitly assumed by downstream code, | |
| # so we explicitly set a safe default. | |
| # num_labels๋ downstream ์ฝ๋์์ ์๋ฌต์ ์ผ๋ก ๊ฐ์ ๋๋ฏ๋ก | |
| # ์์ ํ ๊ธฐ๋ณธ๊ฐ์ ๋ช ์์ ์ผ๋ก ๋ฃ์ด์ค๋๋ค. | |
| if "num_labels" not in kwargs: | |
| kwargs["num_labels"] = 0 | |
| super().__init__(**kwargs) | |
| # Backbone is intentionally unset in this path. | |
| # ์ด ๊ฒฝ๋ก์์๋ backbone์ ์๋์ ์ผ๋ก ์ค์ ํ์ง ์์ต๋๋ค. | |
| self.backbone_name_or_path = None | |
| # Store MLP head hyperparameters for completeness. | |
| # MLP head ๊ด๋ จ ํ์ดํผํ๋ผ๋ฏธํฐ๋ ํํ ์ ์ง๋ฅผ ์ํด ์ ์ฅํฉ๋๋ค. | |
| self.mlp_head_bottleneck = int(mlp_head_bottleneck) | |
| self.mlp_head_dropout = float(mlp_head_dropout) | |
| # Empty label mappings ensure stable save/load behavior. | |
| # ๋น label ๋งคํ์ ์ ์ฅ/๋ก๋ ์์ ์ฑ์ ๋ณด์ฅํฉ๋๋ค. | |
| self.label2id = {} | |
| self.id2label = {} | |
| self.num_labels = int(kwargs.get("num_labels", 0)) | |
| return | |
| # ============================================================ | |
| # 1) Backbone whitelist validation | |
| # 1) Backbone ํ์ดํธ๋ฆฌ์คํธ ๊ฒ์ฆ | |
| # ============================================================ | |
| # Only backbones explicitly registered in BACKBONE_META are allowed. | |
| # BACKBONE_META์ ๋ฑ๋ก๋ backbone๋ง ํ์ฉํฉ๋๋ค. | |
| # | |
| # This prevents accidental usage of unsupported or inconsistent backbones. | |
| # ์ง์๋์ง ์๊ฑฐ๋ ๋ถ์ผ์นํ backbone ์ฌ์ฉ์ ์์ฒ ์ฐจ๋จํฉ๋๋ค. | |
| if backbone_name_or_path not in BACKBONE_META: | |
| raise ValueError( | |
| f"Unsupported backbone_name_or_path={backbone_name_or_path}. " | |
| f"Allowed: {sorted(BACKBONE_META.keys())}" | |
| ) | |
| # ============================================================ | |
| # 2) Label mapping normalization | |
| # 2) ๋ผ๋ฒจ ๋งคํ ์ ๊ทํ | |
| # ============================================================ | |
| # Both label2id and id2label may be None during pure loading scenarios. | |
| # ๋จ์ ๋ก๋(from_pretrained) ์ label2id/id2label์ด None์ผ ์ ์์ต๋๋ค. | |
| # | |
| # We allow this here to keep Hub loading and verification stable. | |
| # Hub ๋ก๋ ๋ฐ ๊ฒ์ฆ ์์ ์ฑ์ ์ํด ์ฌ๊ธฐ์๋ ์ด๋ฅผ ํ์ฉํฉ๋๋ค. | |
| # | |
| # Fail-fast validation should happen at model or training level instead. | |
| # ์ค์ ์ฌ์ฉ ๊ฒ์ฆ์ ๋ชจ๋ธ ๋๋ ํ์ต ๋จ๊ณ์์ fail-fast๋ก ์ฒ๋ฆฌํด์ผ ํฉ๋๋ค. | |
| if label2id is None and id2label is None: | |
| # Respect num_labels if explicitly provided, otherwise default to 0. | |
| # num_labels๊ฐ ์ฃผ์ด์ก๋ค๋ฉด ์กด์คํ๊ณ , ์๋๋ฉด 0์ผ๋ก ๋ก๋๋ค. | |
| num_labels = int(kwargs.get("num_labels", 0)) | |
| label2id_norm: dict[str, int] = {} | |
| id2label_norm: dict[int, str] = {} | |
| else: | |
| # If only one mapping is provided, derive the other. | |
| # ํ๋๋ง ์ฃผ์ด์ง ๊ฒฝ์ฐ ๋๋จธ์ง๋ฅผ ์๋์ผ๋ก ์์ฑํฉ๋๋ค. | |
| if id2label is None: | |
| id2label = {v: k for k, v in label2id.items()} | |
| if label2id is None: | |
| label2id = {v: k for k, v in id2label.items()} | |
| # Ensure both mappings are consistent in size. | |
| # ๋ ๋งคํ์ ํฌ๊ธฐ๊ฐ ์ผ์นํ๋์ง ํ์ธํฉ๋๋ค. | |
| if len(label2id) != len(id2label): | |
| raise ValueError( | |
| f"label2id/id2label size mismatch: " | |
| f"{len(label2id)} vs {len(id2label)}" | |
| ) | |
| num_labels = len(id2label) | |
| label2id_norm = dict(label2id) | |
| id2label_norm = dict(id2label) | |
| # ============================================================ | |
| # 3) num_labels consistency enforcement | |
| # 3) num_labels ์ผ๊ด์ฑ ๊ฐ์ | |
| # ============================================================ | |
| # If num_labels is provided via kwargs, it must match inferred labels. | |
| # kwargs๋ก num_labels๊ฐ ๋ค์ด์จ ๊ฒฝ์ฐ ์ถ๋ก ๋ ๊ฐ๊ณผ ๋ฐ๋์ ์ผ์นํด์ผ ํฉ๋๋ค. | |
| if "num_labels" in kwargs: | |
| if (label2id is not None or id2label is not None) and int(kwargs["num_labels"]) != num_labels: | |
| raise ValueError( | |
| f"kwargs['num_labels']={kwargs['num_labels']} " | |
| f"!= inferred num_labels={num_labels}" | |
| ) | |
| else: | |
| kwargs["num_labels"] = num_labels | |
| # ============================================================ | |
| # 4) Parent initialization | |
| # 4) ๋ถ๋ชจ ํด๋์ค ์ด๊ธฐํ | |
| # ============================================================ | |
| # Initialize PretrainedConfig with normalized label mappings. | |
| # ์ ๊ทํ๋ ๋ผ๋ฒจ ๋งคํ์ ์ฌ์ฉํ์ฌ PretrainedConfig๋ฅผ ์ด๊ธฐํํฉ๋๋ค. | |
| super().__init__( | |
| label2id=label2id_norm, | |
| id2label=id2label_norm, | |
| **kwargs, | |
| ) | |
| # ============================================================ | |
| # 5) Explicit attribute assignment for save/load stability | |
| # 5) ์ ์ฅ/๋ก๋ ์์ ์ฑ์ ์ํ ๋ช ์์ ์์ฑ ๊ณ ์ | |
| # ============================================================ | |
| # Explicitly reassign critical fields to avoid subtle serialization issues. | |
| # ๋ฏธ๋ฌํ ์ง๋ ฌํ ๋ฌธ์ ๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด ํต์ฌ ํ๋๋ฅผ ๋ช ์์ ์ผ๋ก ๋ค์ ์ค์ ํฉ๋๋ค. | |
| self.backbone_name_or_path = backbone_name_or_path | |
| self.mlp_head_bottleneck = int(mlp_head_bottleneck) | |
| self.mlp_head_dropout = float(mlp_head_dropout) | |
| self.label2id = label2id_norm | |
| self.id2label = id2label_norm | |
| self.num_labels = int(kwargs["num_labels"]) | |
| def to_dict(self): | |
| # Call the parent implementation first. | |
| # ๋จผ์ ๋ถ๋ชจ ๊ตฌํ์ ํธ์ถํฉ๋๋ค. | |
| output = super().to_dict() | |
| # Force num_labels to be present and consistent. | |
| # num_labels๊ฐ ๋ฐ๋์ ์กด์ฌํ๊ณ ์ผ๊ด๋๋๋ก ๊ฐ์ ํฉ๋๋ค. | |
| output["num_labels"] = int( | |
| getattr(self, "num_labels", output.get("num_labels", 0)) | |
| ) | |
| return output | |
| # Register this config so it can be resolved via AutoConfig. | |
| # AutoConfig๋ฅผ ํตํด ์ด config๊ฐ ํด์๋ ์ ์๋๋ก ๋ฑ๋กํฉ๋๋ค. | |
| BackboneMLPHeadConfig.register_for_auto_class("AutoConfig") | |