bb_mlp_224 / ds_meta.py
dsaint31's picture
Add/Update backbone checkpoints (count=6)
2eed5eb verified
from typing import Literal, Any
# ============================================================
# 1) Allowed 224 backbones (Fixed whitelist)
# 1) ํ—ˆ์šฉ 224 ๋ฐฑ๋ณธ (ํ™”์ดํŠธ๋ฆฌ์ŠคํŠธ ๊ณ ์ •)
# ============================================================
# This Literal defines the only backbone identifiers that are allowed in configs.
# ์ด Literal์€ config์—์„œ ํ—ˆ์šฉ๋˜๋Š” backbone ์‹๋ณ„์ž ์ง‘ํ•ฉ์„ ๊ฐ•์ œํ•ฉ๋‹ˆ๋‹ค.
BackboneID = Literal[
"google/vit-base-patch16-224",
"microsoft/swin-tiny-patch4-window7-224",
"microsoft/resnet-50",
"google/efficientnet-b0",
"timm/densenet121.tv_in1k",
"torchvision/densenet121",
]
# ============================================================
# 2) Backbone metadata registry (Feature dim/rule/unfreeze rule)
# 2) ๋ฐฑ๋ณธ ๋ฉ”ํƒ€ ๋ ˆ์ง€์ŠคํŠธ๋ฆฌ (feature dim/rule/unfreeze rule ๊ณ ์ •)
# ============================================================
# This table is the single source of truth for feature extraction and fine-tuning rules per backbone.
# ์ด ํ…Œ์ด๋ธ”์€ backbone๋ณ„ feature ์ถ”์ถœ ๋ฐ ๋ฏธ์„ธ์กฐ์ • ๊ทœ์น™์˜ ๋‹จ์ผ ๊ธฐ์ค€(source of truth)์ž…๋‹ˆ๋‹ค.
#
# The key type is BackboneID to ensure meta keys never drift from the whitelist.
# ํ‚ค ํƒ€์ž…์„ BackboneID๋กœ ๊ณ ์ •ํ•˜์—ฌ ๋ฉ”ํƒ€ ํ‚ค๊ฐ€ ํ™”์ดํŠธ๋ฆฌ์ŠคํŠธ์™€ ์–ด๊ธ‹๋‚˜์ง€ ์•Š๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.
BACKBONE_META: dict[BackboneID, dict[str, Any]] = {
# -------------------------
# Transformers (ViT/Swin)
# -------------------------
# These backbones come from transformers and typically output hidden states and/or pooler outputs.
# ์ด ๋ฐฑ๋ณธ๋“ค์€ transformers ๊ณ„์—ด์ด๋ฉฐ hidden states์™€ pooler ์ถœ๋ ฅ ๋“ฑ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
"google/vit-base-patch16-224": {
# type indicates which loading/forward/extraction pathway the model code should use.
# type์€ ๋ชจ๋ธ ์ฝ”๋“œ๊ฐ€ ์–ด๋–ค ๋กœ๋”ฉ/forward/feature ์ถ”์ถœ ๊ฒฝ๋กœ๋ฅผ ์‚ฌ์šฉํ• ์ง€ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
"type": "vit",
# feat_dim is the feature vector dimension consumed by the MLP head.
# feat_dim์€ MLP head๊ฐ€ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›๋Š” feature ๋ฒกํ„ฐ ์ฐจ์›์ž…๋‹ˆ๋‹ค.
"feat_dim": 768,
# feat_rule defines how to get a (B, feat_dim) tensor from backbone outputs.
# feat_rule์€ backbone ์ถœ๋ ฅ์—์„œ (B, feat_dim) ํ…์„œ๋ฅผ ์–ป๋Š” ๊ทœ์น™์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
"feat_rule": "cls", # Use last_hidden_state[:, 0, :] as CLS token embedding.
# last_hidden_state[:, 0, :]๋ฅผ CLS ํ† ํฐ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
# unfreeze defines the policy to unfreeze layers during stage2 fine-tuning.
# unfreeze๋Š” stage2 ๋ฏธ์„ธ์กฐ์ •์—์„œ ์–ด๋–ค ๋ ˆ์ด์–ด๋ฅผ ํ’€์ง€ ์ •์ฑ…์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
"unfreeze": "last_n", # Unfreeze the last n encoder blocks.
# encoder ๋ธ”๋ก์˜ ๋งˆ์ง€๋ง‰ n๊ฐœ๋ฅผ unfreeze ํ•ฉ๋‹ˆ๋‹ค.
# has_bn indicates whether BatchNorm exists and should be handled carefully when freezing.
# has_bn์€ BatchNorm ์กด์žฌ ์—ฌ๋ถ€์ด๋ฉฐ freeze ์‹œ ํŠน๋ณ„ ์ทจ๊ธ‰์ด ํ•„์š”ํ•œ์ง€ ํŒ๋‹จ์— ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
"has_bn": False,
},
"microsoft/swin-tiny-patch4-window7-224": {
# This backbone is a Swin Transformer, which may or may not provide a pooler output depending on implementation.
# ์ด ๋ฐฑ๋ณธ์€ Swin Transformer์ด๋ฉฐ ๊ตฌํ˜„์— ๋”ฐ๋ผ pooler output ์ œ๊ณต ์—ฌ๋ถ€๊ฐ€ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
"type": "swin",
"feat_dim": 768,
# Prefer pooler output if available, otherwise fall back to mean pooling.
# pooler๊ฐ€ ์žˆ์œผ๋ฉด ์šฐ์„  ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด mean pooling์œผ๋กœ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค.
"feat_rule": "pool_or_mean",
# Unfreeze strategy is aligned with transformer-style encoder blocks.
# unfreeze ์ „๋žต์€ transformer ๊ณ„์—ด encoder ๋ธ”๋ก ๊ธฐ์ค€์œผ๋กœ ๋งž์ถฅ๋‹ˆ๋‹ค.
"unfreeze": "last_n",
"has_bn": False,
},
# -------------------------
# Transformers (CNNs)
# -------------------------
# These backbones are CNNs exposed via transformers, usually producing pooled feature vectors or feature maps.
# ์ด ๋ฐฑ๋ณธ๋“ค์€ transformers๋กœ ๋…ธ์ถœ๋œ CNN์ด๋ฉฐ pooled feature ๋˜๋Š” feature map์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
"microsoft/resnet-50": {
# This entry assumes a transformers-compatible ResNet that can expose pooler or a final feature map.
# ์ด ํ•ญ๋ชฉ์€ transformers ํ˜ธํ™˜ ResNet์ด pooler ๋˜๋Š” ์ตœ์ข… feature map์„ ์ œ๊ณตํ•  ์ˆ˜ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
"type": "resnet",
"feat_dim": 2048,
# Use pooler output if the model provides it, otherwise apply global average pooling (GAP).
# pooler๊ฐ€ ์žˆ์œผ๋ฉด ์‚ฌ์šฉํ•˜๊ณ , ์—†์œผ๋ฉด global average pooling(GAP)์„ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
"feat_rule": "pool_or_gap",
# CNN unfreeze policy can still be expressed as "last_n" at a block/stage granularity in your model code.
# CNN๋„ ๋ชจ๋ธ ์ฝ”๋“œ์—์„œ block/stage ๋‹จ์œ„๋กœ last_n ์ •์ฑ…์„ ์ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
"unfreeze": "last_n",
"has_bn": True,
},
"google/efficientnet-b0": {
# This entry assumes a transformers-compatible EfficientNet that exposes pooled features or a final feature map.
# ์ด ํ•ญ๋ชฉ์€ transformers ํ˜ธํ™˜ EfficientNet์ด pooled feature ๋˜๋Š” ์ตœ์ข… feature map์„ ์ œ๊ณตํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
"type": "efficientnet",
"feat_dim": 1280,
"feat_rule": "pool_or_gap",
"unfreeze": "last_n",
"has_bn": True,
},
# -------------------------
# timm (DenseNet via HF Hub)
# -------------------------
# This backbone is loaded via timm using the "hf_hub:" prefix in your model loader.
# ์ด ๋ฐฑ๋ณธ์€ ๋ชจ๋ธ ๋กœ๋”์—์„œ timm์˜ "hf_hub:" ํ”„๋ฆฌํ”ฝ์Šค๋ฅผ ์‚ฌ์šฉํ•ด ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
"timm/densenet121.tv_in1k": {
"type": "timm_densenet",
# DenseNet-121 final channel dimension is 1024 for the canonical architecture.
# DenseNet-121์˜ ํ‘œ์ค€ ์•„ํ‚คํ…์ฒ˜์—์„œ ์ตœ์ข… ์ฑ„๋„ ์ฐจ์›์€ 1024์ž…๋‹ˆ๋‹ค.
"feat_dim": 1024,
# timm forward_features typically returns a feature map that you then GAP to (B, C).
# timm์˜ forward_features๋Š” ๋ณดํ†ต feature map์„ ๋ฐ˜ํ™˜ํ•˜๊ณ  ์ดํ›„ GAP์œผ๋กœ (B, C)๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
"feat_rule": "timm_gap",
# DenseNet uses BatchNorm heavily, so freeze_bn behavior matters for stage1/stage2.
# DenseNet์€ BatchNorm ์‚ฌ์šฉ์ด ๋งŽ์•„ stage1/stage2์—์„œ freeze_bn ์ฒ˜๋ฆฌ๊ฐ€ ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.
"unfreeze": "last_n",
"has_bn": True,
},
# -------------------------
# torchvision (DenseNet direct)
# -------------------------
# This backbone is intended for torchvision-style loading and feature extraction, not transformers/timm.
# ์ด ๋ฐฑ๋ณธ์€ transformers/timm์ด ์•„๋‹ˆ๋ผ torchvision ์Šคํƒ€์ผ ๋กœ๋”ฉ ๋ฐ feature ์ถ”์ถœ์„ ๋Œ€์ƒ์œผ๋กœ ํ•ฉ๋‹ˆ๋‹ค.
"torchvision/densenet121": {
"type": "torchvision_densenet",
"feat_dim": 1024,
# torchvision DenseNet usually exposes .features and you apply GAP to obtain (B, C).
# torchvision DenseNet์€ ๋ณดํ†ต .features๋ฅผ ๋…ธ์ถœํ•˜๋ฉฐ GAP์œผ๋กœ (B, C)๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค.
"feat_rule": "torchvision_densenet_gap",
# Unfreeze policy remains last_n, but the interpretation must match torchvision module naming.
# unfreeze ์ •์ฑ…์€ last_n์„ ์œ ์ง€ํ•˜๋˜, ํ•ด์„์€ torchvision ๋ชจ๋“ˆ ๋„ค์ด๋ฐ์— ๋งž์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.
"unfreeze": "last_n",
"has_bn": True,
},
}