File size: 7,527 Bytes
17461fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | from typing import Literal, Any
# ============================================================
# 1) Allowed 224 backbones (Fixed whitelist)
# 1) ํ์ฉ 224 ๋ฐฑ๋ณธ (ํ์ดํธ๋ฆฌ์คํธ ๊ณ ์ )
# ============================================================
# This Literal defines the only backbone identifiers that are allowed in configs.
# ์ด Literal์ config์์ ํ์ฉ๋๋ backbone ์๋ณ์ ์งํฉ์ ๊ฐ์ ํฉ๋๋ค.
BackboneID = Literal[
"google/vit-base-patch16-224",
"microsoft/swin-tiny-patch4-window7-224",
"microsoft/resnet-50",
"google/efficientnet-b0",
"timm/densenet121.tv_in1k",
"torchvision/densenet121",
]
# ============================================================
# 2) Backbone metadata registry (Feature dim/rule/unfreeze rule)
# 2) ๋ฐฑ๋ณธ ๋ฉํ ๋ ์ง์คํธ๋ฆฌ (feature dim/rule/unfreeze rule ๊ณ ์ )
# ============================================================
# This table is the single source of truth for feature extraction and fine-tuning rules per backbone.
# ์ด ํ
์ด๋ธ์ backbone๋ณ feature ์ถ์ถ ๋ฐ ๋ฏธ์ธ์กฐ์ ๊ท์น์ ๋จ์ผ ๊ธฐ์ค(source of truth)์
๋๋ค.
#
# The key type is BackboneID to ensure meta keys never drift from the whitelist.
# ํค ํ์
์ BackboneID๋ก ๊ณ ์ ํ์ฌ ๋ฉํ ํค๊ฐ ํ์ดํธ๋ฆฌ์คํธ์ ์ด๊ธ๋์ง ์๊ฒ ํฉ๋๋ค.
BACKBONE_META: dict[BackboneID, dict[str, Any]] = {
# -------------------------
# Transformers (ViT/Swin)
# -------------------------
# These backbones come from transformers and typically output hidden states and/or pooler outputs.
# ์ด ๋ฐฑ๋ณธ๋ค์ transformers ๊ณ์ด์ด๋ฉฐ hidden states์ pooler ์ถ๋ ฅ ๋ฑ์ ์ ๊ณตํฉ๋๋ค.
"google/vit-base-patch16-224": {
# type indicates which loading/forward/extraction pathway the model code should use.
# type์ ๋ชจ๋ธ ์ฝ๋๊ฐ ์ด๋ค ๋ก๋ฉ/forward/feature ์ถ์ถ ๊ฒฝ๋ก๋ฅผ ์ฌ์ฉํ ์ง ๊ฒฐ์ ํฉ๋๋ค.
"type": "vit",
# feat_dim is the feature vector dimension consumed by the MLP head.
# feat_dim์ MLP head๊ฐ ์
๋ ฅ์ผ๋ก ๋ฐ๋ feature ๋ฒกํฐ ์ฐจ์์
๋๋ค.
"feat_dim": 768,
# feat_rule defines how to get a (B, feat_dim) tensor from backbone outputs.
# feat_rule์ backbone ์ถ๋ ฅ์์ (B, feat_dim) ํ
์๋ฅผ ์ป๋ ๊ท์น์ ์ ์ํฉ๋๋ค.
"feat_rule": "cls", # Use last_hidden_state[:, 0, :] as CLS token embedding.
# last_hidden_state[:, 0, :]๋ฅผ CLS ํ ํฐ ์๋ฒ ๋ฉ์ผ๋ก ์ฌ์ฉํฉ๋๋ค.
# unfreeze defines the policy to unfreeze layers during stage2 fine-tuning.
# unfreeze๋ stage2 ๋ฏธ์ธ์กฐ์ ์์ ์ด๋ค ๋ ์ด์ด๋ฅผ ํ์ง ์ ์ฑ
์ ์ ์ํฉ๋๋ค.
"unfreeze": "last_n", # Unfreeze the last n encoder blocks.
# encoder ๋ธ๋ก์ ๋ง์ง๋ง n๊ฐ๋ฅผ unfreeze ํฉ๋๋ค.
# has_bn indicates whether BatchNorm exists and should be handled carefully when freezing.
# has_bn์ BatchNorm ์กด์ฌ ์ฌ๋ถ์ด๋ฉฐ freeze ์ ํน๋ณ ์ทจ๊ธ์ด ํ์ํ์ง ํ๋จ์ ์ฌ์ฉํฉ๋๋ค.
"has_bn": False,
},
"microsoft/swin-tiny-patch4-window7-224": {
# This backbone is a Swin Transformer, which may or may not provide a pooler output depending on implementation.
# ์ด ๋ฐฑ๋ณธ์ Swin Transformer์ด๋ฉฐ ๊ตฌํ์ ๋ฐ๋ผ pooler output ์ ๊ณต ์ฌ๋ถ๊ฐ ๋ฌ๋ผ์ง ์ ์์ต๋๋ค.
"type": "swin",
"feat_dim": 768,
# Prefer pooler output if available, otherwise fall back to mean pooling.
# pooler๊ฐ ์์ผ๋ฉด ์ฐ์ ์ฌ์ฉํ๊ณ , ์์ผ๋ฉด mean pooling์ผ๋ก ๋์ฒดํฉ๋๋ค.
"feat_rule": "pool_or_mean",
# Unfreeze strategy is aligned with transformer-style encoder blocks.
# unfreeze ์ ๋ต์ transformer ๊ณ์ด encoder ๋ธ๋ก ๊ธฐ์ค์ผ๋ก ๋ง์ถฅ๋๋ค.
"unfreeze": "last_n",
"has_bn": False,
},
# -------------------------
# Transformers (CNNs)
# -------------------------
# These backbones are CNNs exposed via transformers, usually producing pooled feature vectors or feature maps.
# ์ด ๋ฐฑ๋ณธ๋ค์ transformers๋ก ๋
ธ์ถ๋ CNN์ด๋ฉฐ pooled feature ๋๋ feature map์ ์ ๊ณตํฉ๋๋ค.
"microsoft/resnet-50": {
# This entry assumes a transformers-compatible ResNet that can expose pooler or a final feature map.
# ์ด ํญ๋ชฉ์ transformers ํธํ ResNet์ด pooler ๋๋ ์ต์ข
feature map์ ์ ๊ณตํ ์ ์๋ค๊ณ ๊ฐ์ ํฉ๋๋ค.
"type": "resnet",
"feat_dim": 2048,
# Use pooler output if the model provides it, otherwise apply global average pooling (GAP).
# pooler๊ฐ ์์ผ๋ฉด ์ฌ์ฉํ๊ณ , ์์ผ๋ฉด global average pooling(GAP)์ ์ ์ฉํฉ๋๋ค.
"feat_rule": "pool_or_gap",
# CNN unfreeze policy can still be expressed as "last_n" at a block/stage granularity in your model code.
# CNN๋ ๋ชจ๋ธ ์ฝ๋์์ block/stage ๋จ์๋ก last_n ์ ์ฑ
์ ์ ์ฉํ ์ ์์ต๋๋ค.
"unfreeze": "last_n",
"has_bn": True,
},
"google/efficientnet-b0": {
# This entry assumes a transformers-compatible EfficientNet that exposes pooled features or a final feature map.
# ์ด ํญ๋ชฉ์ transformers ํธํ EfficientNet์ด pooled feature ๋๋ ์ต์ข
feature map์ ์ ๊ณตํ๋ค๊ณ ๊ฐ์ ํฉ๋๋ค.
"type": "efficientnet",
"feat_dim": 1280,
"feat_rule": "pool_or_gap",
"unfreeze": "last_n",
"has_bn": True,
},
# -------------------------
# timm (DenseNet via HF Hub)
# -------------------------
# This backbone is loaded via timm using the "hf_hub:" prefix in your model loader.
# ์ด ๋ฐฑ๋ณธ์ ๋ชจ๋ธ ๋ก๋์์ timm์ "hf_hub:" ํ๋ฆฌํฝ์ค๋ฅผ ์ฌ์ฉํด ๋ก๋ํฉ๋๋ค.
"timm/densenet121.tv_in1k": {
"type": "timm_densenet",
# DenseNet-121 final channel dimension is 1024 for the canonical architecture.
# DenseNet-121์ ํ์ค ์ํคํ
์ฒ์์ ์ต์ข
์ฑ๋ ์ฐจ์์ 1024์
๋๋ค.
"feat_dim": 1024,
# timm forward_features typically returns a feature map that you then GAP to (B, C).
# timm์ forward_features๋ ๋ณดํต feature map์ ๋ฐํํ๊ณ ์ดํ GAP์ผ๋ก (B, C)๋ฅผ ๋ง๋ญ๋๋ค.
"feat_rule": "timm_gap",
# DenseNet uses BatchNorm heavily, so freeze_bn behavior matters for stage1/stage2.
# DenseNet์ BatchNorm ์ฌ์ฉ์ด ๋ง์ stage1/stage2์์ freeze_bn ์ฒ๋ฆฌ๊ฐ ์ค์ํฉ๋๋ค.
"unfreeze": "last_n",
"has_bn": True,
},
# -------------------------
# torchvision (DenseNet direct)
# -------------------------
# This backbone is intended for torchvision-style loading and feature extraction, not transformers/timm.
# ์ด ๋ฐฑ๋ณธ์ transformers/timm์ด ์๋๋ผ torchvision ์คํ์ผ ๋ก๋ฉ ๋ฐ feature ์ถ์ถ์ ๋์์ผ๋ก ํฉ๋๋ค.
"torchvision/densenet121": {
"type": "torchvision_densenet",
"feat_dim": 1024,
# torchvision DenseNet usually exposes .features and you apply GAP to obtain (B, C).
# torchvision DenseNet์ ๋ณดํต .features๋ฅผ ๋
ธ์ถํ๋ฉฐ GAP์ผ๋ก (B, C)๋ฅผ ์ป์ต๋๋ค.
"feat_rule": "torchvision_densenet_gap",
# Unfreeze policy remains last_n, but the interpretation must match torchvision module naming.
# unfreeze ์ ์ฑ
์ last_n์ ์ ์งํ๋, ํด์์ torchvision ๋ชจ๋ ๋ค์ด๋ฐ์ ๋ง์์ผ ํฉ๋๋ค.
"unfreeze": "last_n",
"has_bn": True,
},
}
|