| | """ timm model adapter |
| | |
| | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. |
| | """ |
| | import logging |
| | from collections import OrderedDict |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | try: |
| | import timm |
| | from timm.models.layers import Mlp, to_2tuple |
| | try: |
| | |
| | from timm.models.layers.attention_pool2d import RotAttentionPool2d |
| | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d |
| | except ImportError: |
| | |
| | from timm.layers import RotAttentionPool2d |
| | from timm.layers import AttentionPool2d as AbsAttentionPool2d |
| | except ImportError: |
| | timm = None |
| |
|
| | from .utils import freeze_batch_norm_2d |
| |
|
| |
|
| | class TimmModel(nn.Module): |
| | """ timm model adapter |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model_name, |
| | embed_dim, |
| | image_size=224, |
| | pool='avg', |
| | proj='linear', |
| | proj_bias=False, |
| | drop=0., |
| | drop_path=None, |
| | patch_drop=None, |
| | pretrained=False, |
| | ): |
| | super().__init__() |
| | if timm is None: |
| | raise RuntimeError("Please `pip install timm` to use timm models.") |
| | self.image_size = to_2tuple(image_size) |
| |
|
| | |
| | timm_kwargs = {} |
| | if drop_path is not None: |
| | timm_kwargs['drop_path_rate'] = drop_path |
| | if patch_drop is not None: |
| | timm_kwargs['patch_drop_rate'] = patch_drop |
| |
|
| | custom_pool = pool in ('abs_attn', 'rot_attn') |
| | if proj: |
| | assert proj in ("linear", "mlp", "none") |
| | extra_proj = proj in ("linear", "mlp") |
| | if not extra_proj and not custom_pool: |
| | |
| | |
| | proj_dim = 0 if proj == 'none' else embed_dim |
| | self.trunk = timm.create_model( |
| | model_name, |
| | num_classes=proj_dim, |
| | global_pool=pool, |
| | pretrained=pretrained, |
| | **timm_kwargs, |
| | ) |
| | prev_chs = embed_dim |
| | else: |
| | self.trunk = timm.create_model( |
| | model_name, |
| | pretrained=pretrained, |
| | **timm_kwargs, |
| | ) |
| | feat_size = self.trunk.default_cfg.get('pool_size', None) |
| | feature_ndim = 1 if not feat_size else 2 |
| | if custom_pool: |
| | assert feature_ndim == 2 |
| | |
| | self.trunk.reset_classifier(0, global_pool='') |
| | else: |
| | |
| | reset_kwargs = dict(global_pool=pool) if pool else {} |
| | self.trunk.reset_classifier(0, **reset_kwargs) |
| | prev_chs = self.trunk.num_features |
| |
|
| | head_layers = OrderedDict() |
| |
|
| | |
| | if pool == 'abs_attn': |
| | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) |
| | prev_chs = embed_dim |
| | elif pool == 'rot_attn': |
| | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) |
| | prev_chs = embed_dim |
| |
|
| | |
| | if proj == 'linear': |
| | head_layers['drop'] = nn.Dropout(drop) |
| | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) |
| | elif proj == 'mlp': |
| | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) |
| |
|
| | self.head = nn.Sequential(head_layers) |
| |
|
| | def lock(self, unlocked_groups=0, freeze_bn_stats=False): |
| | """ lock modules |
| | Args: |
| | unlocked_groups (int): leave last n layer groups unlocked (default: 0) |
| | """ |
| | if not unlocked_groups: |
| | |
| | for param in self.trunk.parameters(): |
| | param.requires_grad = False |
| | if freeze_bn_stats: |
| | freeze_batch_norm_2d(self.trunk) |
| | else: |
| | |
| | try: |
| | |
| | from timm.models.helpers import group_parameters, group_modules |
| | except ImportError: |
| | raise RuntimeError( |
| | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') |
| | matcher = self.trunk.group_matcher() |
| | gparams = group_parameters(self.trunk, matcher) |
| | max_layer_id = max(gparams.keys()) |
| | max_layer_id = max_layer_id - unlocked_groups |
| | for group_idx in range(max_layer_id + 1): |
| | group = gparams[group_idx] |
| | for param in group: |
| | self.trunk.get_parameter(param).requires_grad = False |
| | if freeze_bn_stats: |
| | gmodules = group_modules(self.trunk, matcher, reverse=True) |
| | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} |
| | freeze_batch_norm_2d(self.trunk, gmodules) |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable=True): |
| | try: |
| | self.trunk.set_grad_checkpointing(enable) |
| | except Exception as e: |
| | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') |
| |
|
| | def forward(self, x): |
| | x = self.trunk(x) |
| | x = self.head(x) |
| | return x |
| |
|