# Originally from OpenCLIP (https://github.com/mlfoundations/open_clip) import logging import types from collections import OrderedDict from typing import Optional, Tuple, Union import torch import torch.nn as nn try: import timm from timm.layers import RotAttentionPool2d from timm.layers import AttentionPool2d as AbsAttentionPool2d from timm.layers import Mlp, to_2tuple from timm.layers import AttentionRope, RotaryEmbeddingCat except ImportError: timm = None class TimmModel(nn.Module): """timm model adapter""" def __init__( self, model_name: str, embed_dim: int, image_size: Union[int, Tuple[int, int]] = 224, pool: str = "avg", proj: str = "linear", proj_bias: bool = False, drop: float = 0.0, drop_path: Optional[float] = None, patch_drop: Optional[float] = None, init_values: Optional[float] = None, qk_norm: bool = False, use_rope: bool = False, rope_keep_ape: bool = False, dynamic_img_size: bool = False, norm_pre: bool = False, pretrained: bool = False, output_tokens: bool = False, ): super().__init__() if timm is None: raise RuntimeError( "Please install the latest timm (`pip install timm`) to use timm based models." ) self.image_size = to_2tuple(image_size) self.output_tokens = output_tokens timm_kwargs = {} if drop_path is not None: timm_kwargs["drop_path_rate"] = drop_path if patch_drop is not None: timm_kwargs["patch_drop_rate"] = patch_drop if init_values is not None: timm_kwargs["init_values"] = init_values if qk_norm: timm_kwargs["qk_norm"] = True if dynamic_img_size: timm_kwargs["dynamic_img_size"] = True if use_rope: class _AttentionRopeNoPrefix(AttentionRope): """AttentionRope with num_prefix_tokens=0 for models without cls token.""" def __init__(self, *args, **kwargs): kwargs["num_prefix_tokens"] = 0 super().__init__(*args, **kwargs) timm_kwargs["attn_layer"] = _AttentionRopeNoPrefix if not rope_keep_ape: timm_kwargs["pos_embed"] = "none" custom_pool = pool in ("abs_attn", "rot_attn") if proj: assert proj in ("linear", "mlp", "none") extra_proj = proj in ("linear", "mlp") if not extra_proj and not custom_pool: proj_dim = 0 if proj == "none" else embed_dim self.trunk = timm.create_model( model_name, num_classes=proj_dim, global_pool=pool, pretrained=pretrained, **timm_kwargs, ) prev_chs = embed_dim else: self.trunk = timm.create_model( model_name, pretrained=pretrained, **timm_kwargs, ) feat_size = self.trunk.default_cfg.get("pool_size", None) feature_ndim = 1 if not feat_size else 2 if custom_pool: assert feature_ndim == 2 self.trunk.reset_classifier(0, global_pool="") else: reset_kwargs = dict(global_pool=pool) if pool else {} self.trunk.reset_classifier(0, **reset_kwargs) prev_chs = self.trunk.num_features head_layers = OrderedDict() if pool == "abs_attn": head_layers["pool"] = AbsAttentionPool2d( prev_chs, feat_size=feat_size, out_features=embed_dim ) prev_chs = embed_dim elif pool == "rot_attn": head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) prev_chs = embed_dim if proj == "linear": head_layers["drop"] = nn.Dropout(drop) head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) elif proj == "mlp": head_layers["mlp"] = Mlp( prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias), ) self.head = nn.Sequential(head_layers) if ( norm_pre and hasattr(self.trunk, "norm_pre") and isinstance(self.trunk.norm_pre, nn.Identity) ): self.trunk.norm_pre = nn.LayerNorm(self.trunk.embed_dim) logging.info( f"Replaced norm_pre Identity with LayerNorm({self.trunk.embed_dim})" ) self._has_rope = use_rope if use_rope: self._setup_rope() def _setup_rope(self): """Inject 2D Rotary Position Embedding into the timm trunk.""" num_heads = self.trunk.blocks[0].attn.num_heads head_dim = self.trunk.embed_dim // num_heads self.trunk.patch_embed.strict_img_size = False self.rope = RotaryEmbeddingCat( dim=head_dim, max_res=max(self.image_size), in_pixels=True, ) def _block_forward_rope(block_self, x, rope=None, attn_mask=None): x = x + block_self.drop_path1( block_self.ls1( block_self.attn(block_self.norm1(x), rope=rope, attn_mask=attn_mask) ) ) x = x + block_self.drop_path2( block_self.ls2(block_self.mlp(block_self.norm2(x))) ) return x for blk in self.trunk.blocks: blk.forward = types.MethodType(_block_forward_rope, blk) timm_model_ref = self _num_prefix = getattr(self.trunk, "num_prefix_tokens", 0) def _forward_features_rope(trunk_self, x, attn_mask=None): from torch.utils.checkpoint import checkpoint from timm.layers import resample_abs_pos_embed ps = trunk_self.patch_embed.patch_size grid_shape = [x.shape[2] // ps[0], x.shape[3] // ps[1]] x = trunk_self.patch_embed(x) if x.ndim == 4: x = x.reshape(x.shape[0], -1, x.shape[-1]) if hasattr(trunk_self, "pos_embed") and trunk_self.pos_embed is not None: if x.shape[1] != trunk_self.pos_embed.shape[1]: x = x + resample_abs_pos_embed( trunk_self.pos_embed, grid_shape, num_prefix_tokens=_num_prefix ) else: x = x + trunk_self.pos_embed x = trunk_self.pos_drop(x) x = trunk_self.norm_pre(x) rot_pos_embed = timm_model_ref.rope.get_embed(shape=grid_shape) _sdpa_mask = None if attn_mask is not None: _sdpa_mask = torch.zeros_like(attn_mask, dtype=x.dtype) _sdpa_mask.masked_fill_(~attn_mask, float("-inf")) _sdpa_mask = _sdpa_mask.unsqueeze(1).unsqueeze(2) for blk in trunk_self.blocks: if trunk_self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint( blk, x, rope=rot_pos_embed, attn_mask=_sdpa_mask, use_reentrant=False, ) else: x = blk(x, rope=rot_pos_embed, attn_mask=_sdpa_mask) x = trunk_self.norm(x) return x self.trunk.forward_features = types.MethodType( _forward_features_rope, self.trunk ) def _setup_dynamic_pos_embed(self): """Patch forward_features for variable-resolution pos_embed interpolation (non-RoPE).""" self.trunk.patch_embed.strict_img_size = False _num_prefix = getattr(self.trunk, "num_prefix_tokens", 0) def _forward_features_dynamic(trunk_self, x, patch_valid_mask=None): from torch.utils.checkpoint import checkpoint from timm.layers import resample_abs_pos_embed ps = trunk_self.patch_embed.patch_size grid_shape = [x.shape[2] // ps[0], x.shape[3] // ps[1]] x = trunk_self.patch_embed(x) if x.ndim == 4: x = x.reshape(x.shape[0], -1, x.shape[-1]) if hasattr(trunk_self, "pos_embed") and trunk_self.pos_embed is not None: if x.shape[1] != trunk_self.pos_embed.shape[1]: x = x + resample_abs_pos_embed( trunk_self.pos_embed, grid_shape, num_prefix_tokens=_num_prefix ) else: x = x + trunk_self.pos_embed x = trunk_self.pos_drop(x) x = trunk_self.norm_pre(x) _sdpa_mask = None if patch_valid_mask is not None: _sdpa_mask = torch.zeros_like(patch_valid_mask, dtype=x.dtype) _sdpa_mask.masked_fill_(~patch_valid_mask, float("-inf")) _sdpa_mask = _sdpa_mask.unsqueeze(1).unsqueeze(2) for blk in trunk_self.blocks: if trunk_self.grad_checkpointing and not torch.jit.is_scripting(): if _sdpa_mask is not None: x = checkpoint( blk, x, attn_mask=_sdpa_mask, use_reentrant=False ) else: x = checkpoint(blk, x, use_reentrant=False) else: x = blk(x, attn_mask=_sdpa_mask) x = trunk_self.norm(x) return x self.trunk.forward_features = types.MethodType( _forward_features_dynamic, self.trunk ) def _setup_1d_forward(self): """Patch forward_features for NaFlex 1D mode (SigLIP2 style).""" _num_prefix = getattr(self.trunk, "num_prefix_tokens", 0) def _forward_features_1d( trunk_self, x, patch_valid_mask=None, spatial_shapes=None ): from torch.utils.checkpoint import checkpoint conv = trunk_self.patch_embed.proj D = conv.weight.shape[0] x = torch.nn.functional.linear( x.to(conv.weight.dtype), conv.weight.reshape(D, -1), conv.bias ) if ( hasattr(trunk_self, "pos_embed") and trunk_self.pos_embed is not None and spatial_shapes is not None ): pos_embed = trunk_self.pos_embed base_n = pos_embed.shape[1] base_grid = int(base_n**0.5) pos_2d = ( pos_embed.reshape(1, base_grid, base_grid, -1) .permute(0, 3, 1, 2) .float() ) B, sl, D_emb = x.shape pos_resized = torch.zeros(B, sl, D_emb, device=x.device, dtype=x.dtype) for i in range(B): gh, gw = spatial_shapes[i].tolist() pe = torch.nn.functional.interpolate( pos_2d, size=(gh, gw), mode="bilinear", align_corners=False ) pe = pe.squeeze(0).permute(1, 2, 0).reshape(gh * gw, -1).to(x.dtype) n_patches = gh * gw pos_resized[i, :n_patches] = pe if n_patches < sl: pos_resized[i, n_patches:] = pe[0] x = x + pos_resized elif hasattr(trunk_self, "pos_embed") and trunk_self.pos_embed is not None: x = x + trunk_self.pos_embed x = trunk_self.pos_drop(x) x = trunk_self.norm_pre(x) _sdpa_mask = None if patch_valid_mask is not None: _sdpa_mask = torch.zeros_like(patch_valid_mask, dtype=x.dtype) _sdpa_mask.masked_fill_(~patch_valid_mask, float("-inf")) _sdpa_mask = _sdpa_mask.unsqueeze(1).unsqueeze(2) for blk in trunk_self.blocks: if trunk_self.grad_checkpointing and not torch.jit.is_scripting(): if _sdpa_mask is not None: x = checkpoint( blk, x, attn_mask=_sdpa_mask, use_reentrant=False ) else: x = checkpoint(blk, x, use_reentrant=False) else: x = blk(x, attn_mask=_sdpa_mask) x = trunk_self.norm(x) return x self.trunk._forward_features_1d = types.MethodType( _forward_features_1d, self.trunk ) self._has_1d_forward = True def forward_patch_features(self, x): """Forward pass returning per-patch features (before pooling/projection).""" return self.trunk.forward_features(x) def forward(self, x, patch_valid_mask=None, spatial_shapes=None): if spatial_shapes is not None and getattr(self, "_has_1d_forward", False): patch_features = self.trunk._forward_features_1d( x, patch_valid_mask=patch_valid_mask, spatial_shapes=spatial_shapes ) elif patch_valid_mask is not None and self._has_rope: patch_features = self.trunk.forward_features(x, attn_mask=patch_valid_mask) elif patch_valid_mask is not None: patch_features = self.trunk.forward_features( x, patch_valid_mask=patch_valid_mask ) else: patch_features = self.trunk.forward_features(x) if patch_valid_mask is not None: mask_f = patch_valid_mask.unsqueeze(-1).to( patch_features.dtype ) patch_features = patch_features * mask_f self._cached_patch_features = patch_features if ( patch_valid_mask is not None and getattr(self.trunk, "global_pool", "") == "avg" ): pooled = patch_features.sum(dim=1) / mask_f.sum(dim=1).clamp(min=1) pooled = ( self.trunk.fc_norm(pooled) if hasattr(self.trunk, "fc_norm") else pooled ) elif ( patch_valid_mask is not None and getattr(self.trunk, "attn_pool", None) is not None ): attn_mask = torch.zeros( patch_valid_mask.shape, dtype=patch_features.dtype, device=patch_features.device, ) attn_mask.masked_fill_(~patch_valid_mask.bool(), float("-inf")) attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) pooled = self.trunk.attn_pool(patch_features, attn_mask=attn_mask) pooled = ( self.trunk.fc_norm(pooled) if hasattr(self.trunk, "fc_norm") else pooled ) else: pooled = self.trunk.forward_head(patch_features) pooled = self.head(pooled) if self.output_tokens: return pooled, patch_features return pooled