| """ |
| timm-based backbones: EfficientNetV2-S and ConvNeXt V2 Tiny. |
| |
| These are fully fine-tunable models with ImageNet pretraining. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import timm |
|
|
| from ..registry import register |
|
|
|
|
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
|
|
| class TimmBackbone(nn.Module): |
| """Generic wrapper around any timm model used as a feature extractor.""" |
|
|
| def __init__(self, model_name: str, pretrained: bool = True): |
| super().__init__() |
| self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0) |
| |
| self.embed_dim = self.model.num_features |
| self._model_name = model_name |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.model(x) |
|
|
| def freeze(self) -> None: |
| for param in self.model.parameters(): |
| param.requires_grad = False |
|
|
| def unfreeze(self) -> None: |
| for param in self.model.parameters(): |
| param.requires_grad = True |
|
|
| def get_param_groups(self, lr: float, backbone_lr_mult: float = 0.1) -> list[dict]: |
| return [{"params": list(self.model.parameters()), "lr": lr * backbone_lr_mult}] |
|
|
| def get_preprocess_config(self) -> dict: |
| data_cfg = timm.data.resolve_data_config(self.model.pretrained_cfg) |
| return { |
| "mean": list(data_cfg["mean"]), |
| "std": list(data_cfg["std"]), |
| "input_size": data_cfg["input_size"][-1], |
| } |
|
|
|
|
| @register("efficientnetv2_s") |
| class EfficientNetV2S(TimmBackbone): |
| def __init__(self, pretrained: bool = True): |
| super().__init__("tf_efficientnetv2_s.in21k_ft_in1k", pretrained=pretrained) |
|
|
|
|
| @register("convnextv2_tiny") |
| class ConvNeXtV2Tiny(TimmBackbone): |
| def __init__(self, pretrained: bool = True): |
| super().__init__("convnextv2_tiny.fcmae_ft_in22k_in1k", pretrained=pretrained) |
|
|