Feature Extraction
Transformers
Safetensors
custom_code
Files changed (2) hide show
  1. extra_timm_models.py +225 -0
  2. hf_model.py +2 -4
extra_timm_models.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import math
10
+ import warnings
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+
16
+ from timm.models import register_model, PretrainedCfg
17
+ from timm.models.vision_transformer import (
18
+ VisionTransformer,
19
+ _create_vision_transformer as _timm_create_vision_transformer,
20
+ Mlp,
21
+ Block,
22
+ LayerScale as TIMMLayerScale,
23
+ )
24
+
25
+ # Import these to also register them
26
+ from . import dinov2_arch
27
+
28
+
29
+ @register_model
30
+ def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
31
+ """ ViT-Tiny (Vit-Ti/16)
32
+ """
33
+ model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3)
34
+ model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
35
+ return model
36
+
37
+
38
+ @register_model
39
+ def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
40
+ """ ViT-Small (ViT-S/16)
41
+ """
42
+ model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6)
43
+ model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
44
+ return model
45
+
46
+
47
+ @register_model
48
+ def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
49
+ """ ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929).
50
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
51
+ """
52
+ model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12)
53
+ model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
54
+ return model
55
+
56
+
57
+ @register_model
58
+ def vit_base_patch16_v2_224(pretrained=False, **kwargs) -> VisionTransformer:
59
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
60
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
61
+ """
62
+ model_args = dict(
63
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
64
+ reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
65
+ )
66
+ model = _create_vision_transformer(
67
+ 'vit_base_patch14_reg4_dinov2', pretrained=False, **dict(model_args, **kwargs))
68
+ return model
69
+
70
+
71
+ @register_model
72
+ def vit_large_patch16_v2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
73
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
74
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
75
+ """
76
+ name = 'vit_large_patch14_reg4_dinov2'
77
+ model_args = dict(
78
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
79
+ reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
80
+ )
81
+ model = _create_vision_transformer(name, pretrained=False, **dict(model_args, **kwargs))
82
+
83
+ return model
84
+
85
+
86
+ @register_model
87
+ def vit_so400m_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
88
+ """ ViT model matching the architecture of the So400M model from
89
+ "Scaling Vision Transformers to 400 Million Parameters" (https://arxiv.org/abs/2302.05442).
90
+ """
91
+ if pretrained:
92
+ raise ValueError('There is no pretrained weights for vit_so400m_patch16_224')
93
+ mlp_ratio = 4304 / 1152
94
+
95
+ model_args = dict(patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=mlp_ratio)
96
+ model = _create_vision_transformer('vit_so400m_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
97
+ return model
98
+
99
+
100
+ @register_model
101
+ def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
102
+ """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
103
+ """
104
+ model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16)
105
+ if pretrained:
106
+ # There is no pretrained version of ViT-H/16, but we can adapt a ViT-H/14 for this purpose
107
+ model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs))
108
+ else:
109
+ model = _create_vision_transformer('vit_huge_patch16_224', pretrained=False, **dict(model_args, **kwargs))
110
+ return model
111
+
112
+
113
+ @register_model
114
+ def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer:
115
+ """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
116
+ """
117
+ model = vit_huge_patch16_224(pretrained=pretrained, **kwargs)
118
+
119
+ for m in model.modules():
120
+ if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm):
121
+ m.norm = nn.LayerNorm(m.fc1.out_features)
122
+
123
+ return model
124
+
125
+
126
+ @register_model
127
+ def vit_giant_patch16_224(pretrained=False, scaled_ln: bool = False, **kwargs) -> VisionTransformer:
128
+ """ ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929).
129
+ """
130
+ model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24)
131
+ model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs))
132
+ if scaled_ln:
133
+ _apply_scaled_ln(model)
134
+ return model
135
+
136
+
137
+ @register_model
138
+ def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
139
+ model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6)
140
+ model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs))
141
+ return model
142
+
143
+
144
+ def _create_vision_transformer(*args, **kwargs):
145
+ if kwargs.get('pretrained_cfg', None) is None:
146
+ # This prevents the warning from being emitted
147
+ kwargs['pretrained_cfg'] = PretrainedCfg()
148
+
149
+ model = _timm_create_vision_transformer(*args, **kwargs)
150
+ _patch_layer_scale(model)
151
+ return model
152
+
153
+
154
+ def _patch_layer_scale(model: VisionTransformer):
155
+ def replace_ls(old_ls: TIMMLayerScale):
156
+ new_ls = dinov2_arch.LayerScale(old_ls.gamma.shape[0], inplace=old_ls.inplace)
157
+ new_ls.load_state_dict(old_ls.state_dict())
158
+ return new_ls
159
+
160
+ # Monkey patch: Replace TIMM's LayerScale with our modified DINOv2 one, that uses a param name
161
+ # other than gamma, so that HFHub doesn't mess with it!
162
+ for mod in model.modules():
163
+ if isinstance(mod, Block):
164
+ if isinstance(mod.ls1, TIMMLayerScale):
165
+ mod.ls1 = replace_ls(mod.ls1)
166
+ if isinstance(mod.ls2, TIMMLayerScale):
167
+ mod.ls2 = replace_ls(mod.ls2)
168
+ pass
169
+
170
+
171
+ class ScaledLayerNorm(nn.LayerNorm):
172
+ '''
173
+ https://arxiv.org/pdf/2502.05795v1
174
+ '''
175
+ def __init__(self, ln_base: nn.LayerNorm, depth: int = 0):
176
+ super().__init__(ln_base.normalized_shape, eps=ln_base.eps, elementwise_affine=ln_base.elementwise_affine)
177
+ self.load_state_dict(ln_base.state_dict())
178
+ self.register_buffer('ln_scale', torch.tensor(1.0 / math.sqrt(depth)), persistent=False)
179
+
180
+ def forward(self, x):
181
+ y = super().forward(x)
182
+ y = y * self.ln_scale
183
+ return y
184
+
185
+
186
+ class DyT(nn.Module):
187
+ def __init__(self, C: int, init_alpha: float):
188
+ super().__init__()
189
+ self.alpha = nn.Parameter(torch.full((1,), init_alpha))
190
+ self.gamma = nn.Parameter(torch.ones(C))
191
+ self.beta = nn.Parameter(torch.zeros(C))
192
+
193
+ def forward(self, x: torch.Tensor):
194
+ x = F.tanh(self.alpha * x)
195
+ return self.gamma * x + self.beta
196
+
197
+ @register_model
198
+ def vit_large_dyt_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
199
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
200
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
201
+ """
202
+ model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
203
+ model = _create_vision_transformer('vit_large_dyt_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
204
+
205
+ def _replace_ln_with_dyt(ln: nn.LayerNorm, depth: int):
206
+ return DyT(ln.normalized_shape[0], init_alpha=0.9)
207
+ _replace_ln(model, _replace_ln_with_dyt)
208
+
209
+ return model
210
+
211
+
212
+ def _apply_scaled_ln(model: VisionTransformer):
213
+ warnings.warn('Post-LayerNorm scaling activated!')
214
+
215
+ _replace_ln(model, lambda ln, depth: ScaledLayerNorm(ln, depth=depth))
216
+
217
+ def _replace_ln(model: VisionTransformer, fn):
218
+ def _inner_replace_ln(block: Block, depth: int, key: str):
219
+ prev = getattr(block, key)
220
+ if isinstance(prev, nn.LayerNorm):
221
+ setattr(block, key, fn(prev, depth=depth))
222
+
223
+ for i, block in enumerate(model.blocks):
224
+ _inner_replace_ln(block, i + 1, 'norm1')
225
+ _inner_replace_ln(block, i + 1, 'norm2')
hf_model.py CHANGED
@@ -44,10 +44,8 @@ from .vit_patch_generator import ViTPatchGenerator
44
  from .vitdet import apply_vitdet_arch, VitDetArgs
45
 
46
  # Register extra models
47
- from . import extra_timm_models
48
- from . import extra_models
49
- # from .extra_timm_models import *
50
- # from .extra_models import *
51
 
52
 
53
  class RADIOConfig(PretrainedConfig):
 
44
  from .vitdet import apply_vitdet_arch, VitDetArgs
45
 
46
  # Register extra models
47
+ from .extra_timm_models import *
48
+ from .extra_models import *
 
 
49
 
50
 
51
  class RADIOConfig(PretrainedConfig):