Tenbatsu24 commited on
Commit
a10ce46
·
1 Parent(s): 94c37e9

add: missing files

Browse files
configuration_vitv2.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class ViTv2Config(PretrainedConfig):
5
+ model_type = "vitv2"
6
+
7
+ def __init__(
8
+ self,
9
+ img_size=224,
10
+ patch_size=16,
11
+ embed_dim=384,
12
+ depth=12,
13
+ num_heads=6,
14
+ mlp_ratio=4,
15
+ num_register_tokens=0,
16
+ init_values=None,
17
+ **ignored_kwargs,
18
+ ):
19
+ super().__init__(**ignored_kwargs)
20
+
21
+ self.depth = depth
22
+ self.img_size = img_size
23
+ self.embed_dim = embed_dim
24
+ self.num_heads = num_heads
25
+ self.mlp_ratio = mlp_ratio
26
+ self.patch_size = patch_size
27
+ self.init_values = init_values
28
+ self.num_register_tokens = num_register_tokens
hf_src/__init__.py ADDED
File without changes
hf_src/layers/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mlp import Mlp
2
+ from .block import Block # noqa: F401
3
+ from .rms_norm import RMSNorm
4
+ from .drop_path import DropPath
5
+ from .dino_head import DINOHead
6
+ from .layer_scale import LayerScale
7
+ from .patch_embed import PatchEmbed
8
+ from .block import NestedTensorBlock
9
+ from .attention import MemEffAttention
10
+ from .rope_block import SelfAttentionBlock
11
+ from .cva_head import CVAHead, IdentityHead
12
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
13
+ from .rope_position_encoding import RopePositionEmbedding
14
+
15
+ __all__ = [
16
+ "CVAHead",
17
+ "RMSNorm",
18
+ "IdentityHead",
19
+ "DINOHead",
20
+ "DropPath",
21
+ "Block",
22
+ "Mlp",
23
+ "PatchEmbed",
24
+ "LayerScale",
25
+ "SwiGLUFFN",
26
+ "SwiGLUFFNFused",
27
+ "NestedTensorBlock",
28
+ "MemEffAttention",
29
+ "SelfAttentionBlock",
30
+ "RopePositionEmbedding",
31
+ ]
hf_src/layers/attention.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
4
+
5
+ import os
6
+
7
+ from torch import Tensor
8
+ from torch import nn
9
+
10
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
11
+ try:
12
+ if XFORMERS_ENABLED:
13
+ from xformers.ops import memory_efficient_attention, unbind
14
+
15
+ XFORMERS_AVAILABLE = True
16
+ else:
17
+ raise ImportError
18
+ except ImportError:
19
+ XFORMERS_AVAILABLE = False
20
+
21
+
22
+ class Attention(nn.Module):
23
+ def __init__(
24
+ self,
25
+ dim: int,
26
+ num_heads: int = 8,
27
+ qkv_bias: bool = False,
28
+ proj_bias: bool = True,
29
+ attn_drop: float = 0.0,
30
+ proj_drop: float = 0.0,
31
+ ) -> None:
32
+ super().__init__()
33
+ self.num_heads = num_heads
34
+ head_dim = dim // num_heads
35
+ self.scale = head_dim**-0.5
36
+
37
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
38
+ self.attn_drop = nn.Dropout(attn_drop)
39
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
40
+ self.proj_drop = nn.Dropout(proj_drop)
41
+
42
+ def forward(self, x: Tensor, return_attn=False) -> Tensor:
43
+ """
44
+ Adapted from https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/
45
+ """
46
+ B, N, C = x.shape
47
+ qkv = (
48
+ self.qkv(x)
49
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
50
+ .permute(2, 0, 3, 1, 4)
51
+ )
52
+
53
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
+ attn = q @ k.transpose(-2, -1)
55
+
56
+ attn = attn.softmax(dim=-1)
57
+ attn = self.attn_drop(attn)
58
+
59
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+
63
+ # Adaptation for returing attentions
64
+ if return_attn:
65
+ return attn
66
+ return x
67
+
68
+
69
+ class MemEffAttention(Attention):
70
+ """
71
+ Adapted from https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/
72
+ """
73
+
74
+ def forward(self, x: Tensor, attn_bias=None, return_attn=False) -> Tensor:
75
+ if not XFORMERS_AVAILABLE:
76
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
77
+ # Change this line
78
+ # return super().forward(x)
79
+ # Adaptation for returing attentions
80
+ return super().forward(x, return_attn)
81
+
82
+ B, N, C = x.shape
83
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
84
+
85
+ q, k, v = unbind(qkv, 2)
86
+
87
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
88
+ if return_attn:
89
+ # Support for XFORMERS to return attention
90
+ # Adapted from https://github.com/facebookresearch/dinov2/issues/90#issuecomment-1574001076
91
+ attn = x.permute(0, 2, 1, 3) @ v.permute(0, 2, 3, 1)
92
+ return attn
93
+ x = x.reshape([B, N, C])
94
+
95
+ x = self.proj(x)
96
+ x = self.proj_drop(x)
97
+ return x
98
+
99
+
100
+ if __name__ == "__main__":
101
+ import torch
102
+
103
+ _att = MemEffAttention(dim=32, num_heads=4).to("cuda")
104
+ print(_att(torch.randn(4, 16, 32, device="cuda"), return_attn=True).shape)
105
+ print(_att(torch.randn(4, 16, 32, device="cuda")).shape)
hf_src/layers/block.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
4
+
5
+ import os
6
+
7
+ from typing import Callable, List, Any, Tuple, Dict
8
+
9
+ import torch
10
+
11
+ from torch import nn, Tensor
12
+
13
+ from .attention import Attention, MemEffAttention
14
+ from .drop_path import DropPath
15
+ from .layer_scale import LayerScale
16
+ from .mlp import Mlp
17
+
18
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
19
+ try:
20
+ if XFORMERS_ENABLED:
21
+ from xformers.ops import fmha
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ else:
25
+ raise ImportError
26
+ except ImportError:
27
+ XFORMERS_AVAILABLE = False
28
+
29
+
30
+ class Block(nn.Module):
31
+ def __init__(
32
+ self,
33
+ dim: int,
34
+ num_heads: int,
35
+ mlp_ratio: float = 4.0,
36
+ qkv_bias: bool = False,
37
+ proj_bias: bool = True,
38
+ ffn_bias: bool = True,
39
+ drop: float = 0.0,
40
+ attn_drop: float = 0.0,
41
+ init_values=None,
42
+ drop_path: float = 0.0,
43
+ act_layer: Callable[..., nn.Module] = nn.GELU,
44
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
45
+ attn_class: Callable[..., nn.Module] = Attention,
46
+ ffn_layer: Callable[..., nn.Module] = Mlp,
47
+ ) -> None:
48
+ super().__init__()
49
+ self.norm1 = norm_layer(dim)
50
+ self.attn = attn_class(
51
+ dim,
52
+ num_heads=num_heads,
53
+ qkv_bias=qkv_bias,
54
+ proj_bias=proj_bias,
55
+ attn_drop=attn_drop,
56
+ proj_drop=drop,
57
+ )
58
+ self.ls1 = (
59
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
60
+ )
61
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
62
+
63
+ self.norm2 = norm_layer(dim)
64
+ mlp_hidden_dim = int(dim * mlp_ratio)
65
+ self.mlp = ffn_layer(
66
+ in_features=dim,
67
+ hidden_features=mlp_hidden_dim,
68
+ act_layer=act_layer,
69
+ drop=drop,
70
+ bias=ffn_bias,
71
+ )
72
+ self.ls2 = (
73
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
74
+ )
75
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
76
+
77
+ self.sample_drop_ratio = drop_path
78
+
79
+ def forward(self, x: Tensor, return_attention=False) -> Tensor:
80
+ """
81
+ Adapted from https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/
82
+ """
83
+
84
+ def attn_residual_func(x: Tensor) -> Tensor:
85
+ return self.ls1(self.attn(self.norm1(x)))
86
+
87
+ def ffn_residual_func(x: Tensor) -> Tensor:
88
+ return self.ls2(self.mlp(self.norm2(x)))
89
+
90
+ # Adaptation for returning attentions
91
+ if return_attention:
92
+ attn = self.attn(self.norm1(x), return_attn=True)
93
+
94
+ if self.training and self.sample_drop_ratio > 0.1:
95
+ # the overhead is compensated only for a drop path rate larger than 0.1
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=attn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ x = drop_add_residual_stochastic_depth(
102
+ x,
103
+ residual_func=ffn_residual_func,
104
+ sample_drop_ratio=self.sample_drop_ratio,
105
+ )
106
+ elif self.training and self.sample_drop_ratio > 0.0:
107
+ x = x + self.drop_path1(attn_residual_func(x))
108
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
109
+ else:
110
+ x = x + attn_residual_func(x)
111
+ x = x + ffn_residual_func(x)
112
+
113
+ # Adaptation for returing attentions
114
+ if return_attention:
115
+ return x, attn
116
+
117
+ return x
118
+
119
+
120
+ def drop_add_residual_stochastic_depth(
121
+ x: Tensor,
122
+ residual_func: Callable[[Tensor], Tensor],
123
+ sample_drop_ratio: float = 0.0,
124
+ ) -> Tensor:
125
+ # 1) extract subset using permutation
126
+ b, n, d = x.shape
127
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
128
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
129
+ x_subset = x[brange]
130
+
131
+ # 2) apply residual_func to get residual
132
+ residual = residual_func(x_subset)
133
+
134
+ x_flat = x.flatten(1)
135
+ residual = residual.flatten(1)
136
+
137
+ residual_scale_factor = b / sample_subset_size
138
+
139
+ # 3) add the residual
140
+ x_plus_residual = torch.index_add(
141
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
142
+ )
143
+ return x_plus_residual.view_as(x)
144
+
145
+
146
+ def get_branges_scales(x, sample_drop_ratio=0.0):
147
+ b, n, d = x.shape
148
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
149
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
150
+ residual_scale_factor = b / sample_subset_size
151
+ return brange, residual_scale_factor
152
+
153
+
154
+ def add_residual(x, brange, residual, residual_scale_factor, ls=None):
155
+ if ls is None:
156
+ x_flat = x.flatten(1)
157
+ residual = residual.flatten(1)
158
+ x_plus_residual = x_flat.index_add_(
159
+ dim=0,
160
+ index=brange,
161
+ source=residual.to(dtype=x.dtype),
162
+ alpha=residual_scale_factor,
163
+ )
164
+ else:
165
+ x_plus_residual = x.index_add_(
166
+ dim=0,
167
+ source=ls(residual.to(dtype=x.dtype)),
168
+ index=brange,
169
+ alpha=residual_scale_factor,
170
+ )
171
+ return x_plus_residual
172
+
173
+
174
+ attn_bias_cache: Dict[Tuple, Any] = {}
175
+
176
+
177
+ def get_attn_bias_and_cat(x_list, branges=None):
178
+ """
179
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
180
+ """
181
+ batch_sizes = (
182
+ [b.shape[0] for b in branges]
183
+ if branges is not None
184
+ else [x.shape[0] for x in x_list]
185
+ )
186
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
187
+ if all_shapes not in attn_bias_cache.keys():
188
+ seqlens = []
189
+ for b, x in zip(batch_sizes, x_list):
190
+ for _ in range(b):
191
+ seqlens.append(x.shape[1])
192
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
193
+ attn_bias._batch_sizes = batch_sizes
194
+ attn_bias_cache[all_shapes] = attn_bias
195
+
196
+ if branges is not None:
197
+ cat_tensors = torch.cat(
198
+ [
199
+ _s.index_select(0, _i).reshape(-1)
200
+ for _s, _i in zip([_x.flatten(1) for _x in x_list], branges)
201
+ ],
202
+ dim=0,
203
+ ).view(1, -1, x_list[0].shape[-1])
204
+ # cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
205
+ # 1, -1, x_list[0].shape[-1]
206
+ # )
207
+ else:
208
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
209
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
210
+
211
+ return attn_bias_cache[all_shapes], cat_tensors
212
+
213
+
214
+ def drop_add_residual_stochastic_depth_list(
215
+ x_list: List[Tensor],
216
+ residual_func: Callable[[Tensor, Any], Tensor],
217
+ sample_drop_ratio: float = 0.0,
218
+ scaling_vector=None,
219
+ ) -> Tensor:
220
+ # 1) generate random set of indices for dropping samples in the batch
221
+ branges_scales = [
222
+ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
223
+ ]
224
+ branges = [s[0] for s in branges_scales]
225
+ residual_scale_factors = [s[1] for s in branges_scales]
226
+
227
+ # 2) get attention bias and index+concat the tensors
228
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
229
+
230
+ # 3) apply residual_func to get residual, and split the result
231
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
232
+
233
+ outputs = []
234
+ for x, brange, residual, residual_scale_factor in zip(
235
+ x_list, branges, residual_list, residual_scale_factors
236
+ ):
237
+ outputs.append(
238
+ add_residual(
239
+ x, brange, residual, residual_scale_factor, scaling_vector
240
+ ).view_as(x)
241
+ )
242
+ return outputs
243
+
244
+
245
+ class NestedTensorBlock(Block):
246
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
247
+ """
248
+ x_list contains a list of tensors to nest together and run
249
+ """
250
+ assert isinstance(self.attn, MemEffAttention)
251
+
252
+ if self.training and self.sample_drop_ratio > 0.0:
253
+
254
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
255
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
256
+
257
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
258
+ return self.mlp(self.norm2(x))
259
+
260
+ x_list = drop_add_residual_stochastic_depth_list(
261
+ x_list,
262
+ residual_func=attn_residual_func,
263
+ sample_drop_ratio=self.sample_drop_ratio,
264
+ scaling_vector=self.ls1 if isinstance(self.ls1, LayerScale) else None,
265
+ )
266
+ x_list = drop_add_residual_stochastic_depth_list(
267
+ x_list,
268
+ residual_func=ffn_residual_func,
269
+ sample_drop_ratio=self.sample_drop_ratio,
270
+ scaling_vector=self.ls2 if isinstance(self.ls1, LayerScale) else None,
271
+ )
272
+ return x_list
273
+ else:
274
+
275
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
276
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
277
+
278
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
279
+ return self.ls2(self.mlp(self.norm2(x)))
280
+
281
+ attn_bias, x = get_attn_bias_and_cat(x_list)
282
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
283
+ x = x + ffn_residual_func(x)
284
+ return attn_bias.split(x)
285
+
286
+ def forward(self, x_or_x_list, return_attention=False):
287
+ """
288
+ Adapted from https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/
289
+ """
290
+ if isinstance(x_or_x_list, Tensor):
291
+ # Change the following line
292
+ # return super().forward(x_or_x_list)
293
+ return super().forward(x_or_x_list, return_attention)
294
+ elif isinstance(x_or_x_list, list):
295
+ if return_attention:
296
+ raise NotImplementedError(
297
+ "return_attention not supported for nested tensors"
298
+ )
299
+ assert (
300
+ XFORMERS_AVAILABLE
301
+ ), "Please install xFormers for nested tensors usage"
302
+ return self.forward_nested(x_or_x_list)
303
+ else:
304
+ raise AssertionError
305
+
306
+
307
+ if __name__ == "__main__":
308
+ _device = (
309
+ "cuda"
310
+ if torch.cuda.is_available()
311
+ else "mps" if torch.backends.mps.is_available() else "cpu"
312
+ )
313
+ # Example usage
314
+ block = Block(dim=64, num_heads=8, drop_path=0.3).to(_device)
315
+ x = torch.randn(
316
+ 10, 16, 64, device=_device
317
+ ) # Batch size 10, sequence length 16, feature dimension 64
318
+ output = block(x)
319
+ print(output.shape) # Should be (10, 16, 64)
320
+
321
+ nested_block = NestedTensorBlock(
322
+ dim=64, num_heads=8, attn_class=MemEffAttention
323
+ ).to(_device)
324
+ nested_x = [
325
+ torch.randn(10, 16, 64, device=_device),
326
+ torch.randn(10, 16, 64, device=_device),
327
+ ] # List of tensors
328
+ nested_output = nested_block(nested_x)
329
+ print(
330
+ [o.shape for o in nested_output]
331
+ ) # Should print shapes of tensors in the list
hf_src/layers/cva_head.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from einops import rearrange
7
+ from torch.nn.init import trunc_normal_
8
+
9
+
10
+ def _make_lna_block(input_dim, output_dim, bias, norm_op, activation):
11
+ layers = [nn.Linear(input_dim, output_dim, bias=bias)]
12
+ if norm_op is not None:
13
+ layers.append(norm_op(output_dim))
14
+ if activation is not None:
15
+ layers.append(activation())
16
+ return nn.Sequential(*layers)
17
+
18
+
19
+ def _build_projector(n_layers, in_dim, out_dim, hidden_dim, activation=nn.GELU):
20
+ norm_op = partial(nn.BatchNorm1d, track_running_stats=False)
21
+ if n_layers > 1:
22
+ layers = _make_lna_block(in_dim, hidden_dim, True, norm_op, activation)
23
+ for _ in range(n_layers - 2):
24
+ layers += _make_lna_block(hidden_dim, hidden_dim, True, norm_op, activation)
25
+ layers += nn.Sequential(
26
+ *[nn.Linear(hidden_dim, out_dim, bias=False), norm_op(out_dim)]
27
+ )
28
+ return nn.Sequential(*layers)
29
+ else:
30
+ layers = [nn.Linear(in_dim, out_dim, bias=False), norm_op(out_dim)]
31
+ return nn.Sequential(*layers)
32
+
33
+
34
+ def _build_predictor(n_layers, in_out_dim, bottleneck_dim, activation=nn.GELU):
35
+ norm_op = partial(nn.BatchNorm1d, track_running_stats=False)
36
+ layers = [_make_lna_block(in_out_dim, bottleneck_dim, True, norm_op, activation)]
37
+
38
+ for _ in range(n_layers - 1):
39
+ layers += _make_lna_block(
40
+ bottleneck_dim, bottleneck_dim, True, norm_op, activation
41
+ )
42
+
43
+ layers += _make_lna_block(bottleneck_dim, in_out_dim, False, None, None)
44
+ return nn.Sequential(*layers)
45
+
46
+
47
+ class CVAHead(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_dim,
51
+ out_dim=1024,
52
+ projector_layers=3,
53
+ predictor_layers=1,
54
+ hidden_dim=2048,
55
+ bottleneck_dim=256,
56
+ act_op=nn.GELU,
57
+ use_predictor=True,
58
+ ):
59
+ super().__init__()
60
+ projector_layers = max(projector_layers, 1)
61
+
62
+ self.projector = _build_projector(
63
+ projector_layers,
64
+ in_dim,
65
+ out_dim,
66
+ hidden_dim=hidden_dim,
67
+ activation=act_op,
68
+ )
69
+
70
+ if use_predictor:
71
+ self.predictor = _build_predictor(
72
+ predictor_layers,
73
+ out_dim,
74
+ bottleneck_dim,
75
+ activation=act_op,
76
+ )
77
+
78
+ self.apply(self._init_weights)
79
+
80
+ def _init_weights(self, m):
81
+ if isinstance(m, nn.Linear):
82
+ trunc_normal_(m.weight, std=0.02)
83
+ if isinstance(m, nn.Linear) and m.bias is not None:
84
+ nn.init.constant_(m.bias, 0)
85
+
86
+ def project(self, latent):
87
+ if latent.ndim == 2:
88
+ return self.projector(latent)
89
+
90
+ if latent.ndim == 4:
91
+ # spatial_latent: (B, C, H, W)
92
+ b, _, h, w = latent.shape
93
+ flattened_latent = rearrange(latent, "b c h w -> (b h w) c").contiguous()
94
+
95
+ proj = self.projector(flattened_latent)
96
+
97
+ # make it spatial again
98
+ return rearrange(proj, "(b h w) c -> b c h w", b=b, h=h, w=w).contiguous()
99
+
100
+ if latent.ndim == 3:
101
+ # (B, N, C)
102
+ b, n, _ = latent.shape
103
+
104
+ return self.projector(latent.flatten(0, 1)).unflatten(0, (b, n))
105
+
106
+ raise ValueError(f"{latent.ndim=}D latent input is not supported")
107
+
108
+ def predict(self, latent):
109
+ if latent.ndim == 2:
110
+ return self.predictor(self.projector(latent))
111
+
112
+ if latent.ndim == 4:
113
+ # spatial_latent: (B, C, H, W)
114
+ b, _, h, w = latent.shape
115
+ flattened_latent = rearrange(latent, "b c h w -> (b h w) c").contiguous()
116
+
117
+ projection = self.projector(flattened_latent)
118
+ pred = self.predictor(projection)
119
+
120
+ # make it spatial again
121
+ return rearrange(pred, "(b h w) c -> b c h w", b=b, h=h, w=w).contiguous()
122
+
123
+ if latent.ndim == 3:
124
+ # (B, N, C)
125
+ b, n, _ = latent.shape
126
+ return self.predictor(self.projector(latent.flatten(0, 1))).unflatten(
127
+ 0, (b, n)
128
+ )
129
+
130
+ raise ValueError(f"{latent.ndim=}D latent input is not supported")
131
+
132
+ def project_predict(self, latent):
133
+ projected = self.project(latent)
134
+ predicted = self.predictor(projected)
135
+ return projected, predicted
136
+
137
+ def forward(self, latent, project_only=False):
138
+ if project_only:
139
+ return self.project(latent)
140
+
141
+ return self.predict(latent)
142
+
143
+
144
+ class IdentityHead(torch.nn.Module):
145
+ def __init__(self):
146
+ super().__init__()
147
+
148
+ def project(self, x):
149
+ return x
150
+
151
+ def predict(self, x):
152
+ return x
153
+
154
+ def project_predict(self, x):
155
+ return x, x
156
+
157
+ def forward(self, x, **kwargs):
158
+ return x
159
+
160
+
161
+ class CVAHeadList(torch.nn.Module):
162
+ def __init__(self, num_scales=2, **params):
163
+ super().__init__()
164
+ self.heads = torch.nn.ModuleList([CVAHead(**params) for _ in range(num_scales)])
165
+
166
+ def forward(self, x, scale_idx, project_only=False):
167
+ return self.heads[scale_idx](x, project_only=project_only)
168
+
169
+
170
+ if __name__ == "__main__":
171
+ model = CVAHead(
172
+ 768,
173
+ 512,
174
+ hidden_dim=2048,
175
+ bottleneck_dim=256,
176
+ act_op=nn.GELU,
177
+ )
178
+ print(model)
179
+ x = torch.randn(2, 36, 768)
180
+ out = model(x, project_only=True)
181
+
182
+ print("Output shape:", out.shape) # Expected: (2, 2048, 6, 6)
183
+ out2 = model(x, project_only=False)
184
+ print("Output shape after prediction:", out2.shape) # Expected: (2, 2048, 6, 6)
hf_src/layers/dino_head.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torch.nn.init import trunc_normal_
5
+ from torch.nn.utils.parametrizations import weight_norm
6
+
7
+
8
+ class DINOHead(nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_dim,
12
+ out_dim=2**16,
13
+ use_bn=False,
14
+ nlayers=3,
15
+ hidden_dim=2048,
16
+ bottleneck_dim=256,
17
+ mlp_bias=True,
18
+ use_last_layer=True,
19
+ ):
20
+ super().__init__()
21
+ nlayers = max(nlayers, 1)
22
+
23
+ self.use_last_layer = use_last_layer
24
+
25
+ self.mlp = _build_mlp(
26
+ nlayers,
27
+ in_dim,
28
+ bottleneck_dim,
29
+ hidden_dim=hidden_dim,
30
+ use_bn=use_bn,
31
+ bias=mlp_bias,
32
+ )
33
+
34
+ if use_last_layer:
35
+ self.last_layer = weight_norm(
36
+ nn.Linear(bottleneck_dim, out_dim, bias=False)
37
+ )
38
+ self.last_layer.parametrizations.weight.original0.data.fill_(1)
39
+
40
+ def init_weights(self) -> None:
41
+ self.apply(self._init_weights)
42
+
43
+ def _init_weights(self, m):
44
+ if isinstance(m, nn.Linear):
45
+ trunc_normal_(m.weight, std=0.02)
46
+ if isinstance(m, nn.Linear) and m.bias is not None:
47
+ nn.init.constant_(m.bias, 0)
48
+
49
+ def forward(self, x, **kwargs):
50
+ x = self.mlp(x)
51
+
52
+ if self.use_last_layer:
53
+ eps = torch.finfo(x.dtype).eps
54
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
55
+ return self.last_layer(x)
56
+ else:
57
+ return x
58
+
59
+
60
+ def _build_mlp(
61
+ nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
62
+ ):
63
+ if nlayers == 1:
64
+ return nn.Linear(in_dim, bottleneck_dim, bias=not use_bn)
65
+ else:
66
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
67
+ if use_bn:
68
+ layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats=False))
69
+ layers.append(nn.GELU())
70
+ for _ in range(nlayers - 2):
71
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
72
+ if use_bn:
73
+ layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats=False))
74
+ layers.append(nn.GELU())
75
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=not use_bn))
76
+ return nn.Sequential(*layers)
hf_src/layers/drop_path.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
4
+
5
+
6
+ from torch import nn
7
+
8
+
9
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
10
+ if drop_prob == 0.0 or not training:
11
+ return x
12
+ keep_prob = 1 - drop_prob
13
+ shape = (x.shape[0],) + (1,) * (
14
+ x.ndim - 1
15
+ ) # work with diff dim tensors, not just 2D ConvNets
16
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
17
+ if keep_prob > 0.0:
18
+ random_tensor.div_(keep_prob)
19
+ output = x * random_tensor
20
+ return output
21
+
22
+
23
+ class DropPath(nn.Module):
24
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
25
+
26
+ def __init__(self, drop_prob=None):
27
+ super(DropPath, self).__init__()
28
+ self.drop_prob = drop_prob
29
+
30
+ def forward(self, x):
31
+ return drop_path(x, self.drop_prob, self.training)
hf_src/layers/fp8_linear.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ import re
7
+
8
+ import torch
9
+
10
+ from hf_src.utils import named_replace
11
+ from hf_src.layers.rope_attention import LinearKMaskedBias
12
+
13
+ # avoid division by zero when calculating scale
14
+ EPS = 1e-12
15
+
16
+
17
+ def scale(t, amax_t):
18
+ max_v = torch.finfo(torch.float8_e4m3fn).max
19
+ scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v
20
+ t_fp8 = (t / scale_t).to(torch.float8_e4m3fn)
21
+ return t_fp8, scale_t
22
+
23
+
24
+ def matmul(first, amax_first, second_t, amax_second_t, bias):
25
+ first_fp8, scale_first = scale(first, amax_first)
26
+ second_t_fp8, scale_second_t = scale(second_t, amax_second_t)
27
+ # PyTorch's row-wise scaled matmul kernel is based on CUTLASS and is quite
28
+ # slow. Hence we fall back to an "unscaled" matmul, which uses cuBLAS, and
29
+ # apply the scale manually afterwards.
30
+ output = torch._scaled_mm(
31
+ first_fp8,
32
+ second_t_fp8.t(),
33
+ scale_a=scale_first.new_ones((1, 1)),
34
+ scale_b=scale_second_t.t().new_ones((1, 1)),
35
+ bias=None,
36
+ out_dtype=torch.bfloat16,
37
+ use_fast_accum=False,
38
+ )
39
+ output = (output * scale_first * scale_second_t.t()).to(torch.bfloat16)
40
+ if bias is not None:
41
+ output = output + bias
42
+ return output
43
+
44
+
45
+ @torch.compiler.allow_in_graph
46
+ class Fp8LinearFn(torch.autograd.Function):
47
+ @staticmethod
48
+ def forward(ctx, a, b_t, bias):
49
+ amax_a = a.abs().amax(dim=-1, keepdim=True)
50
+ amax_b_t = b_t.abs().amax(dim=-1, keepdim=True)
51
+ out = matmul(a, amax_a, b_t, amax_b_t, bias)
52
+
53
+ ctx.a_requires_grad = a.requires_grad
54
+ ctx.b_requires_grad = b_t.requires_grad
55
+ ctx.bias_requires_grad = bias.requires_grad if bias is not None else False
56
+
57
+ ctx.save_for_backward(a, b_t, amax_b_t.max())
58
+
59
+ return out
60
+
61
+ @staticmethod
62
+ def backward(ctx, grad_out):
63
+ a, b_t, amax_b = ctx.saved_tensors
64
+
65
+ if ctx.a_requires_grad:
66
+ b = b_t.t().contiguous()
67
+ amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True)
68
+ amax_b = amax_b.repeat(b.shape[0], 1)
69
+ grad_a = matmul(grad_out, amax_grad_out, b, amax_b, None)
70
+ else:
71
+ grad_a = None
72
+ if ctx.b_requires_grad:
73
+ grad_b = grad_out.t() @ a
74
+ else:
75
+ grad_b = None
76
+ if ctx.bias_requires_grad:
77
+ grad_bias = grad_out.sum(dim=0)
78
+ else:
79
+ grad_bias = None
80
+
81
+ return grad_a, grad_b, grad_bias
82
+
83
+
84
+ class Fp8Linear(torch.nn.Linear):
85
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
86
+ out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias)
87
+ out = out.unflatten(0, input.shape[:-1])
88
+ return out
89
+
90
+
91
+ class Fp8LinearKMaskedBias(LinearKMaskedBias):
92
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
93
+ masked_bias = self.bias * self.bias_mask if self.bias is not None else None
94
+ out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, masked_bias)
95
+ out = out.unflatten(0, input.shape[:-1])
96
+ return out
97
+
98
+
99
+ def convert_linears_to_fp8(
100
+ root_module: torch.nn.Module, *, filter: str
101
+ ) -> torch.nn.Module:
102
+ filter_re = re.compile(filter)
103
+ total_count = 0
104
+
105
+ def replace(module: torch.nn.Module, name: str) -> torch.nn.Module:
106
+ nonlocal total_count
107
+ if not isinstance(module, torch.nn.Linear) or not filter_re.search(name):
108
+ return module
109
+ if type(module) == torch.nn.Linear:
110
+ new_cls = Fp8Linear
111
+ elif type(module) == LinearKMaskedBias:
112
+ new_cls = Fp8LinearKMaskedBias
113
+ else:
114
+ assert False, str(type(module))
115
+ if module.in_features % 64 != 0 or module.out_features % 64 != 0:
116
+ # This is not a strict requirement, but H100 TensorCores for fp8
117
+ # operate on tiles of 64 elements anyways, and Inductor sometimes
118
+ # pads inner dims to become multiples of 64. Also, if one day we
119
+ # switch back to cuBLAS, it artificially requires dims to be
120
+ # multiples of 16.
121
+ raise RuntimeError(
122
+ "fp8 requires all dimensions to be multiples of 64 "
123
+ "(consider using ffn_layer=swiglu64 or higher)"
124
+ )
125
+ new_module = new_cls(
126
+ in_features=module.in_features,
127
+ out_features=module.out_features,
128
+ bias=module.bias is not None,
129
+ dtype=module.weight.dtype,
130
+ device=module.weight.device,
131
+ )
132
+ new_module.weight = module.weight
133
+ new_module.bias = module.bias
134
+ total_count += 1
135
+ return new_module
136
+
137
+ out = named_replace(replace, root_module)
138
+ assert total_count > 0, "fp8: no layer found to convert"
139
+ # Force re-compile everything
140
+ torch._dynamo.reset_code_caches()
141
+ from torch._inductor.cudagraph_trees import reset_cudagraph_trees
142
+
143
+ reset_cudagraph_trees()
144
+ return out
hf_src/layers/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
2
+
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from torch import nn
8
+ from torch import Tensor
9
+
10
+
11
+ class LayerScale(nn.Module):
12
+ def __init__(
13
+ self,
14
+ dim: int,
15
+ init_values: Union[float, Tensor] = 1e-5,
16
+ inplace: bool = False,
17
+ device=None,
18
+ ) -> None:
19
+ super().__init__()
20
+ self.inplace = inplace
21
+ self.gamma = nn.Parameter(torch.empty(dim, device=device))
22
+ self.init_values = init_values
23
+
24
+ def reset_parameters(self):
25
+ nn.init.constant_(self.gamma, self.init_values)
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
hf_src/layers/mlp.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
4
+
5
+
6
+ from typing import Callable, List, Optional
7
+
8
+
9
+ from torch import Tensor, nn
10
+
11
+ from hf_src.utils import cat_keep_shapes, uncat_with_shapes
12
+
13
+
14
+ class ListForwardMixin(object):
15
+ def forward(self, x: Tensor):
16
+ raise NotImplementedError
17
+
18
+ def forward_list(self, x_list: List[Tensor]) -> List[Tensor]:
19
+ x_flat, shapes, num_tokens = cat_keep_shapes(x_list)
20
+ x_flat = self.forward(x_flat)
21
+ return uncat_with_shapes(x_flat, shapes, num_tokens)
22
+
23
+
24
+ class Mlp(nn.Module, ListForwardMixin):
25
+ def __init__(
26
+ self,
27
+ in_features: int,
28
+ hidden_features: Optional[int] = None,
29
+ out_features: Optional[int] = None,
30
+ act_layer: Callable[..., nn.Module] = nn.GELU,
31
+ drop: float = 0.0,
32
+ bias: bool = True,
33
+ device=None,
34
+ ) -> None:
35
+ super().__init__()
36
+ out_features = out_features or in_features
37
+ hidden_features = hidden_features or in_features
38
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device)
39
+ self.act = act_layer()
40
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device)
41
+ self.drop = nn.Dropout(drop)
42
+
43
+ def forward(self, x: Tensor) -> Tensor:
44
+ x = self.fc1(x)
45
+ x = self.act(x)
46
+ x = self.drop(x)
47
+ x = self.fc2(x)
48
+ x = self.drop(x)
49
+ return x
hf_src/layers/patch_embed.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
4
+
5
+ import math
6
+ from typing import Callable, Tuple, Union
7
+
8
+ from torch import Tensor, nn
9
+
10
+
11
+ def make_2tuple(x):
12
+ if isinstance(x, tuple):
13
+ assert len(x) == 2
14
+ return x
15
+
16
+ assert isinstance(x, int)
17
+ return (x, x)
18
+
19
+
20
+ class PatchEmbed(nn.Module):
21
+ """
22
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
23
+
24
+ Args:
25
+ img_size: Image size.
26
+ patch_size: Patch token size.
27
+ in_chans: Number of input image channels.
28
+ embed_dim: Number of linear projection output channels.
29
+ norm_layer: Normalization layer.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ img_size: Union[int, Tuple[int, int]] = 224,
35
+ patch_size: Union[int, Tuple[int, int]] = 16,
36
+ in_chans: int = 3,
37
+ embed_dim: int = 768,
38
+ norm_layer: Callable | None = None,
39
+ flatten_embedding: bool = True,
40
+ ) -> None:
41
+ super().__init__()
42
+
43
+ image_HW = make_2tuple(img_size)
44
+ patch_HW = make_2tuple(patch_size)
45
+ patch_grid_size = (
46
+ image_HW[0] // patch_HW[0],
47
+ image_HW[1] // patch_HW[1],
48
+ )
49
+
50
+ self.img_size = image_HW
51
+ self.patch_size = patch_HW
52
+ self.patches_resolution = patch_grid_size
53
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
54
+
55
+ self.in_chans = in_chans
56
+ self.embed_dim = embed_dim
57
+
58
+ self.flatten_embedding = flatten_embedding
59
+
60
+ self.proj = nn.Conv2d(
61
+ in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
62
+ )
63
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
64
+
65
+ def forward(self, x: Tensor) -> Tensor:
66
+ _, _, H, W = x.shape
67
+ # patch_H, patch_W = self.patch_size
68
+ # assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
69
+ # assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
70
+
71
+ x = self.proj(x) # B C H W
72
+ H, W = x.size(2), x.size(3)
73
+ x = x.flatten(2).transpose(1, 2) # B HW C
74
+ x = self.norm(x)
75
+ if not self.flatten_embedding:
76
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
77
+ return x
78
+
79
+ def flops(self) -> float:
80
+ Ho, Wo = self.patches_resolution
81
+ flops = (
82
+ Ho
83
+ * Wo
84
+ * self.embed_dim
85
+ * self.in_chans
86
+ * (self.patch_size[0] * self.patch_size[1])
87
+ )
88
+ if self.norm is not None:
89
+ flops += Ho * Wo * self.embed_dim
90
+ return flops
91
+
92
+ def reset_parameters(self):
93
+ k = 1 / (self.in_chans * (self.patch_size[0] ** 2))
94
+ nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k))
95
+ if self.proj.bias is not None:
96
+ nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k))
hf_src/layers/rms_norm.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+
9
+
10
+ class RMSNorm(nn.Module):
11
+ def __init__(self, dim: int, eps: float = 1e-5):
12
+ super().__init__()
13
+ self.weight = nn.Parameter(torch.ones(dim))
14
+ self.eps = eps
15
+
16
+ def reset_parameters(self) -> None:
17
+ nn.init.constant_(self.weight, 1)
18
+
19
+ def _norm(self, x: Tensor) -> Tensor:
20
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
21
+
22
+ def forward(self, x: Tensor) -> Tensor:
23
+ output = self._norm(x.float()).type_as(x)
24
+ return output * self.weight
hf_src/layers/rope_attention.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ import math
7
+ from typing import List, Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from torch import Tensor, nn
13
+
14
+ from hf_src.utils import cat_keep_shapes, uncat_with_shapes
15
+
16
+
17
+ # RoPE-related functions:
18
+ def rope_rotate_half(x: Tensor) -> Tensor:
19
+ # x: [ x0 x1 x2 x3 x4 x5]
20
+ # out: [-x3 -x4 -x5 x0 x1 x2]
21
+ x1, x2 = x.chunk(2, dim=-1)
22
+ return torch.cat([-x2, x1], dim=-1)
23
+
24
+
25
+ def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
26
+ # x: [..., D], eg [x0, x1, x2, x3, x4, x5]
27
+ # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2]
28
+ # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2]
29
+ return (x * cos) + (rope_rotate_half(x) * sin)
30
+
31
+
32
+ class LinearKMaskedBias(nn.Linear):
33
+ def __init__(self, *args, **kwargs):
34
+ super().__init__(*args, **kwargs)
35
+ o = self.out_features
36
+ assert o % 3 == 0
37
+ if self.bias is not None:
38
+ self.register_buffer(
39
+ "bias_mask", torch.full_like(self.bias, fill_value=math.nan)
40
+ )
41
+
42
+ def forward(self, input: Tensor) -> Tensor:
43
+ masked_bias = (
44
+ self.bias * self.bias_mask.to(self.bias.dtype)
45
+ if self.bias is not None
46
+ else None
47
+ )
48
+ return F.linear(input, self.weight, masked_bias)
49
+
50
+
51
+ class SelfAttention(nn.Module):
52
+ def __init__(
53
+ self,
54
+ dim: int,
55
+ num_heads: int = 8,
56
+ qkv_bias: bool = False,
57
+ proj_bias: bool = True,
58
+ attn_drop: float = 0.0,
59
+ proj_drop: float = 0.0,
60
+ mask_k_bias: bool = False,
61
+ device=None,
62
+ ) -> None:
63
+ super().__init__()
64
+ self.num_heads = num_heads
65
+ head_dim = dim // num_heads
66
+ self.scale = head_dim**-0.5
67
+
68
+ linear_class = LinearKMaskedBias if mask_k_bias else nn.Linear
69
+ self.qkv = linear_class(dim, dim * 3, bias=qkv_bias, device=device)
70
+ self.attn_drop = nn.Dropout(attn_drop)
71
+ self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device)
72
+ self.proj_drop = nn.Dropout(proj_drop)
73
+
74
+ def apply_rope(
75
+ self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor]
76
+ ) -> Tuple[Tensor, Tensor]:
77
+ # All operations will use the dtype of rope, the output is cast back to the dtype of q and k
78
+ q_dtype = q.dtype
79
+ k_dtype = k.dtype
80
+ sin, cos = rope
81
+ rope_dtype = sin.dtype
82
+ q = q.to(dtype=rope_dtype)
83
+ k = k.to(dtype=rope_dtype)
84
+ N = q.shape[-2]
85
+ prefix = N - sin.shape[-2]
86
+ assert prefix >= 0
87
+ q_prefix = q[:, :, :prefix, :]
88
+ q = rope_apply(q[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head]
89
+ q = torch.cat((q_prefix, q), dim=-2) # [B, head, N, D//head]
90
+ k_prefix = k[:, :, :prefix, :]
91
+ k = rope_apply(k[:, :, prefix:, :], sin, cos) # [B, head, hw, D//head]
92
+ k = torch.cat((k_prefix, k), dim=-2) # [B, head, N, D//head]
93
+ q = q.to(dtype=q_dtype)
94
+ k = k.to(dtype=k_dtype)
95
+ return q, k
96
+
97
+ def forward(self, x: Tensor, attn_bias=None, rope: Tensor = None) -> Tensor:
98
+ qkv = self.qkv(x)
99
+ attn_v = self.compute_attention(qkv=qkv, attn_bias=attn_bias, rope=rope)
100
+ x = self.proj(attn_v)
101
+ x = self.proj_drop(x)
102
+ return x
103
+
104
+ def forward_list(self, x_list, attn_bias=None, rope_list=None) -> List[Tensor]:
105
+ assert len(x_list) == len(rope_list) # should be enforced by the Block
106
+ x_flat, shapes, num_tokens = cat_keep_shapes(x_list)
107
+ qkv_flat = self.qkv(x_flat)
108
+ qkv_list = uncat_with_shapes(qkv_flat, shapes, num_tokens)
109
+ att_out = []
110
+ for _, (qkv, _, rope) in enumerate(zip(qkv_list, shapes, rope_list)):
111
+ att_out.append(self.compute_attention(qkv, attn_bias=attn_bias, rope=rope))
112
+ x_flat, shapes, num_tokens = cat_keep_shapes(att_out)
113
+ x_flat = self.proj(x_flat)
114
+ return uncat_with_shapes(x_flat, shapes, num_tokens)
115
+
116
+ def compute_attention(self, qkv: Tensor, attn_bias=None, rope=None) -> Tensor:
117
+ assert attn_bias is None
118
+ B, N, _ = qkv.shape
119
+ C = self.qkv.in_features
120
+
121
+ qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
122
+ q, k, v = torch.unbind(qkv, 2)
123
+ q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
124
+ if rope is not None:
125
+ q, k = self.apply_rope(q, k, rope)
126
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
127
+ x = x.transpose(1, 2)
128
+ return x.reshape([B, N, C])
129
+
130
+
131
+ class CausalSelfAttention(nn.Module):
132
+ def __init__(
133
+ self,
134
+ dim: int,
135
+ num_heads: int = 8,
136
+ qkv_bias: bool = False,
137
+ proj_bias: bool = True,
138
+ attn_drop: float = 0.0,
139
+ proj_drop: float = 0.0,
140
+ ) -> None:
141
+ super().__init__()
142
+ self.dim = dim
143
+ self.num_heads = num_heads
144
+ head_dim = dim // num_heads
145
+ self.scale = head_dim**-0.5
146
+
147
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
148
+ self.attn_drop = attn_drop
149
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
150
+ self.proj_drop = nn.Dropout(proj_drop)
151
+
152
+ def init_weights(
153
+ self,
154
+ init_attn_std: float | None = None,
155
+ init_proj_std: float | None = None,
156
+ factor: float = 1.0,
157
+ ) -> None:
158
+ init_attn_std = init_attn_std or (self.dim**-0.5)
159
+ init_proj_std = init_proj_std or init_attn_std * factor
160
+ nn.init.normal_(self.qkv.weight, std=init_attn_std)
161
+ nn.init.normal_(self.proj.weight, std=init_proj_std)
162
+ if self.qkv.bias is not None:
163
+ nn.init.zeros_(self.qkv.bias)
164
+ if self.proj.bias is not None:
165
+ nn.init.zeros_(self.proj.bias)
166
+
167
+ def forward(self, x: Tensor, is_causal: bool = True) -> Tensor:
168
+ B, N, C = x.shape
169
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
170
+ q, k, v = torch.unbind(qkv, 2)
171
+ q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
172
+ x = torch.nn.functional.scaled_dot_product_attention(
173
+ q,
174
+ k,
175
+ v,
176
+ attn_mask=None,
177
+ dropout_p=self.attn_drop if self.training else 0,
178
+ is_causal=is_causal,
179
+ )
180
+ x = x.transpose(1, 2).contiguous().view(B, N, C)
181
+ x = self.proj_drop(self.proj(x))
182
+ return x
hf_src/layers/rope_block.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ from typing import Callable, List, Optional
7
+
8
+ import torch
9
+ from torch import Tensor, nn
10
+
11
+ from hf_src.utils import cat_keep_shapes, uncat_with_shapes
12
+
13
+ from .mlp import Mlp
14
+ from .layer_scale import LayerScale # , DropPath
15
+ from .rope_attention import CausalSelfAttention, SelfAttention
16
+
17
+ torch._dynamo.config.automatic_dynamic_shapes = False
18
+ torch._dynamo.config.accumulated_cache_size_limit = 1024
19
+
20
+
21
+ class SelfAttentionBlock(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ num_heads: int,
26
+ ffn_ratio: float = 4.0,
27
+ qkv_bias: bool = False,
28
+ proj_bias: bool = True,
29
+ ffn_bias: bool = True,
30
+ drop: float = 0.0,
31
+ attn_drop: float = 0.0,
32
+ init_values=None,
33
+ drop_path: float = 0.0,
34
+ act_layer: Callable[..., nn.Module] = nn.GELU,
35
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
36
+ attn_class: Callable[..., nn.Module] = SelfAttention,
37
+ ffn_layer: Callable[..., nn.Module] = Mlp,
38
+ mask_k_bias: bool = False,
39
+ device=None,
40
+ ) -> None:
41
+ super().__init__()
42
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
43
+ self.norm1 = norm_layer(dim)
44
+ self.attn = attn_class(
45
+ dim,
46
+ num_heads=num_heads,
47
+ qkv_bias=qkv_bias,
48
+ proj_bias=proj_bias,
49
+ attn_drop=attn_drop,
50
+ proj_drop=drop,
51
+ mask_k_bias=mask_k_bias,
52
+ device=device,
53
+ )
54
+ self.ls1 = (
55
+ LayerScale(dim, init_values=init_values, device=device)
56
+ if init_values
57
+ else nn.Identity()
58
+ )
59
+
60
+ self.norm2 = norm_layer(dim)
61
+ mlp_hidden_dim = int(dim * ffn_ratio)
62
+ self.mlp = ffn_layer(
63
+ in_features=dim,
64
+ hidden_features=mlp_hidden_dim,
65
+ act_layer=act_layer,
66
+ drop=drop,
67
+ bias=ffn_bias,
68
+ device=device,
69
+ )
70
+ self.ls2 = (
71
+ LayerScale(dim, init_values=init_values, device=device)
72
+ if init_values
73
+ else nn.Identity()
74
+ )
75
+
76
+ self.sample_drop_ratio = drop_path
77
+
78
+ @staticmethod
79
+ def _maybe_index_rope(
80
+ rope: tuple[Tensor, Tensor] | None, indices: Tensor
81
+ ) -> tuple[Tensor, Tensor] | None:
82
+ if rope is None:
83
+ return None
84
+
85
+ sin, cos = rope
86
+ assert sin.ndim == cos.ndim
87
+ if sin.ndim == 4:
88
+ # If the rope embedding has a batch dimension (is different for each batch element), index into it
89
+ return sin[indices], cos[indices] # [batch, heads, patches, embed_dim]
90
+ else:
91
+ # No batch dimension, do not index
92
+ return sin, cos # [heads, patches, embed_dim] or [patches, embed_dim]
93
+
94
+ def _forward(self, x: Tensor, rope=None) -> Tensor:
95
+ """
96
+ This is the reference implementation for a single tensor, matching what is done below for a list.
97
+ We call the list op on [x] instead of this function.
98
+ """
99
+ b, _, _ = x.shape
100
+ sample_subset_size = max(int(b * (1 - self.sample_drop_ratio)), 1)
101
+ residual_scale_factor = b / sample_subset_size
102
+
103
+ if self.training and self.sample_drop_ratio > 0.0:
104
+ indices_1 = (torch.randperm(b, device=x.device))[:sample_subset_size]
105
+
106
+ x_subset_1 = x[indices_1]
107
+ rope_subset = self._maybe_index_rope(rope, indices_1)
108
+ residual_1 = self.attn(self.norm1(x_subset_1), rope=rope_subset)
109
+
110
+ x_attn = torch.index_add(
111
+ x,
112
+ dim=0,
113
+ source=self.ls1(residual_1),
114
+ index=indices_1,
115
+ alpha=residual_scale_factor,
116
+ )
117
+
118
+ indices_2 = (torch.randperm(b, device=x.device))[:sample_subset_size]
119
+
120
+ x_subset_2 = x_attn[indices_2]
121
+ residual_2 = self.mlp(self.norm2(x_subset_2))
122
+
123
+ x_ffn = torch.index_add(
124
+ x_attn,
125
+ dim=0,
126
+ source=self.ls2(residual_2),
127
+ index=indices_2,
128
+ alpha=residual_scale_factor,
129
+ )
130
+ else:
131
+ x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope))
132
+ x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn)))
133
+
134
+ return x_ffn
135
+
136
+ def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]:
137
+ """
138
+ This list operator concatenates the tokens from the list of inputs together to save
139
+ on the elementwise operations. Torch-compile memory-planning allows hiding the overhead
140
+ related to concat ops.
141
+ """
142
+ b_list = [x.shape[0] for x in x_list]
143
+ sample_subset_sizes = [
144
+ max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list
145
+ ]
146
+ residual_scale_factors = [
147
+ b / sample_subset_size
148
+ for b, sample_subset_size in zip(b_list, sample_subset_sizes)
149
+ ]
150
+
151
+ if self.training and self.sample_drop_ratio > 0.0:
152
+ indices_1_list = [
153
+ (torch.randperm(b, device=x.device))[:sample_subset_size]
154
+ for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes)
155
+ ]
156
+ x_subset_1_list = [
157
+ x[indices_1] for x, indices_1 in zip(x_list, indices_1_list)
158
+ ]
159
+
160
+ if rope_list is not None:
161
+ rope_subset_list = [
162
+ self._maybe_index_rope(rope, indices_1)
163
+ for rope, indices_1 in zip(rope_list, indices_1_list)
164
+ ]
165
+ else:
166
+ rope_subset_list = rope_list
167
+
168
+ flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list)
169
+ norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens)
170
+ residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list)
171
+
172
+ x_attn_list = [
173
+ torch.index_add(
174
+ x,
175
+ dim=0,
176
+ source=self.ls1(residual_1),
177
+ index=indices_1,
178
+ alpha=residual_scale_factor,
179
+ )
180
+ for x, residual_1, indices_1, residual_scale_factor in zip(
181
+ x_list, residual_1_list, indices_1_list, residual_scale_factors
182
+ )
183
+ ]
184
+
185
+ indices_2_list = [
186
+ (torch.randperm(b, device=x.device))[:sample_subset_size]
187
+ for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes)
188
+ ]
189
+ x_subset_2_list = [
190
+ x[indices_2] for x, indices_2 in zip(x_attn_list, indices_2_list)
191
+ ]
192
+ flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list)
193
+ norm2_flat = self.norm2(flattened)
194
+ norm2_list = uncat_with_shapes(norm2_flat, shapes, num_tokens)
195
+
196
+ residual_2_list = self.mlp.forward_list(norm2_list)
197
+
198
+ x_ffn = [
199
+ torch.index_add(
200
+ x_attn,
201
+ dim=0,
202
+ source=self.ls2(residual_2),
203
+ index=indices_2,
204
+ alpha=residual_scale_factor,
205
+ )
206
+ for x_attn, residual_2, indices_2, residual_scale_factor in zip(
207
+ x_attn_list, residual_2_list, indices_2_list, residual_scale_factors
208
+ )
209
+ ]
210
+ else:
211
+ x_out = []
212
+ for x, rope in zip(x_list, rope_list):
213
+ x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope))
214
+ x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn)))
215
+ x_out.append(x_ffn)
216
+ x_ffn = x_out
217
+
218
+ return x_ffn
219
+
220
+ def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]:
221
+ if isinstance(x_or_x_list, Tensor):
222
+ # for reference:
223
+ # return self._forward(x_or_x_list, rope=rope_or_rope_list)
224
+ # in order to match implementations we call the list op:
225
+ return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0]
226
+ elif isinstance(x_or_x_list, list):
227
+ if rope_or_rope_list is None:
228
+ rope_or_rope_list = [None for x in x_or_x_list]
229
+ # return [self._forward(x, rope=rope) for x, rope in zip(x_or_x_list, rope_or_rope_list)]
230
+ return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list)
231
+ else:
232
+ raise AssertionError
233
+
234
+
235
+ class CausalSelfAttentionBlock(nn.Module):
236
+ def __init__(
237
+ self,
238
+ dim: int,
239
+ num_heads: int,
240
+ ffn_ratio: float = 4.0,
241
+ ls_init_value: Optional[float] = None,
242
+ is_causal: bool = True,
243
+ act_layer: Callable = nn.GELU,
244
+ norm_layer: Callable = nn.LayerNorm,
245
+ dropout_prob: float = 0.0,
246
+ ):
247
+ super().__init__()
248
+
249
+ self.dim = dim
250
+ self.is_causal = is_causal
251
+ self.ls1 = (
252
+ LayerScale(dim, init_values=ls_init_value)
253
+ if ls_init_value
254
+ else nn.Identity()
255
+ )
256
+ self.attention_norm = norm_layer(dim)
257
+ self.attention = CausalSelfAttention(
258
+ dim, num_heads, attn_drop=dropout_prob, proj_drop=dropout_prob
259
+ )
260
+
261
+ self.ffn_norm = norm_layer(dim)
262
+ ffn_hidden_dim = int(dim * ffn_ratio)
263
+ self.feed_forward = Mlp(
264
+ in_features=dim,
265
+ hidden_features=ffn_hidden_dim,
266
+ drop=dropout_prob,
267
+ act_layer=act_layer,
268
+ )
269
+
270
+ self.ls2 = (
271
+ LayerScale(dim, init_values=ls_init_value)
272
+ if ls_init_value
273
+ else nn.Identity()
274
+ )
275
+
276
+ def init_weights(
277
+ self,
278
+ init_attn_std: float | None = None,
279
+ init_proj_std: float | None = None,
280
+ init_fc_std: float | None = None,
281
+ factor: float = 1.0,
282
+ ) -> None:
283
+ init_attn_std = init_attn_std or (self.dim**-0.5)
284
+ init_proj_std = init_proj_std or init_attn_std * factor
285
+ init_fc_std = init_fc_std or (2 * self.dim) ** -0.5
286
+ self.attention.init_weights(init_attn_std, init_proj_std)
287
+ self.attention_norm.reset_parameters()
288
+ nn.init.normal_(self.feed_forward.fc1.weight, std=init_fc_std)
289
+ nn.init.normal_(self.feed_forward.fc2.weight, std=init_proj_std)
290
+ self.ffn_norm.reset_parameters()
291
+
292
+ def forward(
293
+ self,
294
+ x: torch.Tensor,
295
+ ):
296
+
297
+ x_attn = x + self.ls1(self.attention(self.attention_norm(x), self.is_causal))
298
+ x_ffn = x_attn + self.ls2(self.feed_forward(self.ffn_norm(x_attn)))
299
+ return x_ffn
hf_src/layers/rope_position_encoding.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ import math
7
+ from typing import Literal
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch import Tensor, nn
12
+
13
+
14
+ # RoPE positional embedding with no mixing of coordinates (axial) and no learnable weights
15
+ # Supports two parametrizations of the rope parameters: either using `base` or `min_period` and `max_period`.
16
+ class RopePositionEmbedding(nn.Module):
17
+ def __init__(
18
+ self,
19
+ embed_dim: int,
20
+ *,
21
+ num_heads: int,
22
+ base: float | None = 100.0,
23
+ min_period: float | None = None,
24
+ max_period: float | None = None,
25
+ normalize_coords: Literal["min", "max", "separate"] = "separate",
26
+ shift_coords: float | None = None,
27
+ jitter_coords: float | None = None,
28
+ rescale_coords: float | None = None,
29
+ dtype: torch.dtype | None = None,
30
+ device: torch.device | None = None,
31
+ ):
32
+ super().__init__()
33
+ assert embed_dim % (4 * num_heads) == 0
34
+ both_periods = min_period is not None and max_period is not None
35
+ if (base is None and not both_periods) or (base is not None and both_periods):
36
+ raise ValueError(
37
+ "Either `base` or `min_period`+`max_period` must be provided."
38
+ )
39
+
40
+ D_head = embed_dim // num_heads
41
+ self.base = base
42
+ self.min_period = min_period
43
+ self.max_period = max_period
44
+ self.D_head = D_head
45
+ self.normalize_coords = normalize_coords
46
+ self.shift_coords = shift_coords
47
+ self.jitter_coords = jitter_coords
48
+ self.rescale_coords = rescale_coords
49
+
50
+ # Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher
51
+ self.dtype = dtype # Don't rely on self.periods.dtype
52
+ self.register_buffer(
53
+ "periods",
54
+ torch.empty(D_head // 4, device=device, dtype=dtype),
55
+ persistent=True,
56
+ )
57
+ self._init_weights()
58
+
59
+ def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]:
60
+ device = self.periods.device
61
+ dtype = self.dtype
62
+ dd = {"device": device, "dtype": dtype}
63
+
64
+ # Prepare coords in range [-1, +1]
65
+ if self.normalize_coords == "max":
66
+ max_HW = max(H, W)
67
+ coords_h = torch.arange(0.5, H, **dd) / max_HW # [H]
68
+ coords_w = torch.arange(0.5, W, **dd) / max_HW # [W]
69
+ elif self.normalize_coords == "min":
70
+ min_HW = min(H, W)
71
+ coords_h = torch.arange(0.5, H, **dd) / min_HW # [H]
72
+ coords_w = torch.arange(0.5, W, **dd) / min_HW # [W]
73
+ elif self.normalize_coords == "separate":
74
+ coords_h = torch.arange(0.5, H, **dd) / H # [H]
75
+ coords_w = torch.arange(0.5, W, **dd) / W # [W]
76
+ else:
77
+ raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
78
+ coords = torch.stack(
79
+ torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1
80
+ ) # [H, W, 2]
81
+ coords = coords.flatten(0, 1) # [HW, 2]
82
+ coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1]
83
+
84
+ # Shift coords by adding a uniform value in [-shift, shift]
85
+ if self.training and self.shift_coords is not None:
86
+ shift_hw = torch.empty(2, **dd).uniform_(
87
+ -self.shift_coords, self.shift_coords
88
+ )
89
+ coords += shift_hw[None, :]
90
+
91
+ # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
92
+ if self.training and self.jitter_coords is not None:
93
+ jitter_max = np.log(self.jitter_coords)
94
+ jitter_min = -jitter_max
95
+ jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
96
+ coords *= jitter_hw[None, :]
97
+
98
+ # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
99
+ if self.training and self.rescale_coords is not None:
100
+ rescale_max = np.log(self.rescale_coords)
101
+ rescale_min = -rescale_max
102
+ rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
103
+ coords *= rescale_hw
104
+
105
+ # Prepare angles and sin/cos
106
+ angles = (
107
+ 2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
108
+ ) # [HW, 2, D//4]
109
+ angles = angles.flatten(1, 2) # [HW, D//2]
110
+ angles = angles.tile(2) # [HW, D]
111
+ cos = torch.cos(angles) # [HW, D]
112
+ sin = torch.sin(angles) # [HW, D]
113
+
114
+ return sin, cos # 2 * [HW, D]
115
+
116
+ def _init_weights(self):
117
+ device = self.periods.device
118
+ dtype = self.dtype
119
+ if self.base is not None:
120
+ periods = self.base ** (
121
+ 2
122
+ * torch.arange(self.D_head // 4, device=device, dtype=dtype)
123
+ / (self.D_head // 2)
124
+ ) # [D//4]
125
+ else:
126
+ base = self.max_period / self.min_period
127
+ exponents = torch.linspace(
128
+ 0, 1, self.D_head // 4, device=device, dtype=dtype
129
+ ) # [D//4] range [0, 1]
130
+ periods = base**exponents # range [1, max_period / min_period]
131
+ periods = periods / base # range [min_period / max_period, 1]
132
+ periods = periods * self.max_period # range [min_period, max_period]
133
+ self.periods.data = periods
134
+
135
+
136
+ if __name__ == "__main__":
137
+ import torch
138
+ import numpy as np
139
+ import matplotlib.pyplot as plt
140
+
141
+ def get_rope_values(H, W, embed_dim, num_heads, base):
142
+ # Setup parameters similar to Repo 1
143
+ D_head = embed_dim // num_heads
144
+ print(D_head // 4, D_head // 2, (D_head // 4) / (D_head // 2))
145
+ # We'll pick the first period (the "fastest" one)
146
+ period = base ** (2 * torch.arange(D_head // 4) / (D_head // 2))
147
+
148
+ period = period[3] # First period
149
+
150
+ # Normalized coordinates as per Repo 1
151
+ coords_h = torch.arange(0.5, H) / H
152
+ coords_w = torch.arange(0.5, W) / W
153
+ grid_h, grid_w = torch.meshgrid(coords_h, coords_w, indexing="ij")
154
+
155
+ # Convert to [-1, 1]
156
+ grid_h = 2.0 * grid_h - 1.0
157
+ grid_w = 2.0 * grid_w - 1.0
158
+
159
+ # Calculate Sine value (using H-coordinate for visualization)
160
+ # Formula: sin(2 * pi * coord / period)
161
+ vals = torch.sin(2 * np.pi * grid_h / period)
162
+ return vals.numpy()
163
+
164
+ # Settings
165
+ embed_dim = 768
166
+ num_heads = 12
167
+ bases = [100, 10000]
168
+ sizes = [14, 28]
169
+
170
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
171
+
172
+ for i, base in enumerate(bases):
173
+ for j, size in enumerate(sizes):
174
+ vals = get_rope_values(size, size, embed_dim, num_heads, base)
175
+
176
+ ax = axes[i, j]
177
+ im = ax.imshow(vals, cmap="RdBu", extent=[-1, 1, -1, 1])
178
+ ax.set_title(f"Base: {base} | Grid: {size}x{size}")
179
+ ax.set_xlabel("Width (Normalized)")
180
+ ax.set_ylabel("Height (Normalized)")
181
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
182
+
183
+ plt.tight_layout()
184
+ plt.show()
hf_src/layers/sparse_linear.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ import logging
7
+
8
+ from typing import Callable
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import xformers.ops as xops
14
+
15
+ from hf_src.utils import named_apply, named_replace
16
+
17
+
18
+ class LinearW24(torch.nn.Linear):
19
+ ALGO = "largest_abs_values_greedy"
20
+
21
+ def __init__(self, *args, **kwargs) -> None:
22
+ super().__init__(*args, **kwargs)
23
+ self.sparsity_enabled = False
24
+
25
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
26
+ if not self.sparsity_enabled:
27
+ return super().forward(input)
28
+
29
+ input_shape = input.shape
30
+ input = input.flatten(end_dim=-2)
31
+ dim0 = input.shape[0]
32
+ if dim0 % 8 != 0:
33
+ # NOTE: This should be torch-compiled away
34
+ input = F.pad(input, [0, 0, 0, -dim0 % 8])
35
+ w_sparse = xops.sparsify24(
36
+ self.weight,
37
+ algo=self.ALGO,
38
+ gradient="ste",
39
+ backend="cusparselt",
40
+ )
41
+ return F.linear(
42
+ input,
43
+ w_sparse,
44
+ self.bias,
45
+ )[
46
+ :dim0
47
+ ].unflatten(dim=0, sizes=input_shape[:-1])
48
+
49
+
50
+ def replace_linears_with_sparse_linear(
51
+ root_module: nn.Module, *, filter_fn: Callable[[str], bool]
52
+ ) -> nn.Module:
53
+ total_count = 0
54
+
55
+ def replace(module: nn.Module, name: str) -> nn.Module:
56
+ nonlocal total_count
57
+ if not isinstance(module, nn.Linear) or not filter_fn(name):
58
+ return module
59
+ assert type(module) == nn.Linear, "Subtypes not supported"
60
+ new_module = LinearW24(
61
+ in_features=module.in_features,
62
+ out_features=module.out_features,
63
+ bias=module.bias is not None,
64
+ dtype=module.weight.dtype,
65
+ device=module.weight.device,
66
+ )
67
+ new_module.weight = module.weight
68
+ new_module.bias = module.bias
69
+ total_count += 1
70
+ return new_module
71
+
72
+ out = named_replace(replace, root_module)
73
+ assert total_count > 0, "2:4 sparsity: no layer found to sparsify"
74
+ return out
75
+
76
+
77
+ def update_24sparsity(root_module: nn.Module, enabled: bool) -> int:
78
+ num_modified = 0
79
+
80
+ def maybe_apply_sparsity(module: nn.Module, name: str) -> nn.Module:
81
+ nonlocal num_modified
82
+ if not isinstance(module, LinearW24):
83
+ return module
84
+ num_modified += 1
85
+ module.sparsity_enabled = enabled
86
+ return module
87
+
88
+ named_apply(maybe_apply_sparsity, root_module)
89
+ # Force re-compile everything
90
+ torch._dynamo.reset_code_caches()
91
+ from torch._inductor.cudagraph_trees import reset_cudagraph_trees
92
+
93
+ reset_cudagraph_trees()
94
+ return num_modified
hf_src/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import Callable, Optional
4
+
5
+ from torch import Tensor, nn
6
+
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class SwiGLUFFN(nn.Module):
11
+ def __init__(
12
+ self,
13
+ in_features: int,
14
+ hidden_features: Optional[int] = None,
15
+ out_features: Optional[int] = None,
16
+ act_layer: Callable[..., nn.Module] = None,
17
+ drop: float = 0.0,
18
+ bias: bool = True,
19
+ ) -> None:
20
+ super().__init__()
21
+ out_features = out_features or in_features
22
+ hidden_features = hidden_features or in_features
23
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
24
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ x12 = self.w12(x)
28
+ x1, x2 = x12.chunk(2, dim=-1)
29
+ hidden = F.silu(x1) * x2
30
+ return self.w3(hidden)
31
+
32
+
33
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
34
+ try:
35
+ if XFORMERS_ENABLED:
36
+ from xformers.ops import SwiGLU
37
+
38
+ XFORMERS_AVAILABLE = True
39
+ else:
40
+ raise ImportError
41
+ except ImportError:
42
+ SwiGLU = SwiGLUFFN
43
+ XFORMERS_AVAILABLE = False
44
+
45
+
46
+ class SwiGLUFFNFused(SwiGLU):
47
+ def __init__(
48
+ self,
49
+ in_features: int,
50
+ hidden_features: Optional[int] = None,
51
+ out_features: Optional[int] = None,
52
+ act_layer: Callable[..., nn.Module] = None,
53
+ drop: float = 0.0,
54
+ bias: bool = True,
55
+ ) -> None:
56
+ out_features = out_features or in_features
57
+ hidden_features = hidden_features or in_features
58
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
59
+ super().__init__(
60
+ in_features=in_features,
61
+ hidden_features=hidden_features,
62
+ out_features=out_features,
63
+ bias=bias,
64
+ )
hf_src/model/__init__.py ADDED
File without changes
hf_src/model/image/__init__.py ADDED
File without changes
hf_src/model/image/vitv2/__init__.py ADDED
File without changes
hf_src/model/image/vitv2/transformer.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
2
+ # References:
3
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
4
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
5
+
6
+ import math
7
+
8
+ from functools import partial
9
+ from typing import Sequence, Tuple, Union, Callable
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.utils.checkpoint
14
+
15
+ from torch.nn.init import trunc_normal_
16
+ from torch.nn.functional import interpolate
17
+
18
+ from hf_src.layers import (
19
+ Mlp,
20
+ PatchEmbed,
21
+ SwiGLUFFNFused,
22
+ MemEffAttention,
23
+ NestedTensorBlock as Block,
24
+ LayerScale,
25
+ RMSNorm,
26
+ )
27
+
28
+
29
+ def named_apply(
30
+ fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
31
+ ) -> nn.Module:
32
+ if not depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ for child_name, child_module in module.named_children():
35
+ child_name = ".".join((name, child_name)) if name else child_name
36
+ named_apply(
37
+ fn=fn,
38
+ module=child_module,
39
+ name=child_name,
40
+ depth_first=depth_first,
41
+ include_root=True,
42
+ )
43
+ if depth_first and include_root:
44
+ fn(module=module, name=name)
45
+ return module
46
+
47
+
48
+ class BlockChunk(nn.ModuleList):
49
+ def forward(self, x, return_attention=False):
50
+ # Adaptation for returing attentions
51
+ for i, b in enumerate(self):
52
+ if i < len(self) - 1:
53
+ x = b(x)
54
+ else:
55
+ return b(x, return_attention=return_attention)
56
+ return x
57
+
58
+
59
+ class ViTv2(nn.Module):
60
+ def __init__(
61
+ self,
62
+ *,
63
+ img_size=518,
64
+ patch_size=16,
65
+ in_chans=3,
66
+ embed_dim=768,
67
+ depth=12,
68
+ num_heads=12,
69
+ mlp_ratio=4.0,
70
+ qkv_bias=True,
71
+ ffn_bias=True,
72
+ proj_bias=True,
73
+ drop_path_rate=0.0,
74
+ drop_path_uniform=True,
75
+ init_values=None, # for layerscale: None or 0 => no layerscale
76
+ embed_layer=PatchEmbed,
77
+ act_layer=nn.GELU,
78
+ block_fn=Block,
79
+ ffn_layer="mlp",
80
+ block_chunks=0,
81
+ num_register_tokens=0,
82
+ interpolate_antialias=False,
83
+ interpolate_offset=0.1,
84
+ num_classes=None,
85
+ **ignored_kwargs,
86
+ ):
87
+ """
88
+ Args:
89
+ img_size (int, tuple): input image size
90
+ patch_size (int, tuple): patch size
91
+ in_chans (int): number of input channels
92
+ embed_dim (int): embedding dimension
93
+ depth (int): depth of transformer
94
+ num_heads (int): number of attention heads
95
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
96
+ qkv_bias (bool): enable bias for qkv if True
97
+ proj_bias (bool): enable bias for proj in attn if True
98
+ ffn_bias (bool): enable bias for ffn if True
99
+ drop_path_rate (float): stochastic depth rate
100
+ drop_path_uniform (bool): apply uniform drop rate across blocks
101
+ weight_init (str): weight init scheme
102
+ init_values (float): layer-scale init values
103
+ embed_layer (nn.Module): patch embedding layer
104
+ act_layer (nn.Module): MLP activation layer
105
+ block_fn (nn.Module): transformer block class
106
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
107
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
108
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
109
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
110
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
111
+ """
112
+ super().__init__(**ignored_kwargs)
113
+
114
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
115
+ self.img_size = img_size
116
+
117
+ self.num_features = self.embed_dim = embed_dim
118
+
119
+ self.num_tokens = 1
120
+ self.n_blocks = depth
121
+ self.num_heads = num_heads
122
+ self.patch_size = patch_size
123
+ self.num_register_tokens = num_register_tokens
124
+ self.interpolate_antialias = interpolate_antialias
125
+ self.interpolate_offset = interpolate_offset
126
+
127
+ self.patch_embed = embed_layer(
128
+ img_size=img_size,
129
+ patch_size=patch_size,
130
+ in_chans=in_chans,
131
+ embed_dim=embed_dim,
132
+ )
133
+ num_patches = self.patch_embed.num_patches
134
+
135
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
136
+ self.pos_embed = nn.Parameter(
137
+ torch.zeros(1, num_patches + self.num_tokens, embed_dim)
138
+ )
139
+ assert num_register_tokens >= 0
140
+ self.register_tokens = (
141
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
142
+ if num_register_tokens
143
+ else None
144
+ )
145
+
146
+ if drop_path_uniform is True:
147
+ dpr = [drop_path_rate] * depth
148
+ else:
149
+ dpr = [
150
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
151
+ ] # stochastic depth decay rule
152
+
153
+ if ffn_layer == "mlp":
154
+ ffn_layer = Mlp
155
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
156
+ ffn_layer = SwiGLUFFNFused
157
+ elif ffn_layer == "identity":
158
+
159
+ def f(*args, **kwargs):
160
+ return nn.Identity()
161
+
162
+ ffn_layer = f
163
+ else:
164
+ raise NotImplementedError
165
+
166
+ blocks_list = [
167
+ block_fn(
168
+ dim=embed_dim,
169
+ num_heads=num_heads,
170
+ mlp_ratio=mlp_ratio,
171
+ qkv_bias=qkv_bias,
172
+ proj_bias=proj_bias,
173
+ ffn_bias=ffn_bias,
174
+ drop_path=dpr[i],
175
+ norm_layer=norm_layer,
176
+ act_layer=act_layer,
177
+ ffn_layer=ffn_layer,
178
+ init_values=init_values,
179
+ )
180
+ for i in range(depth)
181
+ ]
182
+ if block_chunks > 0:
183
+ self.chunked_blocks = True
184
+ chunked_blocks = []
185
+ chunksize = depth // block_chunks
186
+ for i in range(0, depth, chunksize):
187
+ # this is to keep the block index consistent if we chunk the block list
188
+ chunked_blocks.append(
189
+ [nn.Identity()] * i + blocks_list[i : i + chunksize]
190
+ )
191
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
192
+ else:
193
+ self.chunked_blocks = False
194
+ self.blocks = nn.ModuleList(blocks_list)
195
+
196
+ self.mask_token = None
197
+ self.norm = norm_layer(embed_dim)
198
+ self.norm_patch = None
199
+
200
+ self.head = (
201
+ nn.Identity() if num_classes is None else nn.Linear(embed_dim, num_classes)
202
+ )
203
+
204
+ # Initialize the model's weights
205
+ self.init_weights()
206
+
207
+ def init_weights(self):
208
+ trunc_normal_(self.pos_embed, std=0.02)
209
+ nn.init.normal_(self.cls_token, std=1e-6)
210
+ if self.register_tokens is not None:
211
+ nn.init.normal_(self.register_tokens, std=1e-6)
212
+ if self.mask_token is not None:
213
+ nn.init.zeros_(self.mask_token)
214
+ named_apply(init_weights_vit, self)
215
+
216
+ def interpolate_pos_encoding(self, x, w, h):
217
+ previous_dtype = x.dtype
218
+ npatch = x.shape[1] - 1
219
+ N = self.pos_embed.shape[1] - 1
220
+ if npatch == N and w == h:
221
+ return self.pos_embed
222
+ pos_embed = self.pos_embed.float()
223
+ class_pos_embed = pos_embed[:, 0]
224
+ patch_pos_embed = pos_embed[:, 1:]
225
+ dim = x.shape[-1]
226
+ w0 = w // self.patch_size
227
+ h0 = h // self.patch_size
228
+ # we add a small number to avoid floating point error in the interpolation
229
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
230
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
231
+
232
+ sqrt_N = math.sqrt(N)
233
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
234
+ patch_pos_embed = interpolate(
235
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(
236
+ 0, 3, 1, 2
237
+ ),
238
+ scale_factor=(sx, sy),
239
+ mode="bicubic",
240
+ # antialias=self.interpolate_antialias,
241
+ )
242
+
243
+ assert int(w0) == patch_pos_embed.shape[-2]
244
+ assert int(h0) == patch_pos_embed.shape[-1]
245
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
246
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
247
+ previous_dtype
248
+ )
249
+
250
+ def prepare_tokens_with_masks(self, x, masks=None):
251
+ B, nc, w, h = x.shape
252
+ x = self.patch_embed(x)
253
+ if masks is not None:
254
+ x = torch.where(
255
+ masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
256
+ )
257
+
258
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
259
+ x = x + self.interpolate_pos_encoding(x, w, h)
260
+
261
+ if self.register_tokens is not None:
262
+ x = torch.cat(
263
+ (
264
+ x[:, :1],
265
+ self.register_tokens.expand(x.shape[0], -1, -1),
266
+ x[:, 1:],
267
+ ),
268
+ dim=1,
269
+ )
270
+
271
+ return x
272
+
273
+ def forward_features_list(self, x_list, masks_list):
274
+ x = [
275
+ self.prepare_tokens_with_masks(x, masks)
276
+ for x, masks in zip(x_list, masks_list)
277
+ ]
278
+ for blk in self.blocks:
279
+ x = blk(x)
280
+
281
+ all_x = x
282
+ output = []
283
+ for x, masks in zip(all_x, masks_list):
284
+ cls_tokens = self.norm(x[:, : self.num_register_tokens + 1])
285
+
286
+ if self.norm_patch is None:
287
+ patch_tokens = self.norm(x[:, self.num_register_tokens + 1 :])
288
+ else:
289
+ patch_tokens = self.norm_patch(x[:, self.num_register_tokens + 1 :])
290
+
291
+ output.append(
292
+ {
293
+ "latent": cls_tokens[:, 0],
294
+ "patch_latent": patch_tokens,
295
+ "raw_latent": x[:, 0],
296
+ }
297
+ )
298
+ return output
299
+
300
+ def forward_features(self, x, masks=None, last_self_attention=False):
301
+ if isinstance(x, list):
302
+ return self.forward_features_list(x, masks)
303
+
304
+ x = self.prepare_tokens_with_masks(x, masks)
305
+
306
+ for i, blk in enumerate(self.blocks):
307
+ if i < len(self.blocks) - 1:
308
+ x = blk(x)
309
+ else:
310
+ x = blk(x, return_attention=last_self_attention)
311
+
312
+ attn = None
313
+ if last_self_attention:
314
+ x, attn = x
315
+ # Attention is selected from the cls token to the patch tokens only
316
+ # Thus, we ignore the cls from the patch tokens (i.e., start from 1)
317
+ attn = attn[:, :, 0, self.num_register_tokens + 1 :]
318
+
319
+ cls_tokens = self.norm(x[:, : self.num_register_tokens + 1])
320
+
321
+ if self.norm_patch is None:
322
+ patch_tokens = self.norm(x[:, self.num_register_tokens + 1 :])
323
+ else:
324
+ patch_tokens = self.norm_patch(x[:, self.num_register_tokens + 1 :])
325
+
326
+ return {
327
+ "latent": cls_tokens[:, 0],
328
+ "patch_latent": patch_tokens,
329
+ "raw_latent": x[:, 0],
330
+ "last_self_attention": attn,
331
+ "logits": self.head(cls_tokens[:, 0]),
332
+ }
333
+
334
+ def forward_head(self, x):
335
+ # Projection with l2-norm bottleneck
336
+ x = self.projection_head(x)
337
+ if self.l2_norm:
338
+ x = nn.functional.normalize(x, dim=1, p=2)
339
+ return x
340
+
341
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
342
+ x = self.prepare_tokens_with_masks(x)
343
+ # If n is an int, take the n last blocks. If it's a list, take them
344
+ output, total_block_len = [], len(self.blocks)
345
+ blocks_to_take = (
346
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
347
+ )
348
+ for i, blk in enumerate(self.blocks):
349
+ x = blk(x)
350
+ if i in blocks_to_take:
351
+ output.append(x)
352
+ assert len(output) == len(
353
+ blocks_to_take
354
+ ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
355
+ return output
356
+
357
+ def _get_intermediate_layers_chunked(self, x, n=1):
358
+ x = self.prepare_tokens_with_masks(x)
359
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
360
+ # If n is an int, take the n last blocks. If it's a list, take them
361
+ blocks_to_take = (
362
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
363
+ )
364
+ for block_chunk in self.blocks:
365
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
366
+ x = blk(x)
367
+ if i in blocks_to_take:
368
+ output.append(x)
369
+ i += 1
370
+ assert len(output) == len(
371
+ blocks_to_take
372
+ ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
373
+ return output
374
+
375
+ def get_intermediate_layers(
376
+ self,
377
+ x: torch.Tensor,
378
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
379
+ reshape: bool = False,
380
+ return_class_token: bool = False,
381
+ norm=True,
382
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
383
+ if self.chunked_blocks:
384
+ outputs = self._get_intermediate_layers_chunked(x, n)
385
+ else:
386
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
387
+
388
+ class_tokens = [
389
+ (
390
+ out[:, 0]
391
+ if not norm
392
+ else self.norm(out[:, : 1 + self.num_register_tokens])[:, 0]
393
+ )
394
+ for out in outputs
395
+ ]
396
+ outputs = [
397
+ (
398
+ out[:, 1 + self.num_register_tokens :]
399
+ if not norm
400
+ else (
401
+ self.norm(out[:, self.num_register_tokens + 1 :])
402
+ if self.norm_patch is None
403
+ else self.norm_patch(out[:, self.num_register_tokens + 1 :])
404
+ )
405
+ )
406
+ for out in outputs
407
+ ]
408
+
409
+ if reshape:
410
+ B, _, w, h = x.shape
411
+ outputs = [
412
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
413
+ .permute(0, 3, 1, 2)
414
+ .contiguous()
415
+ for out in outputs
416
+ ]
417
+ if return_class_token:
418
+ return tuple(zip(outputs, class_tokens))
419
+ return tuple(outputs)
420
+
421
+ def forward(self, xs, masks=None, last_self_attention=False, **kwargs):
422
+ if not (isinstance(xs, list) or isinstance(xs, tuple)):
423
+ return self.forward_features(xs, masks, last_self_attention)
424
+
425
+ if masks is None:
426
+ masks = [None] * len(xs)
427
+
428
+ return self.forward_features_list(xs, masks)
429
+
430
+ def forward_backbone(self, x, last_self_attention=False):
431
+ out_dict = self.forward_features(x, last_self_attention=last_self_attention)
432
+ cls_token = out_dict["latent"]
433
+ x = out_dict["patch_latent"]
434
+ # Combine the cls token and the patch tokens
435
+ x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
436
+ if last_self_attention:
437
+ return x, out_dict["last_self_attention"]
438
+ return x
439
+
440
+ def get_last_selfattention(self, x, masks=None):
441
+ """
442
+ Adapted from https://gitlab.com/ziegleto-machine-learning/dino/-/tree/main/
443
+ """
444
+ if isinstance(x, list):
445
+ raise NotImplementedError("Not implemented for list of inputs")
446
+ # return self.forward_features_list(x, masks)
447
+
448
+ x = self.prepare_tokens_with_masks(x, masks)
449
+
450
+ # Run through model, at the last block just return the attention.
451
+ for i, blk in enumerate(self.blocks):
452
+ if i < len(self.blocks) - 1:
453
+ x = blk(x)
454
+ else:
455
+ _, attn = blk(x, return_attention=True)
456
+ return attn
457
+
458
+
459
+ def init_weights_vit(module: nn.Module, name: str = ""):
460
+ if isinstance(module, nn.Linear):
461
+ torch.nn.init.trunc_normal_(module.weight, std=0.02)
462
+ if module.bias is not None:
463
+ nn.init.zeros_(module.bias)
464
+ if hasattr(module, "bias_mask") and module.bias_mask is not None:
465
+ o = module.out_features
466
+ module.bias_mask.fill_(1)
467
+ module.bias_mask[o // 3 : 2 * o // 3].fill_(0)
468
+ if isinstance(module, nn.LayerNorm):
469
+ module.reset_parameters()
470
+ if isinstance(module, LayerScale):
471
+ module.reset_parameters()
472
+ if isinstance(module, PatchEmbed):
473
+ module.reset_parameters()
474
+ if isinstance(module, RMSNorm):
475
+ module.reset_parameters()
hf_src/utils/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ from .dtype import as_torch_dtype
7
+ from .utils import (
8
+ cat_keep_shapes,
9
+ count_parameters,
10
+ fix_random_seeds,
11
+ get_conda_env,
12
+ get_sha,
13
+ named_apply,
14
+ named_replace,
15
+ uncat_with_shapes,
16
+ )
hf_src/utils/download.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import hashlib
4
+ from tqdm import tqdm
5
+
6
+
7
+ def check_sha1(filename, sha1_hash):
8
+ """Check whether the sha1 hash of the file content matches the expected hash.
9
+ Parameters
10
+ ----------
11
+ filename : str
12
+ Path to the file.
13
+ sha1_hash : str
14
+ Expected sha1 hash in hexadecimal digits.
15
+ Returns
16
+ -------
17
+ bool
18
+ Whether the file content matches the expected hash.
19
+ """
20
+ sha1 = hashlib.sha1()
21
+ with open(filename, "rb") as f:
22
+ while True:
23
+ data = f.read(1048576)
24
+ if not data:
25
+ break
26
+ sha1.update(data)
27
+
28
+ return sha1.hexdigest() == sha1_hash
29
+
30
+
31
+ def download(url, path=None, overwrite=False, sha1_hash=None):
32
+ """
33
+ https://github.com/junfu1115/DANet/blob/master/encoding/utils/files.py
34
+ Download a given URL
35
+ Parameters
36
+ ----------
37
+ url : str
38
+ URL to download
39
+ path : str, optional
40
+ Destination path to store downloaded file. By default stores to the
41
+ current directory with same name as in url.
42
+ overwrite : bool, optional
43
+ Whether to overwrite destination file if already exists.
44
+ sha1_hash : str, optional
45
+ Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
46
+ but doesn't match.
47
+ Returns
48
+ -------
49
+ str
50
+ The file path of the downloaded file.
51
+ """
52
+ if path is None:
53
+ fname = url.split("/")[-1]
54
+ else:
55
+ path = os.path.expanduser(path)
56
+ if os.path.isdir(path):
57
+ fname = os.path.join(path, url.split("/")[-1])
58
+ else:
59
+ fname = path
60
+
61
+ if (
62
+ overwrite
63
+ or not os.path.exists(fname)
64
+ or (sha1_hash and not check_sha1(fname, sha1_hash))
65
+ ):
66
+ dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
67
+ if not os.path.exists(dirname):
68
+ os.makedirs(dirname)
69
+
70
+ print("Downloading %s from %s..." % (fname, url))
71
+ r = requests.get(url, stream=True)
72
+ if r.status_code != 200:
73
+ raise RuntimeError("Failed downloading url %s" % url)
74
+ total_length = r.headers.get("content-length")
75
+ with open(fname, "wb") as f:
76
+ if total_length is None: # no content length header
77
+ for chunk in r.iter_content(chunk_size=1024):
78
+ if chunk: # filter out keep-alive new chunks
79
+ f.write(chunk)
80
+ else:
81
+ total_length = int(total_length)
82
+ for chunk in tqdm(
83
+ r.iter_content(chunk_size=1024),
84
+ total=int(total_length / 1024.0 + 0.5),
85
+ unit="KB",
86
+ unit_scale=False,
87
+ dynamic_ncols=True,
88
+ ):
89
+ f.write(chunk)
90
+
91
+ if sha1_hash and not check_sha1(fname, sha1_hash):
92
+ raise UserWarning(
93
+ "File {} is downloaded but the content hash does not match. "
94
+ "The repo may be outdated or download may be incomplete. "
95
+ 'If the "repo_url" is overridden, consider switching to '
96
+ "the default repo.".format(fname)
97
+ )
98
+
99
+ return fname
hf_src/utils/dtype.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ from typing import Dict, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ TypeSpec = Union[str, np.dtype, torch.dtype]
12
+
13
+
14
+ _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
15
+ np.dtype("bool"): torch.bool,
16
+ np.dtype("uint8"): torch.uint8,
17
+ np.dtype("int8"): torch.int8,
18
+ np.dtype("int16"): torch.int16,
19
+ np.dtype("int32"): torch.int32,
20
+ np.dtype("int64"): torch.int64,
21
+ np.dtype("float16"): torch.float16,
22
+ np.dtype("float32"): torch.float32,
23
+ np.dtype("float64"): torch.float64,
24
+ np.dtype("complex64"): torch.complex64,
25
+ np.dtype("complex128"): torch.complex128,
26
+ }
27
+
28
+
29
+ def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
30
+ if isinstance(dtype, torch.dtype):
31
+ return dtype
32
+ if isinstance(dtype, str):
33
+ dtype = np.dtype(dtype)
34
+ assert isinstance(
35
+ dtype, np.dtype
36
+ ), f"Expected an instance of nunpy dtype, got {type(dtype)}"
37
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
hf_src/utils/masking.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+
7
+ def complete_mask_randomly_np(mask, num_masking_patches, rng):
8
+ flat = mask.reshape(-1)
9
+ missing = num_masking_patches - flat.sum()
10
+
11
+ if missing <= 0:
12
+ return mask
13
+
14
+ available = np.flatnonzero(~flat)
15
+ chosen = rng.choice(available, size=missing, replace=False)
16
+ flat[chosen] = True
17
+
18
+ return mask
19
+
20
+
21
+ class IBotMasker:
22
+ def __init__(
23
+ self,
24
+ input_size,
25
+ num_masking_patches=None,
26
+ min_num_patches=0,
27
+ max_num_patches=None,
28
+ min_aspect=0.3,
29
+ max_aspect=3.33,
30
+ max_tries=10,
31
+ ):
32
+ if isinstance(input_size, int):
33
+ input_size = (input_size, input_size)
34
+
35
+ self.h, self.w = input_size
36
+ self.num_patches = self.h * self.w
37
+
38
+ self.min_num_patches = min_num_patches
39
+ self.num_masking_patches = num_masking_patches
40
+ self.max_num_patches = max_num_patches or num_masking_patches
41
+
42
+ self.log_min_aspect = np.log(min_aspect)
43
+ self.log_max_aspect = np.log(max_aspect or 1 / min_aspect)
44
+
45
+ self.max_tries = max_tries
46
+
47
+ def __call__(self, num_masking_patches, starting_mask=None, rng=None):
48
+ if rng is None:
49
+ rng = np.random.default_rng()
50
+
51
+ if starting_mask is None:
52
+ mask = np.zeros((self.h, self.w), dtype=np.bool_)
53
+ else:
54
+ mask = starting_mask.copy()
55
+
56
+ mask_count = mask.sum()
57
+
58
+ while mask_count < num_masking_patches:
59
+ max_mask = num_masking_patches - mask_count
60
+ if self.max_num_patches is not None:
61
+ max_mask = min(max_mask, self.max_num_patches)
62
+
63
+ delta = self._mask(mask, max_mask, rng)
64
+ if delta == 0:
65
+ break
66
+
67
+ mask_count += delta
68
+
69
+ return complete_mask_randomly_np(mask, num_masking_patches, rng)
70
+
71
+ def _mask(self, mask, max_mask_patches, rng):
72
+ for _ in range(self.max_tries):
73
+ target = rng.uniform(self.min_num_patches, max_mask_patches)
74
+ aspect = np.exp(rng.uniform(self.log_min_aspect, self.log_max_aspect))
75
+
76
+ h = int(round(np.sqrt(target * aspect)))
77
+ w = int(round(np.sqrt(target / aspect)))
78
+
79
+ if h <= 0 or w <= 0 or h >= self.h or w >= self.w:
80
+ continue
81
+
82
+ top = rng.integers(0, self.h - h + 1)
83
+ left = rng.integers(0, self.w - w + 1)
84
+
85
+ region = mask[top : top + h, left : left + w]
86
+ newly = (~region).sum()
87
+
88
+ if 0 < newly <= max_mask_patches:
89
+ region[:] = True
90
+ return newly
91
+
92
+ return 0
93
+
94
+
95
+ def generate_masks(
96
+ mask_generator, number_of_samples, mask_prob=0.5, per_sample_range=(0.1, 0.5)
97
+ ):
98
+ num_masks = int(number_of_samples * mask_prob)
99
+ num_tokens = mask_generator.num_patches
100
+ prob_per_sample = np.linspace(*per_sample_range, num=num_masks)
101
+ masks = [
102
+ (
103
+ mask_generator(num_masking_patches=int(prob_per_sample[i] * num_tokens))
104
+ if i < num_masks
105
+ else mask_generator(num_masking_patches=0)
106
+ )
107
+ for i in range(number_of_samples)
108
+ ]
109
+ random.shuffle(masks)
110
+ masks = np.stack(masks, dtype=bool)
111
+ masks = torch.from_numpy(masks).flatten(1, -1)
112
+
113
+ return masks
hf_src/utils/seedlet_masking.py ADDED
File without changes
hf_src/utils/utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This software may be used and distributed in accordance with
4
+ # the terms of the DINOv3 License Agreement.
5
+
6
+ import logging
7
+ import os
8
+ import random
9
+ import subprocess
10
+ from typing import Callable, List, Optional, Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import Tensor, nn
15
+
16
+ logger = logging.getLogger("dinov3")
17
+
18
+
19
+ def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]:
20
+ shapes = [x.shape for x in x_list]
21
+ num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list]
22
+ flattened = torch.cat([x.flatten(0, -2) for x in x_list])
23
+ return flattened, shapes, num_tokens
24
+
25
+
26
+ def uncat_with_shapes(
27
+ flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]
28
+ ) -> List[Tensor]:
29
+ outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0)
30
+ shapes_adjusted = [
31
+ shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes
32
+ ]
33
+ outputs_reshaped = [
34
+ o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)
35
+ ]
36
+ return outputs_reshaped
37
+
38
+
39
+ def named_replace(
40
+ fn: Callable,
41
+ module: nn.Module,
42
+ name: str = "",
43
+ depth_first: bool = True,
44
+ include_root: bool = False,
45
+ ) -> nn.Module:
46
+ if not depth_first and include_root:
47
+ module = fn(module=module, name=name)
48
+ for child_name_o, child_module in list(module.named_children()):
49
+ child_name = ".".join((name, child_name_o)) if name else child_name_o
50
+ new_child = named_replace(
51
+ fn=fn,
52
+ module=child_module,
53
+ name=child_name,
54
+ depth_first=depth_first,
55
+ include_root=True,
56
+ )
57
+ setattr(module, child_name_o, new_child)
58
+
59
+ if depth_first and include_root:
60
+ module = fn(module=module, name=name)
61
+ return module
62
+
63
+
64
+ def named_apply(
65
+ fn: Callable,
66
+ module: nn.Module,
67
+ name: str = "",
68
+ depth_first: bool = True,
69
+ include_root: bool = False,
70
+ ) -> nn.Module:
71
+ if not depth_first and include_root:
72
+ fn(module=module, name=name)
73
+ for child_name, child_module in module.named_children():
74
+ child_name = ".".join((name, child_name)) if name else child_name
75
+ named_apply(
76
+ fn=fn,
77
+ module=child_module,
78
+ name=child_name,
79
+ depth_first=depth_first,
80
+ include_root=True,
81
+ )
82
+ if depth_first and include_root:
83
+ fn(module=module, name=name)
84
+ return module
85
+
86
+
87
+ def fix_random_seeds(seed: int = 31):
88
+ """
89
+ Fix random seeds.
90
+ """
91
+ torch.manual_seed(seed)
92
+ torch.cuda.manual_seed_all(seed)
93
+ np.random.seed(seed)
94
+ random.seed(seed)
95
+
96
+
97
+ def get_sha() -> str:
98
+ cwd = os.path.dirname(os.path.abspath(__file__))
99
+
100
+ def _run(command):
101
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
102
+
103
+ sha = "N/A"
104
+ diff = "clean"
105
+ branch = "N/A"
106
+ try:
107
+ sha = _run(["git", "rev-parse", "HEAD"])
108
+ subprocess.check_output(["git", "diff"], cwd=cwd)
109
+ diff = _run(["git", "diff-index", "HEAD"])
110
+ diff = "has uncommited changes" if diff else "clean"
111
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
112
+ except Exception:
113
+ pass
114
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
115
+ return message
116
+
117
+
118
+ def get_conda_env() -> Tuple[Optional[str], Optional[str]]:
119
+ conda_env_name = os.environ.get("CONDA_DEFAULT_ENV")
120
+ conda_env_path = os.environ.get("CONDA_PREFIX")
121
+ return conda_env_name, conda_env_path
122
+
123
+
124
+ def count_parameters(module: nn.Module) -> int:
125
+ c = 0
126
+ for m in module.parameters():
127
+ c += m.nelement()
128
+ return c
129
+
130
+
131
+ def has_batchnorms(model: nn.Module) -> bool:
132
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
133
+ for _, module in model.named_modules():
134
+ if isinstance(module, bn_types):
135
+ return True
136
+ return False
modelling_vitv2.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch
4
+
5
+ from transformers import PreTrainedModel
6
+
7
+ from configuration_vitv2 import ViTv2Config
8
+ from hf_src.model.image.vitv2.transformer import ViTv2
9
+
10
+
11
+ class ViTv2PretrainedModel(PreTrainedModel):
12
+ config_class = ViTv2Config
13
+
14
+ def __init__(self, config: ViTv2Config):
15
+
16
+ super().__init__(config)
17
+
18
+ self.backbone = ViTv2(
19
+ img_size=config.img_size,
20
+ patch_size=config.patch_size,
21
+ embed_dim=config.embed_dim,
22
+ depth=config.depth,
23
+ num_heads=config.num_heads,
24
+ mlp_ratio=config.mlp_ratio,
25
+ init_values=config.init_values,
26
+ num_register_tokens=config.num_register_tokens,
27
+ )
28
+
29
+ self.post_init()
30
+
31
+ def forward(self, *args, **kwargs) -> dict[str, Union[torch.Tensor, None]]:
32
+ return self.backbone(*args, **kwargs)
requirements.txt CHANGED
@@ -5,3 +5,5 @@ transformers>=4.38.0
5
  scikit-learn>=1.3.0
6
  Pillow>=9.0.0
7
  numpy>=1.24.0
 
 
 
5
  scikit-learn>=1.3.0
6
  Pillow>=9.0.0
7
  numpy>=1.24.0
8
+ einops
9
+ opt_einsum