|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, List, Any, Dict |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers import AutoModel, AutoConfig |
|
|
from transformers.modeling_outputs import ImageClassifierOutput |
|
|
|
|
|
|
|
|
from torchvision import models as tv_models |
|
|
|
|
|
try: |
|
|
from .ds_cfg import BackboneMLPHeadConfig, BACKBONE_META |
|
|
except ImportError: |
|
|
from ds_cfg import BackboneMLPHeadConfig, BACKBONE_META |
|
|
|
|
|
|
|
|
class MLPHead(nn.Module): |
|
|
""" |
|
|
간단한 2-layer MLP head. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
in_dim : int |
|
|
backbone feature dim |
|
|
num_labels : int |
|
|
class count |
|
|
bottleneck : int |
|
|
hidden dim |
|
|
p : float |
|
|
dropout prob |
|
|
""" |
|
|
def __init__(self, in_dim: int, num_labels: int, bottleneck: int = 256, p: float = 0.2): |
|
|
super().__init__() |
|
|
self.fc1 = nn.Linear(in_dim, bottleneck) |
|
|
self.act = nn.GELU() |
|
|
self.drop = nn.Dropout(p) |
|
|
self.fc2 = nn.Linear(bottleneck, num_labels) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.fc2(self.drop(self.act(self.fc1(x)))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resolve_backbone_meta(config: BackboneMLPHeadConfig, fallback_table: Dict[str, Dict[str, Any]] | None = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Resolve runtime backbone meta. |
|
|
|
|
|
Priority: |
|
|
1) config.backbone_meta (preferred; required for Hub runtime determinism) |
|
|
2) fallback_table[config.backbone_name_or_path] (backward compatibility for local/dev) |
|
|
|
|
|
Returns a dict with at least: type, feat_rule, feat_dim (and optional has_bn/unfreeze). |
|
|
""" |
|
|
meta = getattr(config, "backbone_meta", None) |
|
|
if isinstance(meta, dict) and len(meta) > 0: |
|
|
return meta |
|
|
|
|
|
bb = getattr(config, "backbone_name_or_path", None) |
|
|
if fallback_table is not None and bb in fallback_table: |
|
|
return fallback_table[bb] |
|
|
|
|
|
raise ValueError( |
|
|
"config.backbone_meta is missing/empty and no fallback meta is available. " |
|
|
"Populate config.backbone_meta when saving to the Hub (single source of truth)." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BackboneWithMLPHeadForImageClassification(PreTrainedModel): |
|
|
|
|
|
|
|
|
config_class = BackboneMLPHeadConfig |
|
|
|
|
|
def __init__(self, config: BackboneMLPHeadConfig): |
|
|
|
|
|
|
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if config.backbone_name_or_path is None: |
|
|
raise ValueError( |
|
|
"config.backbone_name_or_path is None. " |
|
|
"Provide a valid backbone id (whitelist key in BACKBONE_META)." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if int(getattr(config, "num_labels", 0)) <= 0: |
|
|
raise ValueError( |
|
|
f"config.num_labels must be > 0, got {getattr(config, 'num_labels', None)}. " |
|
|
"Set num_labels (or id2label/label2id) when creating the config." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._meta = _resolve_backbone_meta(config, fallback_table=BACKBONE_META) |
|
|
|
|
|
|
|
|
|
|
|
self.backbone = self._build_backbone_skeleton(config.backbone_name_or_path) |
|
|
|
|
|
|
|
|
|
|
|
self.classifier = MLPHead( |
|
|
in_dim=int(self._meta["feat_dim"]), |
|
|
num_labels=int(config.num_labels), |
|
|
bottleneck=int(config.mlp_head_bottleneck), |
|
|
p=float(config.mlp_head_dropout), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def init_weights(self): |
|
|
""" |
|
|
Initialize only the head to avoid touching the backbone skeleton. |
|
|
backbone skeleton을 건드리지 않기 위해 head만 초기화. |
|
|
|
|
|
HF's default init may traverse the entire module tree, which is undesirable here. |
|
|
HF 기본 init은 전체 모듈 트리를 순회할 수 있어 여기서 그대로 사용하기 부적절. |
|
|
|
|
|
초기 설계에서 __init__ 내부에서 backbone의 가중치 로드를 수행함(편리를 위해). |
|
|
이 경우, HF의 post_init()으로 인해 해당 로드가 취소되는 경우가 존재(timm, torchvision 등의 백본). |
|
|
때문에 이를 오버라이드 하여 classifier만 초기화 하도록 변경함. |
|
|
""" |
|
|
if getattr(self, "classifier", None) is not None: |
|
|
self.classifier.apply(self._init_weights) |
|
|
self.tie_weights() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_backbone_skeleton(self, backbone_id: str) -> nn.Module: |
|
|
|
|
|
|
|
|
meta = self._meta if backbone_id == self.config.backbone_name_or_path else BACKBONE_META.get(backbone_id) |
|
|
if meta is None: |
|
|
raise KeyError(f"Unknown backbone_id={backbone_id}. Provide backbone_meta in config or extend BACKBONE_META.") |
|
|
|
|
|
t = meta["type"] |
|
|
|
|
|
if t == "timm_densenet": |
|
|
return self._build_timm_densenet_skeleton(backbone_id) |
|
|
|
|
|
if t == "torchvision_densenet": |
|
|
return self._build_torchvision_densenet_skeleton(backbone_id) |
|
|
|
|
|
|
|
|
|
|
|
bb_cfg = AutoConfig.from_pretrained(backbone_id) |
|
|
return AutoModel.from_config(bb_cfg) |
|
|
|
|
|
@staticmethod |
|
|
def _build_timm_densenet_skeleton(hf_repo_id: str) -> nn.Module: |
|
|
|
|
|
|
|
|
try: |
|
|
import timm |
|
|
except Exception as e: |
|
|
raise ImportError( |
|
|
"DenseNet(timm) backbone requires `timm`. Install: pip install timm" |
|
|
) from e |
|
|
|
|
|
|
|
|
|
|
|
return timm.create_model( |
|
|
f"hf_hub:{hf_repo_id}", |
|
|
pretrained=False, |
|
|
num_classes=0, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _build_torchvision_densenet_skeleton(model_id: str) -> nn.Module: |
|
|
|
|
|
|
|
|
if model_id != "torchvision/densenet121": |
|
|
raise ValueError(f"Unsupported torchvision DenseNet id (224 whitelist only): {model_id}") |
|
|
|
|
|
|
|
|
|
|
|
m = tv_models.densenet121(weights=None) |
|
|
return m |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def load_backbone_pretrained_( |
|
|
self, |
|
|
*, |
|
|
low_cpu_mem_usage: bool = False, |
|
|
device_map=None, |
|
|
): |
|
|
""" |
|
|
Fresh-start only: inject pretrained backbone weights into the skeleton. |
|
|
fresh-start 전용: skeleton backbone에 pretrained 가중치를 주입. |
|
|
|
|
|
Do NOT call this after from_pretrained() because it would overwrite checkpoint weights. |
|
|
from_pretrained() 이후 호출하면 체크포인트 가중치를 덮어쓰므로 주의할 것. |
|
|
""" |
|
|
bb = self.config.backbone_name_or_path |
|
|
meta = self._meta |
|
|
t = meta["type"] |
|
|
|
|
|
if t == "timm_densenet": |
|
|
self._load_timm_pretrained_into_skeleton_(bb) |
|
|
return |
|
|
|
|
|
if t == "torchvision_densenet": |
|
|
self._load_torchvision_pretrained_into_skeleton_(bb) |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
ref = AutoModel.from_pretrained( |
|
|
bb, |
|
|
low_cpu_mem_usage=low_cpu_mem_usage, |
|
|
device_map=device_map, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.backbone.load_state_dict(ref.state_dict(), strict=False) |
|
|
del ref |
|
|
|
|
|
@torch.no_grad() |
|
|
def _load_timm_pretrained_into_skeleton_(self, hf_repo_id: str): |
|
|
|
|
|
|
|
|
import timm |
|
|
|
|
|
|
|
|
|
|
|
ref = timm.create_model( |
|
|
f"hf_hub:{hf_repo_id}", |
|
|
pretrained=True, |
|
|
num_classes=0, |
|
|
).eval() |
|
|
|
|
|
self.backbone.load_state_dict(ref.state_dict(), strict=True) |
|
|
del ref |
|
|
|
|
|
@torch.no_grad() |
|
|
def _load_torchvision_pretrained_into_skeleton_(self, model_id: str): |
|
|
|
|
|
|
|
|
if model_id != "torchvision/densenet121": |
|
|
raise ValueError(f"Unsupported torchvision DenseNet id (224 whitelist only): {model_id}") |
|
|
|
|
|
|
|
|
|
|
|
ref = tv_models.densenet121(weights=tv_models.DenseNet121_Weights.DEFAULT).eval() |
|
|
|
|
|
self.backbone.load_state_dict(ref.state_dict(), strict=True) |
|
|
del ref |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _pool_or_gap(outputs) -> torch.Tensor: |
|
|
|
|
|
|
|
|
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: |
|
|
x = outputs.pooler_output |
|
|
if x.dim() == 2: |
|
|
return x |
|
|
if x.dim() == 4 and x.size(-1) == 1 and x.size(-2) == 1: |
|
|
return x.flatten(1) |
|
|
raise RuntimeError(f"Unexpected pooler_output shape: {tuple(x.shape)}") |
|
|
|
|
|
|
|
|
|
|
|
x = outputs.last_hidden_state |
|
|
if x.dim() == 4: |
|
|
return x.mean(dim=(2, 3)) |
|
|
|
|
|
raise RuntimeError( |
|
|
"Expected pooler_output or (B,C,H,W) last_hidden_state for CNN backbones. " |
|
|
f"Got last_hidden_state shape={tuple(x.shape)}" |
|
|
) |
|
|
|
|
|
def _extract_features(self, outputs, pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
|
|
|
|
|
|
rule = self._meta["feat_rule"] |
|
|
|
|
|
if rule == "cls": |
|
|
|
|
|
|
|
|
return outputs.last_hidden_state[:, 0, :] |
|
|
|
|
|
if rule == "pool_or_mean": |
|
|
|
|
|
|
|
|
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: |
|
|
return outputs.pooler_output |
|
|
return outputs.last_hidden_state.mean(dim=1) |
|
|
|
|
|
if rule == "pool_or_gap": |
|
|
|
|
|
|
|
|
return self._pool_or_gap(outputs) |
|
|
|
|
|
if rule == "timm_gap": |
|
|
|
|
|
|
|
|
if not isinstance(outputs, torch.Tensor): |
|
|
raise TypeError(f"timm_gap expects Tensor features, got {type(outputs)}") |
|
|
if outputs.dim() != 4: |
|
|
raise RuntimeError(f"Expected (B,C,H,W), got {tuple(outputs.shape)}") |
|
|
return outputs.mean(dim=(2, 3)) |
|
|
|
|
|
if rule == "torchvision_densenet_gap": |
|
|
|
|
|
|
|
|
if not isinstance(outputs, torch.Tensor): |
|
|
raise TypeError(f"torchvision_densenet_gap expects Tensor, got {type(outputs)}") |
|
|
if outputs.dim() != 4: |
|
|
raise RuntimeError(f"Expected (B,C,H,W), got {tuple(outputs.shape)}") |
|
|
return outputs.mean(dim=(2, 3)) |
|
|
|
|
|
raise RuntimeError(f"unknown feat_rule={rule}") |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
pixel_values=None, |
|
|
labels=None, |
|
|
output_attentions=None, |
|
|
output_hidden_states=None, |
|
|
return_dict=True, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
|
|
|
t = self._meta["type"] |
|
|
|
|
|
if t == "timm_densenet": |
|
|
|
|
|
|
|
|
if pixel_values is None: |
|
|
raise ValueError("timm DenseNet backbone requires pixel_values.") |
|
|
if pixel_values.dim() != 4: |
|
|
raise ValueError(f"pixel_values must be (B,C,H,W), got {tuple(pixel_values.shape)}") |
|
|
|
|
|
features_map = self.backbone.forward_features(pixel_values) |
|
|
feats = self._extract_features(features_map, pixel_values=pixel_values) |
|
|
hidden_states = None |
|
|
attentions = None |
|
|
|
|
|
elif t == "torchvision_densenet": |
|
|
|
|
|
|
|
|
if pixel_values is None: |
|
|
raise ValueError("torchvision DenseNet backbone requires pixel_values.") |
|
|
if pixel_values.dim() != 4: |
|
|
raise ValueError(f"pixel_values must be (B,C,H,W), got {tuple(pixel_values.shape)}") |
|
|
|
|
|
features_map = self.backbone.features(pixel_values) |
|
|
features_map = F.relu(features_map, inplace=False) |
|
|
feats = self._extract_features(features_map, pixel_values=pixel_values) |
|
|
hidden_states = None |
|
|
attentions = None |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
outputs = self.backbone( |
|
|
pixel_values=pixel_values, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
**kwargs, |
|
|
) |
|
|
feats = self._extract_features(outputs, pixel_values=pixel_values) |
|
|
hidden_states = getattr(outputs, "hidden_states", None) |
|
|
attentions = getattr(outputs, "attentions", None) |
|
|
|
|
|
|
|
|
|
|
|
logits = self.classifier(feats) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
|
|
|
loss = F.cross_entropy(logits, labels) |
|
|
|
|
|
if not return_dict: |
|
|
out = (logits,) |
|
|
return ((loss,) + out) if loss is not None else out |
|
|
|
|
|
return ImageClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=hidden_states, |
|
|
attentions=attentions, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_requires_grad(module: nn.Module, flag: bool): |
|
|
|
|
|
|
|
|
for p in module.parameters(): |
|
|
p.requires_grad = flag |
|
|
|
|
|
|
|
|
def set_bn_eval(module: nn.Module): |
|
|
|
|
|
|
|
|
for m in module.modules(): |
|
|
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)): |
|
|
m.eval() |
|
|
|
|
|
|
|
|
def freeze_backbone(model: BackboneWithMLPHeadForImageClassification, freeze_bn: bool = True): |
|
|
|
|
|
|
|
|
_set_requires_grad(model.backbone, False) |
|
|
_set_requires_grad(model.classifier, True) |
|
|
|
|
|
meta = getattr(model, "_meta", None) or getattr(model.config, "backbone_meta", None) |
|
|
if freeze_bn and meta.get("has_bn", False): |
|
|
set_bn_eval(model.backbone) |
|
|
|
|
|
|
|
|
def finetune_train_mode(model: BackboneWithMLPHeadForImageClassification, keep_bn_eval: bool = True): |
|
|
|
|
|
|
|
|
model.train() |
|
|
meta = getattr(model, "_meta", None) or getattr(model.config, "backbone_meta", None) |
|
|
if keep_bn_eval and meta.get("has_bn", False): |
|
|
set_bn_eval(model.backbone) |
|
|
|
|
|
|
|
|
def trainable_summary(model: nn.Module): |
|
|
|
|
|
|
|
|
total = sum(p.numel() for p in model.parameters()) |
|
|
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
ratio = trainable / total if total > 0 else 0.0 |
|
|
print(f"trainable: {trainable:,} / total: {total:,} ({ratio*100:.2f}%)") |
|
|
return {"trainable": trainable, "total": total, "ratio": ratio} |
|
|
|
|
|
|
|
|
def unfreeze_last_stage( |
|
|
model: BackboneWithMLPHeadForImageClassification, |
|
|
last_n: int = 2, |
|
|
keep_bn_eval: bool = True, |
|
|
): |
|
|
|
|
|
|
|
|
freeze_backbone(model, freeze_bn=keep_bn_eval) |
|
|
|
|
|
n = int(last_n) |
|
|
if n <= 0: |
|
|
return |
|
|
|
|
|
meta = getattr(model, "_meta", None) or getattr(model.config, "backbone_meta", None) |
|
|
if meta.get("unfreeze") != "last_n": |
|
|
raise RuntimeError(f"Unexpected unfreeze rule: {meta.get('unfreeze')} (expected 'last_n')") |
|
|
|
|
|
bb_type = meta["type"] |
|
|
|
|
|
if bb_type == "vit": |
|
|
|
|
|
|
|
|
blocks = list(model.backbone.encoder.layer) |
|
|
for blk in blocks[-n:]: |
|
|
_set_requires_grad(blk, True) |
|
|
return |
|
|
|
|
|
if bb_type == "swin": |
|
|
|
|
|
|
|
|
stages = list(model.backbone.encoder.layers) |
|
|
blocks: List[nn.Module] = [] |
|
|
for st in stages: |
|
|
blocks.extend(list(st.blocks)) |
|
|
for blk in blocks[-n:]: |
|
|
_set_requires_grad(blk, True) |
|
|
return |
|
|
|
|
|
if bb_type == "resnet": |
|
|
|
|
|
|
|
|
bb = model.backbone |
|
|
for name in ("layer1", "layer2", "layer3", "layer4"): |
|
|
if not hasattr(bb, name): |
|
|
raise RuntimeError(f"Unexpected ResNet structure: missing {name}") |
|
|
|
|
|
blocks: List[nn.Module] = [] |
|
|
blocks.extend(list(bb.layer1.children())) |
|
|
blocks.extend(list(bb.layer2.children())) |
|
|
blocks.extend(list(bb.layer3.children())) |
|
|
blocks.extend(list(bb.layer4.children())) |
|
|
|
|
|
for blk in blocks[-n:]: |
|
|
_set_requires_grad(blk, True) |
|
|
|
|
|
if keep_bn_eval: |
|
|
set_bn_eval(bb) |
|
|
return |
|
|
|
|
|
if bb_type == "efficientnet": |
|
|
|
|
|
|
|
|
bb = model.backbone |
|
|
if not hasattr(bb, "features"): |
|
|
raise RuntimeError("Unexpected EfficientNet structure: missing features") |
|
|
|
|
|
blocks: List[nn.Module] = [] |
|
|
for st in bb.features.children(): |
|
|
blocks.extend(list(st.children())) |
|
|
|
|
|
for blk in blocks[-n:]: |
|
|
_set_requires_grad(blk, True) |
|
|
|
|
|
if keep_bn_eval: |
|
|
set_bn_eval(bb) |
|
|
return |
|
|
|
|
|
if bb_type in ("timm_densenet", "torchvision_densenet"): |
|
|
|
|
|
|
|
|
bb = model.backbone |
|
|
if not hasattr(bb, "features"): |
|
|
raise RuntimeError("Unexpected DenseNet: missing features") |
|
|
f = bb.features |
|
|
|
|
|
req = [ |
|
|
"conv0", "norm0", "relu0", "pool0", |
|
|
"denseblock1", "transition1", |
|
|
"denseblock2", "transition2", |
|
|
"denseblock3", "transition3", |
|
|
"denseblock4", "norm5", |
|
|
] |
|
|
for name in req: |
|
|
if not hasattr(f, name): |
|
|
raise RuntimeError(f"Unexpected DenseNet features: missing {name}") |
|
|
|
|
|
def _denselayers(db: nn.Module) -> List[nn.Module]: |
|
|
|
|
|
|
|
|
return list(db.children()) |
|
|
|
|
|
blocks: List[nn.Module] = [] |
|
|
blocks.extend([f.conv0, f.norm0, f.relu0, f.pool0]) |
|
|
blocks.extend(_denselayers(f.denseblock1)); blocks.append(f.transition1) |
|
|
blocks.extend(_denselayers(f.denseblock2)); blocks.append(f.transition2) |
|
|
blocks.extend(_denselayers(f.denseblock3)); blocks.append(f.transition3) |
|
|
blocks.extend(_denselayers(f.denseblock4)); blocks.append(f.norm5) |
|
|
|
|
|
for blk in blocks[-n:]: |
|
|
_set_requires_grad(blk, True) |
|
|
|
|
|
if keep_bn_eval: |
|
|
set_bn_eval(bb) |
|
|
return |
|
|
|
|
|
raise RuntimeError(f"Unsupported backbone type: {bb_type}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BackboneWithMLPHeadForImageClassification.register_for_auto_class("AutoModelForImageClassification") |
|
|
|