""" timm model adapter Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. """ import logging from collections import OrderedDict import torch import torch.nn as nn try: import timm from timm.models.layers import Mlp, to_2tuple from timm.models.layers.attention_pool2d import RotAttentionPool2d from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d from timm.models.layers.adaptive_avgmax_pool import SelectAdaptivePool2d from einops.layers.torch import Reduce except ImportError as e: timm = None from llava.open_clip.utils import freeze_batch_norm_2d class HeadModule(nn.Module): def __init__(self, layers): super(HeadModule, self).__init__() self.pool = layers.get("pool", None) self.drop = layers.get("drop", None) self.proj = layers.get("proj", None) self.mlp = layers.get("mlp", None) def forward(self, x, pool=True): if pool and self.pool is not None: x = self.pool(x) if self.mlp is not None: x = self.mlp(x) if self.proj is not None: assert self.drop is not None x = self.drop(x) x = self.proj(x) return x class TimmModel(nn.Module): """ timm model adapter # FIXME this adapter is a work in progress, may change in ways that break weight compat """ def __init__( self, model_name, embed_dim, image_size=224, pool='avg', proj='linear', proj_bias=False, drop=0., pretrained=False): super().__init__() if timm is None: raise RuntimeError("Please `pip install timm` to use timm models.") self.image_size = to_2tuple(image_size) self.trunk = timm.create_model(model_name, pretrained=pretrained) feat_size = self.trunk.default_cfg.get('pool_size', None) feature_ndim = 1 if not feat_size else 2 if pool in ('abs_attn', 'rot_attn', 'global_max'): assert pool == 'global_max' or feature_ndim == 2 # if attn pooling used, remove both classifier and default pool self.trunk.reset_classifier(0, global_pool='') else: # reset global pool if pool config set, otherwise leave as network default reset_kwargs = dict(global_pool=pool) if pool else {} self.trunk.reset_classifier(0, **reset_kwargs) prev_chs = self.trunk.num_features head_layers = OrderedDict() if pool == 'abs_attn': head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) prev_chs = embed_dim elif pool == 'rot_attn': head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) prev_chs = embed_dim elif pool == "global_max": head_layers['pool'] = Reduce('b n c -> b c', 'max') # max pool on token dim prev_chs = embed_dim proj = "" # no proj / mlp else: assert proj, 'projection layer needed if non-attention pooling is used.' # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used if proj == 'linear': head_layers['drop'] = nn.Dropout(drop) head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) elif proj == 'mlp': head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) self.head = HeadModule(head_layers) # similar to nn.Sequential(head_layers) def lock(self, unlocked_groups=0, freeze_bn_stats=False): """ lock modules Args: unlocked_groups (int): leave last n layer groups unlocked (default: 0) """ if not unlocked_groups: # lock full model for param in self.trunk.parameters(): param.requires_grad = False if freeze_bn_stats: freeze_batch_norm_2d(self.trunk) else: # NOTE: partial freeze requires latest timm (master) branch and is subject to change try: # FIXME import here until API stable and in an official release from timm.models.helpers import group_parameters, group_modules except ImportError: raise RuntimeError( 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') matcher = self.trunk.group_matcher() gparams = group_parameters(self.trunk, matcher) max_layer_id = max(gparams.keys()) max_layer_id = max_layer_id - unlocked_groups for group_idx in range(max_layer_id + 1): group = gparams[group_idx] for param in group: self.trunk.get_parameter(param).requires_grad = False if freeze_bn_stats: gmodules = group_modules(self.trunk, matcher, reverse=True) gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} freeze_batch_norm_2d(self.trunk, gmodules) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): try: self.trunk.set_grad_checkpointing(enable) except Exception as e: logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') def forward(self, x, pool=True): x = self.trunk(x) x = self.head(x, pool=pool) return x