File size: 1,948 Bytes
154b6ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
timm-based backbones: EfficientNetV2-S and ConvNeXt V2 Tiny.

These are fully fine-tunable models with ImageNet pretraining.
"""

import torch
import torch.nn as nn
import timm

from ..registry import register


IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]


class TimmBackbone(nn.Module):
    """Generic wrapper around any timm model used as a feature extractor."""

    def __init__(self, model_name: str, pretrained: bool = True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        # num_classes=0 → returns pooled features, no classifier head
        self.embed_dim = self.model.num_features
        self._model_name = model_name

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def freeze(self) -> None:
        for param in self.model.parameters():
            param.requires_grad = False

    def unfreeze(self) -> None:
        for param in self.model.parameters():
            param.requires_grad = True

    def get_param_groups(self, lr: float, backbone_lr_mult: float = 0.1) -> list[dict]:
        return [{"params": list(self.model.parameters()), "lr": lr * backbone_lr_mult}]

    def get_preprocess_config(self) -> dict:
        data_cfg = timm.data.resolve_data_config(self.model.pretrained_cfg)
        return {
            "mean": list(data_cfg["mean"]),
            "std": list(data_cfg["std"]),
            "input_size": data_cfg["input_size"][-1],  # e.g. 224
        }


@register("efficientnetv2_s")
class EfficientNetV2S(TimmBackbone):
    def __init__(self, pretrained: bool = True):
        super().__init__("tf_efficientnetv2_s.in21k_ft_in1k", pretrained=pretrained)


@register("convnextv2_tiny")
class ConvNeXtV2Tiny(TimmBackbone):
    def __init__(self, pretrained: bool = True):
        super().__init__("convnextv2_tiny.fcmae_ft_in22k_in1k", pretrained=pretrained)