File size: 27,071 Bytes
2eed5eb
 
 
 
 
bf566dc
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
 
e115a15
 
 
 
bf566dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
200cb5d
2eed5eb
 
200cb5d
2eed5eb
 
 
 
 
 
 
200cb5d
2eed5eb
 
200cb5d
2eed5eb
 
 
 
 
 
 
200cb5d
 
bf566dc
 
2eed5eb
 
200cb5d
2eed5eb
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
200cb5d
2eed5eb
 
200cb5d
 
 
 
 
2eed5eb
 
 
 
 
 
 
 
 
 
 
200cb5d
bf566dc
 
 
 
2eed5eb
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
200cb5d
2eed5eb
 
bf566dc
2eed5eb
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
200cb5d
2eed5eb
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
200cb5d
2eed5eb
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
bf566dc
2eed5eb
 
 
 
 
 
200cb5d
2eed5eb
bf566dc
2eed5eb
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
bf566dc
2eed5eb
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200cb5d
2eed5eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# src/ds_model.py

from typing import Optional, List, Any, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F

# --- transformers core ---
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoModel, AutoConfig
from transformers.modeling_outputs import ImageClassifierOutput

# --- torchvision ---
from torchvision import models as tv_models

try:
    from .ds_cfg import BackboneMLPHeadConfig, BACKBONE_META
except ImportError:
    from ds_cfg import BackboneMLPHeadConfig, BACKBONE_META
# from mlp_head import MLPHead

class MLPHead(nn.Module):
    """
    간단한 2-layer MLP head.

    Parameters
    ----------
    in_dim : int
        backbone feature dim
    num_labels : int
        class count
    bottleneck : int
        hidden dim
    p : float
        dropout prob
    """
    def __init__(self, in_dim: int, num_labels: int, bottleneck: int = 256, p: float = 0.2):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, bottleneck)
        self.act = nn.GELU()
        self.drop = nn.Dropout(p)
        self.fc2 = nn.Linear(bottleneck, num_labels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.drop(self.act(self.fc1(x))))

# ------------------------------------------------------------
# backbone_meta resolver
# ------------------------------------------------------------
def _resolve_backbone_meta(config: BackboneMLPHeadConfig, fallback_table: Dict[str, Dict[str, Any]] | None = None) -> Dict[str, Any]:
    """
    Resolve runtime backbone meta.

    Priority:
      1) config.backbone_meta (preferred; required for Hub runtime determinism)
      2) fallback_table[config.backbone_name_or_path] (backward compatibility for local/dev)

    Returns a dict with at least: type, feat_rule, feat_dim (and optional has_bn/unfreeze).
    """
    meta = getattr(config, "backbone_meta", None)
    if isinstance(meta, dict) and len(meta) > 0:
        return meta

    bb = getattr(config, "backbone_name_or_path", None)
    if fallback_table is not None and bb in fallback_table:
        return fallback_table[bb]

    raise ValueError(
        "config.backbone_meta is missing/empty and no fallback meta is available. "
        "Populate config.backbone_meta when saving to the Hub (single source of truth)."
    )


# ============================================================
# (3) Model: backbone + MLP head
# (3) 모델: backbone + MLP head
# ============================================================
# Design principle: __init__ builds only a skeleton and MUST NOT load pretrained weights.
# 설계 원칙: __init__에서는 skeleton만 만들고 pretrained weight 로드는 절대 하면 안 됩니다.
#
# Pretrained injection is allowed ONLY via an explicit call in fresh-start flows.
# pretrained 주입은 fresh-start 흐름에서의 명시적 호출로만 허용합니다.
#
# HF from_pretrained should restore checkpoints as-is without side effects.
# HF from_pretrained는 부작용 없이 체크포인트를 그대로 복원해야 합니다.
class BackboneWithMLPHeadForImageClassification(PreTrainedModel):
    # This links the model to its custom config for AutoClass usage.
    # AutoClass 사용을 위해 모델과 커스텀 config를 연결합니다.
    config_class = BackboneMLPHeadConfig

    def __init__(self, config: BackboneMLPHeadConfig):
        # PreTrainedModel expects a config object and stores it internally.
        # PreTrainedModel은 config 객체를 받아 내부에 저장함.
        super().__init__(config)

        # Fail-fast: the model is not meant to be instantiated without a valid backbone id.
        # fail-fast: 유효한 backbone id 없이 모델을 만드는 사용 시나리오는 허용하지 않음 - fast fail.
        #
        # Note: Transformers may create configs with no args, but models are conventionally created with configs.
        # 참고: Transformers는 config 무인자 생성이 있을 수 있으나, 모델은 관례적으로 config를 받아 생성.
        if config.backbone_name_or_path is None:
            raise ValueError(
                "config.backbone_name_or_path is None. "
                "Provide a valid backbone id (whitelist key in BACKBONE_META)."
            )

        # Fail-fast: training/inference requires a positive number of labels.
        # fail-fast: 학습/추론은 num_labels가 양수여야 함.
        #
        # Config may exist in a minimal form for internal serialization paths, but the model should not.
        # config는 내부 직렬화 경로에서 최소 형태로 존재할 수 있으나 모델은 해당 없음.
        if int(getattr(config, "num_labels", 0)) <= 0:
            raise ValueError(
                f"config.num_labels must be > 0, got {getattr(config, 'num_labels', None)}. "
                "Set num_labels (or id2label/label2id) when creating the config."
            )

        # Meta is a single source of truth for extraction and fine-tuning rules.
        # meta는 feature 추출 및 미세조정 규칙의 단일 기준.
        # Resolve backbone meta from config (preferred) or fallback table (for backward compatibility).
        # Prefer config.backbone_meta to keep Hub runtime self-contained.
        self._meta = _resolve_backbone_meta(config, fallback_table=BACKBONE_META)

        # Backbone skeleton is always created without pretrained weights.
        # backbone skeleton은 항상 pretrained weight 없이 생성.
        self.backbone = self._build_backbone_skeleton(config.backbone_name_or_path)

        # Head shape is driven by meta feat_dim and config.num_labels.
        # head shape은 meta의 feat_dim과 config.num_labels로 결정.
        self.classifier = MLPHead(
            in_dim=int(self._meta["feat_dim"]),
            num_labels=int(config.num_labels),
            bottleneck=int(config.mlp_head_bottleneck),
            p=float(config.mlp_head_dropout),
        )

        # HF initialization hook, but we override init_weights to initialize head-only.
        # HF 초기화 훅이지만 init_weights를 override하여 head만 초기화하도록 변경.
        self.post_init()

    def init_weights(self):
        """
        Initialize only the head to avoid touching the backbone skeleton.
        backbone skeleton을 건드리지 않기 위해 head만 초기화.

        HF's default init may traverse the entire module tree, which is undesirable here.
        HF 기본 init은 전체 모듈 트리를 순회할 수 있어 여기서 그대로 사용하기 부적절.
        
        초기 설계에서 __init__ 내부에서 backbone의 가중치 로드를 수행함(편리를 위해).
        이 경우, HF의 post_init()으로 인해 해당 로드가 취소되는 경우가 존재(timm, torchvision 등의 백본).
        때문에 이를 오버라이드 하여 classifier만 초기화 하도록 변경함.
        """
        if getattr(self, "classifier", None) is not None:
            self.classifier.apply(self._init_weights)
        self.tie_weights()

    # ----------------------------
    # backbone skeleton builders
    # backbone skeleton 생성기
    # ----------------------------
    def _build_backbone_skeleton(self, backbone_id: str) -> nn.Module:
        # Meta decides which loader path to use.
        # meta가 어떤 로더 경로를 사용할지 결정.
        meta = self._meta if backbone_id == self.config.backbone_name_or_path else BACKBONE_META.get(backbone_id)
        if meta is None:
            raise KeyError(f"Unknown backbone_id={backbone_id}. Provide backbone_meta in config or extend BACKBONE_META.")

        t = meta["type"]

        if t == "timm_densenet":
            return self._build_timm_densenet_skeleton(backbone_id)

        if t == "torchvision_densenet":
            return self._build_torchvision_densenet_skeleton(backbone_id)

        # For transformers backbones: build a random-weight skeleton from config only.
        # transformers 백본: config로부터 랜덤 초기화 skeleton만 생성.
        bb_cfg = AutoConfig.from_pretrained(backbone_id)
        return AutoModel.from_config(bb_cfg)

    @staticmethod
    def _build_timm_densenet_skeleton(hf_repo_id: str) -> nn.Module:
        # timm is an optional dependency and should be imported lazily.
        # timm은 옵션 의존성이므로 지연 import 수행.
        try:
            import timm
        except Exception as e:
            raise ImportError(
                "DenseNet(timm) backbone requires `timm`. Install: pip install timm"
            ) from e

        # Build structure only (pretrained=False) and remove classifier head (num_classes=0).
        # 구조만 생성(pretrained=False)하고 분류기 head는 제거(num_classes=0).
        return timm.create_model(
            f"hf_hub:{hf_repo_id}",
            pretrained=False,
            num_classes=0,
        )

    @staticmethod
    def _build_torchvision_densenet_skeleton(model_id: str) -> nn.Module:
        # This project intentionally supports only torchvision/densenet121 in the 224 whitelist.
        # 이 프로젝트는 224 화이트리스트에서 torchvision/densenet121만 의도적으로 지원.
        if model_id != "torchvision/densenet121":
            raise ValueError(f"Unsupported torchvision DenseNet id (224 whitelist only): {model_id}")

        # Build structure only (weights=None) to avoid implicit pretrained loading.
        # implicit pretrained 로드를 피하기 위해 구조만 생성(weights=None).
        m = tv_models.densenet121(weights=None)
        return m

    # ------------------------------------------------------------
    # Pretrained loading is explicit and fresh-start only
    # pretrained 로딩은 명시적 호출이며 fresh-start 전용
    # ------------------------------------------------------------
    @torch.no_grad()
    def load_backbone_pretrained_(
        self,
        *,
        low_cpu_mem_usage: bool = False,
        device_map=None,
    ):
        """
        Fresh-start only: inject pretrained backbone weights into the skeleton.
        fresh-start 전용: skeleton backbone에 pretrained 가중치를 주입.

        Do NOT call this after from_pretrained() because it would overwrite checkpoint weights.
        from_pretrained() 이후 호출하면 체크포인트 가중치를 덮어쓰므로 주의할 것.
        """
        bb = self.config.backbone_name_or_path
        meta = self._meta
        t = meta["type"]

        if t == "timm_densenet":
            self._load_timm_pretrained_into_skeleton_(bb)
            return

        if t == "torchvision_densenet":
            self._load_torchvision_pretrained_into_skeleton_(bb)
            return

        # For transformers backbones, load a reference pretrained model and copy weights into our skeleton.
        # transformers 백본은 reference pretrained 모델을 로드한 뒤 skeleton에 가중치를 복사.
        ref = AutoModel.from_pretrained(
            bb,
            low_cpu_mem_usage=low_cpu_mem_usage,
            device_map=device_map,
        )

        # strict=False is used to tolerate harmless key differences across minor versions.
        # strict=False는 마이너 버전 차이로 인한 무해한 키 차이를 허용하기 위해 사용.
        self.backbone.load_state_dict(ref.state_dict(), strict=False)
        del ref

    @torch.no_grad()
    def _load_timm_pretrained_into_skeleton_(self, hf_repo_id: str):
        # timm must be present for timm backbones.
        # timm 백본에는 timm 설치가 필요.
        import timm

        # Create a pretrained reference model and copy its weights strictly.
        # pretrained reference 모델을 만들고 가중치를 strict하게 복사.
        ref = timm.create_model(
            f"hf_hub:{hf_repo_id}",
            pretrained=True,
            num_classes=0,
        ).eval()

        self.backbone.load_state_dict(ref.state_dict(), strict=True)
        del ref

    @torch.no_grad()
    def _load_torchvision_pretrained_into_skeleton_(self, model_id: str):
        # This project intentionally supports only torchvision/densenet121 in the 224 whitelist.
        # 이 프로젝트는 224 화이트리스트에서 torchvision/densenet121만 지원.
        if model_id != "torchvision/densenet121":
            raise ValueError(f"Unsupported torchvision DenseNet id (224 whitelist only): {model_id}")

        # Use torchvision's default pretrained weights for densenet121.
        # torchvision의 densenet121 기본 pretrained weights를 사용.
        ref = tv_models.densenet121(weights=tv_models.DenseNet121_Weights.DEFAULT).eval()

        self.backbone.load_state_dict(ref.state_dict(), strict=True)
        del ref

    # ----------------------------
    # feature extraction
    # feature 추출
    # ----------------------------
    @staticmethod
    def _pool_or_gap(outputs) -> torch.Tensor:
        # Some transformers vision CNNs provide pooler_output explicitly.
        # 일부 transformers vision CNN은 pooler_output을 명시적으로 제공.
        if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
            x = outputs.pooler_output
            if x.dim() == 2:
                return x
            if x.dim() == 4 and x.size(-1) == 1 and x.size(-2) == 1:
                return x.flatten(1)
            raise RuntimeError(f"Unexpected pooler_output shape: {tuple(x.shape)}")

        # Otherwise we expect a CNN-style last_hidden_state=(B,C,H,W) and apply GAP.
        # 그렇지 않으면 CNN 스타일 last_hidden_state=(B,C,H,W)를 기대하고 GAP을 적용.
        x = outputs.last_hidden_state
        if x.dim() == 4:
            return x.mean(dim=(2, 3))

        raise RuntimeError(
            "Expected pooler_output or (B,C,H,W) last_hidden_state for CNN backbones. "
            f"Got last_hidden_state shape={tuple(x.shape)}"
        )

    def _extract_features(self, outputs, pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Feature rule is defined by BACKBONE_META and must remain stable across saves/loads.
        # feature 규칙은 BACKBONE_META로 정의되며 저장/로드 간 안정적 동작을 위해 제한된 모델만 사용.
        rule = self._meta["feat_rule"]

        if rule == "cls":
            # ViT-style: use CLS token embedding from last_hidden_state.
            # ViT 스타일: last_hidden_state에서 CLS 토큰 임베딩을 사용.
            return outputs.last_hidden_state[:, 0, :]

        if rule == "pool_or_mean":
            # Swin-style: prefer pooler_output if present, else mean-pool over tokens.
            # Swin 스타일: pooler_output이 있으면 우선 사용하고, 없으면 토큰 평균 풀링을 사용.
            if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
                return outputs.pooler_output
            return outputs.last_hidden_state.mean(dim=1)

        if rule == "pool_or_gap":
            # CNN-style: use pooler_output if present, else GAP over spatial dims.
            # CNN 스타일: pooler_output이 있으면 사용하고, 없으면 공간 차원 GAP을 사용.
            return self._pool_or_gap(outputs)

        if rule == "timm_gap":
            # timm forward_features returns a feature map (B,C,H,W) which we GAP to (B,C).
            # timm forward_features는 (B,C,H,W) feature map을 반환하며 이를 GAP으로 (B,C)로 변환. 
            if not isinstance(outputs, torch.Tensor):
                raise TypeError(f"timm_gap expects Tensor features, got {type(outputs)}")
            if outputs.dim() != 4:
                raise RuntimeError(f"Expected (B,C,H,W), got {tuple(outputs.shape)}")
            return outputs.mean(dim=(2, 3))

        if rule == "torchvision_densenet_gap":
            # torchvision DenseNet features are feature maps (B,C,H,W) and require GAP.
            # torchvision DenseNet features는 (B,C,H,W) feature map이며 GAP이 필요.
            if not isinstance(outputs, torch.Tensor):
                raise TypeError(f"torchvision_densenet_gap expects Tensor, got {type(outputs)}")
            if outputs.dim() != 4:
                raise RuntimeError(f"Expected (B,C,H,W), got {tuple(outputs.shape)}")
            return outputs.mean(dim=(2, 3))

        raise RuntimeError(f"unknown feat_rule={rule}")

    def forward(
        self,
        pixel_values=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=True,
        **kwargs,
    ):
        # Type decides the backbone forward path and output format.
        # type이 backbone forward 경로 및 출력 포맷을 결정.
        t = self._meta["type"]

        if t == "timm_densenet":
            # timm DenseNet consumes pixel_values as a 4D tensor (B,C,H,W).
            # timm DenseNet은 pixel_values를 4D 텐서 (B,C,H,W)로 받습니다.
            if pixel_values is None:
                raise ValueError("timm DenseNet backbone requires pixel_values.")
            if pixel_values.dim() != 4:
                raise ValueError(f"pixel_values must be (B,C,H,W), got {tuple(pixel_values.shape)}")

            features_map = self.backbone.forward_features(pixel_values)
            feats = self._extract_features(features_map, pixel_values=pixel_values)
            hidden_states = None
            attentions = None

        elif t == "torchvision_densenet":
            # torchvision DenseNet consumes pixel_values as a 4D tensor (B,C,H,W).
            # torchvision DenseNet은 pixel_values를 4D 텐서 (B,C,H,W)로 받습니다.
            if pixel_values is None:
                raise ValueError("torchvision DenseNet backbone requires pixel_values.")
            if pixel_values.dim() != 4:
                raise ValueError(f"pixel_values must be (B,C,H,W), got {tuple(pixel_values.shape)}")

            features_map = self.backbone.features(pixel_values)
            features_map = F.relu(features_map, inplace=False)
            feats = self._extract_features(features_map, pixel_values=pixel_values)
            hidden_states = None
            attentions = None

        else:
            # Transformers vision models are called with pixel_values and return ModelOutput.
            # transformers vision 모델은 pixel_values로 호출되며 ModelOutput을 반환.
            outputs = self.backbone(
                pixel_values=pixel_values,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=True,
                **kwargs,
            )
            feats = self._extract_features(outputs, pixel_values=pixel_values)
            hidden_states = getattr(outputs, "hidden_states", None)
            attentions = getattr(outputs, "attentions", None)

        # Classifier consumes (B, feat_dim) and returns logits (B, num_labels).
        # classifier는 (B, feat_dim)을 받아 logits (B, num_labels)를 반환.
        logits = self.classifier(feats)

        loss = None
        if labels is not None:
            # Cross entropy expects labels as class indices in [0, num_labels).
            # cross entropy는 labels가 [0, num_labels) 범위의 class index이길 기대함.
            loss = F.cross_entropy(logits, labels)

        if not return_dict:
            out = (logits,)
            return ((loss,) + out) if loss is not None else out

        return ImageClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=hidden_states,
            attentions=attentions,
        )


# ============================================================
# Freeze / Unfreeze utilities
# Freeze / Unfreeze 유틸리티
# ============================================================
def _set_requires_grad(module: nn.Module, flag: bool):
    # Toggle requires_grad for all parameters in a module.
    # 모듈의 모든 파라미터에 대해 requires_grad를 토글.
    for p in module.parameters():
        p.requires_grad = flag


def set_bn_eval(module: nn.Module):
    # Put BatchNorm layers into eval mode to freeze running stats.
    # BatchNorm 레이어를 eval 모드로 두어 running stats를 고정.
    for m in module.modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)):
            m.eval()


def freeze_backbone(model: BackboneWithMLPHeadForImageClassification, freeze_bn: bool = True):
    # Stage1: freeze backbone and train only the head.
    # stage1: backbone을 freeze하고 head만 학습.
    _set_requires_grad(model.backbone, False)
    _set_requires_grad(model.classifier, True)

    meta = getattr(model, "_meta", None) or getattr(model.config, "backbone_meta", None)
    if freeze_bn and meta.get("has_bn", False):
        set_bn_eval(model.backbone)


def finetune_train_mode(model: BackboneWithMLPHeadForImageClassification, keep_bn_eval: bool = True):
    # Stage2: train mode, optionally keeping BN layers in eval for stability.
    # stage2: train 모드로 두되 안정성을 위해 BN을 eval로 유지할 수 있음. (buffer 등을 유지하기 위해)
    model.train()
    meta = getattr(model, "_meta", None) or getattr(model.config, "backbone_meta", None)
    if keep_bn_eval and meta.get("has_bn", False):
        set_bn_eval(model.backbone)


def trainable_summary(model: nn.Module):
    # Print a compact summary of trainable parameters.
    # 학습 가능 파라미터 요약을 간단히 출력.
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    ratio = trainable / total if total > 0 else 0.0
    print(f"trainable: {trainable:,} / total: {total:,} ({ratio*100:.2f}%)")
    return {"trainable": trainable, "total": total, "ratio": ratio}


def unfreeze_last_stage(
    model: BackboneWithMLPHeadForImageClassification,
    last_n: int = 2,
    keep_bn_eval: bool = True,
):
    # This utility implements BACKBONE_META['unfreeze']=="last_n" across supported backbones.
    # 이 유틸은 지원 백본들에 대해 BACKBONE_META['unfreeze']=="last_n"을 구현.
    freeze_backbone(model, freeze_bn=keep_bn_eval)

    n = int(last_n)
    if n <= 0:
        return

    meta = getattr(model, "_meta", None) or getattr(model.config, "backbone_meta", None)
    if meta.get("unfreeze") != "last_n":
        raise RuntimeError(f"Unexpected unfreeze rule: {meta.get('unfreeze')} (expected 'last_n')")

    bb_type = meta["type"]

    if bb_type == "vit":
        # ViT blocks live under backbone.encoder.layer in the transformers implementation.
        # ViT 블록은 transformers 구현에서 backbone.encoder.layer 아래에 존재함.
        blocks = list(model.backbone.encoder.layer)
        for blk in blocks[-n:]:
            _set_requires_grad(blk, True)
        return

    if bb_type == "swin":
        # Swin blocks are nested by stages and blocks; we flatten and unfreeze last n blocks.
        # Swin 블록은 stage와 block으로 중첩되어 있어 펼친 후 마지막 n개를 unfreeze.
        stages = list(model.backbone.encoder.layers)
        blocks: List[nn.Module] = []
        for st in stages:
            blocks.extend(list(st.blocks))
        for blk in blocks[-n:]:
            _set_requires_grad(blk, True)
        return

    if bb_type == "resnet":
        # ResNet uses layer1..layer4 stages; we unfreeze at block granularity.
        # ResNet은 layer1..layer4 stage를 사용하며 block 단위로 unfreeze.
        bb = model.backbone
        for name in ("layer1", "layer2", "layer3", "layer4"):
            if not hasattr(bb, name):
                raise RuntimeError(f"Unexpected ResNet structure: missing {name}")

        blocks: List[nn.Module] = []
        blocks.extend(list(bb.layer1.children()))
        blocks.extend(list(bb.layer2.children()))
        blocks.extend(list(bb.layer3.children()))
        blocks.extend(list(bb.layer4.children()))

        for blk in blocks[-n:]:
            _set_requires_grad(blk, True)

        if keep_bn_eval:
            set_bn_eval(bb)
        return

    if bb_type == "efficientnet":
        # EfficientNet in transformers exposes features; we unfreeze from the tail blocks.
        # transformers EfficientNet은 features를 노출하며 뒤쪽 블록부터 unfreeze.
        bb = model.backbone
        if not hasattr(bb, "features"):
            raise RuntimeError("Unexpected EfficientNet structure: missing features")

        blocks: List[nn.Module] = []
        for st in bb.features.children():
            blocks.extend(list(st.children()))

        for blk in blocks[-n:]:
            _set_requires_grad(blk, True)

        if keep_bn_eval:
            set_bn_eval(bb)
        return

    if bb_type in ("timm_densenet", "torchvision_densenet"):
        # DenseNet exposes a .features module with named blocks; we unfreeze last n submodules.
        # DenseNet은 .features 모듈에 블록들이 이름으로 존재하며 마지막 n개 서브모듈을 unfreeze.
        bb = model.backbone
        if not hasattr(bb, "features"):
            raise RuntimeError("Unexpected DenseNet: missing features")
        f = bb.features

        req = [
            "conv0", "norm0", "relu0", "pool0",
            "denseblock1", "transition1",
            "denseblock2", "transition2",
            "denseblock3", "transition3",
            "denseblock4", "norm5",
        ]
        for name in req:
            if not hasattr(f, name):
                raise RuntimeError(f"Unexpected DenseNet features: missing {name}")

        def _denselayers(db: nn.Module) -> List[nn.Module]:
            # Dense blocks contain multiple DenseLayer children; we return them for fine-grained unfreezing.
            # denseblock은 DenseLayer 자식들을 가지므로 세밀한 unfreeze를 위해 이를 반환.
            return list(db.children())

        blocks: List[nn.Module] = []
        blocks.extend([f.conv0, f.norm0, f.relu0, f.pool0])
        blocks.extend(_denselayers(f.denseblock1)); blocks.append(f.transition1)
        blocks.extend(_denselayers(f.denseblock2)); blocks.append(f.transition2)
        blocks.extend(_denselayers(f.denseblock3)); blocks.append(f.transition3)
        blocks.extend(_denselayers(f.denseblock4)); blocks.append(f.norm5)

        for blk in blocks[-n:]:
            _set_requires_grad(blk, True)

        if keep_bn_eval:
            set_bn_eval(bb)
        return

    raise RuntimeError(f"Unsupported backbone type: {bb_type}")


# -------------------------
# register
# register
# -------------------------
# Register for AutoModelForImageClassification so from_pretrained can resolve this custom class.
# from_pretrained가 이 커스텀 클래스를 해석할 수 있도록 AutoModelForImageClassification에 등록.
BackboneWithMLPHeadForImageClassification.register_for_auto_class("AutoModelForImageClassification")